In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Tonic
import tonic
from tonic import DiskCachedDataset
from tonic import MemoryCachedDataset
from tonic.transforms import Compose, ToFrame, Downsample

# Other
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
import sys
import pandas as pd
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import pyfenrir as fenrir

In [None]:
class NetUtils():
    @staticmethod
    def beta_clamp(mem, beta):
        """
        Soft-clamping of beta to allow gradients.
        """
        beta_abs = torch.abs(beta)
        # Positive side: approximate clamp(mem - beta_abs, min=0)
        pos_mask = (mem > 0)
        pos_val = F.relu(mem - beta_abs)  # ReLU is differentiable everywhere except 0 (and better than clamp)

        # Negative side: approximate clamp(mem + beta_abs, max=0)
        neg_mask = (mem < 0)
        neg_val = -F.relu(-(mem + beta_abs))  # negative ReLU for negative side

        mem_new = torch.where(pos_mask, pos_val, mem)
        mem_new = torch.where(neg_mask, neg_val, mem_new)

        return mem_new

    @staticmethod
    def mem_clamp(mem, scale, multiplier, bits=12):
        max_val = (2**(bits - 1)) - 1
        max_val = max_val * scale / multiplier
        min_val = -(2**(bits - 1)) - 1
        min_val = min_val * scale / multiplier
        mem = torch.clamp(mem, min=min_val, max=max_val)
        return mem

In [None]:
def plot_mem_spk(mem, spk, thr_line=False, vline=False, title=False, ylim_max2=1.25):
    # Generate two vertically stacked subplots
    fig, ax = plt.subplots(2, figsize=(8, 4), sharex=True, 
                            gridspec_kw={'height_ratios': [1, 0.4]})

    # Plot membrane potential (top subplot)
    ax[0].plot(mem)
    ax[0].set_ylim([0, ylim_max2])
    ax[0].set_ylabel("Membrane Potential ($U_{mem}$)")
    ax[0].set_yticks([])  # This line removes the y-axis ticks
    if title:
        ax[0].set_title(title)
    if thr_line:
        ax[0].axhline(y=thr_line, alpha=0.25, linestyle="dashed", c="black", linewidth=2)
    
    # Plot output spikes (bottom subplot)
    splt.raster(spk, ax[1], s=400, c="black", marker="|")
    plt.ylabel("Output Spikes")
    plt.yticks([])
    if vline:
        ax[1].axvline(x=vline, ymin=0, ymax=1, alpha=0.15, c="black", linewidth=2, zorder=0, clip_on=False)

    plt.xlabel("Time step")
    plt.xlim([0, len(mem)]) # Use the length of the membrane potential for x-axis limit
    fig.tight_layout()
    fig.savefig("/mnt/c/home/temp/lif.pdf")
    plt.show()

In [None]:
import matplotlib
matplotlib.rcParams['axes.linewidth'] = 3

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica"
})

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

x_heaviside = np.linspace(-10, 10, 1000)
y_heaviside = np.where(x_heaviside < 0, 0, 1)

x_dirac = np.linspace(-10, 10, 1000)
y_dirac = np.zeros_like(x_dirac)
zero_index = np.argmin(np.abs(x_dirac))
y_dirac[zero_index] = 1

ax[0].plot(x_heaviside, y_heaviside, linewidth=4, label='Heaviside')
ax[0].set_ylabel('$S$', loc='top', rotation=0, fontsize=45)
ax[0].yaxis.set_label_coords(0.1, 0.75)

ax[1].plot(x_dirac, y_dirac, linewidth=4, color='tab:orange', label='Dirac')
ax[1].arrow(0, 0, 0, 1, head_width=1.0, head_length=0.15, fc='tab:orange', ec='tab:orange', width = 0.1)
ax[1].set_ylabel('$\\frac{\\partial S}{\\partial U}$', loc='top', rotation=0, fontsize=45)
ax[1].yaxis.set_label_coords(0.14, 0.7)

for a in ax:
    a.set_yticks([])
    a.set_xticks([0])
    a.set_xticklabels(['$\\theta$'], fontsize=45)
    a.set_ylim(-0.1, 1.2)
    a.set_xlim(-10, 10)

    a.spines['right'].set_color('none')
    a.spines['top'].set_color('none')

    a.plot(1, -0.1, ">k", transform=a.get_yaxis_transform(), clip_on=False, markersize=15)
    a.plot(-10, 1, "^k", transform=a.get_xaxis_transform(), clip_on=False, markersize=15)

    a.set_xlabel('$U$', loc='right', fontsize=45)
    a.xaxis.set_label_coords(0.97, -0.04)
    a.legend(loc='lower right', fontsize=18, framealpha=1.0)

