In [1]:
import numpy as np
import torch
import torch.nn as nn
import time
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter
import matplotlib.pyplot as plt
import torchvision

In [2]:
#from google.colab import drive
#drive.mount('/content/gdrive')

In [3]:
%matplotlib inline
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x7f3d24149c10>

In [4]:
from torchvision import transforms

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cuda device


# Handling the data

In [6]:
class Config(object):
    batch_size = 500

    totalT=2.0;

    n_layer=Ntime=4;

    sqrt_deltaT=np.sqrt(totalT/Ntime);

    logging_frequency = 100
    verbose = True

    input_chanel=1
    output_chanel_pj1=32
    output_chanel_pj2=16

    unflatten_shape=output_chanel_pj2*7*7

def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")

cfg=get_config('Config')

In [7]:
from torchvision import datasets
batch_size_train=cfg.batch_size
batch_size_test=cfg.batch_size

In [8]:
train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('/files/', train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=cfg.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('/files/', train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=cfg.batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /files/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 8370954.84it/s] 


Extracting /files/FashionMNIST/raw/train-images-idx3-ubyte.gz to /files/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /files/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 141518.52it/s]


Extracting /files/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /files/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /files/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2583484.70it/s]


Extracting /files/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /files/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /files/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8544628.81it/s]

Extracting /files/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /files/FashionMNIST/raw






In [9]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

We have stored both the training and the validation datasets

Defining the dataloader

## Defining the configuration

# Constructing a dense net

## Building the building block

In [10]:
class ProjBlock(nn.Module):
    def __init__(self,input_chanel,output_chanel):
        super(ProjBlock,self).__init__()
        self.input_chanel=input_chanel
        self.output_chanel=output_chanel

        self.conv1=nn.Conv2d(input_chanel,output_chanel,kernel_size=3,padding=1)
        self.act1=nn.Tanh()
        self.pool1=nn.MaxPool2d(2)

      #  self.conv2=nn.Conv2d(2*output_chanel,output_chanel,kernel_size=3,padding=1)
      #  self.act1=nn.Tanh()
      #  self.pool1=nn.MaxPool2d(2)

    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
      #  out = self.pool2(self.act2(self.conv2(x)))
        return out

In [11]:
class BasicBlock(nn.Module):
    def __init__(self,num_chanel):
        super(BasicBlock,self).__init__()
        self.input_chanel=num_chanel
        self.output_chanel=num_chanel

        self.conv=nn.Conv2d(self.input_chanel,self.output_chanel,kernel_size=3,padding=1)
        self.act=nn.Tanh()
        ## there should not be any MaxPooling layer in the inbetween set

    def forward(self,x):
        out=self.act(self.conv(x))
        return out

In [12]:
# One is responsible for figuring out the unflatten shape
class FullyConnected(nn.Module):
    def __init__(self,unflatten_shape):
        super(FullyConnected,self).__init__()
        self.unflatten_shape=unflatten_shape
        self.fc1=nn.Linear(unflatten_shape,32)
        self.ac1=nn.Tanh()
        self.fc2=nn.Linear(32,10)
        # Let's only tell the airplane from a bird

    def forward(self,x):
        inputx=x.view(-1, self.unflatten_shape)
        out=self.fc2(self.ac1(self.fc1(inputx)))
        return out

## Stacking up the blocks

In [13]:
normal.rvs(size=[2,2],random_state=12345)

array([[-0.20470766,  0.47894334],
       [-0.51943872, -0.5557303 ]])

In [14]:
normal.rvs(size=[2,2],random_state=12345)

array([[-0.20470766,  0.47894334],
       [-0.51943872, -0.5557303 ]])

