In [2]:
import numpy as np
import torch
import torch.nn as nn
import time 
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter
import matplotlib.pyplot as plt
import torchvision

In [3]:
%matplotlib inline
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x213a8ecbd30>

In [4]:
from torchvision import transforms
from torchvision import datasets 

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cuda device


In [6]:
class Config(object):
    batch_size = 500
    
    totalT=2.0;
    
    n_layer=Ntime=4; 
    
    sqrt_deltaT=np.sqrt(totalT/Ntime); 

    logging_frequency = 100
    verbose = True
   
    input_chanel=1
    output_chanel_pj1=32
    output_chanel_pj2=16 
    
    unflatten_shape=output_chanel_pj2*7*7
    
def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")
cfg=get_config('Config')

In [7]:
batch_size_train=cfg.batch_size
batch_size_test=cfg.batch_size

In [8]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [9]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

## Plain vanilla CNN

In [10]:
class ProjBlock(nn.Module):
    def __init__(self,input_chanel,output_chanel):
        super(ProjBlock,self).__init__()
        self.input_chanel=input_chanel
        self.output_chanel=output_chanel
        
        self.conv1=nn.Conv2d(input_chanel,output_chanel,kernel_size=3,padding=1) 
        self.act1=nn.Tanh()
        self.pool1=nn.MaxPool2d(2)
        
      #  self.conv2=nn.Conv2d(2*output_chanel,output_chanel,kernel_size=3,padding=1) 
      #  self.act1=nn.Tanh()
      #  self.pool1=nn.MaxPool2d(2)
    
    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
      #  out = self.pool2(self.act2(self.conv2(x)))
        return out

In [11]:
class BasicBlock(nn.Module):
    def __init__(self,num_chanel):
        super(BasicBlock,self).__init__()
        self.input_chanel=num_chanel
        self.output_chanel=num_chanel
        
        self.conv=nn.Conv2d(self.input_chanel,self.output_chanel,kernel_size=3,padding=1)
        self.act=nn.Tanh()
        ## there should not be any MaxPooling layer in the inbetween set
        
    def forward(self,x):
        out=self.act(self.conv(x))
        return out

In [12]:
class FullyConnected(nn.Module):
    def __init__(self,unflatten_shape): 
        super(FullyConnected,self).__init__()
        self.unflatten_shape=unflatten_shape
        self.fc1=nn.Linear(unflatten_shape,32)
        self.ac1=nn.Tanh()
        self.fc2=nn.Linear(32,10) 
        # Let's only tell the airplane from a bird
    
    def forward(self,x):
        inputx=x.view(-1, self.unflatten_shape)
        out=self.fc2(self.ac1(self.fc1(inputx)))
        return out

## Stacking up the blocks

In [13]:
loss_fn=nn.CrossEntropyLoss()

