# Quantized SNN for MNIST

## Imports

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Tonic
from tonic import DiskCachedDataset
from tonic import MemoryCachedDataset

# Other
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

import pyfenrir as fenrir

## 1. Setting up MNIST

In [None]:
batch_size = 128
data_path  = './data/mnist'
num_class  = 10

### 1.1 Download Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Make a sanity check plot
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for i in range(3):
    image, label = mnist_train[i]
    axs[i].imshow(image.squeeze(0).numpy(), cmap="gray")
    axs[i].set_title(f"Label: {label}")

### 1.2 Dataloader and Spike Encoding

In [None]:
# Create dataloader
cached_mnist_train = MemoryCachedDataset(mnist_train)
cached_mnist_test = MemoryCachedDataset(mnist_test)

train_loader = DataLoader(cached_mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(cached_mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)

# --- The rest in this cell is just to show how to encode data into spikes ---

# Fetch a batch of data
data = iter(train_loader)
data_it, targets_it = next(data)

# Encode it into spikes
#   - linear, normalize, clip are IMPORTANT for the spikegen.latency function
#     details here: https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_1.html
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)

# Plot to keep me sane
fig = plt.figure(facecolor="w", figsize=(12, 3))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(100, -1), ax, s=25, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

### 1.3 Network Definition

In [None]:
config = {
    "num_epochs": 2,       # Number of epochs to train for (per trial)
    "batch_size": 128,      # Batch size
    "seed": 0,              # Random seed
    
    # Quantization
    "num_bits": 4,          # Bit resolution
    
    # Network parameters
    "grad_clip": False,     # Whether or not to clip gradients
    "weight_clip": False,   # Whether or not to clip weights
    "batch_norm": True,     # Whether or not to use batch normalization
    "dropout": 0.07,        # Dropout rate
    "beta": 1.0,           # Decay rate parameter (beta)
    "threshold": 10,        # Threshold parameter (theta)
    "lr": 3.0e-3,           # Initial learning rate
    "slope": 5.6,           # Slope value (k)
    
    # Fixed params
    "num_steps": 100,       # Number of timesteps to encode input for
    "correct_rate": 0.8,    # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
    "t_0": 4690,            # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,           # Minimum learning rate
}

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# First train with learnable beta and threshold
# Round the values to integers and train again with those.

class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_bits   = config["num_bits"]
        self.thr        = config["threshold"]
        self.slope      = config["slope"]
        self.num_steps  = config["num_steps"]
        self.batch_norm = config["batch_norm"]
        
        self.beta = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)

        # Initialize Layers
        self.fc1        = qnn.QuantLinear(32*32, 10, bias=False, weight_bit_width=self.num_bits)
        self.lif1       = snn.Leaky(beta=1.0, threshold=1.0, learn_threshold=True, reset_mechanism='zero', reset_delay=False)

    def forward(self, x):        
        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        
        # Record the final layer
        spk_rec = []
        mem_rec = []
        for step in range(self.num_steps):
            cur = self.fc1(x[step].view(x.shape[1], -1))

            spk1, mem1 = self.lif1(cur, mem1)

            mem1 = torch.where(
                mem1 > 0,
                torch.clamp(mem1 - self.beta, min=0.0),
                mem1
            )

            mem1 = torch.where(
                mem1 < 0,
                torch.clamp(mem1 + self.beta, max=0.0),
                mem1
            )

            spk_rec.append(spk1)
            mem_rec.append(mem1)
        
        return torch.stack(spk_rec), torch.stack(mem_rec)

net = Net(config).to(device)

### 1.4 Define Optimizer, LR Scheduler and Loss Function

In [None]:
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=config["lr"],
    betas=config["betas"]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config["t_0"],
    eta_min=config["eta_min"],
    last_epoch=-1
)

