# Packages and Global variables

In [2]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import PercentFormatter
import seaborn as sns
import numpy as np
import itertools
from collections import defaultdict
import time
from torchsummary import summary

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler
from torchvision import datasets, transforms, utils
import snntorch as snn
from snntorch import surrogate
from snntorch import spikegen
import snntorch.spikeplot as splt
import math

torch.manual_seed(42)
np.random.seed(42)

#print(torch.cuda.is_available())

In [3]:
data_path = '\\Users\\liamh\\OneDrive - University of Strathclyde\\University'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training Parameters
batch_size=128

# Network Architecture
num_hidden = 350
num_outputs = 10
num_steps = 25
dropout = 0.25

# Temporal Dynamics
time_step = 1e-3
tau_mem = 2e-2
beta = float(np.exp(-time_step/tau_mem))

# Loss Function
loss_fn = nn.MSELoss()

# Functions

In [4]:
def load_in_data(res, ratio = 1):
    transform = transforms.Compose([
        transforms.Resize((res, res)), #Resize images to 28*28
        transforms.Grayscale(), # Make sure image is grayscale
        transforms.ToTensor()]) # change each image array to a tensor which automatically scales inputs to [0,1]

    mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # Download training set and apply transformations. 
    mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) # same for test set

    train_len = int(len(mnist_train)/ratio)
    dummy_len = len(mnist_train) - train_len
    train_dataset, _ = random_split(mnist_train, (train_len, dummy_len), generator=torch.Generator().manual_seed(42))
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # Load the data into the DataLoader so it's passed through the model, shuffled in batches. 
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
    
    return train_loader, test_loader

def output_formula(input_size, filter_size, padding, stride):
    formula = math.floor(((((input_size - filter_size + 2*padding)/stride) + 1)))
    
    return formula 

def all_output_sizes(res, conv_filter = 3, conv_padding = 1, conv_stride = 1, mp_filter = 3, mp_padding = 0, mp_stride = 2):
    
    conv1 = output_formula(res, conv_filter, conv_padding, conv_stride)   # Output size from applying conv1 to input 
    mp1 = output_formula(conv1, mp_filter, mp_padding, mp_stride)         # Output size from applying max pooling 1 to conv1 
    
    conv2 = output_formula(mp1, conv_filter, conv_padding, conv_stride)   # Output size from applying conv2 to max pooling 1
    conv3 = output_formula(conv2, conv_filter, conv_padding, conv_stride) # Output size from applying conv3 to conv 2
    mp2 = output_formula(conv3, mp_filter, mp_padding, mp_stride)         # Output size from applying max pooling 2 to conv3
    
    conv4 = output_formula(mp2, conv_filter, conv_padding, conv_stride)   # Output size from applying conv 4 to max pooling 2
    conv5 = output_formula(conv4, conv_filter, conv_padding, conv_stride) # Output size from applying conv5 to conv 4
    mp3 = output_formula(conv5, mp_filter, mp_padding, mp_stride)         # Output size from applying max pooling 3 to conv 5
    
    outputs_I_need = [mp1, conv2, mp2, conv4, mp3]
    
    return outputs_I_need

def plot_training_history(history, res, loss_upper = 1.05, acc_lower = -0.05, acc_higher = 105):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    
    
    ax1.plot(history['avg_train_loss'], label='train loss',marker = 'o')
    ax1.plot(history['avg_valid_loss'], label='validation loss',marker = 'o')

    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax1.set_ylim([-0.05, loss_upper])
    ax1.legend()
    ax1.set_ylabel('Loss',fontsize = 16)
    ax1.set_xlabel('Epoch',fontsize = 16)
    
    ax2.plot(history['train_accuracy'], label='train accuracy',marker = 'o')
    ax2.plot(history['valid_accuracy'], label='validation accuracy',marker = 'o')

    ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2.set_ylim([acc_lower, acc_higher])

    ax2.legend()

    ax2.set_ylabel('Accuracy',fontsize = 16)
    ax2.yaxis.set_major_formatter(PercentFormatter(100))
    ax2.set_xlabel('Epoch',fontsize = 16)
    fig.suptitle(f'Training history ({res}*{res})',fontsize = 20)
    plt.show()

