<a href="https://colab.research.google.com/github/andrewsiyoon/spiking-seRNN/blob/main/RSNN_no_modifications.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Imports -----

import torch, torch.nn as nn
import snntorch as snn
import random

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#Random seeds -----

random.seed(211)
torch.manual_seed(211)

In [None]:
#MNIST -----

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

#Define variables
batch_size = 128
data_path = '/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#Define transformations
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

#Download MNIST (without permissions)
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

#Define training and test sets
mnist_train = datasets.MNIST(root = './', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root = './', train=False, download=True, transform=transform)

#Create DataLoaders
trainloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last = True) #drop_last to remove last (incongruent) batch
testloader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last = True)

/bin/bash: wget: command not found
tar: Error opening archive: Failed to open 'MNIST.tar.gz'


In [None]:
#Dimensional analysis -----

print(len(mnist_train))
print(len(mnist_test))

examples = enumerate(trainloader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
print(example_targets.shape)

event_tensor, target = next(iter(trainloader))
print(event_tensor.shape)

60000
10000
torch.Size([128, 1, 28, 28])
torch.Size([128])
torch.Size([128, 1, 28, 28])


In [None]:
#Model architecture -----

#Size parameters
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

#Network parameters
beta = 0.95
num_steps = 25

#Model definition
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        #Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.RLeaky(beta = beta, linear_features = num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta = beta)

    def forward(self, x):

        #Initialize parameters
        spk1, mem1 = self.lif1.init_rleaky() #init_rleaky() creates a tuple [_SpikeTensor, _SpikeTensor], assigns each to mem1 and spk1
        mem2 = self.lif2.init_leaky() #init_leaky() creates a tensor _SpikeTensor

        #Record final layer
        spk_rec = []
        mem_rec = []

        #Forward loop
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, spk1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk_rec.append(spk2)
            mem_rec.append(mem2)

        #Convert final lists to tensors
        spk_rec = torch.stack(spk_rec)
        mem_rec = torch.stack(mem_rec)
        
        return spk_rec, mem_rec

net = Net()

In [None]:
#Model visualizations -----

#All layers and associated parameters
for name, param in net.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

#Dataset dimensions


#Visual map


Layer: fc1.weight | Size: torch.Size([1000, 784]) | Values : tensor([[-0.0191, -0.0323, -0.0147,  ...,  0.0102, -0.0193,  0.0085],
        [ 0.0055, -0.0354,  0.0019,  ..., -0.0184, -0.0155,  0.0227]],
       grad_fn=<SliceBackward0>) 

Layer: fc1.bias | Size: torch.Size([1000]) | Values : tensor([-0.0323,  0.0102], grad_fn=<SliceBackward0>) 

Layer: lif1.recurrent.weight | Size: torch.Size([1000, 1000]) | Values : tensor([[-0.0070, -0.0091, -0.0068,  ..., -0.0155,  0.0268,  0.0038],
        [ 0.0153,  0.0143,  0.0237,  ...,  0.0118,  0.0283,  0.0158]],
       grad_fn=<SliceBackward0>) 

Layer: lif1.recurrent.bias | Size: torch.Size([1000]) | Values : tensor([-0.0301,  0.0288], grad_fn=<SliceBackward0>) 

Layer: fc2.weight | Size: torch.Size([10, 1000]) | Values : tensor([[ 0.0164, -0.0059, -0.0019,  ..., -0.0150,  0.0082,  0.0209],
        [-0.0230,  0.0164, -0.0230,  ..., -0.0090,  0.0147,  0.0184]],
       grad_fn=<SliceBackward0>) 

Layer: fc2.bias | Size: torch.Size([10]) | Values

In [None]:
#Optimizer and loss function

import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr = 2e-3, betas = (0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate = 0.8, incorrect_rate = 0.2)

In [None]:
#Training paradigm -----

#Training parameters
num_epochs = 2
num_steps = 25
counter = 0

#Initialize loss and accuracy 
loss_hist = []
acc_hist = []

#Training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(trainloader)):

        #Load on CUDA (if available)
        data = data.to(device)
        targets = targets.to(device)

        #Set model to training mode
        net.train()
        outputs, _ = net(data.view(batch_size, -1))

        #Calculate loss
        loss_val = loss_fn(outputs, targets)

        #Gradient calculation and weight updates
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        #Store loss history
        loss_hist.append(loss_val.item())

        #Prints (print every 25 iterations)
        if i % 25 == 0:
            net.eval()

            #Print training loss
            print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

            #Model performance on single batch
            acc = SF.accuracy_rate(outputs, targets) #Outputs: [num_steps, batch_size, num_outputs]. Targets: [batch_size]
            acc_hist.append(acc)
            print(f"Accuracy: {acc * 100:.2f}%\n")

Epoch 0, Iteration 0 
Train Loss: 2.41
Accuracy: 7.03%

Epoch 0, Iteration 25 
Train Loss: 0.57
Accuracy: 57.81%

Epoch 0, Iteration 50 
Train Loss: 0.32
Accuracy: 81.25%

Epoch 0, Iteration 75 
Train Loss: 0.21
Accuracy: 89.06%

Epoch 0, Iteration 100 
Train Loss: 0.15
Accuracy: 92.97%

Epoch 0, Iteration 125 
Train Loss: 0.19
Accuracy: 92.19%

Epoch 0, Iteration 150 
Train Loss: 0.19
Accuracy: 86.72%

Epoch 0, Iteration 175 
Train Loss: 0.14
Accuracy: 94.53%

Epoch 0, Iteration 200 
Train Loss: 0.13
Accuracy: 95.31%

Epoch 0, Iteration 225 
Train Loss: 0.17
Accuracy: 88.28%

Epoch 0, Iteration 250 
Train Loss: 0.17
Accuracy: 86.72%

Epoch 0, Iteration 275 
Train Loss: 0.11
Accuracy: 94.53%

Epoch 0, Iteration 300 
Train Loss: 0.10
Accuracy: 96.09%

Epoch 0, Iteration 325 
Train Loss: 0.09
Accuracy: 95.31%

Epoch 0, Iteration 350 
Train Loss: 0.11
Accuracy: 94.53%

Epoch 0, Iteration 375 
Train Loss: 0.11
Accuracy: 95.31%

Epoch 0, Iteration 400 
Train Loss: 0.11
Accuracy: 93.75%

Epo