# DVS Gesture Training

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Tonic
import tonic
from tonic import DiskCachedDataset
from tonic import MemoryCachedDataset
from tonic.transforms import Compose, ToFrame, Downsample
from tonic import transforms as tonic_transforms

# Other
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
import sys
import pandas as pd
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import pyfenrir as fenrir

In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
batch_size = 64
frame_length_us = 16.6e3
target_size = (60, 60)
n_timesteps = 180

def pad_time_dimension(frames, fixed_time_steps=100):
    """
    Pad or truncate the time dimension of frames to a fixed number of time steps.
    Input: frames [time, channels, height, width] (numpy or tensor)
    Output: frames [fixed_time_steps, channels, height, width] (tensor)
    """
    # Convert to tensor if input is numpy array
    if isinstance(frames, np.ndarray):
        frames = torch.tensor(frames, dtype=torch.float)
    current_time_steps = frames.shape[0]
    #print(f"Current time steps: {current_time_steps}, Fixed time steps: {fixed_time_steps}")
    if current_time_steps > fixed_time_steps:
        return frames[:fixed_time_steps]
    elif current_time_steps < fixed_time_steps:
        return torch.nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, fixed_time_steps - current_time_steps))
    return frames

transform = Compose([
    tonic_transforms.Downsample(sensor_size=(128, 128), target_size=(60, 60)),
    tonic_transforms.ToFrame(sensor_size=(60, 60, 128), time_window=frame_length_us),
    transforms.Lambda(lambda x: pad_time_dimension(x, fixed_time_steps=n_timesteps)),
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1).type(torch.float32)),
    transforms.Lambda(lambda x: x[:, 1:2, :, :])  # Select only ON channel (index 1)
])

# Load the dataset
trainset = tonic.datasets.DVSGesture(save_to='../data', train=True, transform=transform)
testset = tonic.datasets.DVSGesture(save_to='../data', train=False, transform=transform)

cached_trainset = MemoryCachedDataset(trainset)
cached_testset = MemoryCachedDataset(testset)

trainloader = DataLoader(cached_trainset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
num_classes = len(np.unique(trainset.targets))

print(f"The DVSGesture dataset has {num_classes} target classes.")

In [None]:
class NetUtils():
    @staticmethod
    def beta_clamp(mem, beta):
        """
        Soft-clamping of beta to allow gradients.
        """
        beta_abs = torch.abs(beta)
        # Positive side: approximate clamp(mem - beta_abs, min=0)
        pos_mask = (mem > 0)
        pos_val = F.relu(mem - beta_abs)  # ReLU is differentiable everywhere except 0 (and better than clamp)

        # Negative side: approximate clamp(mem + beta_abs, max=0)
        neg_mask = (mem < 0)
        neg_val = -F.relu(-(mem + beta_abs))  # negative ReLU for negative side

        mem_new = torch.where(pos_mask, pos_val, mem)
        mem_new = torch.where(neg_mask, neg_val, mem_new)

        return mem_new

    @staticmethod
    def mem_clamp(mem, scale, multiplier, bits=12):
        max_val = (2**(bits - 1)) - 1
        max_val = max_val * scale / multiplier
        min_val = -(2**(bits - 1)) - 1
        min_val = min_val * scale / multiplier
        mem = torch.clamp(mem, min=min_val, max=max_val)
        return mem

In [None]:
test_data = trainset[0]

In [None]:
config = {
    "num_epochs": 10,       # Number of epochs to train for (per trial)
    "batch_size": batch_size,      # Batch size
    "seed": 0,              # Random seed
    
    # Quantization
    "num_bits": 4,          # Bit resolution
    
    # Network parameters
    "grad_clip": False,     # Whether or not to clip gradients
    "weight_clip": False,   # Whether or not to clip weights
    "batch_norm": True,     # Whether or not to use batch normalization
    "dropout": 0.07,        # Dropout rate
    "beta": 1.0,           # Decay rate parameter (beta)
    "threshold": 10,        # Threshold parameter (theta)
    "lr": 3.0e-3,           # Initial learning rate
    "slope": 5.6,           # Slope value (k)
    
    # Fixed params
    "num_steps": 300,       # Number of timesteps to encode input for
    "correct_rate": 0.8,    # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
    "t_0": 4690,            # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,           # Minimum learning rate
}

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        fc1_bits        = 4
        fc2_bits        = 4
        fc1_beta_init   = 0.1
        fc2_beta_init   = 0.1
        fc1_thr_init    = 1.0
        fc2_thr_init    = 1.0
        self.fc1_multiplier  = 10
        self.fc2_multiplier  = 10

        self.fc1_beta   = torch.nn.Parameter(torch.tensor(fc1_beta_init), requires_grad=True)
        self.fc1        = qnn.QuantLinear(60*60, 64, bias=False, weight_bit_width=fc1_bits)
        self.lif1       = snn.Leaky(beta=1.0, threshold=fc1_thr_init, learn_threshold=True, reset_mechanism='zero', reset_delay=False)

        self.fc2_beta   = torch.nn.Parameter(torch.tensor(fc2_beta_init), requires_grad=True)
        self.fc2        = qnn.QuantLinear(64, num_classes, bias=False, weight_bit_width=fc2_bits)
        self.lif2       = snn.Leaky(beta=1.0, threshold=fc2_thr_init, learn_threshold=True, reset_mechanism='zero', reset_delay=False)

    def forward(self, x: torch.Tensor):

        B, T, C, H, W = x.shape

        fc_mem1   = self.lif1.init_leaky()
        fc_mem2   = self.lif2.init_leaky()

        # Record output spikes
        spk_rec = []

        scale_fc1 = self.fc1.quant_weight().scale
        scale_fc2 = self.fc2.quant_weight().scale

        for t in range(T):

            xt = x[:, t, :, :, :]
            xt = xt.contiguous().view(B, -1)

            cur1 = self.fc1(xt)
            fc_mem1 = NetUtils.mem_clamp(fc_mem1, scale_fc1, multiplier=self.fc1_multiplier)
            spk1, fc_mem1 = self.lif1(cur1, fc_mem1)
            fc_mem1 = NetUtils.beta_clamp(fc_mem1, self.fc1_beta)

            cur2 = self.fc2(spk1)
            fc_mem2 = NetUtils.mem_clamp(fc_mem2, scale_fc2, multiplier=self.fc2_multiplier)
            spk2, fc_mem2 = self.lif2(cur2, fc_mem2)
            fc_mem2 = NetUtils.beta_clamp(fc_mem2, self.fc2_beta)

            spk_rec.append(spk2)

        return torch.stack(spk_rec)

net = Net().to(device)

In [None]:
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=config["lr"],
    betas=config["betas"]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config["t_0"],
    eta_min=config["eta_min"],
    last_epoch=-1
)