def store_best_results(history):
    # Want to take the last entry from each output(best results) and store them all in a Dataframe
    placeholder = []
    placeholder.append(history['avg_train_loss'][-1])
    placeholder.append(history['train_accuracy'][-1])
    placeholder.append(history['avg_valid_loss'][-1])
    placeholder.append(history['valid_accuracy'][-1])
    
    return placeholder

def put_results_in_df(output):
    df = pd.DataFrame()
    df['avg_train_loss'] = output['avg_train_loss']
    df['train_accuracy'] = output['train_accuracy']
    df['avg_valid_loss'] = output['avg_valid_loss']
    df['valid_accuracy'] = output['valid_accuracy']
    
    return df

# Spiking CNN and training Loop

In [6]:
# Define a different network
class CSNN(nn.Module):
    def __init__(self,spike_grad):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = conv_kernel_size, padding = conv_padding_size) # Do I change channels to a variable incase I end up with RGB images? ## Padding = 0 as all information is at the centre of image (may change if lower resolution)
        self.mp1 = nn.MaxPool2d(kernel_size = mp_kernel_size, stride = mp_stride_length)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = conv_kernel_size, padding = conv_padding_size)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.conv3 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = conv_kernel_size, padding = conv_padding_size)
        self.mp2 = nn.MaxPool2d(kernel_size = mp_kernel_size, stride = mp_stride_length)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = conv_kernel_size, padding = conv_padding_size)
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.conv5 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = conv_kernel_size, padding = conv_padding_size)
        self.maxpool = nn.MaxPool2d(kernel_size = mp_kernel_size, stride = mp_stride_length)
        self.lif5 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.drop1 = nn.Dropout(dropout)
        
        self.fc1 = nn.Linear(256 * output_sizes[-1] * output_sizes[-1], num_hidden)
        self.lif6 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.drop2 = nn.Dropout(dropout)
        
        self.fc2 = nn.Linear(num_hidden,num_hidden)
        self.lif7 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.fc3 = nn.Linear(num_hidden, num_outputs) 
        self.lif8 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        
    def forward(self, x): 
        
        spk1, mem1 = self.lif1.init_leaky(batch_size, 64, output_sizes[0], output_sizes[0])
        spk2, mem2 = self.lif2.init_leaky(batch_size, 128, output_sizes[1], output_sizes[1])
        spk3, mem3 = self.lif3.init_leaky(batch_size, 128, output_sizes[2], output_sizes[2])
        spk4, mem4 = self.lif4.init_leaky(batch_size, 256, output_sizes[3], output_sizes[3])
        
        spk5, mem5 = self.lif5.init_leaky(batch_size, 256, output_sizes[-1], output_sizes[-1])
        
        spk6, mem6 = self.lif6.init_leaky(batch_size, num_hidden)
        spk7, mem7 = self.lif7.init_leaky(batch_size, num_hidden)
        
        spk8, mem8 = self.lif8.init_leaky(batch_size, num_outputs)
        
        spk8_rec = []
        mem8_rec = []
        
        for step in range(num_steps):
            cur1 = self.mp1(self.conv1(x[step]))
            spk1, mem1 = self.lif1(cur1, mem1)
            
            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.conv3(spk2)
            cur3 = self.mp2(cur3)
            spk3, mem3 = self.lif3(cur3, mem3)
            
            cur4 = self.conv4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)
            cur5 = self.conv5(spk4)
            cur5 = self.maxpool(cur5)
            spk5, mem5 = self.lif5(cur5, mem5)
            
            spk5 = self.drop1(spk5)
            cur6 = self.fc1(spk5.view(batch_size, -1))
            spk6, mem6 = self.lif6(cur6, mem6)
            
            spk6 = self.drop2(spk6)
            cur7 = self.fc2(spk6)
            spk7, mem7 = self.lif7(cur7, mem7)
            
            cur8 = self.fc3(spk7)
            spk8, mem8 = self.lif8(cur8, mem8)
            
            spk8_rec.append(spk8)
            mem8_rec.append(mem8)
            
        return torch.stack(spk8_rec, dim=0), torch.stack(mem8_rec, dim=0)     

