In [None]:
# Main Setup

# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools
import gc

# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(device)

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.70

V1 = 0.5 # shared recurrent connection
V2 = torch.rand(num_outputs) # unshared recurrent connections

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)

        # Default RLeaky Layer where recurrent connections
        # are initialized using PyTorch defaults in nn.Linear.
        self.lif1 = snn.RLeaky(beta=beta,
                    linear_features=num_hidden)

        self.fc2 = nn.Linear(num_hidden, num_outputs)

        # each neuron has a single connection back to itself
        # where the output spike is scaled by V.
        # For `all_to_all = False`, V can be shared between
        # neurons (e.g., V1) or unique / unshared between
        # neurons (e.g., V2).
        # V is learnable by default.
        self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1)

    def forward(self, x):
        # Initialize hidden states at t=0
        spk1, mem1 = self.lif1.init_rleaky()
        spk2, mem2 = self.lif2.init_rleaky()

        # Record output layer spikes and membrane
        spk2_rec = []
        mem2_rec = []

        # time-loop
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, spk1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, spk2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # convert lists to tensors
        spk2_rec = torch.stack(spk2_rec)
        mem2_rec = torch.stack(mem2_rec)

        return spk2_rec, mem2_rec