criterion = SF.mse_count_loss(
    correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

### 1.5 Training and Evaluation

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None):
    """
    Complete one epoch of training.
    """
    
    net.train()
    loss_accum = []
    lr_accum = []
    for data, labels in trainloader:
        data, labels = data.to(device), labels.to(device)
        # Encode data into spikes
        data = spikegen.latency(data, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
        spk_rec, mem_rec = net(data)
        loss = criterion(spk_rec, labels)
        optimizer.zero_grad()
        loss.backward()

        ## Enable gradient clipping
        if config["grad_clip"]:
            nn.utils.clip_grad_norm_(net.parameters(), 1.0)

        ## Enable weight clipping
        if config["weight_clip"]:
            with torch.no_grad():
                for param in net.parameters():
                    param.clamp_(-1, 1)

        optimizer.step()
        scheduler.step()
        loss_accum.append(loss.item() / config["num_steps"])
        lr_accum.append(optimizer.param_groups[0]["lr"])

    return loss_accum, lr_accum

def test(config, net, testloader, device="cpu"):
    """
    Calculate accuracy on full test set.
    """

    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            images = spikegen.latency(images, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

loss_list = []
lr_list = []

## Load model instead of training
load_model = True
if load_model:
    net.load_state_dict(torch.load('models/mnist.pth'))
else:
    print(f"=======Training Network=======")
    for epoch in range(config['num_epochs']):
        loss, lr  = train(config, net, train_loader, criterion, optimizer, device, scheduler)
        loss_list = loss_list + loss
        lr_list   = lr_list + lr
        # Test
        test_accuracy = test(config, net, test_loader, device)
        print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}")

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(loss_list, color='tab:orange')
    ax2.plot(lr_list, color='tab:blue')
    ax1.set_xlabel('Iterationb')
    ax1.set_ylabel('Loss', color='tab:orange')
    ax2.set_ylabel('Learning Rate', color='tab:blue')
    plt.show()

In [None]:
save_model = False
if save_model:
    dir = "./models"
    if not os.path.exists(dir):
        os.makedirs(dir)
    torch.save(net.state_dict(), f"{dir}/mnist.pth")

### 1.5 Network Results and Test

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(12, 6))
iter_test = iter(test_loader)
data_it, targets_it = next(iter_test)

dataset = test_loader.dataset
num_samples = len(dataset)
ran_idx = torch.randint(0, num_samples, (3,))

for i, idx in enumerate(ran_idx):
    # Get some data
    data_tmp, target_tmp = dataset[idx]

    # Spike encode it
    spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
    spike_data = spike_data.to(device)

    # Forward pass
    spk_rec, mem_rec = net(spike_data)

    # Besværgelse (just summing the spikes)
    pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

    # Plot
    splt.raster(spk_rec[:, 0].view(100, -1), ax[i], s=25, c="black")
    ax[i].set_yticks(np.arange(0, 10, 1))
    ax[i].set_title(f"Prediction: {pred}, Target: {target_tmp}")

plt.subplots_adjust(hspace=0.5)

spike_sums = spk_rec.sum(dim=0).squeeze()  # shape: [10]
spike_avg = spike_sums.sum(dim=0)/10

print(f"Spike counts per neuron: {spike_sums.tolist()}")
print(f"Average spikes per neuron: {spike_avg.tolist()}")

### 1.7 Extract Quantized Weights

In [None]:
fenrir.export_weights(net.fc1, 40, 1024, 'weights.txt')
quant_scale = net.fc1.quant_weight().scale
beta        = net.beta/net.fc1.quant_weight().scale
thr         = fenrir.get_threshold(net.fc1, net.lif1)
print(f"Quant Scale: {quant_scale}")
print(f"Threshold: {thr}")
print(f"Beta: {beta}")

## Extract Spike Encoded Number for Testbench

In [None]:
## Select a datapoint from the dataset
dataset = test_loader.dataset
data_tmp, target_tmp = dataset[6]

## Spike encode it
spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
spike_data = spike_data.to(device)

## Forward pass
spk_rec, mem_rec = net(spike_data)
pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

fig, ax = plt.subplots(2, 1, figsize=(12, 6))
splt.raster(spike_data[:, 0].view(100, -1), ax[0], s=25, c="black")
splt.raster(spk_rec[:, 0].view(100, -1), ax[1], s=25, c="black")
ax[0].set_title(f"Prediction: {pred}, Target: {target_tmp}")

## Detach as we will use these for comparison with the testbench
spk_rec = spk_rec.cpu().detach().numpy()
mem_rec = mem_rec.cpu().detach().numpy()

In [None]:
def bin_to_signed_int(bstr, bits=12):
    """Convert binary string to signed integer."""
    val = int(bstr, 2)
    if val >= 2**(bits - 1):
        val -= 2**bits
    return val

names = [
    't',
    'nrn_addr',
    'val'
]
df_mem = pd.read_csv('../vivado/fenrir/fenrir.sim/sim_1/behav/xsim/mem_rec.csv', names=names)

names = [
    't',
    'nrn_addr',
]
df_spk = pd.read_csv('../vivado/fenrir/fenrir.sim/sim_1/behav/xsim/spk_rec.csv', names=names)
df_spk['nrn_addr'] = df_spk['nrn_addr'].apply(lambda b: int(str(b), 2))

nrn_mem = {
    '0': [],
    '1': [],
    '2': [],
    '3': [],
    '4': [],
    '5': [],
    '6': [],
    '7': [],
    '8': [],
    '9': [],
}

