In [1]:
import numpy as np
import torch
import torch.nn as nn
import time
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter
import matplotlib.pyplot as plt
import torchvision

In [2]:
#from google.colab import drive
#drive.mount('/content/gdrive')

In [20]:
%matplotlib inline
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x7fa8242c11f0>

In [4]:
from torchvision import transforms

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cuda device


# Handling the data

In [9]:
class Config(object):
    batch_size = 500

    totalT=2.0;

    n_layer=Ntime=4;

    sqrt_deltaT=np.sqrt(totalT/Ntime);

    logging_frequency = 100
    verbose = True

    input_chanel=1
    output_chanel_pj1=32
    output_chanel_pj2=16

    unflatten_shape=output_chanel_pj2*7*7

def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")

cfg=get_config('Config')

In [10]:
from torchvision import datasets
batch_size_train=cfg.batch_size
batch_size_test=cfg.batch_size

In [35]:
train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('/files/', train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=cfg.batch_size, shuffle=False)

test_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('/files/', train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=cfg.batch_size, shuffle=False)

In [42]:
for img, label in test_loader:
  print(img.sum())
  break

tensor(209122.0469)


In [38]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

0

In [13]:
example_data.abs().shape

torch.Size([500, 1, 28, 28])

In [14]:
example_data.abs().max()

tensor(2.8215)

We have stored both the training and the validation datasets

Defining the dataloader

## Defining the configuration

In [15]:
normal.rvs(size=[10])*0.14

array([-0.12158161, -0.11574859,  0.18398867, -0.05220034, -0.14767625,
       -0.07879621,  0.09517863,  0.06066116, -0.35531458,  0.06522647])

# Constructing a dense net

## Building the building block

In [43]:
class ProjBlock(nn.Module):
    def __init__(self,input_chanel,output_chanel):
        super(ProjBlock,self).__init__()
        self.input_chanel=input_chanel
        self.output_chanel=output_chanel

        self.conv1=nn.Conv2d(input_chanel,output_chanel,kernel_size=3,padding=1)
        self.act1=nn.Tanh()
        self.pool1=nn.MaxPool2d(2)

      #  self.conv2=nn.Conv2d(2*output_chanel,output_chanel,kernel_size=3,padding=1)
      #  self.act1=nn.Tanh()
      #  self.pool1=nn.MaxPool2d(2)

    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
      #  out = self.pool2(self.act2(self.conv2(x)))
        return out

In [51]:
class BasicBlock(nn.Module):
    def __init__(self,num_chanel):
        super(BasicBlock,self).__init__()
        self.input_chanel=num_chanel
        self.output_chanel=num_chanel

        self.conv=nn.Conv2d(self.input_chanel,self.output_chanel,kernel_size=3,padding=1)
        self.act=nn.Tanh()
        ## there should not be any MaxPooling layer in the inbetween set

    def forward(self,x):
        out=self.act(self.conv(x))
        return out

In [52]:
# One is responsible for figuring out the unflatten shape
class FullyConnected(nn.Module):
    def __init__(self,unflatten_shape):
        super(FullyConnected,self).__init__()
        self.unflatten_shape=unflatten_shape
        self.fc1=nn.Linear(unflatten_shape,32)
        self.ac1=nn.Tanh()
        self.fc2=nn.Linear(32,10)
        # Let's only tell the airplane from a bird

    def forward(self,x):
        inputx=x.view(-1, self.unflatten_shape)
        out=self.fc2(self.ac1(self.fc1(inputx)))
        return out

## Stacking up the blocks

In [53]:
normal.rvs(size=[2,2],random_state=12345)

array([[-0.20470766,  0.47894334],
       [-0.51943872, -0.5557303 ]])

In [78]:
normal.rvs(size=[2,2],random_state=12345)

array([[-0.20470766,  0.47894334],
       [-0.51943872, -0.5557303 ]])

