In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
from torchvision import transforms
import time
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from inspect import getfullargspec,signature

import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import Dataset
from torchvision import datasets

import torch as ch
from robustness.datasets import FashionMnist
ds = FashionMnist('/tmp/FashionMNIST')

%matplotlib inline
%config InlineBackend.figure_format = 'svg'


In [2]:
############################
class MnistResNet(nn.Module):
  def __init__(self, in_channels=1):
    super(MnistResNet, self).__init__()

    # Load a pretrained resnet model from torchvision.models in Pytorch
    self.model = models.resnet50(pretrained=False)

    # Change the input layer to take Grayscale image, instead of RGB images.  **************
    # Hence in_channels is set as 1 or 3 respectively
    # original definition of the first layer on the ResNet class
    # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
    self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    
    # Change the output layer to output 10 classes instead of 1000 classes
    num_ftrs = self.model.fc.in_features
    self.model.fc = nn.Linear(num_ftrs, 10)

  def forward(self, x):
    return self.model(x)


my_resnet = MnistResNet()

input = ch.randn((16,1,244,244))
output = my_resnet(input)
print(output.shape)

print(my_resnet)

torch.Size([16, 10])
MnistResNet(
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample):

In [3]:
        
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")
    
# load STD model 
model_std = MnistResNet().to(device)

# params you need to specify:
epochs = 150
batch_size = 128
model_std = MnistResNet()
model_state_dict = torch.load("./FashionMnistSTD50Net")
model_std.load_state_dict(model_state_dict)
#model_std.to(device)

<All keys matched successfully>

In [4]:
### Initiate ADV model 
from robustness.model_utils import make_and_restore_model
model_adv, _ = make_and_restore_model(arch='resnet50', dataset=ds,
              resume_path= '/home/sharim.jamal/FashionMnistAdv50/checkpoint.pt.latest') #'/tmp/35daedae-1b39-4941-ad08-8bd6459c1bd8/checkpoint.pt.best')
model_adv.eval()
pass

=> loading checkpoint '/home/sharim.jamal/FashionMnistAdv50/checkpoint.pt.latest'
=> loaded checkpoint '/home/sharim.jamal/FashionMnistAdv50/checkpoint.pt.latest' (epoch 150)


In [5]:
#only load the validation set
ATTACK_EPS = 0.5
ATTACK_STEPSIZE = 0.1
ATTACK_STEPS = 10
NUM_WORKERS = 8
BATCH_SIZE = 128

kwargs = {
    'constraint':'2', # use L2-PGD
    'eps': ATTACK_EPS, # L2 radius around original image
    'step_size': ATTACK_STEPSIZE,
    'iterations': ATTACK_STEPS,
    'do_tqdm': True,
}
#label.size()
from sklearn import metrics

# TensorDataset loaders
_, test_loader = ds.make_loaders(workers=NUM_WORKERS, batch_size=BATCH_SIZE) #make it bigger

test_enum = enumerate(test_loader)
# measure accuracy on the validation set but in batch sizes, since robust model can't memory-handle large image sizes at once
acc_list = []

for _ in range(len(test_loader)):
    _, (im, label) = next(test_enum)  
    im,label = im.to(device),label.to(device)
    
    #the model can generate adv by applying max l-2 loss
    _, im_adv = model_adv(im, label, make_adv=True, **kwargs)
    im_adv_transformed = transforms.Resize((224, 224)).forward(im_adv)
    
    
    im_adv = im_adv.to(device='cpu')
    im_adv_transformed = im_adv_transformed.to(device='cpu')
    
    # Get predicted labels for adversarial examples using 
    pred = model_std(im_adv_transformed)
    label_pred = ch.argmax(pred, dim=1)
    
    im,label = im.to(device='cpu'),label.to(device='cpu')
    im_adv = im_adv.to(device='cpu')
    im_adv_transformed = im_adv_transformed.to(device='cpu')
    
    #Get accuracy
    acc_list.append(metrics.accuracy_score(label ,label_pred.cpu()))
    print("Accuracy:",metrics.accuracy_score(label ,label_pred.cpu()))
    
    #test_list = torch.cat((test_list,im_adv_transformed.to(device='cpu')),dim=0)
    #label_list = torch.cat((label_list,label),dim=0)

==> Preparing dataset fashionmnist..


Current loss: 0.3595261573791504: 100%|█████████| 10/10 [00:02<00:00,  4.28it/s]


Accuracy: 0.6171875


