<a href="https://colab.research.google.com/github/andrewsiyoon/spiking-seRNN/blob/main/controlSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install snntorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting snntorch
  Downloading snntorch-0.5.3-py2.py3-none-any.whl (95 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.5/95.5 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: snntorch
Successfully installed snntorch-0.5.3


In [2]:
#Imports -----

import torch, torch.nn as nn
import snntorch as snn

In [3]:
#Dataloading -----

batch_size = 128
data_path='/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [10]:
#Define network -----

from snntorch import surrogate
import torch.nn.functional as F

#Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        num_inputs = 784
        num_hidden = 1000
        num_outputs = 10
        spike_grad = surrogate.fast_sigmoid()

        #Heterogeneous membrane time constants: [0, 1)
        beta1 = torch.rand((num_hidden), dtype = torch.float)
        beta2 = torch.rand((num_outputs), dtype = torch.float) #.to(device)

        #Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta1, spike_grad=spike_grad, learn_beta=True)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta2, spike_grad=spike_grad,learn_beta=True)

    def forward(self, x):

        #Initialize hidden states and outputs
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        #Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
net = Net().to(device)

In [11]:
#Optimizer and loss function -----
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr = 2e-3, betas = (0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate = 0.75, incorrect_rate = 0.25)

In [12]:
#Training loop -----
num_epochs = 2
num_steps = 25  

loss_hist = []
acc_hist = []

for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        #Train the network
        net.train()
        spk_rec, _ = net(data) #spk_rec = "outputs" in other documents

        reg_loss = 1e-5*torch.sum(spk_rec) #L1 loss on total number of spikes
        loss_val = loss_fn(spk_rec, targets) + reg_loss

        #Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        #Store loss history
        loss_hist.append(loss_val.item())

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

Epoch 0, Iteration 0 
Train Loss: 2.57
Accuracy: 3.12%

Epoch 0, Iteration 25 
Train Loss: 0.53
Accuracy: 56.25%

Epoch 0, Iteration 50 
Train Loss: 0.35
Accuracy: 82.03%

Epoch 0, Iteration 75 
Train Loss: 0.28
Accuracy: 85.16%

Epoch 0, Iteration 100 
Train Loss: 0.25
Accuracy: 89.06%

Epoch 0, Iteration 125 
Train Loss: 0.23
Accuracy: 93.75%

Epoch 0, Iteration 150 
Train Loss: 0.25
Accuracy: 89.06%

Epoch 0, Iteration 175 
Train Loss: 0.21
Accuracy: 93.75%

Epoch 0, Iteration 200 
Train Loss: 0.22
Accuracy: 93.75%

Epoch 0, Iteration 225 
Train Loss: 0.24
Accuracy: 88.28%

Epoch 0, Iteration 250 
Train Loss: 0.21
Accuracy: 88.28%

Epoch 0, Iteration 275 
Train Loss: 0.25
Accuracy: 92.19%

Epoch 0, Iteration 300 
Train Loss: 0.21
Accuracy: 92.19%

Epoch 0, Iteration 325 
Train Loss: 0.17
Accuracy: 95.31%

Epoch 0, Iteration 350 
Train Loss: 0.19
Accuracy: 93.75%

Epoch 0, Iteration 375 
Train Loss: 0.19
Accuracy: 91.41%

Epoch 0, Iteration 400 
Train Loss: 0.19
Accuracy: 95.31%

Epo

In [15]:
print(spk_rec.size())

torch.Size([25, 96, 10])
