In [1]:
import torch
import torch.nn as layer
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
from trainer import *

from torchvision.transforms import v2

import snntorch as snn
from snntorch import surrogate, functional, BatchNormTT2d

from line_profiler import LineProfiler, profile


In [2]:


dtype=torch.float
print("VGG-block CNN SNN Trained on cifar10")

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Training Parameters
batch_size=256
data_path='./tmp/data/cifar10/'
num_classes = 10  # cifar has 10 output classes

# Define a transform
transform1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomResizedCrop(size=(32, 32),scale=(0.8,1), antialias=True),
            ])

transform2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,)),
            ])

cifar_train = datasets.CIFAR10(data_path, train=True, download=True,transform=transform1)
cifar_test = datasets.CIFAR10(data_path, train=False, download=True,transform=transform2)

train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True,drop_last=False,pin_memory=True,num_workers=1)
test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=True,drop_last=False,pin_memory=True,num_workers=1)

print("Train batches:",len(train_loader))
print("Test batches:",len(test_loader))

VGG-block CNN SNN Trained on cifar10
Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Train batches: 196
Test batches: 40


Useful to know what a convolutional layers output dimension is given by $n_{out}=\frac{n_{in}+2p-k}{s}+1$, 
with default padding=0 stride=1 $n_{out}=n_{in}-k+1$.
For max pool with defaults, $\lfloor\frac{n}{2}\rfloor$.

In [3]:


################ DVS Gesture Model #############################

# layer parameters

lr=1e-4

spike_grad1 = surrogate.atan()

num_steps = 11

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.loss = functional.ce_count_loss()
        self.accuracy_metric = functional.accuracy_rate

        #initialise neuron connections
        self.layers = nn.ModuleList([
            nn.Conv2d(3,32,3,padding=1),
            BatchNormTT2d(32,num_steps),
            nn.Conv2d(32,32,3,padding=1),
            BatchNormTT2d(32,num_steps),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1),
            BatchNormTT2d(64,num_steps),
            nn.Conv2d(64,64,3,padding=1),
            BatchNormTT2d(64,num_steps),
            nn.MaxPool2d(2),
            nn.Linear(4096,256),
            nn.Dropout(0.5),
            nn.Linear(256,11),      
        ])

        # initialize neurons
        self.neurons = nn.ModuleList(
            [snn.Leaky(beta=0.95,threshold=1,spike_grad=spike_grad1)] * len(self.layers)
        )

        self.to(device) #yes, this is needed twice

        #pytorch creates the tensors to represent the network layout and weights for each layer; snntorch provides the model that operates on the entire tensor (at each layer).

    def forward(self,x): #x is input data
        #events should be treated as spikes i.e. already encoded

        # Initialize hidden states
        mem = []
        for i in range(len(self.layers)):
            mem.append(self.neurons[i].init_leaky())
        
        # record spike outputs
        spk_rec = []

        x = x.unsqueeze(0).repeat(num_steps, 1, 1, 1, 1)
        #x_spk = spikegen.rate(x,num_steps=num_steps) 

        for step in range(num_steps):
            #form inputs
            spk_i = x[step]


            for i in range(len(self.layers)):

                if(i==18):
                    spk_i = self.layers[i](spk_i)
                    continue
                elif i in {1,3,6,8}:
                    spk_i = self.layers[i][step](spk_i)
                    continue

                if(i==10): #need to flatten from pooling to Linear
                    spk_i = torch.flatten(spk_i,start_dim=1)

                cur_i = self.layers[i](spk_i)
                spk_i, mem[i] = self.neurons[i](cur_i,mem[i])
                        

            spk_rec.append(spk_i)
            


        return torch.stack(spk_rec, dim=0)
    
###################################################################################
    






In [4]:
'''
net = Net().to(device)
def p():
    optimiser = torch.optim.Adam(net.parameters(),lr=lr,weight_decay=0.001)
    net.train()
    d,t = next(iter(train_loader))
    d = d.to(device)
    t = t.to(device)
    logits = net(d)
    optimiser.zero_grad()
    loss = net.loss(logits,t)
    loss.backward()
    optimiser.step()


profiler = LineProfiler()
profiler.add_function(p)
profiler.add_function(net.forward)

profiler.run('p()')
profiler.print_stats()
'''

"\nnet = Net().to(device)\ndef p():\n    optimiser = torch.optim.Adam(net.parameters(),lr=lr,weight_decay=0.001)\n    net.train()\n    d,t = next(iter(train_loader))\n    d = d.to(device)\n    t = t.to(device)\n    logits = net(d)\n    optimiser.zero_grad()\n    loss = net.loss(logits,t)\n    loss.backward()\n    optimiser.step()\n\n\nprofiler = LineProfiler()\nprofiler.add_function(p)\nprofiler.add_function(net.forward)\n\nprofiler.run('p()')\nprofiler.print_stats()\n"

In [5]:
model_path = "./models/SNN_VGG_CIFAR10.pt"
net = Net()
optimiser = torch.optim.Adam(net.parameters(),lr=lr,weight_decay=0)
gen_reset() #reset the PRNG generators for the random samplers so we consistently get the same sequence of samples for each experiment run
net = trainer(net,train_loader=train_loader,valid_loader=test_loader,model_path=model_path,optimiser=optimiser,epochs=10,iterations=None,valid_after=250,valid_iterations=1,deepr=False,device=device)
gen_reset()
a = test_stats(net,test_loader=test_loader,iterations=None,device=device)

Training progress::   0%|          | 0/1960 [00:00<?, ?it/s]

Iteration: 0
Training loss: 2.41
Validation loss: 2.40
Validation accuracy: 8.59%
Training accuracy: 11.72%
----------------
Iteration: 250
Training loss: 1.83
Validation loss: 1.67
Validation accuracy: 38.28%
Training accuracy: 32.81%
----------------
Iteration: 500
Training loss: 1.54
Validation loss: 1.44
Validation accuracy: 45.70%
Training accuracy: 41.02%
----------------
Iteration: 750
Training loss: 1.49
Validation loss: 1.39
Validation accuracy: 48.44%
Training accuracy: 44.14%
----------------
Iteration: 1000
Training loss: 1.47
Validation loss: 1.44
Validation accuracy: 47.27%
Training accuracy: 41.41%
----------------


KeyboardInterrupt: 