In [1]:
import numpy as np
import torch
import torch.nn as nn
import time 
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter
import matplotlib.pyplot as plt
import torchvision

In [2]:
%matplotlib inline
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x16fca333b90>

In [3]:
from torchvision import transforms
from torchvision import datasets 

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cuda device


In [5]:
class Config(object):
    batch_size = 500
    
    totalT=2.0;
    
    n_layer=Ntime=4; 
    
    sqrt_deltaT=np.sqrt(totalT/Ntime); 

    logging_frequency = 100
    verbose = True
   
    input_chanel=1
    output_chanel_pj1=32
    output_chanel_pj2=16 
    
    unflatten_shape=output_chanel_pj2*7*7
    
def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")
cfg=get_config('Config')

In [6]:
batch_size_train=cfg.batch_size
batch_size_test=cfg.batch_size

In [7]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [8]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

## Plain vanilla CNN

In [9]:
class ProjBlock(nn.Module):
    def __init__(self,input_chanel,output_chanel):
        super(ProjBlock,self).__init__()
        self.input_chanel=input_chanel
        self.output_chanel=output_chanel
        
        self.conv1=nn.Conv2d(input_chanel,output_chanel,kernel_size=3,padding=1) 
        self.act1=nn.Tanh()
        self.pool1=nn.MaxPool2d(2)
        
      #  self.conv2=nn.Conv2d(2*output_chanel,output_chanel,kernel_size=3,padding=1) 
      #  self.act1=nn.Tanh()
      #  self.pool1=nn.MaxPool2d(2)
    
    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
      #  out = self.pool2(self.act2(self.conv2(x)))
        return out

In [10]:
class BasicBlock(nn.Module):
    def __init__(self,num_chanel):
        super(BasicBlock,self).__init__()
        self.input_chanel=num_chanel
        self.output_chanel=num_chanel
        
        self.conv=nn.Conv2d(self.input_chanel,self.output_chanel,kernel_size=3,padding=1)
        self.act=nn.Tanh()
        ## there should not be any MaxPooling layer in the inbetween set
        
    def forward(self,x):
        out=self.act(self.conv(x))
        return out

In [11]:
class FullyConnected(nn.Module):
    def __init__(self,unflatten_shape): 
        super(FullyConnected,self).__init__()
        self.unflatten_shape=unflatten_shape
        self.fc1=nn.Linear(unflatten_shape,32)
        self.ac1=nn.Tanh()
        self.fc2=nn.Linear(32,10) 
        # Let's only tell the airplane from a bird
    
    def forward(self,x):
        inputx=x.view(-1, self.unflatten_shape)
        out=self.fc2(self.ac1(self.fc1(inputx)))
        return out

## Stacking up the blocks

In [12]:
loss_fn=nn.CrossEntropyLoss()

In [13]:
class ForwardModel(nn.Module):
    def __init__(self,config):
        super(ForwardModel,self).__init__()
        
        self.config=config
        self.batch_size=self.config.batch_size
        self.Ntime=self.config.Ntime
        self.sqrt_deltaT=self.config.sqrt_deltaT;
        self.n_layer=self.config.n_layer
        
        ## The structure is merely a stack-up of the convolutional blocks
        self.mList=nn.ModuleList([ProjBlock(self.config.input_chanel,self.config.output_chanel_pj1),
                                  ProjBlock(self.config.output_chanel_pj1,self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  BasicBlock(self.config.output_chanel_pj2),
                                  FullyConnected(self.config.unflatten_shape)                              
        ])
    
    def forward(self,data):
        data_temp=torch.clone(data)
        for block in self.mList:
            data_temp=block(data_temp)
        return data_temp

In [23]:
net=ForwardModel(cfg)
net.to(device);

In [24]:
optimizer=optim.Adam(net.parameters(), lr=1.5e-3)#it could be a bad idea to add weight decay

In [16]:
def train_accuracy(model,train_loader):
    total=0;
    correct=0;
    for imgs,labels in train_loader: 
        imgs, labels=imgs.to(device), labels.to(device)
        output=model(imgs)
        _, predicted = torch.max(output, dim=1)
        
        total += imgs.shape[0]
        correct += int((predicted == labels.to(device)).sum())
    return correct/total

In [17]:
total=0;
correct=0;
n_epoch=10

for epoch in range(n_epoch):
    for imgs, labels in train_loader:
        imgs, labels=imgs.to(device), labels.to(device)
        
        output=net(imgs); 
        loss_temp=loss_fn(output,labels)
        loss_temp.to(device)
        
        optimizer.zero_grad();
        loss_temp.backward()
        optimizer.step()
        
    if epoch %2==0:
        train_acc=train_accuracy(net,train_loader)
        print(train_acc)
        test_acc=train_accuracy(net,test_loader)
        print(test_acc)
       
        

0.9519833333333333
0.953
0.9788666666666667
0.9762
0.9879166666666667
0.9844
0.9893
0.9849
0.994
0.988


### We now save the coefficient of the plain vanilla nn

The train accuracy after 8 epochs achieves 99.2%, and the test accuracy achieves 98.7%

In [20]:
PATH='data/VanillaCNN_mnist_model.pth'

In [21]:
torch.save(net.state_dict(), PATH)