In [1]:
import numpy as np
import torch
import torch.nn as nn
import time 
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter
import matplotlib.pyplot as plt
import torchvision

In [2]:
%matplotlib inline
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x21fb44c3b90>

In [3]:
label_dim=10

In [4]:
from torchvision import transforms
from torchvision import datasets

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cuda device


In [6]:
class Config(object):
    batch_size = 500
    
    totalT=2.0;
    
    n_layer=Ntime=4; 
    
    sqrt_deltaT=np.sqrt(totalT/Ntime); 

    logging_frequency = 100
    verbose = True
   
    input_chanel=1
    output_chanel_pj1=32
    output_chanel_pj2=16 
    
    unflatten_shape=output_chanel_pj2*7*7
    
def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")

In [7]:
cfg=get_config('Config')

In [8]:
batch_size_train=cfg.batch_size
batch_size_test=cfg.batch_size

In [9]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [10]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

## The building Blocks

In [11]:
class ProjBlock(nn.Module):
    def __init__(self,input_chanel,output_chanel):
        super(ProjBlock,self).__init__()
        self.input_chanel=input_chanel
        self.output_chanel=output_chanel
        
        self.conv1=nn.Conv2d(input_chanel,output_chanel,kernel_size=3,padding=1) 
        self.act1=nn.Tanh()
        self.pool1=nn.MaxPool2d(2)
        
      #  self.conv2=nn.Conv2d(2*output_chanel,output_chanel,kernel_size=3,padding=1) 
      #  self.act1=nn.Tanh()
      #  self.pool1=nn.MaxPool2d(2)
    
    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
      #  out = self.pool2(self.act2(self.conv2(x)))
        return out

class BasicBlock(nn.Module):
    def __init__(self,num_chanel):
        super(BasicBlock,self).__init__()
        self.input_chanel=num_chanel
        self.output_chanel=num_chanel
        
        self.conv=nn.Conv2d(self.input_chanel,self.output_chanel,kernel_size=3,padding=1)
        self.act=nn.Tanh()
        ## there should not be any MaxPooling layer in the inbetween set
        
    def forward(self,x):
        out=self.act(self.conv(x))
        return out

# One is responsible for figuring out the unflatten shape
class FullyConnected(nn.Module):
    def __init__(self,unflatten_shape): 
        super(FullyConnected,self).__init__()
        self.unflatten_shape=unflatten_shape
        self.fc1=nn.Linear(unflatten_shape,32)
        self.ac1=nn.Tanh()
        self.fc2=nn.Linear(32,10) 
        # Let's only tell the airplane from a bird
    
    def forward(self,x):
        inputx=x.view(-1, self.unflatten_shape)
        out=self.fc2(self.ac1(self.fc1(inputx)))
        return out

### Structure of the model

