<a href="https://colab.research.google.com/github/andrewsiyoon/spiking-seRNN/blob/main/snnLeaky_learnable_betas.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install snntorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting snntorch
  Downloading snntorch-0.5.3-py2.py3-none-any.whl (95 kB)
[K     |████████████████████████████████| 95 kB 2.5 MB/s 
Installing collected packages: snntorch
Successfully installed snntorch-0.5.3


In [7]:
# Imports -----

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import surrogate

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [9]:
# Define Network -----

#This module defines the spiking network with learnable membrane time constants (modeling neural heterogeneity). 
#The time constant is clipped at [0,1) as the Goodman paper does. 
#Also possible to initialize all neurons in a single layer to the same value with beta1 = int, but chose to do torch.rand for individualization

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        num_inputs = 784
        num_hidden = 300
        num_outputs = 10
        spike_grad = surrogate.fast_sigmoid()

        # Independent decay rate initialization for each neuron in Layer 1
        beta1 = torch.rand((num_hidden), dtype = torch.float)
        # Independent decay rate initialization for each  neuron in Layer 2: [0, 1)
        beta2 = torch.rand((num_outputs), dtype = torch.float) #.to(device) #.to(device) is for transfer to CUDA

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden) #Applies linear transformation to input data.
        self.lif1 = snn.Leaky(beta=beta1, spike_grad=spike_grad, learn_beta=True) #First spiking neuron layer: integrates weighted input over time and emits a spike if threshold is met
        #learn_beta = True allows for decay rate learning for each neuron.
        self.fc2 = nn.Linear(num_hidden, num_outputs) #Applies a linear transformation to the output spikes of fc1
        self.lif2 = snn.Leaky(beta=beta2, spike_grad=spike_grad, learn_beta=True) #Second spiking neuron layer: integrates weighted spikes over time
    
    def forward(self, x): 

        # Reset hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

In [None]:
# Loss function and optimization -----

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [None]:
# Training paradigm -----

num_epochs = 1 #Modify if you want to train for >1 epoch
num_steps = 100  #Run for 25 time steps

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = net(data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")