Current loss: 0.6930171251296997: 100%|█████████| 10/10 [00:01<00:00,  6.10it/s]


Accuracy: 0.6640625


Current loss: 0.6583678722381592: 100%|█████████| 10/10 [00:01<00:00,  5.94it/s]


Accuracy: 0.6640625


Current loss: 0.6191776990890503: 100%|█████████| 10/10 [00:01<00:00,  5.95it/s]


Accuracy: 0.671875


Current loss: 0.5967100858688354: 100%|█████████| 10/10 [00:01<00:00,  6.01it/s]


Accuracy: 0.703125


Current loss: 0.3563051223754883: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.640625


Current loss: 0.579093873500824: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s]


Accuracy: 0.703125


Current loss: 0.5991559624671936: 100%|█████████| 10/10 [00:01<00:00,  6.01it/s]


Accuracy: 0.671875


Current loss: 0.7653491497039795: 100%|█████████| 10/10 [00:01<00:00,  5.97it/s]


Accuracy: 0.671875


Current loss: 0.842681348323822: 100%|██████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.6640625


Current loss: 0.4948638081550598: 100%|█████████| 10/10 [00:01<00:00,  5.97it/s]


Accuracy: 0.6015625


Current loss: 0.4588151276111603: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.6484375


Current loss: 0.45319727063179016: 100%|████████| 10/10 [00:01<00:00,  6.00it/s]


Accuracy: 0.6796875


Current loss: 0.6182000041007996: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.71875


Current loss: 0.8195836544036865: 100%|█████████| 10/10 [00:01<00:00,  5.98it/s]


Accuracy: 0.703125


Current loss: 0.3358057737350464: 100%|█████████| 10/10 [00:01<00:00,  5.94it/s]


Accuracy: 0.5625


Current loss: 0.6968474388122559: 100%|█████████| 10/10 [00:01<00:00,  5.94it/s]


Accuracy: 0.6171875


Current loss: 0.734904408454895: 100%|██████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.6796875


Current loss: 0.6300793886184692: 100%|█████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.6875


Current loss: 0.5517191886901855: 100%|█████████| 10/10 [00:01<00:00,  5.93it/s]


Accuracy: 0.671875


Current loss: 0.3162551522254944: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6796875


Current loss: 0.36219167709350586: 100%|████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.6953125


Current loss: 0.6013925075531006: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.640625


Current loss: 0.5076180696487427: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.703125


Current loss: 0.46169981360435486: 100%|████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.6484375


Current loss: 0.8043509125709534: 100%|█████████| 10/10 [00:01<00:00,  5.97it/s]


Accuracy: 0.59375


Current loss: 0.3381631672382355: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6875


Current loss: 0.7246584892272949: 100%|█████████| 10/10 [00:01<00:00,  5.62it/s]


Accuracy: 0.734375


Current loss: 0.8035984635353088: 100%|█████████| 10/10 [00:01<00:00,  5.86it/s]


Accuracy: 0.6875


Current loss: 0.4713132083415985: 100%|█████████| 10/10 [00:01<00:00,  5.69it/s]


Accuracy: 0.5625


Current loss: 0.7314740419387817: 100%|█████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.6875


Current loss: 0.4925720989704132: 100%|█████████| 10/10 [00:01<00:00,  5.72it/s]


Accuracy: 0.609375


Current loss: 0.41008260846138: 100%|███████████| 10/10 [00:01<00:00,  5.84it/s]


Accuracy: 0.671875


Current loss: 0.36508846282958984: 100%|████████| 10/10 [00:01<00:00,  5.68it/s]


Accuracy: 0.6484375


Current loss: 0.2275543063879013: 100%|█████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.734375


Current loss: 0.655623197555542: 100%|██████████| 10/10 [00:01<00:00,  5.93it/s]


Accuracy: 0.671875


Current loss: 0.5430829524993896: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6875


Current loss: 0.4759666323661804: 100%|█████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.625


Current loss: 0.4958436191082001: 100%|█████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.6875


Current loss: 0.6562652587890625: 100%|█████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.671875


Current loss: 0.43765735626220703: 100%|████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6484375


Current loss: 0.6319500207901001: 100%|█████████| 10/10 [00:01<00:00,  5.93it/s]


Accuracy: 0.59375


Current loss: 0.49071216583251953: 100%|████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.625


Current loss: 0.659812867641449: 100%|██████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.6796875


