# <center>Implementing SVRG in a state-of-the-art CNN model</center>

##### Import Pytorch and other useful librairies

In [1]:
from matplotlib import pyplot as plt
plt.gray()
import math
import numpy as np
import pandas as pd

import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
from torch import nn
from torch import optim
import copy

In [2]:
loss_func = F.cross_entropy

def accuracy(Y_hat, Y):
    preds = torch.argmax(Y_hat, dim=1)
    return (preds == Y).float().mean()

##### Load and preprocess dataset

In [106]:
#import data
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=None)

#load trainset into tensors
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=1, shuffle=True)
X_train = train_loader.dataset.data
Y_train = train_loader.dataset.targets

#load testset into tensors
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10000, shuffle=False)
X_test = test_loader.dataset.data
Y_test = test_loader.dataset.targets

#scale data to [0:1] and convert to float32
X_train = (X_train.to(dtype=torch.float32) / X_train.max().to(dtype=torch.float32))
X_test = (X_test.to(dtype=torch.float32) / X_test.max().to(dtype=torch.float32))

#Flatten train and test data
X_train = X_train.reshape(X_train.shape[0],1,28,28)
X_test = X_test.reshape(X_test.shape[0],1,28,28)

print("Train examples : ",X_train.shape[0])
print("Test examples : ",X_test.shape[0])
print("Nb of features : ",X_train.shape[1])

Train examples :  60000
Test examples :  10000
Nb of features :  1


##### Define the CNN architecture

