# Controller Simulation

In [None]:
# SNN
import tonic
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import snntorch.spikeplot as splt

# Visualization
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Misc
import numpy as np
import numpy.lib.recfunctions as rf

# Core
import os

In [None]:
# Parameters
num_neurons = 8*6
beta        = 0.5

# Initialize layers
fc1  = nn.Linear(num_neurons, num_neurons)
lif1 = snn.Leaky(beta=beta, threshold=50, reset_mechanism='zero')

# Initialize FC
torch.nn.init.zeros_(fc1.bias)  # Initialize bias to zero
torch.nn.init.ones_(fc1.weight) # Initialize weights to one

# Initialize hidden states
mem1 = lif1.init_leaky()

# Outputs
mem1_rec = []
spk1_rec = []

num_steps = 100
spk_in = torch.ones((num_steps, num_neurons))
for step in range(num_steps):
    cur1 = fc1(spk_in[step])
    spk1, mem1 = lif1(cur1, mem1)
    mem1_rec.append(mem1)
    spk1_rec.append(spk1)

mem1_rec = torch.stack(mem1_rec)
spk1_rec = torch.stack(spk1_rec)

In [None]:
spk_in_plot = spk_in.reshape((num_steps, -1))
spk1_rec_plot = spk1_rec.reshape((num_steps, -1))

fig, axs = plt.subplots(2, 1, figsize=(10, 6), facecolor='w', sharex=True)
splt.raster(spk_in_plot, axs[0], s=1.5, c="black")
splt.raster(spk1_rec_plot, axs[1], s=1.5, c="black")

axs[0].set_title("Input Layer")
axs[0].set_ylabel("Input Spikes")
axs[1].set_xlabel("Time step")
axs[1].set_ylabel("Output Spikes")
plt.show()

In [None]:
print(fc1.weight)

In [None]:
print(fc1.bias)