fig.tight_layout()
fig.savefig("/mnt/c/home/temp/spike_diff.pdf")

In [None]:
def leaky_integrate_and_fire(mem, cur=0, threshold=1, time_step=1e-3, R=5.1, C=5e-3):
  tau_mem = R*C
  spk = (mem > threshold)
  mem = mem + (time_step/tau_mem)*(-mem + cur*R) - spk*threshold  # every time spk=1, subtract the threhsold
  return mem, spk

In [None]:
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0)
mem = torch.zeros(1)
mem_rec = []
spk_rec = []

num_steps = 200

# neuron simulation
for step in range(num_steps):
  mem, spk = leaky_integrate_and_fire(mem, cur_in[step])
  mem_rec.append(mem)
  spk_rec.append(spk)

# convert lists to tensors
mem_rec = torch.stack(mem_rec)
spk_rec = torch.stack(spk_rec)

plot_mem_spk(mem_rec, spk_rec, thr_line=1, title="Leaky Integrate-and-Fire")

In [None]:
num_steps = 100
threshold = 2.0
beta = torch.tensor(0.01)

lif     = snn.Leaky(beta=1.0, threshold=threshold, learn_threshold=False, reset_mechanism='zero', reset_delay=False)
mem     = lif.init_leaky()

spk_in = torch.zeros(num_steps)

seed        = 0
num_spikes  = 20

torch.manual_seed(seed)
all_indices = torch.randperm(num_steps)
spike_indices = all_indices[:num_spikes]
spk_in[spike_indices] = 1.0

mem_rec = []
spk_rec = []

for step in range(num_steps):
    cur_in = spk_in[step]
    spk, mem = lif(cur_in, mem)
    mem = NetUtils.beta_clamp(mem, beta)

    mem_rec.append(mem)
    spk_rec.append(spk)

spk_out = torch.stack(spk_rec)
mem_out = torch.stack(mem_rec)

fig, ax = plt.subplots(3, figsize=(8, 6), sharex=True, gridspec_kw={'height_ratios': [0.4, 1, 0.4]})

splt.raster(spk_in, ax[0], s=400, c="black", marker="|")
ax[0].set_ylabel("Input Spikes")
ax[0].set_yticks([])

ax[1].plot(mem_out)
ax[1].set_ylim([0, threshold + 0.25])
ax[1].set_ylabel("Membrane Potential ($U_{mem}$)")
ax[1].set_yticks([])
ax[1].axhline(y=threshold, alpha=0.25, linestyle="dashed", c="black", linewidth=2)

splt.raster(spk_out, ax[2], s=400, c="black", marker="|")
ax[2].set_ylabel("Output Spikes")
ax[2].set_yticks([])

ax[2].set_xlabel("Time step")
ax[2].set_xlim([0, len(mem_out)])

fig.tight_layout()
fig.savefig("/mnt/c/home/temp/simple_lif.pdf")
plt.show()

In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
frame_length_us = 10e3*50
target_size = (128, 128)
n_timesteps = 200

def pad_time_dimension(frames, fixed_time_steps=100):
    """
    Pad or truncate the time dimension of frames to a fixed number of time steps.
    Input: frames [time, channels, height, width] (numpy or tensor)
    Output: frames [fixed_time_steps, channels, height, width] (tensor)
    """
    # Convert to tensor if input is numpy array
    if isinstance(frames, np.ndarray):
        frames = torch.tensor(frames, dtype=torch.float)
    current_time_steps = frames.shape[0]
    #print(f"Current time steps: {current_time_steps}, Fixed time steps: {fixed_time_steps}")
    if current_time_steps > fixed_time_steps:
        return frames[:fixed_time_steps]
    elif current_time_steps < fixed_time_steps:
        return torch.nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, fixed_time_steps - current_time_steps))
    return frames

transform = Compose([
    Downsample(sensor_size=sensor_size, target_size=target_size),
    ToFrame(sensor_size=(target_size[0], target_size[1], sensor_size[2]), time_window=frame_length_us),
    transforms.Lambda(lambda x: pad_time_dimension(x, fixed_time_steps=n_timesteps)),   # Pad/truncate time dimension
    # transforms.Lambda(lambda x: torch.clamp(torch.tensor(x), 0, 1).type(torch.float)),  # Clamp spikes accumulted over time to (0,1)
    transforms.Lambda(lambda x: x[:, :, :, :]  ),                                       # Select only ON channel
])

