# Quantized SNN for MNIST

## Imports

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch.functional import quant
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Quantization
import brevitas.nn as qnn

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Other
import matplotlib.pyplot as plt
import numpy as np
import os

## 1. Setting up MNIST

In [None]:
batch_size = 128
data_path  = './data/mnist'
num_class  = 10

### 1.1 Download Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Make a sanity check plot
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for i in range(3):
    image, label = mnist_train[i]
    axs[i].imshow(image.squeeze(0).numpy(), cmap="gray")
    axs[i].set_title(f"Label: {label}")

### 1.2 Dataloader and Spike Encoding

In [None]:
# Create dataloader
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

# --- The rest in this cell is just to show how to encode data into spikes ---

# Fetch a batch of data
data = iter(train_loader)
data_it, targets_it = next(data)

# Encode it into spikes
#   - linear, normalize, clip are IMPORTANT for the spikegen.latency function
#     details here: https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_1.html
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)

# Plot to keep me sane
fig = plt.figure(facecolor="w", figsize=(12, 3))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(100, -1), ax, s=25, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

### 1.3 Network Definition

In [None]:
config = {
    "num_epochs": 1,       # Number of epochs to train for (per trial)
    "batch_size": 128,      # Batch size
    "seed": 0,              # Random seed
    
    # Quantization
    "num_bits": 4,          # Bit resolution
    
    # Network parameters
    "grad_clip": False,     # Whether or not to clip gradients
    "weight_clip": False,   # Whether or not to clip weights
    "batch_norm": True,     # Whether or not to use batch normalization
    "dropout": 0.07,        # Dropout rate
    "beta": 1.00,           # Decay rate parameter (beta)
    "threshold": 30,        # Threshold parameter (theta)
    "lr": 3.0e-3,           # Initial learning rate
    "slope": 5.6,           # Slope value (k)
    
    # Fixed params
    "num_steps": 100,       # Number of timesteps to encode input for
    "correct_rate": 0.8,    # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
    "t_0": 4690,            # Initial frequency of the cosine annealing scheduler
    "eta_min": 0,           # Minimum learning rate
}

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_bits   = config["num_bits"]
        self.thr        = config["threshold"]
        self.slope      = config["slope"]
        self.beta       = config["beta"]
        self.num_steps  = config["num_steps"]
        self.batch_norm = config["batch_norm"]
        self.p1         = config["dropout"]
        self.spike_grad = surrogate.fast_sigmoid(self.slope)
        
        # Initialize Layers
        self.fc1        = qnn.QuantLinear(32*32, 10, bias=False, weight_bit_width=self.num_bits)
        self.lif1       = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.dropout    = nn.Dropout(self.p1)

    def forward(self, x):
        # Flatten the input
        # x = x.view(x.size(0), -1)
        
        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        
        # Record the final layer
        spk_rec = []
        mem_rec = []
        for step in range(self.num_steps):
            cur = self.fc1(x[step].view(x.shape[1], -1))
            mem1, spk1 = self.lif1(cur, mem1)
            
            if self.batch_norm:
                spk1 = self.dropout(spk1)

            spk_rec.append(spk1)
            mem_rec.append(mem1)
        
        return torch.stack(spk_rec), torch.stack(mem_rec)

net = Net(config).to(device)

### 1.4 Define Optimizer, LR Scheduler and Loss Function

In [None]:
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=config["lr"],
    betas=config["betas"]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config["t_0"],
    eta_min=config["eta_min"],
    last_epoch=-1
)