In [79]:
loss_fn=nn.CrossEntropyLoss()
class ForwardModel(nn.Module):
    def __init__(self,config):
        super(ForwardModel,self).__init__()

        self.config=config
        self.batch_size=self.config.batch_size
        self.Ntime=self.config.Ntime
        self.sqrt_deltaT=self.config.sqrt_deltaT;
        self.n_layer=self.config.n_layer
        self.delta=self.config.totalT/self.Ntime;

        self.mList=nn.ModuleList([ProjBlock(self.config.input_chanel,self.config.output_chanel_pj1),
                                  ProjBlock(self.config.output_chanel_pj1,self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  FullyConnected(self.config.unflatten_shape)
        ])

        self.mList_diff=nn.ModuleList([
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2)
        ])


        self.sigma=4.0

    def forwardX(self,x,seed_num):# here x is the batch collection of images

        # Constructing the noises
        # The number 8 is determined from the number of max-pooling size, kernels & paddings etc.
        xMat=[]
        wMat=self.sigma*torch.FloatTensor(normal.rvs(size=[self.batch_size,        ### The batch_size for each different data point.
                                     self.config.output_chanel_pj2,7,7,
                                     self.Ntime],random_state=seed_num) * self.sqrt_deltaT).to(device)
        x0=torch.clone(x).to(device);
        xMat.append(x0);

        x_pj1=self.mList[0](x0);
        xMat.append(x_pj1.to(device));
        x_input=self.mList[1](x_pj1)
        xMat.append(x_input.to(device));

        for i in range(self.Ntime):
            # i + 2 because we already have two layers before
            xtemp=xMat[i+2]+self.mList[i+2](xMat[i+2])*self.delta + self.mList_diff[i](xMat[i+2]) *wMat[:,:,:,:,i] ## torch.sigmoid
            xMat.append(xtemp.to(device))

        x_terminal=self.mList[-1](xMat[-1])
        xMat.append(x_terminal.to(device))

        return xMat, wMat

        # The input of the target must be a tensor not a list


    def backwardYZ(self,xMat,wMat,target):
        yMat=[];
        zMat=[];


        L=len(xMat)
        x_terminal=xMat[-1].to(device)

        loss_val=loss_fn(x_terminal,target.to(device))
        loss_val.to(device);

        y_terminal=torch.autograd.grad(outputs=[loss_val], inputs=[x_terminal], grad_outputs=torch.ones_like(loss_val), allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        #Here y_terminal has dim batch_size x output_size (2 x 2)
        yMat.append(y_terminal.to(device));
        xtemp=xMat[L-2].to(device) # 3

        ## Finding Y[T-1]
        hami=torch.sum(y_terminal.detach()*self.mList[-1](xtemp),dim=1,keepdim=True) # keep dim=1 is correct
        hami=hami.view(-1,1);hami.to(device)

        hami_x=torch.autograd.grad(outputs=[hami], inputs=[xtemp], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]

        yMat.append(hami_x.to(device))

        for i in range(self.Ntime-1,-1,-1):
######### for Z ##
            ztemp=yMat[-1]*wMat[:,:,:,:,i]/self.sqrt_deltaT
            zMat.append(ztemp)

            X=xMat[i+2].to(device);
            hami=torch.sum(yMat[-1].detach()*self.mList[i+2](X) + ztemp.detach()*self.mList_diff[i](X),dim=(1,2,3))
            hami=hami.view(-1,1); hami.to(device);

            hami_x=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
            ytemp=yMat[-1]+hami_x*self.delta

            yMat.append(ytemp.to(device))

    ### Second projection layer
        X=xMat[1].to(device);
       # X.requires_grad
        hami=torch.sum(yMat[-1].detach()*self.mList[1](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device);

        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))

        X=xMat[0].to(device);
        X.requires_grad=True
        hami=torch.sum(yMat[-1].detach()*self.mList[0](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device)

        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))

        return yMat,zMat  #yMat the order is reversed


    def HamCompute(self,xMat,yMat,zMat):
        totalham=0.0
       # l2_norm=sum(p.pow(2.0).sum() for p in self.parameters() )
        for i in range(self.Ntime+3):
            ham_temp=torch.sum(yMat[self.Ntime+2-i].detach()*self.mList[i](xMat[i].detach()) )  #inside the bracket =  +\small_value * self.batch_size *self.mList[i]*self.mList[i] (No, this doesn't contain batchsize)
            totalham+=ham_temp
        for i in range(self.Ntime-1, -1 , -1):
          ham_temp=torch.sum(zMat[i].detach()*self.mList_diff[i](xMat[i+2].detach()))
          totalham+=ham_temp
        #totalham+=l2_norm*0.001
        return totalham/self.batch_size/(self.Ntime+3)

In [80]:

def train_accuracy(train_loader):
  correct = 0
  total = 0

  with torch.no_grad():
      seed_t=0
      for imgs, labels in train_loader:
          outputs = net.forwardX(imgs,seed_t)
          _, predicted = torch.max(outputs[0][-1], dim=1)
          total += labels.shape[0]
          correct += int((predicted == labels.to(device)).sum())
          seed_t=seed_t+1
  res=correct/total

  return res