criterion = SF.mse_count_loss(
    correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None):
    """
    Complete one epoch of training.
    """
    
    net.train()
    loss_accum = []
    lr_accum = []

    for batch_idx, (data, labels) in enumerate(tqdm(trainloader, leave=False, desc="Training")):
        data, labels = data.to(device), labels.to(device)
        spk_rec, _ = net(data)
        loss = criterion(spk_rec, labels)
        optimizer.zero_grad()
        loss.backward()

        ## Enable gradient clipping
        if config["grad_clip"]:
            nn.utils.clip_grad_norm_(net.parameters(), 1.0)

        ## Enable weight clipping
        if config["weight_clip"]:
            with torch.no_grad():
                for param in net.parameters():
                    param.clamp_(-1, 1)

        optimizer.step()
        scheduler.step()
        loss_accum.append(loss.item() / config["num_steps"])
        lr_accum.append(optimizer.param_groups[0]["lr"])

    return loss_accum, lr_accum

def test(config, net, testloader, device="cpu"):
    """
    Calculate accuracy on full test set.
    """

    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

loss_list = []
lr_list = []

## Load model instead of training
load_model = True
if load_model:
    checkpoint = torch.load('../models/fenrir_dvsgesture_FenrirFC_best.pth', map_location=device, weights_only=False)
    model_state_dict = checkpoint['model_state_dict']
    net.load_state_dict(model_state_dict)

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(12, 6))
iter_test = iter(testloader)
data_it, targets_it = next(iter_test)

dataset = testloader.dataset
num_samples = len(dataset)
ran_idx = torch.randint(0, num_samples, (3,))

for i, idx in enumerate(ran_idx):
    # Get some data
    spike_data, target = dataset[idx]
    spike_data = spike_data.to(device)
    spike_data = spike_data.unsqueeze(0)

    print(spike_data.shape)

    # Forward pass
    print(spike_data.shape)
    spk_rec = net(spike_data)

    # Besværgelse (just summing the spikes)
    pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

    # Plot
    splt.raster(spk_rec[:, 0].view(180, -1), ax[i], s=25, c="black")
    ax[i].set_yticks(np.arange(0, 11, 1))
    ax[i].set_title(f"Prediction: {pred}, Target: {target}")

plt.subplots_adjust(hspace=0.5)

spike_sums = spk_rec.sum(dim=0).squeeze()  # shape: [10]
spike_avg = spike_sums.sum(dim=0)/11

print(f"Spike counts per neuron: {spike_sums.tolist()}")
print(f"Average spikes per neuron: {spike_avg.tolist()}")