# Load the dataset
trainset = tonic.datasets.DVSGesture(save_to='../data', train=True, transform=transform)
testset = tonic.datasets.DVSGesture(save_to='../data', train=False, transform=transform)

In [None]:
from collections import Counter

# Get labels from the training set
train_labels = [sample[1] for sample in trainset]

# Get labels from the testing set
test_labels = [sample[1] for sample in testset]

# Combine the labels from both sets
all_labels = train_labels + test_labels

# Count the occurrences of each class
class_counts = Counter(all_labels)

class_names = [
    "hand_clapping",
    "right_hand_wave",
    "left_hand_wave",
    "right_arm_clockwise",
    "right_arm_counter_clockwise",
    "left_arm_clockwise",
    "left_arm_counter_clockwise",
    "arm_roll",
    "air_drums",
    "air_guitar",
    "other_gestures",
]

# Print the total amount of datapoints for each class
print("Total amount of datapoints for each class:")
for class_label, count in sorted(class_counts.items()):
    class_name = class_names[class_label]
    print(f"Class {class_label} ({class_name}): {count} datapoints")

# Also, let's get the total number of datapoints
total_datapoints = len(all_labels)
print(f"\nTotal number of datapoints: {total_datapoints}")

In [None]:
from scipy.ndimage import convolve
from skimage.measure import block_reduce

samples     = [1, 3, 6, 10]
timesteps   = [3, 0, 0, 0]

frames = []
labels = []

for sample, timestep in zip(samples, timesteps):
    data, label = trainset[sample]
    frame = data[timestep, 0, :, :].numpy()
    frames.append(frame)
    labels.append(label)

fig, ax = plt.subplots(2, 2, figsize=(8, 8))

for i, a in enumerate(ax.flatten()):
    a.imshow(frames[i], cmap='jet')
    a.set_title(f'{trainset.classes[labels[i]]}')
    a.set_xticks([])
    a.set_yticks([])

fig.tight_layout()
fig.savefig(f"/mnt/c/home/temp/dvsgesture_example.pdf")
plt.show()

In [None]:
from scipy.ndimage import convolve
from skimage.measure import block_reduce

sample = 4
t1 = 10
t2 = 20
t3 = 30

data, label = trainset[sample]
frame1 = data[t1, 0, :, :].numpy()

noise_level = 0.05
noisy_frame = frame1.copy()
noise_mask = np.random.rand(*frame1.shape) < noise_level
noisy_frame[noise_mask & (frame1 == 0)] = 1

subsample_factor = 2
subsampled_frame = block_reduce(frame1,
                                block_size=(subsample_factor, subsample_factor),
                                func=np.max)

timebin_frame = np.clip(subsampled_frame, 0, 1)

fig, ax = plt.subplots(1, 4, figsize=(16, 4))

ax[0].imshow(frame1, cmap='gray_r')
ax[0].set_title('Original Frame')

ax[1].imshow(noisy_frame, cmap='gray_r')
ax[1].set_title(f'With Added Noise ({int(noise_level*100)}%)')

ax[2].imshow(subsampled_frame, cmap='gray_r')
ax[2].set_title(f'Subsampled by {subsample_factor}x (32x32)')

ax[3].imshow(timebin_frame, cmap='gray_r')
ax[3].set_title(f'Timebinned')

for ax in ax:
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
plt.show()

In [None]:
import matplotlib as mpl

frames = [
    noisy_frame,
    frame1,
    subsampled_frame,
    timebin_frame
]

save_names = [
    'frame_raw',
    'frame_denoised',
    'frame_subsampled',
    'frame_timebinned'
]

mpl.rcParams['axes.linewidth'] = 3

for f, name in zip(frames, save_names):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(f, cmap='gray_r')
    ax.set_xticks([])
    ax.set_yticks([])
    fig.tight_layout()
    fig.savefig(f"/mnt/c/home/temp/{name}.png")

In [None]:
sample = 4
t1 = 10
t2 = 20
t3 = 30

data, label = trainset[sample]
frame1 = data[t1, 0, :, :].numpy()
frame2 = data[t2, 0, :, :].numpy()
frame3 = data[t3, 0, :, :].numpy()

# --- Plotting ---
fig, ax = plt.subplots(1, 3, figsize=(8, 3))
ax[0].imshow(frame1, cmap='gray_r')
ax[1].imshow(frame2, cmap='gray_r')
ax[2].imshow(frame3, cmap='gray_r')

for i, a in enumerate(ax):
    a.set_title(f"Timestep: {10*i}")
    a.set_xticks([])
    a.set_yticks([])