In [25]:
loss_fn=nn.CrossEntropyLoss()
class ForwardModel(nn.Module):
    def __init__(self,config):
        super(ForwardModel,self).__init__()

        self.config=config
        self.batch_size=self.config.batch_size
        self.Ntime=self.config.Ntime
        self.sqrt_deltaT=self.config.sqrt_deltaT;
        self.n_layer=self.config.n_layer
        self.delta=self.config.totalT/self.Ntime;

        self.mList=nn.ModuleList([ProjBlock(self.config.input_chanel,self.config.output_chanel_pj1),
                                  ProjBlock(self.config.output_chanel_pj1,self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  FullyConnected(self.config.unflatten_shape)
        ])

        self.mList_diff=nn.ModuleList([
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2)
        ])

        self.sigma=0.2

    def forwardX(self,x):# here x is the batch collection of images

        # Constructing the noises
        # The number 8 is determined from the number of max-pooling size, kernels & paddings etc.
        xMat=[]
        wMat=self.sigma*torch.FloatTensor(normal.rvs(size=[self.batch_size,        ### The batch_size for each different data point.
                                     self.config.output_chanel_pj2,7,7,
                                     self.Ntime]) * self.sqrt_deltaT).to(device)
        x0=torch.clone(x).to(device);
        xMat.append(x0);

        x_pj1=self.mList[0](x0);
        xMat.append(x_pj1.to(device));
        x_input=self.mList[1](x_pj1)
        xMat.append(x_input.to(device));

        for i in range(self.Ntime):
            # i + 2 because we already have two layers before
            xtemp=xMat[i+2]+self.mList[i+2](xMat[i+2])*self.delta +  self.mList_diff[i](xMat[i+2]) *wMat[:,:,:,:,i] ## torch.sigmoid
            xMat.append(xtemp.to(device))

        x_terminal=self.mList[-1](xMat[-1])
        xMat.append(x_terminal.to(device))

        return xMat, wMat

        # The input of the target must be a tensor not a list


    def backwardYZ(self,xMat,wMat,target):
        yMat=[];
        zMat=[];


        L=len(xMat)
        x_terminal=xMat[-1].to(device)

        loss_val=loss_fn(x_terminal,target.to(device))
        loss_val.to(device);

        y_terminal=torch.autograd.grad(outputs=[loss_val], inputs=[x_terminal], grad_outputs=torch.ones_like(loss_val), allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        #Here y_terminal has dim batch_size x output_size (2 x 2)
        yMat.append(y_terminal.to(device));
        xtemp=xMat[L-2].to(device) # 3

        ## Finding Y[T-1]
        hami=torch.sum(y_terminal.detach()*self.mList[-1](xtemp),dim=1,keepdim=True) # keep dim=1 is correct
        hami=hami.view(-1,1);hami.to(device)

        hami_x=torch.autograd.grad(outputs=[hami], inputs=[xtemp], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]

        yMat.append(hami_x.to(device))

        for i in range(self.Ntime-1,-1,-1):
######### for Z ##
            ztemp=yMat[-1]*wMat[:,:,:,:,i]/self.sqrt_deltaT
            zMat.append(ztemp)

            X=xMat[i+2].to(device);
            hami=torch.sum(yMat[-1].detach()*self.mList[i+2](X) + ztemp.detach()*self.mList_diff[i](X),dim=(1,2,3))
            hami=hami.view(-1,1); hami.to(device);

            hami_x=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
            ytemp=yMat[-1]+hami_x*self.delta

            yMat.append(ytemp.to(device))

    ### Second projection layer
        X=xMat[1].to(device);
       # X.requires_grad
        hami=torch.sum(yMat[-1].detach()*self.mList[1](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device);

        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))

        X=xMat[0].to(device);
        X.requires_grad=True
        hami=torch.sum(yMat[-1].detach()*self.mList[0](X),dim=(1,2,3))
        hami=hami.view(-1,1); hami.to(device)

        ytemp=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami),allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
        yMat.append(ytemp.to(device))

        return yMat,zMat  #yMat the order is reversed


    def HamCompute(self,xMat,yMat,zMat):
        totalham=0.0
       # l2_norm=sum(p.pow(2.0).sum() for p in self.parameters() )
        for i in range(self.Ntime+3):
            ham_temp=torch.sum(yMat[self.Ntime+2-i].detach()*self.mList[i](xMat[i].detach()) )  #inside the bracket =  +\small_value * self.batch_size *self.mList[i]*self.mList[i] (No, this doesn't contain batchsize)
            totalham+=ham_temp
        for i in range(self.Ntime-1, -1 , -1):
          ham_temp=torch.sum(zMat[i].detach()*self.mList_diff[i](xMat[i+2].detach()))
          totalham+=ham_temp
        #totalham+=l2_norm*0.001
        return totalham/self.batch_size/(self.Ntime+3)

In [26]:

def train_accuracy(train_loader):
  correct = 0
  total = 0

  with torch.no_grad():
      for imgs, labels in train_loader:
          outputs = net.forwardX(imgs)
          _, predicted = torch.max(outputs[0][-1], dim=1)
          total += labels.shape[0]
          correct += int((predicted == labels.to(device)).sum())
  res=correct/total

  return res