Current loss: 0.7482509613037109: 100%|█████████| 10/10 [00:01<00:00,  5.87it/s]


Accuracy: 0.7265625


Current loss: 0.19323651492595673: 100%|████████| 10/10 [00:01<00:00,  5.88it/s]


Accuracy: 0.671875


Current loss: 0.5127256512641907: 100%|█████████| 10/10 [00:01<00:00,  5.88it/s]


Accuracy: 0.65625


Current loss: 0.544008731842041: 100%|██████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.7109375


Current loss: 0.689791738986969: 100%|██████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.6484375


Current loss: 0.5910865068435669: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6875


Current loss: 0.6517602205276489: 100%|█████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.7109375


Current loss: 0.6766311526298523: 100%|█████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.59375


Current loss: 0.775007426738739: 100%|██████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.625


Current loss: 0.3718796968460083: 100%|█████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.703125


Current loss: 0.9054245948791504: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.59375


Current loss: 0.5490635633468628: 100%|█████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.6875


Current loss: 0.5987420082092285: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.671875


Current loss: 0.9499534368515015: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.6328125


Current loss: 0.40624895691871643: 100%|████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.6875


Current loss: 0.49399274587631226: 100%|████████| 10/10 [00:01<00:00,  5.90it/s]


Accuracy: 0.6484375


Current loss: 0.5268265008926392: 100%|█████████| 10/10 [00:01<00:00,  5.93it/s]


Accuracy: 0.7421875


Current loss: 0.37110620737075806: 100%|████████| 10/10 [00:01<00:00,  5.88it/s]


Accuracy: 0.6484375


Current loss: 0.6884624361991882: 100%|█████████| 10/10 [00:01<00:00,  5.88it/s]


Accuracy: 0.640625


Current loss: 0.5109096765518188: 100%|█████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.7265625


Current loss: 0.402387797832489: 100%|██████████| 10/10 [00:01<00:00,  5.87it/s]


Accuracy: 0.703125


Current loss: 0.3441546559333801: 100%|█████████| 10/10 [00:01<00:00,  5.87it/s]


Accuracy: 0.671875


Current loss: 0.7138121724128723: 100%|█████████| 10/10 [00:01<00:00,  5.89it/s]


Accuracy: 0.640625


Current loss: 0.5556538701057434: 100%|█████████| 10/10 [00:01<00:00,  5.87it/s]


Accuracy: 0.6796875


Current loss: 0.5776450634002686: 100%|█████████| 10/10 [00:01<00:00,  5.86it/s]


Accuracy: 0.625


Current loss: 0.572672963142395: 100%|██████████| 10/10 [00:01<00:00,  5.86it/s]


Accuracy: 0.640625


Current loss: 0.4028715491294861: 100%|█████████| 10/10 [00:01<00:00,  5.88it/s]


Accuracy: 0.7265625


Current loss: 0.5950084328651428: 100%|█████████| 10/10 [00:01<00:00,  5.96it/s]


Accuracy: 0.6796875


Current loss: 0.5427665710449219: 100%|█████████| 10/10 [00:01<00:00,  5.91it/s]


Accuracy: 0.578125


Current loss: 0.7414616942405701: 100%|█████████| 10/10 [00:01<00:00,  5.86it/s]


Accuracy: 0.703125


Current loss: 0.5241565108299255: 100%|█████████| 10/10 [00:01<00:00,  5.94it/s]


Accuracy: 0.6796875


Current loss: 0.8012564182281494: 100%|█████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.625


Current loss: 0.6034153699874878: 100%|█████████| 10/10 [00:01<00:00,  5.92it/s]


Accuracy: 0.7265625


Current loss: 0.7046204805374146: 100%|█████████| 10/10 [00:01<00:00,  5.84it/s]


Accuracy: 0.671875


Current loss: 1.042877197265625: 100%|██████████| 10/10 [00:00<00:00, 19.35it/s]


Accuracy: 0.6875


In [7]:
print(torch.cuda.is_available())

True


In [8]:
## Generate adv examples 

print("Average over all batches ",   sum(acc_list) / len(acc_list))   

#im_adv_transformed = transforms.Resize((224, 224)).forward(im_adv)
#print(im_adv[0].size())
#print(im_adv_transformed[0].size())
# predict and measure the accuracy 

# Get predicted labels for adversarial examples using 
#pred = model_std(test_list)
#label_pred = ch.argmax(pred, dim=1)


Average over all batches  0.6654469936708861
