# Controller Simulation

In [6]:
# SNN
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# SNN
import tonic
from tonic.transforms import ToFrame
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import snntorch.spikeplot as splt

# Visualization
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Misc
import numpy as np
import numpy.lib.recfunctions as rf

# Core
import os

## Single Layer SNN

In [None]:
# Parameters
num_inputs  = 8*6
num_outputs = 10
beta        = 0.5

# Initialize layers
fc1  = nn.Linear(num_inputs, num_outputs)
lif1 = snn.Leaky(beta=beta, threshold=50, reset_mechanism='zero')

# Initialize FC
torch.nn.init.zeros_(fc1.bias)  # Initialize bias to zero
torch.nn.init.ones_(fc1.weight) # Initialize weights to one

# Initialize hidden states
mem1 = lif1.init_leaky()

# Outputs
mem1_rec = []
spk1_rec = []

num_steps = 100
spk_in = torch.ones((num_steps, num_inputs))
for step in range(num_steps):
    cur1 = fc1(spk_in[step])
    spk1, mem1 = lif1(cur1, mem1)
    mem1_rec.append(mem1)
    spk1_rec.append(spk1)

mem1_rec = torch.stack(mem1_rec)
spk1_rec = torch.stack(spk1_rec)

# Plot
spk_in_plot = spk_in.reshape((num_steps, -1))
spk1_rec_plot = spk1_rec.reshape((num_steps, -1))

fig, axs = plt.subplots(2, 1, figsize=(10, 6), facecolor='w', sharex=True)
splt.raster(spk_in_plot, axs[0], s=1.5, c="black")
splt.raster(spk1_rec_plot, axs[1], s=1.5, c="black")

axs[0].set_title("Input Layer")
axs[0].set_ylabel("Input Spikes")
axs[1].set_xlabel("Time step")
axs[1].set_ylabel("Output Spikes")
plt.show()

## MNIST Preprocessing

In [7]:
dataset_transform = tonic.transforms.Compose([
    ToFrame(sensor_size=(28, 28)),  # Convert event-based data to frame
    tonic.transforms.Resize((8, 6))  # Resize to 8x6
])


# Load MNIST SNN dataset using Tonic
mnist_dataset = tonic.datasets.MNIST(root="./data", transform=dataset_transform, download=True)
data_loader = DataLoader(mnist_dataset, batch_size=1, shuffle=True)

# Define a simple convolutional kernel
conv_kernel = torch.tensor([[[-1, -1, -1],
                             [-1,  8, -1],
                             [-1, -1, -1]]], dtype=torch.float32).unsqueeze(0)  # Shape (1,1,3,3)

# Process a single batch (one event frame)
data_iter = iter(data_loader)
image, label = next(data_iter)

# Ensure image has correct shape
image = image.unsqueeze(0)  # Shape: (1, 1, 8, 6)
conv_result = F.conv2d(image, conv_kernel, padding=1)  # Output shape: (1, 1, 8, 6)

# Remove batch dimension and convert back to numpy for visualization
final_image = conv_result.squeeze(0).squeeze(0).detach().numpy()

print(f"Processed Image Shape: {final_image.shape}")


AttributeError: module 'tonic.transforms' has no attribute 'Resize'