In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
BASE = '/content/drive/My Drive/adversarial-attack-filter'
%cd $BASE

/content/drive/My Drive/adversarial_filter


In [None]:
from models.architecture import IMDN

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.transforms as T

import cv2
from google.colab.patches import cv2_imshow

from torchvision.transforms import functional as func
import torchvision.models as models

import numpy as np
import pandas as pd


from PIL import Image

import os

import matplotlib.pyplot as plt

import math

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Utils

In [None]:
df = pd.read_csv('datasets/NIPS-DEV/original.csv')
df = df.drop(df.columns[2 : ], axis = 1)

In [None]:
test_df = df.iloc[900:1000]

In [None]:
def print_tensor_image(image) :
  img = image
  if image.dim() == 4 :
    img = img.resize(3, image.shape[2], image.shape[3])
  plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))

In [None]:
preprocess = T.Compose([
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def check_accuracy(df, sr_model, cls_model, downsample = 1, attack = None) :
  if sr_model != None :
    sr_model.to(device)
    sr_model.eval()

  sum = 0

  cls_model.to(device)
  cls_model.eval()

  with torch.no_grad() :
    for label, id in zip(df.TrueLabel, df.ImageId) :
      if attack == None :
        img = Image.open('datasets/NIPS-DEV/clean/' + id + '.png')
      elif attack == 'd' :
        img = Image.open('datasets/NIPS-DEV/di-2-fgsm/' + id + '.png')
      elif attack == 'm' :
        img = Image.open('datasets/NIPS-DEV/mdi-2-fgsm/' + id + '.png')
      
      if downsample != 1 :
        img = img.resize((math.ceil(299/downsample), math.ceil(299/downsample)), Image.BICUBIC)

      img = func.to_tensor(img).unsqueeze(0).to(device)
      if sr_model != None :
        img = sr_model(img)

      img = preprocess(img)
      pred = cls_model(img)
      pred = torch.argmax(pred) + 1
      
      if label == pred :
        sum = sum + 1

  return (sum / len(df))*100

In [None]:
def check_accuracy_all_type(df, sr_model, cls_model, downsample ) :
  no = check_accuracy(df, sr_model = sr_model, cls_model = cls_model, downsample = downsample, attack = None)
  di = check_accuracy(df, sr_model = sr_model, cls_model = cls_model, downsample = downsample, attack = 'd')
  mdi = check_accuracy(df, sr_model = sr_model, cls_model = cls_model, downsample = downsample, attack = 'm')
  
  print('Clean : ', no)
  print('Di2 : ', di)
  print('Mdi2 : ', mdi)

# Load Models

## Our model

In [None]:
adv_ens_imdn = IMDN(upscale = 2)
ptw = torch.load('model_zoo/ours/adv_ens_imdn_v1_best.pt')
adv_ens_imdn.load_state_dict(ptw)

<All keys matched successfully>

## Original SR model

In [None]:
from collections import OrderedDict
PRETRAINED_IMDN_PATH = 'model_zoo/IMDN/IMDN_x2.pth'
pretrained_weight = torch.load(PRETRAINED_IMDN_PATH)

pretrained_IMDN = OrderedDict()
for k, v in pretrained_weight.items() :
  if 'module' in k:
    name = k[7:]
  else :
    name = k
  pretrained_IMDN[name] = v

imdn_x2 = IMDN(upscale = 2)
imdn_x2.load_state_dict(pretrained_IMDN, strict = True)

<All keys matched successfully>

# Check Accuracy

## ResNet50

In [None]:
resnet50 = models.resnet50(pretrained = True)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

### No defense

In [None]:
check_accuracy_all_type(test_df, sr_model = None, cls_model = resnet50, downsample = 1)

Clean :  90.0
Di2 :  60.0
Mdi2 :  36.0


### IMDN_X2

In [None]:
check_accuracy_all_type(test_df, sr_model = imdn_x2, cls_model = resnet50, downsample = 1)

Clean :  80.0
Di2 :  63.0
Mdi2 :  24.0


### Ours

In [None]:
check_accuracy_all_type(test_df, sr_model = adv_ens_imdn, cls_model = resnet50, downsample = 2)

Clean :  87.0
Di2 :  83.0
Mdi2 :  50.0


## MobileNet_V3_Large

In [None]:
mobileNet = models.mobilenet_v3_large(pretrained = True)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth


  0%|          | 0.00/21.1M [00:00<?, ?B/s]

### No defense

In [None]:
check_accuracy_all_type(test_df, sr_model = None, cls_model = mobileNet, downsample = 1)

Clean :  86.0
Di2 :  60.0
Mdi2 :  36.0


### IMDN_X2

In [None]:
check_accuracy_all_type(test_df, sr_model = imdn_x2, cls_model = mobileNet, downsample = 1)

Clean :  74.0
Di2 :  46.0
Mdi2 :  16.0


### Ours

In [None]:
check_accuracy_all_type(test_df, sr_model = adv_ens_imdn, cls_model = mobileNet, downsample = 2)

Clean :  77.0
Di2 :  76.0
Mdi2 :  47.0


## DenseNet121

In [None]:
denseNet121 = models.densenet121(pretrained = True)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


  0%|          | 0.00/30.8M [00:00<?, ?B/s]

### No defense

In [None]:
check_accuracy_all_type(test_df, sr_model = None, cls_model = denseNet121, downsample = 1)

Clean :  91.0
Di2 :  57.99999999999999
Mdi2 :  33.0


### IMDN_X2

In [None]:
check_accuracy_all_type(test_df, sr_model = imdn_x2, cls_model = denseNet121, downsample = 1)

Clean :  76.0
Di2 :  53.0
Mdi2 :  19.0


### Ours

In [None]:
check_accuracy_all_type(test_df, sr_model = adv_ens_imdn, cls_model = denseNet121, downsample = 2)

Clean :  86.0
Di2 :  86.0
Mdi2 :  56.99999999999999
