# snntorch - Tutorial 3

A feedforward spiking neural network.

## Imports

In [10]:
# Spiking Neural Network
import snntorch as snn
from snntorch import utils
from snntorch import spikegen

# Torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Torchvision
from torchvision import datasets, transforms

# Visualization
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

## Leaky Integrate-and-Fire Model

Manual implementation.

In [None]:
def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
    spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0
    mem = beta * mem + w*x - spk*threshold
    return spk, mem

delta_t = torch.tensor(1e-3)
tau = torch.tensor(5e-3)
beta = torch.exp(-delta_t/tau)
print(f"The decay rate is: {beta:.3f}")

num_steps = 200

# initialize inputs/outputs + small step current input
x = torch.cat((torch.zeros(10), torch.ones(190)*0.5), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = []
spk_rec = []

# neuron parameters
w = 0.4
beta = 0.819

# neuron simulation
for step in range(num_steps):
  spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
  mem_rec.append(mem)
  spk_rec.append(spk)

# convert lists to tensors
mem_rec = torch.stack(mem_rec)
spk_rec = torch.stack(spk_rec)

fig, axs = plt.subplots(3, 1)
axs[0].plot(x*w)
axs[1].plot(mem_rec)
axs[2].plot(spk_rec)

## Leaky Integrate-and-Fire Model

snntorch implementation.

In [None]:
lif1 = snn.Leaky(beta=0.8)

# Small step current input
w=0.21
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*w), 0)
mem = torch.zeros(1)
spk = torch.zeros(1)
mem_rec = []
spk_rec = []

# neuron simulation
for step in range(num_steps):
    spk, mem = lif1(cur_in[step], mem)
    mem_rec.append(mem)
    spk_rec.append(spk)

# convert lists to tensors
mem_rec = torch.stack(mem_rec)
spk_rec = torch.stack(spk_rec)

fig, axs = plt.subplots(3, 1)
axs[0].plot(cur_in)
axs[1].plot(mem_rec)
axs[2].plot(spk_rec)

## A Feedforward Spiking Neural Network

In [12]:
# layer parameters
num_inputs = 784
num_hidden = 1000
num_outputs = 10
beta = 0.99

# initialize layers
fc1 = nn.Linear(num_inputs, num_hidden)
lif1 = snn.Leaky(beta=beta)
fc2 = nn.Linear(num_hidden, num_outputs)
lif2 = snn.Leaky(beta=beta)

# Initialize hidden states
mem1 = lif1.init_leaky()
mem2 = lif2.init_leaky()

# record outputs
mem2_rec = []
spk1_rec = []
spk2_rec = []

spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1)

# network simulation
for step in range(num_steps):
    cur1 = fc1(spk_in[step]) # post-synaptic current <-- spk_in x weight
    spk1, mem1 = lif1(cur1, mem1) # mem[t+1] <--post-syn current + decayed membrane
    cur2 = fc2(spk1)
    spk2, mem2 = lif2(cur2, mem2)

    mem2_rec.append(mem2)
    spk1_rec.append(spk1)
    spk2_rec.append(spk2)

# convert lists to tensors
mem2_rec = torch.stack(mem2_rec)
spk1_rec = torch.stack(spk1_rec)
spk2_rec = torch.stack(spk2_rec)

In [None]:
spk_in_plot = spk_in.reshape((num_steps, -1))
spk1_rec_plot = spk1_rec.reshape((num_steps, -1))
spk2_rec_plot = spk2_rec.reshape((num_steps, -1))

fig, axs = plt.subplots(3, 1, figsize=(10, 10), facecolor='w', sharex=True)
splt.raster(spk_in_plot, axs[0], s=1.5, c="black")
splt.raster(spk1_rec_plot, axs[1], s=1.5, c="black")
splt.raster(spk2_rec_plot, axs[2], s=1.5, c="black")

axs[0].set_title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
fig, ax = plt.subplots()
labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8','9']
spk2_rec_anim = spk2_rec.squeeze(1).detach().cpu()

anim = splt.spike_count(spk2_rec_anim, fig, ax, labels=labels, animate=True)
video = HTML(anim.to_html5_video())
plt.close()
video

In [None]:
# plot membrane potential traces
splt.traces(mem2_rec.squeeze(1), spk=spk2_rec.squeeze(1))
fig = plt.gcf()
fig.set_size_inches(8, 6)