In [14]:
class ForwardModel(nn.Module):
    def __init__(self,config):
        super(ForwardModel,self).__init__()
        
        self.config=config
        self.batch_size=self.config.batch_size
        self.Ntime=self.config.Ntime
        self.sqrt_deltaT=self.config.sqrt_deltaT;
        self.n_layer=self.config.n_layer
        
        ## The structure is merely a stack-up of the convolutional blocks
        self.mList=nn.ModuleList([ProjBlock(self.config.input_chanel,self.config.output_chanel_pj1),
                                  ProjBlock(self.config.output_chanel_pj1,self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  FullyConnected(self.config.unflatten_shape)                              
        ])
    
    def forward(self,data):
        data_temp=torch.clone(data)
        for block in self.mList:
            data_temp=block(data_temp)
        return data_temp

In [15]:
def train_accuracy(model,train_loader):
    total=0;
    correct=0;
    for imgs,labels in train_loader: 
        imgs, labels=imgs.to(device), labels.to(device)
        output=model(imgs)
        _, predicted = torch.max(output, dim=1)
        
        total += imgs.shape[0]
        correct += int((predicted == labels.to(device)).sum())
    return correct/total

In [16]:
# FGSM attack code
# which also works in the batch case
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
#    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

### We now read the coefficient of the plain vanilla nn

The train accuracy after 8 epochs achieves 99.2%, and the test accuracy achieves 98.7%

In [17]:
pretrained_model='data/VanillaCNN_mnist_model.pth'

In [18]:
net=ForwardModel(cfg)
net.to(device);
net.load_state_dict(torch.load(pretrained_model, map_location='cpu'))

<All keys matched successfully>

In [19]:
# The net will only evaluate samples from now
#net.eval();

## In this part we let the attacker and the model play games

In [20]:
example_data=example_data.to(device)
example_targets=example_targets.to(device)

In [21]:
correct = 0
total = 0
correct_post=0
total_post=0

optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay

net.eval()
example_data.requires_grad=True
output=net(example_data)
loss=loss_fn(output,example_targets)
_,pred=torch.max(output,dim=1)
net.zero_grad()
loss.backward()

data_grad=example_data.grad.data # you can obtain the grad in this way??? 
peturbed_data=fgsm_attack(example_data,0.2,data_grad)
output_attack=net(peturbed_data)
_,pred_attack=torch.max(output_attack, dim=1)


correct += int((pred_attack.to(device) == example_targets.to(device)).sum())
total += example_data.shape[0]
correct/total

# Now we need to train the model
# Set the data to be untrainable
peturbed_data= peturbed_data.detach()
net.train() #Now we set the net to train mode
output_post=net(peturbed_data)
loss_attack=loss_fn(output_attack, example_targets)
loss_attack.to(device)

optimizer.zero_grad()
loss_attack.backward()
optimizer.step()

output_trained=net(peturbed_data)
_,pred_trained=torch.max(output_trained, dim=1)

correct_post += int((pred_trained.to(device) == example_targets.to(device)).sum())
total_post += example_data.shape[0]
correct_post/total

### Run this all together

In [22]:
net.eval()
correct_0=0 
total_0=0

for imgs, labels in test_loader:
    imgs, labels=imgs.to(device), labels.to(device)
    imgs.requires_grad=True
 
    output=net(imgs)
    _,pred=torch.max(output,dim=1)
    correct_0 += int((pred.to(device) == labels.to(device)).sum())
    total_0 += imgs.shape[0]

In [23]:
correct_0/total_0

0.9876

In [24]:
net=ForwardModel(cfg)
net.to(device);
net.load_state_dict(torch.load(pretrained_model, map_location='cpu'))

correct = 0
total = 0
correct_post=0
total_post=0
optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay

attack_epsilon=0.5

In [25]:
for _ in range(5):
    for imgs, labels in test_loader:
        imgs, labels=imgs.to(device), labels.to(device)
        imgs.requires_grad=True

        ### Attacker starts to work ### 
        net.eval()
        output=net(imgs)
        loss=loss_fn(output,labels)
        _,pred=torch.max(output,dim=1)
        net.zero_grad()
        loss.backward()

        data_grad=imgs.grad.data # you can obtain the grad in this way??? 
        peturbed_data=fgsm_attack(imgs,attack_epsilon,data_grad)
        output_attack=net(peturbed_data)
        _,pred_attack=torch.max(output_attack, dim=1)


        correct += int((pred_attack.to(device) == labels.to(device)).sum())
        total += imgs.shape[0]

        ### Self-correction made by the model ### 
        peturbed_data= peturbed_data.detach()
        net.train() #Now we set the net to train mode
        output_post=net(peturbed_data)
        loss_attack=loss_fn(output_attack, labels)
        loss_attack.to(device)

        optimizer.zero_grad()
        loss_attack.backward()
        optimizer.step()

        output_trained=net(peturbed_data)
        _,pred_trained=torch.max(output_trained, dim=1)

        correct_post += int((pred_trained.to(device) == labels.to(device)).sum())
        total_post += imgs.shape[0]

In [26]:
correct/total, correct_post/total_post

(0.68816, 0.69938)

As a result, facing the attack, it will have accuarcy of 0.6981. 
How about the natural accuracy? 

In [27]:
net.eval()
correct_natural=0 
total_natural=0

for imgs, labels in test_loader:
    imgs, labels=imgs.to(device), labels.to(device)
    imgs.requires_grad=True
 
    output=net(imgs)
    _,pred=torch.max(output,dim=1)
    correct_natural += int((pred.to(device) == labels.to(device)).sum())
    total_natural += imgs.shape[0]

In [28]:
correct_natural/total_natural

0.9663