criterion = SF.mse_count_loss(
    correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

### 1.5 Training and Evaluation

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None):
    """
    Complete one epoch of training.
    """
    
    net.train()
    loss_accum = []
    lr_accum = []
    for data, labels in trainloader:
        data, labels = data.to(device), labels.to(device)
        # Encode data into spikes
        data = spikegen.latency(data, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
        spk_rec, mem_rec = net(data)
        loss = criterion(spk_rec, labels)
        optimizer.zero_grad()
        loss.backward()

        ## Enable gradient clipping
        if config["grad_clip"]:
            nn.utils.clip_grad_norm_(net.parameters(), 1.0)

        ## Enable weight clipping
        if config["weight_clip"]:
            with torch.no_grad():
                for param in net.parameters():
                    param.clamp_(-1, 1)

        optimizer.step()
        scheduler.step()
        loss_accum.append(loss.item() / config["num_steps"])
        lr_accum.append(optimizer.param_groups[0]["lr"])

    return loss_accum, lr_accum

def test(config, net, testloader, device="cpu"):
    """
    Calculate accuracy on full test set.
    """

    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            images = spikegen.latency(images, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

loss_list = []
lr_list = []

print(f"=======Training Network=======")
for epoch in range(config['num_epochs']):
    loss, lr  = train(config, net, train_loader, criterion, optimizer, device, scheduler)
    loss_list = loss_list + loss
    lr_list   = lr_list + lr
    # Test
    test_accuracy = test(config, net, test_loader, device)
    print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(loss_list, color='tab:orange')
ax2.plot(lr_list, color='tab:blue')
ax1.set_xlabel('Iterationb')
ax1.set_ylabel('Loss', color='tab:orange')
ax2.set_ylabel('Learning Rate', color='tab:blue')
plt.show()

In [None]:
# This seems to work
net.fc1.quant_weight()/net.fc1.quant_weight().scale

In [None]:
# ???
net.lif1.threshold/net.fc1.quant_weight().scale

### 1.5 Network Test

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(12, 6))
iter_test = iter(test_loader)
data_it, targets_it = next(iter_test)

dataset = test_loader.dataset
num_samples = len(dataset)
ran_idx = torch.randint(0, num_samples, (3,))

for i, idx in enumerate(ran_idx):
    # Get some data
    data_tmp, target_tmp = dataset[idx]

    # Spike encode it
    spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
    spike_data = spike_data.to(device)

    # Forward pass
    spk_rec, mem_rec = net(spike_data)

    # Besværgelse (just summing the spikes)
    pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

    # Plot
    splt.raster(spk_rec[:, 0].view(100, -1), ax[i], s=25, c="black")
    ax[i].set_yticks(np.arange(0, 10, 1))
    ax[i].set_title(f"Prediction: {pred}, Target: {target_tmp}")

plt.subplots_adjust(hspace=0.5)

### 1.6 Save Model

In [None]:
dir = "./models"
if not os.path.exists(dir):
    os.makedirs(dir)
torch.save(net.state_dict(), f"{dir}/mnist.pth")

### 1.7 Extract Quantized Weights

In [None]:
weights = net.fc1.quant_weight()/net.fc1.quant_weight().scale
weights = weights.cpu().detach().numpy()

qsyn_bin = []
for nrn in range(0, 10):
    syn_weights = weights[nrn]
    for syn in syn_weights:
        qsyn = int(round(syn))
        qsyn = max(-8, min(7, qsyn))
        qsyn_bin.append(format(qsyn & 0b1111, '04b'))

str = ""
lines = []
for syn in qsyn_bin:
    str += syn
    if len(str) == 32:
        lines.append(str)
        str = ""

with open("weights.txt", "w") as f:
    for line in lines:
        f.write(line + "\n")

## Extract Spike Encoded Number

In [None]:
dataset = test_loader.dataset
data_tmp, target_tmp = dataset[idx]

spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
spike_data = spike_data.to(device)
spk_rec, mem_rec = net(spike_data)
pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

fig, ax = plt.subplots(2, 1, figsize=(12, 6))
splt.raster(spike_data[:, 0].view(100, -1), ax[0], s=25, c="black")
splt.raster(spk_rec[:, 0].view(100, -1), ax[1], s=25, c="black")
ax[0].set_title(f"Prediction: {pred}, Target: {target_tmp}")

In [None]:
plt.plot(spk_rec[:, 0, 6].cpu().detach().numpy())

In [None]:
mem_rec.shape

non_zero_indices = (mem_rec != 0).nonzero(as_tuple=True)
non_zero_values = mem_rec[non_zero_indices]

print("Non-zero indices:", list(zip(*[i.tolist() for i in non_zero_indices])))
print("Non-zero values:", non_zero_values)

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(12, 6))
iter_test = iter(test_loader)
data_it, targets_it = next(iter_test)

for i, idx in enumerate(ran_idx):
    # Get some data
    data_tmp, target_tmp = dataset[idx]

    # Spike encode it
    spike_data = spikegen.latency(data_tmp, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
    spike_data = spike_data.to(device)

    # Forward pass
    spk_rec, _ = net(spike_data)

    # Besværgelse (just summing the spikes)
    pred = torch.argmax(spk_rec.sum(dim=0).squeeze()).item()

    # Plot
    splt.raster(spk_rec[:, 0].view(100, -1), ax[i], s=25, c="black")
    ax[i].set_yticks(np.arange(0, 10, 1))
    ax[i].set_title(f"Prediction: {pred}, Target: {target_tmp}")

plt.subplots_adjust(hspace=0.5)

In [None]:
exp_data = spike_data[:, 0, 0, :, :]

lines = []
count_ones = 0
for i in range(100):
    time_step = exp_data[i, :, :]
    count_ones += torch.sum(time_step == 1).item()
    for j in range(32):
        string_0 = ''
        for k in range(16):
            if time_step[j, k] == 0:
                string_0 += '00'
            else:
                string_0 += '01'
        string_1 = ''
        for k in range(16, 32):
            if time_step[j, k] == 0:
                string_1 += '00'
            else:
                string_1 += '01'
        lines.append(string_0 + '\n')
        lines.append(string_1 + '\n')

    # Save to file
    with open('C:/home/university/8-semester/fenrir/src/design_sources/data/spike_data.txt', 'w') as f:
        f.writelines(lines)

In [None]:
net(spike_data)