In [7]:
def train_rate_spiking_mse_model(resolution, train_loader, valid_loader, model, epochs ,device = device, verbose = True):
    start_time = time.time()
    print('Starting Training')
    history = defaultdict(list)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) # Just an Adam Optimiser
    
    # Training variables
    train_size = len(train_loader.dataset)
    train_num_batches = len(train_loader)
    
    # validation variables
    valid_size = len(valid_loader.dataset)
    valid_num_batches = len(valid_loader)
    
    
    for t in range(epochs):
        
        avg_train_loss = 0
        correct = 0
        avg_valid_loss, valid_correct = 0, 0
        model.train()
        for batch, (data_it, targets_it) in enumerate(train_loader):
            data_it = data_it.to(device)
            targets_it = targets_it.to(device)

            optimizer.zero_grad()
            
            # Compute prediction and loss
            spike_data = spikegen.rate(data_it, num_steps=num_steps, gain=1, offset=0)
            spk_targets_it = torch.clamp(spikegen.to_one_hot(targets_it, 10) * 1.05, min=0.05)
            
            spk_rec, mem_rec = model(spike_data.view(num_steps,batch_size, 1,resolution,resolution)) 
            
            # Sum loss over time steps: BPTT
            loss = torch.zeros((1), dtype=dtype, device=device)   # creates a 1D tensor to store total loss over time. 
            for step in range(num_steps):
                loss += loss_fn(mem_rec[step], spk_targets_it) # Loss at each time step is added to give total loss.

            avg_train_loss += loss
            
            _, predicted = spk_rec.sum(dim=0).max(1) 
            correct += (predicted == targets_it).type(torch.float).sum().item()

            # Backpropagation
            loss.backward()
            optimizer.step()

        avg_train_loss /= train_num_batches
        accuracy = correct / train_size * 100      
        history['avg_train_loss'].append(avg_train_loss.item())
        history['train_accuracy'].append(accuracy)
        
        if verbose == True: 
            print(f"Epoch {t+1} of {epochs}")
            print('-' * 15)
            print(f"Training Results, Epoch {t+1}:\n Accuracy: {(accuracy):>0.1f}%, Avg loss: {avg_train_loss.item():>8f} \n")

              ###################### VALIDATION LOOP ##############################
        model.eval()
        with torch.no_grad():
            for valid_data_it, valid_targets_it in valid_loader:
                valid_data_it = valid_data_it.to(device)
                valid_targets_it = valid_targets_it.to(device)
                
                valid_spike_data = spikegen.rate(valid_data_it, num_steps=num_steps, gain=1, offset=0)
                valid_spk_targets_it = torch.clamp(spikegen.to_one_hot(targets_it, 10) * 1.05, min=0.05)

                valid_spk_rec, valid_mem_rec = model(valid_spike_data.view(num_steps,batch_size, 1, resolution, resolution)) 
                
                valid_loss = torch.zeros((1),dtype = dtype, device = device)    
                for step in range(num_steps):
                    valid_loss += loss_fn(valid_mem_rec[step], valid_spk_targets_it)
                
                avg_valid_loss += valid_loss
                
                
                _, valid_predicted = valid_spk_rec.sum(dim=0).max(1)
                valid_correct += (valid_predicted == valid_targets_it).type(torch.float).sum().item()
        

        avg_valid_loss /= valid_num_batches
        valid_accuracy = valid_correct / valid_size * 100
              
        history['avg_valid_loss'].append(avg_valid_loss.item())
        history['valid_accuracy'].append(valid_accuracy)
        
        if verbose == True: 
            print(f"Validation Results, Epoch {t+1}: \n Accuracy: {(valid_accuracy):>0.1f}%, Avg loss: {avg_valid_loss.item():>8f} \n")


    print("Done!")
    print(f"Final Train Accuracy: {(accuracy):>0.1f}%, and Avg loss: {avg_train_loss.item():>8f} \n")
    print(f"Final Validation Accuracy: {(valid_accuracy):>0.1f}%, and Avg loss: {avg_valid_loss.item():>8f} \n")
    current_time = time.time()
    total = current_time - start_time
    print(f'Training time: {round(total/60,2)} minutes')
    return history

