# DVS Gesture Training

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Tonic
import tonic
from tonic import DiskCachedDataset
from tonic import MemoryCachedDataset
from tonic.transforms import Compose, ToFrame, Downsample

# Other
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
import sys
import pandas as pd
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import pyfenrir as fnr

In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
batch_size = 32
frame_length_us = 16.6e3
target_size = (80, 60)
n_timesteps = 180

def pad_time_dimension(frames, fixed_time_steps=100):
    """
    Pad or truncate the time dimension of frames to a fixed number of time steps.
    Input: frames [time, channels, height, width] (numpy or tensor)
    Output: frames [fixed_time_steps, channels, height, width] (tensor)
    """
    # Convert to tensor if input is numpy array
    if isinstance(frames, np.ndarray):
        frames = torch.tensor(frames, dtype=torch.float)
    current_time_steps = frames.shape[0]
    #print(f"Current time steps: {current_time_steps}, Fixed time steps: {fixed_time_steps}")
    if current_time_steps > fixed_time_steps:
        return frames[:fixed_time_steps]
    elif current_time_steps < fixed_time_steps:
        return torch.nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, fixed_time_steps - current_time_steps))
    return frames

transform = Compose([
    Downsample(sensor_size=sensor_size, target_size=target_size),
    ToFrame(sensor_size=(target_size[0], target_size[1], sensor_size[2]), time_window=frame_length_us),
    transforms.Lambda(lambda x: pad_time_dimension(x, fixed_time_steps=n_timesteps)),   # Pad/truncate time dimension
    transforms.Lambda(lambda x: x[:, 1:2, :, :]  ),                                       # Select only ON channel
])

