# DVS Gesture Training

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Tonic
import tonic
from tonic import DiskCachedDataset
from tonic import MemoryCachedDataset
from tonic.transforms import Compose, ToFrame, Downsample

# Other
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
import sys
import pandas as pd
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import pyfenrir as fnr

In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
batch_size = 32
frame_length_us = 16.6e3
target_size = (160, 120)
n_timesteps = 300

def pad_time_dimension(frames, fixed_time_steps=100):
    """
    Pad or truncate the time dimension of frames to a fixed number of time steps.
    Input: frames [time, channels, height, width] (numpy or tensor)
    Output: frames [fixed_time_steps, channels, height, width] (tensor)
    """
    # Convert to tensor if input is numpy array
    if isinstance(frames, np.ndarray):
        frames = torch.tensor(frames, dtype=torch.float)
    current_time_steps = frames.shape[0]
    #print(f"Current time steps: {current_time_steps}, Fixed time steps: {fixed_time_steps}")
    if current_time_steps > fixed_time_steps:
        return frames[:fixed_time_steps]
    elif current_time_steps < fixed_time_steps:
        return torch.nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, fixed_time_steps - current_time_steps))
    return frames

transform = Compose([
    Downsample(sensor_size=sensor_size, target_size=target_size),
    ToFrame(sensor_size=(target_size[0], target_size[1], sensor_size[2]), time_window=frame_length_us),
    transforms.Lambda(lambda x: pad_time_dimension(x, fixed_time_steps=n_timesteps)),   # Pad/truncate time dimension
    transforms.Lambda(lambda x: torch.clamp(torch.tensor(x), 0, 1).type(torch.float)),  # Clamp spikes accumulted over time to (0,1)
    transforms.Lambda(lambda x: x[:, 1:2, :, :]  ),                                       # Select only ON channel
])

# Load the dataset
trainset = tonic.datasets.DVSGesture(save_to='../data', train=True, transform=transform)
testset = tonic.datasets.DVSGesture(save_to='../data', train=False, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
num_classes = len(np.unique(trainset.targets))

print(f"The DVSGesture dataset has {num_classes} target classes.")

test_data = trainset[0][0]
print(f"Shape: {test_data.shape}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

H, W = 24, 32

if isinstance(test_data, torch.Tensor):
    frames_np = test_data.cpu().numpy()
else:
    frames_np = test_data

print(f"Shape of frames for video (flattened): {frames_np.shape}")
print(f"Number of frames: {frames_np.shape[0]}")

fig, ax = plt.subplots(figsize=(4, 4))
ax.set_title("Positive Events Animation")
ax.axis('off')

vmax_val = np.max(frames_np) if frames_np.size > 0 else 1

im = ax.imshow(frames_np[0], cmap='gray_r', vmin=0, vmax=vmax_val)

def update(frame):
    # Reshape each frame back to 2D before displaying
    im.set_data(frames_np[frame])
    ax.set_title(f"Frame: {frame+1}/{frames_np.shape[0]}")
    return [im]

ani = animation.FuncAnimation(
    fig,
    update,
    frames=frames_np.shape[0],
    interval=frame_length_us / 1000,
    blit=True
)

from IPython.display import HTML
HTML(ani.to_html5_video())


In [None]:
config = {
    # Training
    "num_epochs": 1,            # Number of epochs to train for (per trial)
    "batch_size": 32,           # Batch size
    "seed": 0,                  # Random seed

    # Data
    "num_steps": 300,           # Number of timesteps to encode input for
    "num_classes": 11,          # Number of classes
    "width": 160,                # Sensor width
    "height": 120,               # Sensor height
    
    # Quantization
    "fc1_bits": 4,              # Bits per weight for fc1
    "fc2_bits": 4,              # Bits per weight for fc2
    
    # Conv parameters
    "conv1_out": 12,            # Output channels for conv1
    "conv2_out": 24,            # Output channels for conv1
    "conv3_out": 10,             # Output channels for conv1
    "kernel_size": 3,           # Kernel size for all conv layers  

    # Fully-connected parameters
    "fc1_beta": 1.0,            # Initial decay rate for lif1
    "fc2_beta": 1.0,            # Initial decay rate for lif2 
    "fc1_thr": 1.0,             # Initial threshold for lif1
    "fc2_thr": 1.0,             # Initial threshold for lif2
    "fc1_multiplier": 10,       # Weight multiplier for fc1

    # Learning
    "lr": 3.0e-3,               # Initial learning rate
    "slope": 5.6,               # Slope value (k)
    
    # Fixed params
    "correct_rate": 0.8,        # Correct rate (loss function)
    "incorrect_rate": 0.2,      # Incorrect rate (loss function)
    "betas": (0.9, 0.999),      # Adam optimizer beta values
    "t_0": 4690,                # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,               # Minimum learning rate
}

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = fnr.FenrirNet(config).to(device)
print(f"Device: {device}")

In [None]:
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=config["lr"],
    betas=config["betas"]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config["t_0"],
    eta_min=config["eta_min"],
    last_epoch=-1
)

criterion = SF.mse_count_loss(
    correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None):
    """
    Complete one epoch of training.
    """
    
    net.train()
    loss_accum = []
    lr_accum = []

    for batch_idx, (data, labels) in enumerate(tqdm(trainloader, leave=False, desc="Training")):
        data, labels = data.to(device), labels.to(device)
        spk_rec = net(data)
        loss = criterion(spk_rec, labels)

        optimizer.zero_grad()
        loss.backward()

        # === Gradient Debugging ===
        print(f"Batch {batch_idx} - Gradient Norms:")
        for name, param in net.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                print(f"  {name:30}: {grad_norm:.6f}")
            else:
                print(f"  {name:30}: No gradient")

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        loss_accum.append(loss.item() / config["num_steps"])
        lr_accum.append(optimizer.param_groups[0]["lr"])

    return loss_accum, lr_accum

def test(config, net, testloader, device="cpu"):
    """
    Calculate accuracy on full test set.
    """

    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

loss_list = []
lr_list = []

## Load model instead of training
load_model = False
if load_model:
    net.load_state_dict(torch.load('../models/nmnist_1layer.pth'))
else:
    print(f"=======Training Network=======")
    for epoch in range(config['num_epochs']):
        loss, lr  = train(config, net, trainloader, criterion, optimizer, device, scheduler)
        loss_list = loss_list + loss
        lr_list   = lr_list + lr
        # Test
        test_accuracy = test(config, net, testloader, device)
        print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}")

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(loss_list, color='tab:orange')
    ax2.plot(lr_list, color='tab:blue')
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss', color='tab:orange')
    ax2.set_ylabel('Learning Rate', color='tab:blue')
    plt.show()

