<a href="https://colab.research.google.com/github/GauravGupta06/CMPM118/blob/feature%2Ftraining_setup/small_snn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- 1. INSTALL LIBRARIES ---
%pip install numpy --quiet
%pip install tonic --quiet
%pip install matplotlib --quiet
%pip install snntorch --quiet
%pip install torch --quiet
%pip install Lempel-Ziv-Complexity --quiet

# --- 2. IMPORTS ---
import numpy as np
import numpy.lib.recfunctions as rf
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, functional as SF, utils
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from lempel_ziv_complexity import lempel_ziv_complexity

# imports for ROC curve
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

# Check device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hUsing device: cuda


In [None]:
# --- 3. HYPERPARAMETERS ---
BIN_SIZE = 15000
DOWNSAMPLE = 4
TIME_STEPS = 8
BATCH_SIZE = 32
NUM_CLASSES = 11
W, H = 32, 32  # downsampled spatial size
grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

# --- 4. DATASET TRANSFORMS ---
snn_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size,
                          target_size=(H, W)),
    transforms.ToFrame(sensor_size=(H, W, 2), n_time_bins=TIME_STEPS),
    lambda x: torch.from_numpy(x.copy()).float()  # returns [T, C, H, W]
])


full_dataset = tonic.datasets.DVSGesture(save_to="./data", transform=None)
dataset_size = len(full_dataset)
print(f"Full dataset size: {dataset_size}")

TEST_SPLIT_RATIO = 0.2
train_size = int(dataset_size * (1 - TEST_SPLIT_RATIO))
test_size = dataset_size - train_size
torch.manual_seed(42)
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
print(f"Train: {train_size}, Test: {test_size}")


Full dataset size: 1077
Train: 861, Test: 216


In [None]:
def fast_collate_lazy(batch):
    data_tensors = [snn_transform(events) for events, _ in batch]  # [T, C, H, W]
    data = torch.stack(data_tensors)           # [B, T, C, H, W]
    data = data.permute(1, 0, 2, 3, 4)         # [T, B, C, H, W]
    targets = torch.tensor([target for _, target in batch], dtype=torch.long)
    return data, targets

# --- 8. DATALOADERS ---
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, collate_fn=fast_collate_lazy,
                          num_workers=2, pin_memory=True)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         drop_last=True, collate_fn=fast_collate_lazy,
                         num_workers=2, pin_memory=True)

In [None]:
# --- 10. SMALL SNN MODEL ---
class Small_SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 8, 3, padding=1)   # 2->8 filters
        self.pool = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True)
        self.fc = nn.Linear(FLATTEN_SIZE, NUM_CLASSES)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)

    def forward(self, data):
        """
        data: [T, B, C, H, W]
        returns: [T, B, NUM_CLASSES]
        """
        spk_rec = []
        utils.reset(self)  # reset LIF hidden states

        for t in range(data.size(0)):
            x = data[t]               # [B, C, H, W]
            x = self.conv(x)
            x = self.pool(x)

            # LIF layer 1 (handle tuple output)
            spk1 = self.lif1(x)
            if isinstance(spk1, tuple):
                spk1 = spk1[0]

            x = spk1.flatten(1)       # flatten for fc layer
            x = self.fc(x)

            # LIF layer 2 (output layer)
            spk2 = self.lif2(x)
            if isinstance(spk2, tuple):
                spk2 = spk2[0]

            spk_rec.append(spk2)

        return torch.stack(spk_rec)   # [T, B, NUM_CLASSES]

small_snn_net = Small_SNN().to(device)

In [None]:
# --- 11. TRAINING SETUP ---
optimizer = torch.optim.Adam(small_snn_net.parameters(), lr=0.002)
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
num_epochs = 10


In [None]:
# --- 12. VALIDATION ---
def validate_model(loader, net):
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, targets in loader:
            data = data.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            spk_rec = net(data)
            correct += SF.accuracy_rate(spk_rec, targets) * data.shape[1]
            total += data.shape[1]
    net.train()
    return correct / total

In [None]:
print(next(small_snn_net.parameters()).device)

cuda:0


In [None]:
# --- 13. TRAINING LOOP ---
loss_hist = []
test_acc_hist = []
cnt = 0

print("Starting training...")
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Move tensors to device
        data = data.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        small_snn_net.train()
        optimizer.zero_grad()

        spk_rec = small_snn_net(data)
        loss = loss_fn(spk_rec, targets)
        loss.backward()
        optimizer.step()

        loss_hist.append(loss.item())

        if cnt % 20 == 0:
            train_acc = SF.accuracy_rate(spk_rec, targets)
            test_acc = validate_model(test_loader, small_snn_net)
            test_acc_hist.append(test_acc)
            print(f"Epoch {epoch}, Batch {batch_idx}: Loss={loss.item():.4f}, "
                  f"Train Acc={train_acc*100:.2f}%, Test Acc={test_acc*100:.2f}%")
        cnt += 1

print("Training complete!")

Starting training...
Epoch 0, Batch 0: Loss=0.5217, Train Acc=3.12%, Test Acc=13.02%


KeyboardInterrupt: 