for t in range(0, 100-1):
    subset = df_mem[df_mem['t'] == t][['nrn_addr', 'val']].head(4)
    addr_list = subset['nrn_addr'].tolist()
    val_list = subset['val'].tolist()

    nrn_mem['2'].append(bin_to_signed_int(val_list[0][0:12], bits=12))
    nrn_mem['1'].append(bin_to_signed_int(val_list[0][12:24], bits=12))
    nrn_mem['0'].append(bin_to_signed_int(val_list[0][24:36], bits=12))
    nrn_mem['5'].append(bin_to_signed_int(val_list[1][0:12], bits=12))
    nrn_mem['4'].append(bin_to_signed_int(val_list[1][12:24], bits=12))
    nrn_mem['3'].append(bin_to_signed_int(val_list[1][24:36], bits=12))
    nrn_mem['8'].append(bin_to_signed_int(val_list[2][0:12], bits=12))
    nrn_mem['7'].append(bin_to_signed_int(val_list[2][12:24], bits=12))
    nrn_mem['6'].append(bin_to_signed_int(val_list[2][24:36], bits=12))
    nrn_mem['9'].append(bin_to_signed_int(val_list[3][24:36], bits=12))

x = np.arange(0, 99, 1)
spk_0 = np.zeros_like(x)
spk_1 = np.zeros_like(x)
spk_2 = np.zeros_like(x)
spk_3 = np.zeros_like(x)
spk_4 = np.zeros_like(x)
spk_5 = np.zeros_like(x)
spk_6 = np.zeros_like(x)
spk_7 = np.zeros_like(x)
spk_8 = np.zeros_like(x)
spk_9 = np.zeros_like(x)

for t, nrn in zip(df_spk['t'], df_spk['nrn_addr']):
    if nrn == 0:
        spk_0[t] = 1
    elif nrn == 1:
        spk_1[t] = 1
    elif nrn == 2:
        spk_2[t] = 1
    elif nrn == 3:
        spk_3[t] = 1
    elif nrn == 4:
        spk_4[t] = 1
    elif nrn == 5:
        spk_5[t] = 1
    elif nrn == 6:
        spk_6[t] = 1
    elif nrn == 7:
        spk_7[t] = 1
    elif nrn == 8:
        spk_8[t] = 1
    elif nrn == 9:
        spk_9[t] = 1

spks = [spk_0, spk_1, spk_2, spk_3, spk_4, spk_5, spk_6, spk_7, spk_8, spk_9]

nrn = 0
fig, ax = plt.subplots(2, 1, figsize=(12, 7))

ax[0].plot(x, nrn_mem[f"{nrn}"], color='red', linewidth=2, alpha=0.5, label="TB")

scale = net.fc1.quant_weight().scale.cpu().detach().numpy()
thr = fenrir.get_threshold(net.fc1, net.lif1)*100
ax[0].plot(x, np.full_like(x, thr), linestyle='--', color='black', label="Thr")
ax[0].plot(x, np.full_like(x, 0), linestyle=':', color='black', alpha=0.5)
ax[0].plot(mem_rec[:, 0, nrn]*100/scale, color='blue', linewidth=2, alpha=0.5, label="snntorch")
ax[0].legend()

ax[1].plot(x, spks[nrn], color='red', linewidth=2, alpha=0.5)
ax[1].plot(x, spk_rec[:99, 0, nrn], color='blue', linewidth=2, alpha=0.5)

x_start = 4
x_end   = 100
ax[0].set_xlim(x_start, x_end)
ax[1].set_xlim(x_start, x_end)
ax[0].set_xticks(np.arange(x_start, x_end + 1, 10))
ax[1].set_xticks(np.arange(x_start, x_end + 1, 10))

print(nrn_mem['0'][19])
print(mem_rec[19, 0, 0]/scale*100)
ax[0].set_title("mem_rec")
ax[1].set_title("spk_rec")
fig.suptitle(f"Neuron {nrn}")
fig.tight_layout()

In [None]:
subset = df_mem[df_mem['t'] == 1][['nrn_addr', 'val']].tail(4)
addr_list = subset['nrn_addr'].tolist()
val_list = subset['val'].tolist()

nrn_mem = {
    '0': [],
    '1': [],
    '2': [],
    '3': [],
    '4': [],
    '5': [],
    '6': [],
    '7': [],
    '8': [],
    '9': [],
}