fig.suptitle(f'DVS Gesture, Label: {trainset.classes[label]}')
fig.tight_layout()

fig.savefig("/mnt/c/home/temp/dvs_wave.pdf")

In [None]:
try:
    loss_8b = pd.read_csv('../test_data/8b_fixed_loss_lr_data.csv')['loss']
    loss_4b = pd.read_csv('../test_data/4b_fixed_loss_lr_data.csv')['loss']
    loss_2b = pd.read_csv('../test_data/2b_fixed_loss_lr_data.csv')['loss']
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Please make sure the file paths are correct and the CSV files exist.")
    exit()

# --- 2. Define the Rolling Window Boundaries ---
window_size = 50

# A helper function to avoid repeating code
def get_rolling_stats(loss_series, window):
    mean = loss_series.rolling(window=window).mean()
    min_val = loss_series.rolling(window=window).min()
    max_val = loss_series.rolling(window=window).max()
    return mean, min_val, max_val

loss_8b_mean, loss_8b_min, loss_8b_max = get_rolling_stats(loss_8b, window_size)
loss_4b_mean, loss_4b_min, loss_4b_max = get_rolling_stats(loss_4b, window_size)
loss_2b_mean, loss_2b_min, loss_2b_max = get_rolling_stats(loss_2b, window_size)


fig, axes = plt.subplots(1, 3, figsize=(11, 4), sharey=True)

# Data and titles for looping
data_to_plot = [
    (loss_8b_mean, loss_8b_min, loss_8b_max, '8-bit Quantization (256 kB)'),
    (loss_4b_mean, loss_4b_min, loss_4b_max, '4-bit Quantization (128 kB)'),
    (loss_2b_mean, loss_2b_min, loss_2b_max, '2-bit Quantization (64 kB)')
]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

for i, ax in enumerate(axes):
    mean, min_val, max_val, title = data_to_plot[i]
    
    ax.fill_between(mean.index, min_val, max_val, color=colors[i], alpha=0.25)
    ax.plot(mean, color=colors[i], lw=2)
    
    ax.set_title(title, fontsize=12)
    
    ax.xaxis.set_visible(False)

    ax.grid(True, linestyle='--', alpha=0.6)
    ax.set_xlim(left=window_size)

axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_ylim(0, 0.06)

fig.tight_layout()
plt.show()

fig.savefig("/mnt/c/home/temp/quant_comp.pdf")

In [None]:
try:
    loss_8b = pd.read_csv('../test_data/8b_loss_lr_data.csv')['loss']
    loss_4b = pd.read_csv('../test_data/4b_loss_lr_data.csv')['loss']
    loss_2b = pd.read_csv('../test_data/2b_loss_lr_data.csv')['loss']
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Please make sure the file paths are correct and the CSV files exist.")
    exit()

# --- 2. Define the Rolling Window Boundaries ---
window_size = 50

# A helper function to avoid repeating code
def get_rolling_stats(loss_series, window):
    mean = loss_series.rolling(window=window).mean()
    min_val = loss_series.rolling(window=window).min()
    max_val = loss_series.rolling(window=window).max()
    return mean, min_val, max_val

loss_8b_mean, loss_8b_min, loss_8b_max = get_rolling_stats(loss_8b, window_size)
loss_4b_mean, loss_4b_min, loss_4b_max = get_rolling_stats(loss_4b, window_size)
loss_2b_mean, loss_2b_min, loss_2b_max = get_rolling_stats(loss_2b, window_size)


fig, axes = plt.subplots(1, 3, figsize=(11, 4), sharey=True)

# Data and titles for looping
data_to_plot = [
    (loss_8b_mean, loss_8b_min, loss_8b_max, '8-bit Quantization (128 kB)'),
    (loss_4b_mean, loss_4b_min, loss_4b_max, '4-bit Quantization (128 kB)'),
    (loss_2b_mean, loss_2b_min, loss_2b_max, '2-bit Quantization (128 kB)')
]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

for i, ax in enumerate(axes):
    mean, min_val, max_val, title = data_to_plot[i]
    
    ax.fill_between(mean.index, min_val, max_val, color=colors[i], alpha=0.25)
    ax.plot(mean, color=colors[i], lw=2)
    
    ax.set_title(title, fontsize=12)
    
    ax.xaxis.set_visible(False)

    ax.grid(True, linestyle='--', alpha=0.6)
    ax.set_xlim(left=window_size)

axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_ylim(0, 0.06)

fig.tight_layout()
plt.show()

fig.savefig("/mnt/c/home/temp/quant_comp_same_mem.pdf")