In [None]:
import torch
import snntorch as snn
import tonic
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader
from tonic import datasets, transforms
from collections import namedtuple
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

# Device configuration
device = torch.device("mps")

# Hyperparameters
SAMPLE_T = 128  # Time steps per sample
SHD_TIMESTEP = 1e-6  # Time step of SHD dataset
SHD_CHANNELS = 700  # Number of input channels in the SHD dataset
NET_CHANNELS = 128  # Number of input channels in the network
NET_DT = 1 / SAMPLE_T  # Time step for network
BATCH_SIZE = 256  # Batch size
NUM_EPOCHS = 100  # Number of epochs
NUM_HIDDEN = 128  # Number of hidden units
NEURONS_PER_CLASS = 5  # Number of neurons representing each class
NUM_CLASSES = 20  # Number of output classes

# Set up TensorBoard
run_name = f"SNN_popcoding_hidden_{NUM_HIDDEN}_epochs_{NUM_EPOCHS}"
writer = SummaryWriter(log_dir=f"runs/{run_name}")

# Custom transform to rasterize events into frames
class SHD2Raster:
    def __init__(self, encoding_dim, sample_T=100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T

    def __call__(self, events):
        tensor = np.zeros((events["t"].max() + 1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        tensor = tensor[:self.sample_T, :]
        tensor = np.minimum(tensor, 1)
        return tensor

# Data transformation pipeline
transform = transforms.Compose([
    transforms.Downsample(time_factor=SHD_TIMESTEP / NET_DT, spatial_factor=NET_CHANNELS / SHD_CHANNELS),
    SHD2Raster(NET_CHANNELS, sample_T=SAMPLE_T)
])

# Load datasets
train_dataset = datasets.SHD("../data", train=True, transform=transform)
test_dataset = datasets.SHD("../data", train=False, transform=transform)

# Function to shuffle and batch data
def shuffle(dataset):
    x, y = dataset
    cutoff = y.shape[0] % BATCH_SIZE
    indices = torch.randperm(y.shape[0])[:-cutoff]
    x, y = x[indices], y[indices]
    x = torch.reshape(x, (-1, BATCH_SIZE) + x.shape[1:])
    y = torch.reshape(y, (-1, BATCH_SIZE))
    return namedtuple("State", "obs labels")(x, y)

# Prepare training data
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                           collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

# Prepare testing data
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
x_test, y_test = next(test_dl)
x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
x_test, y_test = x_test.to(device), y_test.to(device)
x_test, y_test = shuffle((x_test, y_test))

# Define the SNN model with population coding
class SNNModel(torch.nn.Module):
    def __init__(self, num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, neurons_per_class=NEURONS_PER_CLASS):
        super(SNNModel, self).__init__()
        self.num_classes = num_classes
        self.neurons_per_class = neurons_per_class
        self.output_size = num_classes * neurons_per_class

        self.fc1 = torch.nn.Linear(NET_CHANNELS, num_hidden)
        self.lif1 = snn.Leaky(beta=torch.ones(num_hidden) * 0.5, learn_beta=True)

        self.fc2 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif2 = snn.Leaky(beta=torch.ones(num_hidden) * 0.5, learn_beta=True)

        self.fc3 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif3 = snn.Leaky(beta=torch.ones(num_hidden) * 0.5, learn_beta=True)

        self.fc4 = torch.nn.Linear(num_hidden, self.output_size)
        self.lif4 = snn.Leaky(beta=torch.ones(self.output_size) * 0.5, learn_beta=True, reset_mechanism="none")

    def forward(self, x):
        x = x.float()
        x = x.permute(1, 0, 2)

        mem1, mem2, mem3, mem4 = self.lif1.init_leaky(), self.lif2.init_leaky(), self.lif3.init_leaky(), self.lif4.init_leaky()
        spikes = []

        for step in x:
            cur1 = self.fc1(step)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)
            spikes.append(spk4)

        return torch.stack(spikes, axis=0).permute(1, 0, 2)

# Instantiate the model
model = SNNModel(num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, neurons_per_class=NEURONS_PER_CLASS).to(device)

# Loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

# Function to calculate accuracy
def accuracy(predictions, targets):
    return (predictions == targets).sum().item() / len(targets)

# Function to aggregate neuron activities per class
def aggregate_class_activity(spikes, num_classes=NUM_CLASSES, neurons_per_class=NEURONS_PER_CLASS):
    batch_size, time_steps, total_neurons = spikes.shape
    spikes = spikes.view(batch_size, time_steps, num_classes, neurons_per_class)
    class_activity = spikes.sum(dim=-1)  # Sum
    # Sum the activity of all neurons in each class
    return class_activity.sum(dim=1)  # Aggregate over time steps

# Training loop with population coding
for epoch in range(NUM_EPOCHS):
    model.train()
    train_batch = shuffle((x_train, y_train))
    train_data, train_targets = train_batch

    for data, targets in zip(train_data, train_targets):
        optimizer.zero_grad()
        spikes = model(data)  # Output spike trains
        class_activity = aggregate_class_activity(spikes)  # Aggregate spikes for each class
        predicted_classes = torch.argmax(class_activity, axis=-1)  # Predicted class based on max activity
        loss_val = loss_fn(class_activity, targets)  # Loss using class activity
        loss_val.backward()
        optimizer.step()

    scheduler.step()
    train_acc = accuracy(predicted_classes, targets)
    writer.add_scalar("Loss/train", loss_val.item(), epoch)
    writer.add_scalar("Accuracy/train", train_acc, epoch)

    if epoch % 10 == 0:
        print(f"Epoch {epoch + 1} | Loss: {loss_val.item()} | Training Accuracy: {train_acc}")

# Evaluation loop
def evaluate(model, x_test, y_test):
    model.eval()
    test_acc = []
    all_preds, all_targets = [], []
    with torch.no_grad():
        for test_data, test_targets in zip(x_test, y_test):
            spikes = model(test_data)  # Output spike trains
            class_activity = aggregate_class_activity(spikes)  # Aggregate spikes for each class
            predicted_classes = torch.argmax(class_activity, axis=-1)  # Predicted class
            test_acc.append(accuracy(predicted_classes, test_targets))
            all_preds.append(predicted_classes)
            all_targets.append(test_targets)

    avg_test_acc = np.mean(test_acc)
    all_preds, all_targets = torch.cat(all_preds), torch.cat(all_targets)
    cm = confusion_matrix(all_targets.cpu(), all_preds.cpu())
    cm_display = ConfusionMatrixDisplay(cm, display_labels=[str(i) for i in range(NUM_CLASSES)])
    
    return avg_test_acc, cm_display

# Evaluate the model
test_acc, cm_display = evaluate(model, x_test, y_test)
print(f"Test Accuracy: {test_acc * 100:.2f}%")
cm_display.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()

# Close TensorBoard writer
writer.close()

Epoch 1 | Loss: 2.995732307434082 | Training Accuracy: 0.046875
Epoch 11 | Loss: 3.4240996837615967 | Training Accuracy: 0.234375
Epoch 21 | Loss: 3.5148568153381348 | Training Accuracy: 0.33984375