def get_rate_mse_snn_results(resolution,epochs = 20,ratio = 1, slope = 25, verbose = True):
    spike_grad = surrogate.fast_sigmoid(slope = slope)
    train, valid = load_in_data(resolution, ratio)
    model = CSNN(spike_grad).to(device)

    output = train_rate_spiking_mse_model(resolution, train,valid,model,epochs, verbose = verbose)
    
    return output


# Training models

## 56x56

* Note: These models were run on Google Colab due to GPU constraints. 

In [8]:
conv_kernel_size = 3
conv_stride_length = 1 
conv_padding_size = 1
mp_kernel_size = 2 
mp_stride_length = 2 
mp_padding_size = 0

In [None]:
output_sizes = all_output_sizes(56, conv_kernel_size, conv_padding_size, conv_stride_length, mp_kernel_size, mp_padding_size, mp_stride_length)

In [None]:
batch_size = 128
output_56_r1 = get_rate_mse_snn_results(resolution = 56, ratio = 1, epochs = 75, slope = 5, verbose = False)

In [None]:
batch_size = 128
output_56_r4 = get_rate_mse_snn_results(resolution = 56, ratio = 4, epochs = 75, slope = 5, verbose = False)

In [None]:
batch_size = 32
output_56_r10 = get_rate_mse_snn_results(resolution = 56, ratio = 10, epochs = 75, slope = 5, verbose = False)

In [None]:
batch_size = 32
output_56_r100 = get_rate_mse_snn_results(resolution = 56, ratio = 100, epochs = 75, slope = 5, verbose = False)

## 28x28

In [9]:
output_sizes = all_output_sizes(28, conv_kernel_size, conv_padding_size, conv_stride_length, mp_kernel_size, mp_padding_size, mp_stride_length)

