<a href="https://colab.research.google.com/github/GauravGupta06/CMPM118/blob/feature%2Ftraining_setup/small_snn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- 1. INSTALL LIBRARIES ---
%pip install numpy --quiet
%pip install tonic --quiet
%pip install matplotlib --quiet
%pip install snntorch --quiet
%pip install torch --quiet
%pip install Lempel-Ziv-Complexity --quiet

# --- 2. IMPORTS ---
import numpy as np
import numpy.lib.recfunctions as rf
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, functional as SF, utils
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from lempel_ziv_complexity import lempel_ziv_complexity

# Check device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")


Using device: cuda


In [None]:
# --- 3. HYPERPARAMETERS ---
BIN_SIZE = 15000
DOWNSAMPLE = 4
TIME_STEPS = 8
BATCH_SIZE = 32
NUM_CLASSES = 11
W, H = 32, 32  # downsampled spatial size
grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

# --- 4. DATASET TRANSFORMS ---
snn_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size,
                          target_size=(H, W)),
    transforms.ToFrame(sensor_size=(H, W, 2), n_time_bins=TIME_STEPS),
    lambda x: torch.from_numpy(x.copy()).permute(0, 3, 1, 2).float()
])

# --- 5. LOAD FULL DATASET ---
full_dataset = tonic.datasets.DVSGesture(save_to="./data", transform=None)
dataset_size = len(full_dataset)
print(f"Full dataset size: {dataset_size}")

# --- 6. SPLIT TRAIN/TEST ---
TEST_SPLIT_RATIO = 0.2
train_size = int(dataset_size * (1 - TEST_SPLIT_RATIO))
test_size = dataset_size - train_size
torch.manual_seed(42)
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

print(f"Train: {train_size}, Test: {test_size}")


Full dataset size: 1077
Train: 861, Test: 216


In [None]:
# --- 7. PRE-TRANSFORM DATASET TO SPEED UP ---
# print("Pre-transforming train dataset...")
# train_dataset_cached = [(snn_transform(events), target) for events, target in train_dataset]
# print("Pre-transforming test dataset...")
# test_dataset_cached = [(snn_transform(events), target) for events, target in test_dataset]

# --- 8. FAST COLLATE FUNCTION ---
# def fast_collate(batch):
#     data = torch.stack([item[0] for item in batch]).permute(1,0,2,3,4).to(device)
#     targets = torch.tensor([item[1] for item in batch], dtype=torch.long).to(device)
#     return data, targets

def fast_collate_lazy(batch):
    # Apply snn_transform only on-the-fly for this batch
    data_tensors = [snn_transform(events) for events, target in batch]
    # Stack and permute to (T, B, C, H, W)
    data = torch.stack(data_tensors).permute(1,0,2,3,4).to(device)
    targets = torch.tensor([target for events, target in batch], dtype=torch.long).to(device)
    return data, targets


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, collate_fn=fast_collate_lazy,
                          num_workers=0)  # <--- CHANGE HERE

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         drop_last=True, collate_fn=fast_collate_lazy,
                         num_workers=0)  # <--- CHANGE HERE


In [None]:
# --- 10. SMALL SNN MODEL ---
FLATTEN_SIZE = 8 * (H//2) * (W//2)  # after 1 conv + maxpool
class Small_SNN(nn.Module):
    def __init__(self):
        super().__init__()
        # --- CHANGES: smaller network ---
        self.conv = nn.Conv2d(2, 8, 3, padding=1)   # 2->8 filters
        self.pool = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True)
        self.fc = nn.Linear(FLATTEN_SIZE, NUM_CLASSES)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)

    def forward(self, data):
        spk_rec = []
        utils.reset(self)
        for step in range(data.size(0)):  # iterate over time
            x = self.conv(data[step])
            x = self.pool(x)
            spk1 = self.lif1(x) if isinstance(self.lif1, nn.Module) else self.lif1(x)[0]
            x = spk1.flatten(1)
            x = self.fc(x)
            spk2 = self.lif2(x) if isinstance(self.lif2, nn.Module) else self.lif2(x)[0]
            spk_rec.append(spk2)
        return torch.stack(spk_rec)

small_snn_net = Small_SNN().to(device)
print(small_snn_net)


Small_SNN(
  (conv): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (lif1): Leaky()
  (fc): Linear(in_features=2048, out_features=11, bias=True)
  (lif2): Leaky()
)


In [None]:
# --- 11. TRAINING SETUP ---
optimizer = torch.optim.Adam(small_snn_net.parameters(), lr=0.002, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
num_epochs = 10  # you can increase later


In [None]:
# --- 12. VALIDATION ---
def validate_model(loader, net):
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, targets in loader:
            spk_rec = net(data)
            correct += SF.accuracy_rate(spk_rec, targets) * data.shape[1]
            total += data.shape[1]
    net.train()
    return correct / total


In [None]:
# --- 13. TRAINING LOOP ---
loss_hist = []
test_acc_hist = []
cnt = 0

print("Starting training...")

for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        small_snn_net.train()
        spk_rec = small_snn_net(data)
        loss = loss_fn(spk_rec, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_hist.append(loss.item())

        if cnt % 20 == 0:
            train_acc = SF.accuracy_rate(spk_rec, targets)
            test_acc = validate_model(test_loader, small_snn_net)
            test_acc_hist.append(test_acc)
            print(f"Epoch {epoch}, Batch {batch_idx}: Loss={loss.item():.2f}, "
                  f"Train Acc={train_acc*100:.2f}%, Test Acc={test_acc*100:.2f}%")
        cnt += 1

print("Training complete!")


Starting training...


RuntimeError: Given groups=1, weight of size [8, 2, 3, 3], expected input[32, 32, 2, 32] to have 2 channels, but got 32 channels instead