# Quantized SNN for MNIST

## Imports

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Other
import matplotlib.pyplot as plt
import numpy as np
import os

### Network Definition

In [None]:
config = {
    "num_epochs": 1,       # Number of epochs to train for (per trial)
    "batch_size": 128,      # Batch size
    "seed": 0,              # Random seed
    
    # Quantization
    "num_bits": 4,          # Bit resolution
    
    # Network parameters
    "grad_clip": False,     # Whether or not to clip gradients
    "weight_clip": False,   # Whether or not to clip weights
    "batch_norm": True,     # Whether or not to use batch normalization
    "dropout": 0.07,        # Dropout rate
    "beta": 1.00,           # Decay rate parameter (beta)
    "threshold": 30,        # Threshold parameter (theta)
    "lr": 3.0e-3,           # Initial learning rate
    "slope": 5.6,           # Slope value (k)
    
    # Fixed params
    "num_steps": 100,       # Number of timesteps to encode input for
    "correct_rate": 0.8,    # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
    "t_0": 4690,            # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,           # Minimum learning rate
}

In [None]:
device = torch.device("cpu")

class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_bits   = config["num_bits"]
        self.thr        = config["threshold"]
        self.slope      = config["slope"]
        self.beta       = config["beta"]
        self.num_steps  = config["num_steps"]
        
        # Initialize Layers
        self.fc1        = qnn.QuantLinear(9, 3, bias=False, weight_bit_width=self.num_bits)
        self.lif1       = snn.Leaky(self.beta, threshold=self.thr)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        mem = self.lif1.init_leaky()
        
        # Record the final layer
        spk_rec = []
        mem_rec = []
        for step in range(self.num_steps):
            cur = self.fc1(x[step])
            spk, mem = self.lif1(cur, mem)

            spk_rec.append(spk)
            mem_rec.append(mem)
        
        return torch.stack(spk_rec), torch.stack(mem_rec)

net = Net(config).to(device)

In [None]:
with torch.no_grad():
    manual_weights = torch.ones((3, 9), device=net.fc1.weight.device)
    net.fc1.weight.copy_(manual_weights)

In [None]:
net.lif1.threshold/net.fc1.quant_weight().scale

In [None]:
# This seems to work
net.fc1.quant_weight()/net.fc1.quant_weight().scale

In [None]:
timesteps   = 100
batch_size  = 1
channels    = 1
length      = 9

data = torch.zeros((timesteps, batch_size, length))
data[:, :, 0] = 1.0

spk_rec, mem_rec = net(data)
spk_rec = spk_rec.detach().numpy()
mem_rec = mem_rec.detach().numpy()

In [None]:
x = np.arange(0, 100, 1)

fig, ax = plt.subplots(3, 1, figsize=(8, 6))
ax[0].plot(x, spk_rec[:, 0, 0])
ax[1].plot(x, spk_rec[:, 0, 1])
ax[2].plot(x, spk_rec[:, 0, 2])
fig.suptitle("spk_rec")

In [None]:
x = np.arange(0, 100, 1)

fig, ax = plt.subplots(3, 1, figsize=(8, 6))
ax[0].plot(x, mem_rec[:, 0, 0])
ax[1].plot(x, mem_rec[:, 0, 1])
ax[2].plot(x, mem_rec[:, 0, 2])
fig.suptitle("spk_rec")