nrn_mem['2'].append(bin_to_signed_int(val_list[0][0:12], bits=12))
nrn_mem['1'].append(bin_to_signed_int(val_list[0][12:24], bits=12))
nrn_mem['0'].append(bin_to_signed_int(val_list[0][24:36], bits=12))
nrn_mem['5'].append(bin_to_signed_int(val_list[1][0:12], bits=12))
nrn_mem['4'].append(bin_to_signed_int(val_list[1][12:24], bits=12))
nrn_mem['3'].append(bin_to_signed_int(val_list[1][24:36], bits=12))
nrn_mem['8'].append(bin_to_signed_int(val_list[2][0:12], bits=12))
nrn_mem['7'].append(bin_to_signed_int(val_list[2][12:24], bits=12))
nrn_mem['6'].append(bin_to_signed_int(val_list[2][24:36], bits=12))
nrn_mem['9'].append(bin_to_signed_int(val_list[3][24:36], bits=12))

for nrn_addr, values in nrn_mem.items():
    print(f"Neuron {nrn_addr}: {values}")

In [None]:
x = np.arange(0, 98, 1)
spk_0 = np.zeros_like(x)

spk_0[97]=0

In [None]:
neuron = 0

fig, ax = plt.subplots(2, 1, figsize=(12, 6))

ax[0].plot(mem_rec[:, 0, neuron])
ax[1].plot(spk_rec[:, 0, neuron])

thr = net.lif1.threshold.cpu().detach().item()
x   = np.arange(0, 100, 1)
ax[0].plot(x, np.full_like(x, thr, dtype=np.float32))
ax[0].plot(x, np.full_like(x, 0), linestyle='--')

spike_t = spk_rec[:, 0, neuron]
for t, s in enumerate(spike_t):
    if s > 0:
        ax[0].axvline(x=t, color='black', linestyle=':', alpha=0.5)

for a in ax:
    a.set_xlim(0, 10)

In [None]:
t = 1

spk_events = []
for nrn in range(0, 10):
    spk_events.append(spk_rec[t, 0, nrn])

spk_events

In [None]:
## Extract input events for specific time index
t_events = []

tstep_data = spike_data[:, 0, :, :]
tstep_data = tstep_data.view(100, -1).cpu().detach().numpy()

for tstep in range(0, 100):
    temp = tstep_data[tstep, :]
    non_zero_indices = np.nonzero(temp)[0]
    t_events.append(non_zero_indices.tolist())

print(t_events[1])

In [None]:
## Find the weights from certain input positions to output neuron
weights= net.fc1.quant_weight()
scale = net.fc1.quant_weight().scale.cpu().detach().numpy()

# in_nrns = [457, 492, 499, 531]
in_nrns = [236, 297, 523, 525, 783]
weights = weights[0][0, in_nrns].cpu().detach().numpy()

q_weights = weights/scale
q_weights_sum = q_weights.sum()

print(q_weights)
print(q_weights_sum)

In [None]:
## Export for FENRIR
fenrir.export_spike_data(spike_data, 'test_data.txt')

In [None]:
exp_data    = spike_data[:, 0, :, :]
exp_data    = exp_data.view(exp_data.shape[0], -1)
tsteps      = spike_data.shape[0]
events      = []
tstep_event_idx = []

for t in range(0, tsteps):
    t_data = exp_data[t, :]
    non_zero_indices = (t_data != 0).nonzero(as_tuple=True)[0]

    #events.append(0b1000000000000)

    for idx in non_zero_indices.tolist():
        events.append(idx)

    events.append(0b1000000000000)

    tstep_event_idx.append(len(events))

binary_events = [format(idx, '010b') for idx in events]

out_file = 'test_data.txt'
with open(out_file, 'w') as f:
    for b in binary_events:
        if not b == '1000000000000':
            f.write("000" + b + '\n')
        else:
            f.write(b + '\n')

In [None]:
## Raster plot of snntorch versus fenrir
dataset = test_loader.dataset
data_tmp, target_tmp = dataset[6]

spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
spike_data = spike_data.to(device)
spk_rec, mem_rec = net(spike_data)
pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

nrn_arr = np.stack([nrn_0, nrn_1, nrn_2, nrn_3, nrn_4, nrn_5, nrn_6, nrn_7, nrn_8, nrn_9], axis=1)
pred_nrn = np.argmax(nrn_arr.sum(axis=0))

fig, ax = plt.subplots(2, 1, figsize=(12, 7))
splt.raster(spk_rec[:, 0].view(100, -1), ax[0], s=25, c="black")
splt.raster(torch.from_numpy(nrn_arr).view(100, -1), ax[1], s=25, c="black")
ax[0].set_title(f"snntorch pred: {pred}, Target: {target_tmp}")
ax[1].set_title(f"fenrir pred: {pred_nrn}, Target: {target_tmp}")