def avg_predictionSeed(model,iters,imgs,labels):
    temp=torch.zeros((cfg.batch_size,10))
    temp=temp.to(device)
    for i in range(iters):
        val,_=model.forwardX(imgs,np.random.randint(1000000))
        temp+=val[-1].to(device)
    temp=temp/iters;
    _,predicts=torch.max(temp,dim=1)
    return predicts


def test_accuracy(val_loader):
  correct = 0
  total = 0
### One the other hand, it may be only reasonable to use
  with torch.no_grad():
      for imgs, labels in val_loader:
       #   outputs =net.forwardX(imgs,np.random.randint(1000))
       #   _, predicted = torch.max(outputs[0][-1], dim=1)
          predicted=avg_predictionSeed(net,5,imgs,labels)
          total += labels.shape[0]
          correct += int((predicted == labels.to(device)).sum())
  res=correct/total
  return res

In [81]:
n_epoch=10

net=ForwardModel(cfg)
net.to(device)

optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2500,4000], gamma=0.2)

Loss_vec=[]
training_accuracy=[]
testing_accuracy=[]

for epoch in range(n_epoch):
    seed_n=0
    for imgs, labels in train_loader:

        xmat,wmat=net.forwardX(imgs,seed_n);
        ymat,zmat=net.backwardYZ(xmat,wmat.to(device),labels)
        loss_temp=net.HamCompute(xmat,ymat,zmat)
        loss_temp.to(device)

        optimizer.zero_grad();
        loss_temp.backward()
        optimizer.step()
        seed_n=seed_n+1

    if epoch %1 ==0:
      print(epoch)
      loss_val=loss_fn(xmat[-1].to(device),labels.to(device))
      loss_val_np=loss_val.cpu().detach().numpy()
      print(loss_val_np)
      Loss_vec.append(loss_val_np)

      test_temp=test_accuracy(test_loader)
      testing_accuracy.append(test_temp)
      print(test_temp)

      train_temp=train_accuracy(train_loader)
      training_accuracy.append(train_temp)
      print(train_temp)

0
0.84447956
0.7065
0.6846666666666666
1
0.68417037
0.7415
0.7391833333333333
2
0.62124497
0.7625
0.7634166666666666
3
0.5981893
0.7746
0.7771166666666667
4
0.59521616
0.7767
0.7827666666666667
5
0.593731
0.7731
0.7833333333333333
6
0.5604158
0.7878
0.79515
7
0.5441509
0.7943
0.7971166666666667
8
0.54203355
0.7905
0.79775
9
0.49463117
0.7965
0.80195


In [82]:
import pandas as pd

In [83]:
#pd.DataFrame(training_accuracy).to_csv("02_snnTrain.csv")
#pd.DataFrame(testing_accuracy).to_csv("02_snnTest.csv")

# Test Adv Attack

In [84]:
net.eval();

In [85]:
# FGSM attack code
# which also works in the batch case
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
#    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

In [86]:
def avg_prediction(model,iters,imgs,labels):
    temp=torch.zeros((cfg.batch_size,10))
    temp=temp.to(device)
    for i in range(iters):
        val,_=model.forwardX(imgs)
        temp+=val[-1].to(device)
    temp=temp/iters;
    _,predicts=torch.max(temp,dim=1)
    return predicts

In [87]:
def test(model,device, epsilon, test_loader):
    correct = 0
    total = 0
    correct_avg=0

    for imgs, labels in test_loader:

        imgs, labels=imgs.to(device), labels.to(device)
        imgs.requires_grad=True
        xm,wm=model.forwardX(imgs,np.random.randint(1000))
        loss=loss_fn(xm[-1].to(device),labels)
        _,predict_init=torch.max(xm[-1],dim=1)


        model.zero_grad()
   #     # Calculate gradients of model in backward pass
        loss.backward()
        data_grad=imgs.grad.data
        perturbed_data = fgsm_attack(imgs, epsilon, data_grad)

   #     output,_ = model.forwardX(perturbed_data)

   #     _,predict_final=torch.max(output[-1],dim=1)

        predict_final_avg=avg_predictionSeed(net,10,perturbed_data,labels)

     #   correct += int((predict_final == labels.to(device)).sum())
        correct_avg += int((predict_final_avg.to(device) == labels.to(device)).sum())
        total += imgs.shape[0]
    return  correct_avg/total # correct/total,

In [88]:
epsilons = [0,.05,.1,.15,0.2, 0.3,0.4,0.5]

In [89]:
accuracies = []
examples = []

# Run test for each epsilon
for eps in epsilons:
    acc = test(net, device, eps, test_loader)
    print(acc)
    accuracies.append(acc)

0.8046
0.7824
0.7551
0.7235
0.6944
0.6206
0.5494
0.4934