def test_accuracy(val_loader):
  correct = 0
  total = 0

  with torch.no_grad():
      for imgs, labels in val_loader:
          outputs =net.forwardX(imgs)
          _, predicted = torch.max(outputs[0][-1], dim=1)
          total += labels.shape[0]
          correct += int((predicted == labels.to(device)).sum())
  res=correct/total
  return res

In [27]:
n_epoch=10

net=ForwardModel(cfg)
net.to(device)

optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2500,4000], gamma=0.2)

Loss_vec=[]
training_accuracy=[]
testing_accuracy=[]

for epoch in range(n_epoch):
    for imgs, labels in train_loader:


        xmat,wmat=net.forwardX(imgs);
        ymat,zmat=net.backwardYZ(xmat,wmat.to(device),labels)
        loss_temp=net.HamCompute(xmat,ymat,zmat)
        loss_temp.to(device)

        optimizer.zero_grad();
        loss_temp.backward()
        optimizer.step()

    if epoch %1 ==0:
        loss_val=loss_fn(xmat[-1].to(device),labels.to(device))
       # ham_loss=net.HamCompute(xmat,ymat)
       # print(ham_loss.cpu().detach().numpy(), loss_val.cpu().detach().numpy())
        loss_val_np=loss_val.cpu().detach().numpy()
        print(epoch, loss_val_np)
        Loss_vec.append(loss_val_np)

    #if epoch %10 ==0:
        test_temp=test_accuracy(test_loader)
        testing_accuracy.append(test_temp)
        print(test_temp)

        train_temp=train_accuracy(train_loader)
        training_accuracy.append(train_temp)
        print(train_temp)

0 0.47832537
0.8367
0.8479166666666667
1 0.36684072
0.857
0.86975
2 0.288941
0.8747
0.88525
3 0.31472087
0.8846
0.8976333333333333
4 0.30505952
0.8888
0.9027333333333334
5 0.2971817
0.8964
0.9132666666666667
6 0.23081361
0.8983
0.9168
7 0.24328998
0.8979
0.91845
8 0.23173404
0.8997
0.9183666666666667
9 0.29068804
0.9025
0.9263666666666667


In [28]:
import pandas as pd

In [29]:
pd.DataFrame(training_accuracy).to_csv("02_snnTrain.csv")
pd.DataFrame(testing_accuracy).to_csv("02_snnTest.csv")

In [None]:
n_epoch=20

net=ForwardModel(cfg)
net.to(device)

optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2500,4000], gamma=0.2)

Loss_vec=[]
training_accuracy=[]
testing_accuracy=[]

for epoch in range(n_epoch):
    for imgs, labels in train_loader:


        xmat,wmat=net.forwardX(imgs);
        ymat=net.backwardYZ(xmat,wmat.to(device),labels)
        loss_temp=net.HamCompute(xmat,ymat)
        loss_temp.to(device)

        optimizer.zero_grad();
        loss_temp.backward()
        optimizer.step()

    if epoch %1 ==0:
        loss_val=loss_fn(xmat[-1].to(device),labels.to(device))
       # ham_loss=net.HamCompute(xmat,ymat)
       # print(ham_loss.cpu().detach().numpy(), loss_val.cpu().detach().numpy())
        loss_val_np=loss_val.cpu().detach().numpy()
        print(epoch, loss_val_np)
        Loss_vec.append(loss_val_np)

    #if epoch %10 ==0:
        test_temp=test_accuracy(test_loader)
        testing_accuracy.append(test_temp)
        print(test_temp)

        train_temp=train_accuracy(train_loader)
        training_accuracy.append(train_temp)
        print(train_temp)

0 0.4600175
0.8342
0.84065
1 0.3930344
0.8601
0.8712166666666666
2 0.3804343
0.8762
0.8894166666666666
3 0.31233355
0.8764
0.8934
4 0.33462223
0.89
0.9055
5 0.2883765
0.8934
0.9103166666666667
6 0.24727301
0.8978
0.9158166666666666
7 0.24122407
0.8946
0.9167
8 0.22976273
0.9036
0.9237166666666666
9 0.25150692
0.908
0.9311
10 0.23556092
0.9021
0.9269833333333334
11 0.20559546
0.9096
0.9346166666666667
12 0.20727155
0.9073
0.9358333333333333
13 0.16154906
0.9059
0.936
14 0.17529136
0.9134
0.9409333333333333
15 0.15335551
0.9096
0.944
16 0.19574562
0.9133
0.948
17 0.14522822
0.9152
0.9506
18 0.12336541
0.911
0.9519
19 0.105183505
0.9147
0.9537166666666667