In [52]:
class NeuralNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 24, kernel_size=5, stride=1, padding=2)
        self.max1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.bn1 = nn.BatchNorm2d(24)
        self.conv2 = nn.Conv2d(24, 48, kernel_size=5, stride=1, padding=2)
        self.max2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.bn2 = nn.BatchNorm2d(48)
        self.conv3 = nn.Conv2d(48, 64, kernel_size=5, stride=1, padding=2)
        self.max3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.bn3 = nn.BatchNorm2d(64)
        self.linear4 = nn.Linear(64*3*3,256)
        self.bn4 = nn.BatchNorm1d(256)
        self.linear5 = nn.Linear(256,10)
        
        self.number_params = 18
        
        self.mu = [None] * self.number_params
        
        self.copy_snapshot()
 
    def forward(self, x):
        #print("--------FORWARD---------")
        x = torch.relu(self.conv1(x))
        #print("conv1 :" , x.shape)
        x = self.max1(x)
        x = self.bn1(x)
        #print("max1 :" , x.shape)
        x = torch.relu(self.conv2(x))
        #print("conv2 :" , x.shape)
        x = self.max2(x)
        x = self.bn2(x)
        #print("max2 :" , x.shape)
        x = torch.relu(self.conv3(x))
        #print("conv3 :" , x.shape)
        x = self.max3(x)
        x = self.bn3(x)
        #print("max3 :" , x.shape)
        x = self.linear4(torch.relu(x.reshape(x.shape[0],-1)))
        #print("linear4 :" , x.shape)
        x = self.bn4(x)
        x = self.linear5(torch.softmax(x,1))
        #print("linear5 :" , x.shape)
        return x
    
    def forward_snapshot(self, x):
        #print("--------FORWARD---------")
        x = torch.relu(self.conv1_snapshot(x))
        #print("conv1 :" , x.shape)
        x = self.max1(x)
        x = self.bn1_snapshot(x)
        #print("max1 :" , x.shape)
        x = torch.relu(self.conv2_snapshot(x))
        #print("conv2 :" , x.shape)
        x = self.max2(x)
        x = self.bn2_snapshot(x)
        #print("max2 :" , x.shape)
        x = torch.relu(self.conv3_snapshot(x))
        #print("conv3 :" , x.shape)
        x = self.max3(x)
        x = self.bn3_snapshot(x)
        #print("max3 :" , x.shape)
        x = self.linear4_snapshot(torch.relu(x.reshape(x.shape[0],-1)))
        #print("linear4 :" , x.shape)
        x = self.bn4_snapshot(x)
        x = self.linear5_snapshot(torch.softmax(x,1))
        #print("linear5 :" , x.shape)
        return x
    
    def copy_snapshot(self):
        self.conv1_snapshot = copy.deepcopy(self.conv1)
        self.bn1_snapshot = copy.deepcopy(self.bn1)
        self.conv2_snapshot = copy.deepcopy(self.conv2)
        self.bn2_snapshot = copy.deepcopy(self.bn2)
        self.conv3_snapshot = copy.deepcopy(self.conv3)
        self.bn3_snapshot = copy.deepcopy(self.bn3)
        self.linear4_snapshot = copy.deepcopy(self.linear4)
        self.bn4_snapshot = copy.deepcopy(self.bn4)
        self.linear5_snapshot = copy.deepcopy(self.linear5)

        i=0
        for param in self.parameters():
            if (i < self.number_params) :
                self.mu[i] = torch.zeros(param.shape)
                i+=1

    def update_SGD(self, lr=1):
        params = list(self.parameters())
        for i in range(self.number_params // 2,self.number_params):
            params[i].data.copy_(params[i].data - lr * params[i].grad.data)
        
    def update_SVRG(self,lr):
        params = list(self.parameters())
        k = len(params) // 2
        for i in range(k):
            params[i].data = params[i].data - lr * (params[i].grad.data - params[i+k].grad.data + self.mu[i].data)       

    def update_mu(self,batch_size):
        params = list(self.parameters())
        for i in range(len(self.mu)):
            self.mu[i].data = self.mu[i].data + params[i+self.number_params].grad.data / batch_size
        
        
        
                     
                            
    def fit(self,optimizer,epochs,batch_size,lr):
        n = X_train.shape[0]
        model.train()
        
        #Warm start
        for _ in range(1):
            for i in range((n - 1) // batch_size + 1):
                optimizer.zero_grad()
                X = X_train[ i * batch_size : (i+1) * batch_size ]
                Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                pred = self.forward( X )
                loss = loss_func( pred , Y )
                loss.backward()
                self.update_SGD()
                
            print("0\t",loss.item())

        self.copy_snapshot()
    
        for epoch in range(epochs):
            model.train()
            #update mu
            for i in range((n - 1) // batch_size + 1):
                optimizer.zero_grad()
                X = X_train[ i * batch_size : (i+1) * batch_size ]
                Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                pred = self.forward_snapshot( X )
                loss_snapshot = loss_func( pred , Y )
                loss_snapshot.backward()
                self.update_mu(batch_size)
            
            
            for m in range(5):
                for i in range((n - 1) // batch_size + 1):
                    optimizer.zero_grad()
                    
                    #Snapshot gradient computation
                    X = X_train[ i * batch_size : (i+1) * batch_size ]
                    Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                    pred = self.forward_snapshot( X )
                    loss_snapshot = loss_func( pred , Y )
                    loss_snapshot.backward()
                    
                    #'real' gradient computation
                    X = X_train[ i * batch_size : (i+1) * batch_size ]
                    Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                    pred = self.forward( X )
                    loss = loss_func( pred , Y )
                    loss.backward()
                    self.update_SVRG(lr)
                    
           
                print(epoch * 5 + m+1,"\t",loss.item())
            
            self.copy_snapshot()
            with torch.no_grad():
                model.eval()
                print("Test set \t", round(accuracy( model.forward(X_test) , Y_test).item(),3))

In [107]:
class NeuralNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 24, kernel_size=5, stride=1, padding=2)
        self.max1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(24, 48, kernel_size=5, stride=1, padding=2)
        self.max2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(48, 64, kernel_size=5, stride=1, padding=2)
        self.max3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.linear4 = nn.Linear(64*3*3,256)
        self.linear5 = nn.Linear(256,10)
        
        self.number_params = 10
        
        self.mu = [None] * self.number_params
        
        self.copy_snapshot()
 
    def forward(self, x):
        #print("--------FORWARD---------")
        x = torch.relu(self.conv1(x))
        #print("conv1 :" , x.shape)
        x = self.max1(x)
        #print("max1 :" , x.shape)
        x = torch.relu(self.conv2(x))
        #print("conv2 :" , x.shape)
        x = self.max2(x)
        #print("max2 :" , x.shape)
        x = torch.relu(self.conv3(x))
        #print("conv3 :" , x.shape)
        x = self.max3(x)
        #print("max3 :" , x.shape)
        x = self.linear4(torch.relu(x.reshape(x.shape[0],-1)))
        #print("linear4 :" , x.shape)
        x = self.linear5(torch.softmax(x,1))
        #print("linear5 :" , x.shape)
        return x
    
    def forward_snapshot(self, x):
        #print("--------FORWARD---------")
        x = torch.relu(self.conv1_snapshot(x))
        #print("conv1 :" , x.shape)
        x = self.max1(x)
        #print("max1 :" , x.shape)
        x = torch.relu(self.conv2_snapshot(x))
        #print("conv2 :" , x.shape)
        x = self.max2(x)
        #print("max2 :" , x.shape)
        x = torch.relu(self.conv3_snapshot(x))
        #print("conv3 :" , x.shape)
        x = self.max3(x)
        #print("max3 :" , x.shape)
        x = self.linear4_snapshot(torch.relu(x.reshape(x.shape[0],-1)))
        #print("linear4 :" , x.shape)
        x = self.linear5_snapshot(torch.softmax(x,1))
        #print("linear5 :" , x.shape)
        return x
    
    def copy_snapshot(self):
        self.conv1_snapshot = copy.deepcopy(self.conv1)
        self.conv2_snapshot = copy.deepcopy(self.conv2)
        self.conv3_snapshot = copy.deepcopy(self.conv3)
        self.linear4_snapshot = copy.deepcopy(self.linear4)
        self.linear5_snapshot = copy.deepcopy(self.linear5)

        i=0
        for param in self.parameters():
            if (i < self.number_params) :
                self.mu[i] = torch.zeros(param.shape)
                i+=1

    def update_SGD(self, lr=1):
        params = list(self.parameters())
        for i in range(self.number_params):
            params[i].data.copy_(params[i].data - lr * params[i].grad.data)

    def update_SVRG(self,lr):
        params = list(self.parameters())
        for i in range(self.number_params):
            params[i].data.copy_(params[i].data - lr * (params[i].grad.data - params[i+self.number_params].grad.data + self.mu[i].data))       

    def update_mu(self,batch_size):
        params = list(self.parameters())
        for i in range(len(self.mu)):
            self.mu[i].data.copy_(self.mu[i].data + params[i+self.number_params].grad.data / batch_size)
                            
    def fit_SVRG(self,optimizer,epochs,batch_size,lr):
        n = X_train.shape[0]
        model.train()
        
        #Warm start
        for _ in range(1):
            for i in range((n - 1) // batch_size + 1):
                optimizer.zero_grad()
                X = X_train[ i * batch_size : (i+1) * batch_size ]
                Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                pred = self.forward( X )
                loss = loss_func( pred , Y )
                loss.backward()
                self.update_SGD()
                
            print("0\t",loss.item())

        self.copy_snapshot()
    
        for epoch in range(epochs):
            model.train()
            #update mu
            for i in range((n - 1) // batch_size + 1):
                optimizer.zero_grad()
                X = X_train[ i * batch_size : (i+1) * batch_size ]
                Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                pred = self.forward_snapshot( X )
                loss_snapshot = loss_func( pred , Y )
                loss_snapshot.backward()
                self.update_mu(batch_size)
            
            
            for m in range(5):
                for i in range((n - 1) // batch_size + 1):
                    optimizer.zero_grad()
                    
                    #Snapshot gradient computation
                    X = X_train[ i * batch_size : (i+1) * batch_size ]
                    Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                    pred = self.forward_snapshot( X )
                    loss_snapshot = loss_func( pred , Y )
                    loss_snapshot.backward()
                    
                    #'real' gradient computation
                    X = X_train[ i * batch_size : (i+1) * batch_size ]
                    Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                    pred = self.forward( X )
                    loss = loss_func( pred , Y )
                    loss.backward()
                    self.update_SVRG(lr)
                    
           
                print(epoch * 5 + m+1,"\t",loss.item())
            
            self.copy_snapshot()
            with torch.no_grad():
                model.eval()
                print("Test set \t", round(accuracy( model.forward(X_test) , Y_test).item(),3))

    def fit_SGD(self,optimizer,epochs,batch_size,lr):
        n = X_train.shape[0]
       
        for epoch in range(epochs):
            model.train()
            for i in range((n - 1) // batch_size + 1):
                optimizer.zero_grad()
                X = X_train[ i * batch_size : (i+1) * batch_size ]
                Y = Y_train[ i * batch_size : (i+1) * batch_size ]
                pred = self.forward( X )
                loss = loss_func( pred , Y )
                loss.backward()
                self.update_SGD()
                
            print(epoch,"\t",loss.item())

            with torch.no_grad():
                model.eval()
                print("Test set \t", round(accuracy( model.forward(X_test) , Y_test).item(),3))

In [108]:
model = NeuralNet()
opt = optim.SGD(model.parameters(), lr=1)
epochs = 20
batch_size = 6000
learning_rate = 0.1

model.fit_SGD(opt,epochs,batch_size,learning_rate)

0 	 2.3019421100616455
Test set 	 0.113
1 	 2.3017959594726562
Test set 	 0.113
2 	 2.301830768585205
Test set 	 0.113
3 	 2.3018176555633545
Test set 	 0.113
4 	 2.3017797470092773
Test set 	 0.113
5 	 2.3017780780792236
Test set 	 0.113
6 	 2.3018150329589844
Test set 	 0.113
7 	 2.3018271923065186
Test set 	 0.113
8 	 2.3018343448638916
Test set 	 0.113
9 	 2.301832437515259
Test set 	 0.113
10 	 2.3018290996551514
Test set 	 0.113
11 	 2.301826238632202
Test set 	 0.113
12 	 2.301823377609253
Test set 	 0.113
13 	 2.3018200397491455
Test set 	 0.113
14 	 2.3018171787261963
Test set 	 0.113
15 	 2.3018147945404053
Test set 	 0.113
16 	 2.301811933517456
Test set 	 0.113
17 	 2.3018107414245605
Test set 	 0.113
18 	 2.3018088340759277
Test set 	 0.113
19 	 2.3018076419830322
Test set 	 0.113


###### Load, Preprocess and predict test set from Kaggle

In [104]:
#Load data from CSV
test = pd.read_csv('../data/MNIST/test.csv')
test_tensor = torch.tensor(test.values)

#Preprocess
test_tensor = (test_tensor.to(dtype=torch.float32) / test_tensor.max().to(dtype=torch.float32))
test_tensor = test_tensor.reshape(test_tensor.shape[0],1,28,28)

#Predict
test_tensor = model.forward(test_tensor)
test_tensor = test_tensor.argmax(1)

##### Save predictions to a csv file

In [19]:
#Convert to a numpy array
arr = test_tensor.numpy()

# write CSV
np.savetxt('../data/MNIST/predictions.csv', arr)

NameError: name 'test_tensor' is not defined

In [78]:
for i in range(len(model.mu)):
    print(model.mu[i].shape)

torch.Size([24, 1, 5, 5])
torch.Size([24])
torch.Size([48, 24, 5, 5])
torch.Size([48])
torch.Size([64, 48, 5, 5])
torch.Size([64])
torch.Size([256, 576])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])
torch.Size([24, 1, 5, 5])
torch.Size([24])
torch.Size([48, 24, 5, 5])
torch.Size([48])
torch.Size([64, 48, 5, 5])
torch.Size([64])
torch.Size([256, 576])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [79]:
for param in model.parameters():
    print(param.shape)

torch.Size([24, 1, 5, 5])
torch.Size([24])
torch.Size([48, 24, 5, 5])
torch.Size([48])
torch.Size([64, 48, 5, 5])
torch.Size([64])
torch.Size([256, 576])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])
torch.Size([24, 1, 5, 5])
torch.Size([24])
torch.Size([48, 24, 5, 5])
torch.Size([48])
torch.Size([64, 48, 5, 5])
torch.Size([64])
torch.Size([256, 576])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [83]:
params = list(model.parameters())
params[10].grad.data     


AttributeError: 'NoneType' object has no attribute 'data'