In [12]:
loss_fn=nn.CrossEntropyLoss()
class ForwardModel(nn.Module):
    def __init__(self,config):
        super(ForwardModel,self).__init__()
        
        self.config=config
        self.batch_size=self.config.batch_size
        self.Ntime=self.config.Ntime
        self.sqrt_deltaT=self.config.sqrt_deltaT;
        self.n_layer=self.config.n_layer
        self.delta=self.config.totalT/self.Ntime;
        
        ## The structure is merely a stack-up of the convolutional blocks
        self.mList=nn.ModuleList([ProjBlock(self.config.input_chanel,self.config.output_chanel_pj1),
                                  ProjBlock(self.config.output_chanel_pj1,self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  FullyConnected(self.config.unflatten_shape)                              
        ])
        
        self.sigma=0.25
        
    def forwardX(self,x):# here x is the batch collection of images
        
        # Constructing the noises
        # The number 8 is determined from the number of max-pooling size, kernels & paddings etc. 
        xMat=[]
        wMat=torch.FloatTensor(normal.rvs(size=[self.batch_size,
                                     self.config.output_chanel_pj2,7,7,
                                     self.Ntime]) * self.sqrt_deltaT).to(device)
        x0=torch.clone(x).to(device); 
        xMat.append(x0); 
        
        x_pj1=self.mList[0](x0); 
        xMat.append(x_pj1.to(device)); 
        x_input=self.mList[1](x_pj1)
        xMat.append(x_input.to(device));
        
        for i in range(self.Ntime):
            # i + 2 because we already have two layers before
            xtemp=xMat[i+2]+self.mList[i+2](xMat[i+2])*self.delta +self.sigma*wMat[:,:,:,:,i] 
            xMat.append(xtemp.to(device))
        
        x_terminal=self.mList[-1](xMat[-1])
        xMat.append(x_terminal.to(device))
        
        
        return xMat, wMat
        
        # The input of the target must be a tensor not a list
    def backwardYZ(self,xMat,wMat,target):
        yMat=[];  
        
        L=len(xMat)
        x_terminal=xMat[-1].to(device)
        
        loss_val=loss_fn(x_terminal,target.to(device))
        loss_val.to(device); 
        
        y_terminal=torch.autograd.grad(outputs=[loss_val], inputs=[x_terminal], grad_outputs=torch.ones_like(loss_val), allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        #Here y_terminal has dim batch_size x output_size (2 x 2)
        yMat.append(y_terminal.to(device)); 
        xtemp=xMat[L-2].to(device) # 3 
        
        ## Finding Y[T-1]
        hami=torch.sum(y_terminal.detach()*self.mList[-1](xtemp),dim=1,keepdim=True) # keep dim=1 is correct
        hami=hami.view(-1,1);hami.to(device)

        hami_x=torch.autograd.grad(outputs=[hami], inputs=[xtemp], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
    
        yMat.append(hami_x.to(device))
        
        for i in range(self.Ntime-1,-1,-1):
            X=xMat[i+2].to(device); 
            hami=torch.sum(yMat[-1].detach()*self.mList[i+2](X),dim=(1,2,3))
            hami=hami.view(-1,1); hami.to(device); 
            
            hami_x=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
            ytemp=yMat[-1]+hami_x*self.delta

            yMat.append(ytemp.to(device))
       
    ### Second projection layer
        X=xMat[1].to(device); 
       # X.requires_grad
        hami=torch.sum(yMat[-1].detach()*self.mList[1](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device); 
            
        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))
        
        X=xMat[0].to(device); 
        X.requires_grad=True
        hami=torch.sum(yMat[-1].detach()*self.mList[0](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device)
            
        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))
        
        return yMat  #yMat the order is reversed 
    
    def HamCompute(self,xMat,yMat):
        totalham=0.0
        
        for i in range(self.Ntime+3):
            ham_temp=torch.sum(yMat[self.Ntime+2-i].detach()*self.mList[i](xMat[i].detach()) )  #inside the bracket =  +\small_value * self.batch_size *self.mList[i]*self.mList[i] (No, this doesn't contain batchsize)
            totalham+=ham_temp
        
        return totalham/self.batch_size/(self.Ntime+3)

In [13]:
pretrained_model = "data/SNN_mnist_model.pth"

In [14]:
net=ForwardModel(cfg)
net.to(device);

In [15]:
net.load_state_dict(torch.load(pretrained_model, map_location='cpu'))

<All keys matched successfully>

In [16]:
# Set the model to evaluation mode
net.eval();

In [17]:
# FGSM attack code
# which also works in the batch case
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
#    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

### Defining the test accuracy function 

In [18]:
def avg_prediction(model,iters,imgs,labels):
    temp=torch.zeros((cfg.batch_size,label_dim))
    temp=temp.to(device)
    for i in range(iters):
        val,_=model.forwardX(imgs)
        temp+=val[-1].to(device)
    temp=temp/iters;
    _,predicts=torch.max(temp,dim=1)
    return predicts

In [19]:
def test(model,device, epsilon, test_loader):
    correct = 0
    total = 0
    correct_avg=0
    
    for imgs, labels in test_loader:
        
        imgs, labels=imgs.to(device), labels.to(device)
        imgs.requires_grad=True
        xm,wm=model.forwardX(imgs)
        loss=loss_fn(xm[-1].to(device),labels)
        _,predict_init=torch.max(xm[-1],dim=1)
        
        
        model.zero_grad()
   #     # Calculate gradients of model in backward pass
        loss.backward()
        data_grad=imgs.grad.data
        perturbed_data = fgsm_attack(imgs, epsilon, data_grad)
        
   #     output,_ = model.forwardX(perturbed_data)
        
   #     _,predict_final=torch.max(output[-1],dim=1)
        
        predict_final_avg=avg_prediction(net,10,perturbed_data,labels)
        
     #   correct += int((predict_final == labels.to(device)).sum())
        correct_avg += int((predict_final_avg.to(device) == labels.to(device)).sum())
        total += imgs.shape[0]
    return  correct_avg/total # correct/total,

In [20]:
test(net,device, 0.3, test_loader)

0.5281

In [21]:
epsilons = [0,.05,.1,.15,0.2, 0.3,0.4,0.5] #.1, .15, .2, .25, .3

In [22]:
accuracies = []
examples = []

# Run test for each epsilon
for eps in epsilons:
    acc = test(net, device, eps, test_loader)
    print(acc)
    accuracies.append(acc)

0.9886
0.9583
0.9081
0.8272
0.7275
0.5197
0.3718
0.2678