In [None]:
fenrir.export_weights(net.fc1, 32, 60*60, '../../src/design_sources/data/FenrirFC_fc1.data')
quant_scale = net.fc1.quant_weight().scale
beta        = net.fc1_beta/net.fc1.quant_weight().scale
thr         = fenrir.get_threshold(net.fc1, net.lif1)
print(f"Quant Scale: {quant_scale}")
print(f"Beta: {beta}")
print(f"Threshold: {thr}")

In [None]:
fenrir.export_weights(net.fc2, 44, 64, '../../src/design_sources/data/FenrirFC_fc2.data')
quant_scale = net.fc2.quant_weight().scale
beta        = net.fc2_beta/net.fc2.quant_weight().scale
thr         = fenrir.get_threshold(net.fc2, net.lif2)
print(f"Quant Scale: {quant_scale}")
print(f"Beta: {beta}")
print(f"Threshold: {thr}")

## Mapping

- 1: hand_clapping
- 2: right_hand_wave
- 3: left_hand_wave
- 4: right_hand_clockwise
- 5: right_hand_counter_clockwise
- 6: left_hand_clockwise
- 7: left_hand_counter_clockwise
- 8: forearm_roll_forward
- 8: forearm_roll_backwards
- 9: drums
- 10: guitar
- 11: random_other_gesturesa

In [None]:
## Select a datapoint from the dataset
dataset = testloader.dataset
spike_data, target_tmp = dataset[1]

## Spike encode it
spike_data = spike_data.unsqueeze(0)
spike_data = spike_data.to(device)

## Forward pass
spk_rec1 = net(spike_data)
pred = torch.argmax(spk_rec1.sum(dim=0).squeeze()).item()

fig, ax = plt.subplots(2, 1, figsize=(12, 6))
splt.raster(spike_data[0, :, :].view(300, -1), ax[0], s=1, c="black")
splt.raster(spk_rec1[:, 0].view(180, -1), ax[1], s=25, c="black")
ax[0].set_title(f"Prediction: {pred}, Target: {target_tmp}")

## Detach as we will use these for comparison with the testbench
spk_rec1 = spk_rec1.cpu().detach().numpy()

In [None]:
for i in range(10):
    nrn_sum = spk_rec1[:, 0, i].sum()
    print(f"nrn{i}_sum = {nrn_sum}")

In [None]:
spike_data.shape

In [None]:
exp_data = spike_data[0, :, 0, :, :].reshape(spike_data.shape[1], -1)
tsteps = 180
events = []
tstep_event_idx = []

for t in range(tsteps):
    t_data = exp_data[t, :]
    non_zero_indices = (t_data != 0).nonzero(as_tuple=True)[0]

    for idx in non_zero_indices.tolist():
        events.append(idx)

    events.append(0b1000000000000)  # marker

    tstep_event_idx.append(len(events))

# Create a C header file with the events as an array
header_file = 'gesture_data_target_1.h'

with open(header_file, 'w', encoding='utf-8') as f:
    f.write('#ifndef NMNIST_FPGA_DATA_H\n')
    f.write('#define NMNIST_FPGA_DATA_H\n\n')

    # Write array size
    f.write(f'#define NMNIST_EVENTS_SIZE {len(events)}\n\n')

    # Write the array data
    f.write('const unsigned int nmnist_events[NMNIST_EVENTS_SIZE] = {\n')

    # Write values as hex or binary, here hex is easier for C
    for i, val in enumerate(events):
        # Write comma except last element
        comma = ',' if i < len(events) - 1 else ''
        f.write(f'    0x{val:X}{comma}\n')

    f.write('};\n\n')
    f.write('#endif // NMNIST_FPGA_DATA_H\n')


In [None]:
exp_data = spike_data[0, :, 0, :, :].reshape(spike_data.shape[1], -1)
tsteps = 180
events = []
tstep_event_idx = []

for t in range(tsteps):
    t_data = exp_data[t, :]
    non_zero_indices = (t_data != 0).nonzero(as_tuple=True)[0]

    for idx in non_zero_indices.tolist():
        events.append(idx)

    # Add special 13-bit marker (MSB = 1)
    events.append(0b1000000000000)
    tstep_event_idx.append(len(events))

# Format with 13 bits: MSB + 12-bit index
binary_events = []
for idx in events:
    if idx == 0b1000000000000:
        binary_events.append('1000000000000')  # reserved marker
    elif idx < 4096:
        binary_events.append('0' + format(idx, '012b'))  # MSB 0 + 12-bit index
    else:
        raise ValueError(f"Index {idx} exceeds 12-bit range")

# Write to file
out_file = 'gesture_data.txt'
with open(out_file, 'w', encoding='utf-8') as f:
    for b in binary_events:
        f.write(b + '\n')