In [None]:
save_model = False
if save_model:
    dir = "../models"
    if not os.path.exists(dir):
        os.makedirs(dir)
    torch.save(net.state_dict(), f"{dir}/gesture_1layer.pth")

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(12, 6))
iter_test = iter(testloader)
data_it, targets_it = next(iter_test)

dataset = testloader.dataset
num_samples = len(dataset)
ran_idx = torch.randint(0, num_samples, (3,))

for i, idx in enumerate(ran_idx):
    # Get some data
    spike_data, target = dataset[idx]
    spike_data = spike_data.unsqueeze(0)
    spike_data = spike_data.to(device)

    # Forward pass
    print(spike_data.shape)
    spk_rec, mem_rec = net(spike_data)

    # Besværgelse (just summing the spikes)
    pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

    # Plot
    splt.raster(spk_rec[:, 0].view(300, -1), ax[i], s=25, c="black")
    ax[i].set_yticks(np.arange(0, 11, 1))
    ax[i].set_title(f"Prediction: {pred}, Target: {target}")

plt.subplots_adjust(hspace=0.5)

spike_sums = spk_rec.sum(dim=0).squeeze()  # shape: [10]
spike_avg = spike_sums.sum(dim=0)/11

print(f"Spike counts per neuron: {spike_sums.tolist()}")
print(f"Average spikes per neuron: {spike_avg.tolist()}")

In [None]:
fenrir.export_weights(net.fc1, 44, 32*24, '../../src/design_sources/data/fc1_gesture.data')
quant_scale = net.fc1.quant_weight().scale
beta        = net.beta1/net.fc1.quant_weight().scale
thr         = fenrir.get_threshold(net.fc1, net.lif1)
print(f"Quant Scale: {quant_scale}")
print(f"Beta: {beta}")
print(f"Threshold: {thr}")

## Mapping

- 1: hand_clapping
- 2: right_hand_wave
- 3: left_hand_wave
- 4: right_hand_clockwise
- 5: right_hand_counter_clockwise
- 6: left_hand_clockwise
- 7: left_hand_counter_clockwise
- 8: forearm_roll_forward
- 8: forearm_roll_backwards
- 9: drums
- 10: guitar
- 11: random_other_gesturesa

In [None]:
## Select a datapoint from the dataset
dataset = testloader.dataset
spike_data, target_tmp = dataset[1]

## Spike encode it
spike_data = spike_data.unsqueeze(0)
spike_data = spike_data.to(device)

## Forward pass
spk_rec1, mem_rec1 = net(spike_data)
pred = torch.argmax(spk_rec1.sum(dim=0).squeeze()).item()

fig, ax = plt.subplots(2, 1, figsize=(12, 6))
splt.raster(spike_data[0, :, :].view(300, -1), ax[0], s=1, c="black")
splt.raster(spk_rec1[:, 0].view(300, -1), ax[1], s=25, c="black")
ax[0].set_title(f"Prediction: {pred}, Target: {target_tmp}")

## Detach as we will use these for comparison with the testbench
spk_rec1 = spk_rec1.cpu().detach().numpy()
mem_rec1 = mem_rec1.cpu().detach().numpy()

In [None]:
for i in range(10):
    nrn_sum = spk_rec1[:, 0, i].sum()
    print(f"nrn{i}_sum = {nrn_sum}")

In [None]:
exp_data = spike_data[0, :, :]
tsteps = exp_data.shape[0]
events = []
tstep_event_idx = []

for t in range(tsteps):
    t_data = exp_data[t, :]
    non_zero_indices = (t_data != 0).nonzero(as_tuple=True)[0]

    for idx in non_zero_indices.tolist():
        events.append(idx)

    events.append(0b1000000000000)  # marker

    tstep_event_idx.append(len(events))

# Create a C header file with the events as an array
header_file = 'gesture_data_target_1.h'

with open(header_file, 'w', encoding='utf-8') as f:
    f.write('#ifndef NMNIST_FPGA_DATA_H\n')
    f.write('#define NMNIST_FPGA_DATA_H\n\n')

    # Write array size
    f.write(f'#define NMNIST_EVENTS_SIZE {len(events)}\n\n')

    # Write the array data
    f.write('const unsigned int nmnist_events[NMNIST_EVENTS_SIZE] = {\n')

    # Write values as hex or binary, here hex is easier for C
    for i, val in enumerate(events):
        # Write comma except last element
        comma = ',' if i < len(events) - 1 else ''
        f.write(f'    0x{val:X}{comma}\n')

    f.write('};\n\n')
    f.write('#endif // NMNIST_FPGA_DATA_H\n')