In [10]:
batch_size = 128
output_28_r1 = get_rate_mse_snn_results(resolution = 28, ratio = 1, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 99.8%, and Avg loss: 0.850419 

Final Validation Accuracy: 99.3%, and Avg loss: 3.886898 

Training time: 436.13 minutes


In [11]:
batch_size = 128
output_28_r4 = get_rate_mse_snn_results(resolution = 28, ratio = 4, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 99.8%, and Avg loss: 0.846046 

Final Validation Accuracy: 99.0%, and Avg loss: 3.835765 

Training time: 157.47 minutes


In [12]:
batch_size = 32
output_28_r10 = get_rate_mse_snn_results(resolution = 28, ratio = 10, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 99.7%, and Avg loss: 0.870913 

Final Validation Accuracy: 98.7%, and Avg loss: 3.864564 

Training time: 89.04 minutes


In [13]:
batch_size = 32
output_28_r100 = get_rate_mse_snn_results(resolution = 28, ratio = 100, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 96.0%, and Avg loss: 1.014356 

Final Validation Accuracy: 96.4%, and Avg loss: 3.751583 

Training time: 42.8 minutes


# 14x14

In [15]:
conv_kernel_size = 3
conv_stride_length = 1 
conv_padding_size = 1
mp_kernel_size = 3
mp_stride_length = 1
mp_padding_size = 0

In [16]:
output_sizes = all_output_sizes(14, conv_kernel_size, conv_padding_size, conv_stride_length, mp_kernel_size, mp_padding_size, mp_stride_length)

In [17]:
batch_size = 128
output_14_r1 = get_rate_mse_snn_results(resolution = 14, ratio = 1, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 98.2%, and Avg loss: 1.008348 

Final Validation Accuracy: 98.1%, and Avg loss: 3.716639 

Training time: 670.86 minutes


In [18]:
batch_size = 128
output_14_r4 = get_rate_mse_snn_results(resolution = 14, ratio = 4, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 98.8%, and Avg loss: 0.987170 

Final Validation Accuracy: 97.7%, and Avg loss: 3.735347 

Training time: 159.39 minutes


In [19]:
batch_size = 32
output_14_r10 = get_rate_mse_snn_results(resolution = 14, ratio = 10, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 98.2%, and Avg loss: 1.019498 

Final Validation Accuracy: 97.4%, and Avg loss: 3.701958 

Training time: 111.47 minutes


In [20]:
batch_size = 32
output_14_r100 = get_rate_mse_snn_results(resolution = 14, ratio = 100, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 94.8%, and Avg loss: 1.151304 

Final Validation Accuracy: 92.6%, and Avg loss: 3.596613 

Training time: 49.8 minutes


## 7x7

In [22]:
conv_kernel_size = 3
conv_stride_length = 1 
conv_padding_size = 1
mp_kernel_size = 2 
mp_stride_length = 1
mp_padding_size = 0

In [23]:
output_sizes = all_output_sizes(7, conv_kernel_size, conv_padding_size, conv_stride_length, mp_kernel_size, mp_padding_size, mp_stride_length)

In [24]:
output_sizes

[6, 6, 5, 5, 4]

In [25]:
batch_size = 128
output_7_r1 = get_rate_mse_snn_results(resolution = 7, ratio = 1, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 85.7%, and Avg loss: 1.518860 

Final Validation Accuracy: 86.7%, and Avg loss: 3.206643 

Training time: 236.97 minutes


In [27]:
batch_size = 128
output_7_r4 = get_rate_mse_snn_results(resolution = 7, ratio = 4, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 84.5%, and Avg loss: 1.475704 

Final Validation Accuracy: 83.9%, and Avg loss: 3.258993 

Training time: 104.75 minutes


In [28]:
batch_size = 32
output_7_r10 = get_rate_mse_snn_results(resolution = 7, ratio = 10, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 81.2%, and Avg loss: 1.507865 

Final Validation Accuracy: 81.0%, and Avg loss: 3.231734 

Training time: 62.57 minutes


In [29]:
batch_size = 32
output_7_r100 = get_rate_mse_snn_results(resolution = 7, ratio = 100, epochs = 75, slope = 5, verbose = False)

Starting Training
Done!
Final Train Accuracy: 70.7%, and Avg loss: 1.716163 

Final Validation Accuracy: 69.6%, and Avg loss: 3.029347 

Training time: 23.58 minutes


# Saving Results

In [33]:
output_res = ['output_28', 'output_14','output_7']
output_ratio = ['_r1','_r4','_r10','_r100']
index = ['avg_train_loss', 'train_accuracy', 'avg_valid_loss', 'valid_accuracy']

In [34]:
all_columns = []
all_models = []
for name in output_res:
    for ratio in output_ratio: 
        model_name = name + ratio
        all_models.append(model_name)
        for indice in index: 
            column_name = name + ratio + '_' + indice
            all_columns.append(column_name)

In [35]:
df = pd.DataFrame()
for entry in all_models:
    for key in index:
        string = entry + '_' + key
        
        df[string] = locals()[entry][key]

In [37]:
df.to_csv('all_scnn_training_histories_updated.csv')