# Load the dataset
trainset = tonic.datasets.DVSGesture(save_to='../data', train=True, transform=transform)
testset = tonic.datasets.DVSGesture(save_to='../data', train=False, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
num_classes = len(np.unique(trainset.targets))

print(f"The DVSGesture dataset has {num_classes} target classes.")

test_data = trainset[0][0]
print(f"Shape: {test_data.shape}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

if isinstance(test_data, torch.Tensor):
    frames_np = test_data.cpu().numpy()
else:
    frames_np = test_data

print(f"Shape of frames for video (flattened): {frames_np.shape}")
print(f"Number of frames: {frames_np.shape[0]}")

fig, ax = plt.subplots(figsize=(4, 4))
ax.set_title("Positive Events Animation")
ax.axis('off')

vmax_val = np.max(frames_np) if frames_np.size > 0 else 1

# Use the first frame and squeeze channel dimension for imshow
im = ax.imshow(frames_np[0, 0, :, :], cmap='gray_r', vmin=0, vmax=vmax_val)

def update(frame):
    # Display the frame by squeezing out channel dimension
    im.set_data(frames_np[frame, 0, :, :])
    ax.set_title(f"Frame: {frame+1}/{frames_np.shape[0]}")
    return [im]

ani = animation.FuncAnimation(
    fig,
    update,
    frames=frames_np.shape[0],
    interval=frame_length_us / 1000,
    blit=True
)

from IPython.display import HTML
HTML(ani.to_html5_video())


In [None]:
config = {
    # Training
    "num_epochs": 10,           # Number of epochs to train for (per trial)
    "batch_size": batch_size,   # Batch size
    "seed": 0,                  # Random seed

    # Data
    "num_steps": 100,           # Number of timesteps to encode input for
    "num_classes": 11,          # Number of classes
    "width": 80,                # Sensor width
    "height": 60,               # Sensor height
    
    # Quantization
    "fc1_bits": 4,              # Bits per weight for fc1
    "fc2_bits": 4,              # Bits per weight for fc2
    
    # Conv parameters
    "conv1_out": 12,            # Output channels for conv1
    "conv2_out": 5,            # Output channels for conv1
    "conv3_out": 5,             # Output channels for conv1
    "kernel_size": 3,           # Kernel size for all conv layers  

    # Fully-connected parameters
    "fc1_beta": 0.1,            # Initial decay rate for lif1
    "fc2_beta": 1.0,            # Initial decay rate for lif2 
    "fc1_thr": 1.0,             # Initial threshold for lif1
    "fc2_thr": 1.0,             # Initial threshold for lif2
    "fc1_multiplier": 10,       # Weight multiplier for fc1
    "fc2_multiplier": 10,       # Weight multiplier for fc2

    # Learning
    "lr": 3.0e-3,               # Initial learning rate
    "slope": 5.6,               # Slope value (k)
    
    # Fixed params
    "correct_rate": 0.8,        # Correct rate (loss function)
    "incorrect_rate": 0.2,      # Incorrect rate (loss function)
    "betas": (0.9, 0.999),      # Adam optimizer beta values
    "t_0": 4690,                # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,               # Minimum learning rate
}

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = fnr.FenrirNet(config).to(device)
print(f"Device: {device}")

In [None]:
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=config["lr"],
    betas=config["betas"]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config["t_0"],
    eta_min=config["eta_min"],
    last_epoch=-1
)

criterion = SF.mse_count_loss(
    correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None, current_epoch=0, grad_log_file=None):
    """
    Complete one epoch of training.
    """
    net.train()
    loss_accum = []
    lr_accum = []
    
    pbar = tqdm(trainloader, leave=False, desc=f"Epoch {current_epoch} Training")

    for batch_idx, (data, labels) in enumerate(pbar):
        data, labels = data.to(device), labels.to(device)
        
        spk_rec = net(data)
        loss = criterion(spk_rec, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if scheduler is not None:
            scheduler.step()

        # Calculate batch accuracy
        batch_accuracy_percent = 0.0
        current_loss_val = loss.item()

        with torch.no_grad():
            batch_accuracy = SF.accuracy_rate(spk_rec, labels)
            
            if isinstance(batch_accuracy, torch.Tensor):
                batch_accuracy = batch_accuracy.item()
            batch_accuracy_percent = batch_accuracy * 100

        total_grad_norm = 0
        for p in net.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_grad_norm += param_norm.item() ** 2
        total_grad_norm = total_grad_norm ** 0.5

        postfix_dict = {
            "loss": f"{current_loss_val:.4f}",
            "acc": f"{batch_accuracy_percent:.2f}%",
            "grad_norm": f"{total_grad_norm:.2e}"
        }
        pbar.set_postfix(postfix_dict)
        
        loss_to_accumulate = current_loss_val / config["num_steps"]
        loss_accum.append(loss_to_accumulate)
        
        lr_accum.append(optimizer.param_groups[0]["lr"])

        # === Gradient Logging to File ===
        if batch_idx == 0 and grad_log_file is not None:
            try:
                with open(grad_log_file, 'a') as f_log: # Open in append mode
                    f_log.write(f"\nEpoch {current_epoch}, Batch {batch_idx} - Gradient Norms (Detailed):\n")
                    for name, param in net.named_parameters():
                        if param.grad is not None:
                            grad_norm_val = param.grad.norm().item()
                            f_log.write(f"  {name:30}: {grad_norm_val:.6f}\n")
                        else:
                            f_log.write(f"  {name:30}: No gradient\n")
                    f_log.write("-" * 50 + "\n")
            except IOError as e:
                print(f"Warning: Could not write to gradient log file {grad_log_file}. Error: {e}")

    return loss_accum, lr_accum

def test(config, net, testloader, device="cpu"):
    """
    Calculate accuracy on full test set.
    """

    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

loss_list = []
lr_list = []

GRADIENT_LOG_FILE = "gradient_norms_log.txt"

## Load model instead of training
load_model = False
if load_model:
    net.load_state_dict(torch.load('../models/nmnist_1layer.pth'))
else:
    print(f"=======Training Network=======")
    with open(GRADIENT_LOG_FILE, 'w') as f_log:
        f_log.write("Gradient Norm Log\n")
        f_log.write("====================\n")

    for epoch in range(config['num_epochs']):
        loss, lr  = train(config, net, trainloader, criterion, optimizer, device, scheduler, 
                          current_epoch=epoch, grad_log_file=GRADIENT_LOG_FILE)
        loss_list = loss_list + loss
        lr_list   = lr_list + lr
        # Test
        test_accuracy = test(config, net, testloader, device)
        print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}")

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(loss_list, color='tab:orange')
    ax2.plot(lr_list, color='tab:blue')
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss', color='tab:orange')
    ax2.set_ylabel('Learning Rate', color='tab:blue')
    plt.show()

In [None]:
save_model = True
if save_model:
    dir = "../models"
    if not os.path.exists(dir):
        os.makedirs(dir)
    torch.save(net.state_dict(), f"{dir}/gesture_fenrir_1.pth")

In [None]:
#fenrir.export_weights(net.fc1, 44, 32*24, '../../src/design_sources/data/fc1_gesture.data')
quant_scale = net.fc1.quant_weight().scale
beta        = net.fc1_beta/net.fc1.quant_weight().scale
thr         = fnr.get_threshold(net.fc1, net.lif1)
print(f"Quant Scale: {quant_scale}")
print(f"Beta: {beta}")
print(f"Threshold: {thr}")

In [None]:
#fenrir.export_weights(net.fc1, 44, 32*24, '../../src/design_sources/data/fc1_gesture.data')
quant_scale = net.fc2.quant_weight().scale
beta        = net.fc2_beta/net.fc2.quant_weight().scale
thr         = fnr.get_threshold(net.fc2, net.lif2)
print(f"Quant Scale: {quant_scale}")
print(f"Beta: {beta}")
print(f"Threshold: {thr}")