# snntorch - Tutorial 1

## Imports

In [None]:
# Spiking Neural Network
import snntorch as snn
from snntorch import utils
from snntorch import spikegen

# Torch
import torch
from torch.utils.data import DataLoader

# Torchvision
from torchvision import datasets, transforms

# Visualization
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

In [None]:
# Training Parameters
batch_size=128
data_path='/tmp/data/mnist'
num_classes = 10  # MNIST has 10 output classes

# Torch Variables
dtype = torch.float

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
subset = 10
mnist_train = utils.data_subset(mnist_train, subset)
print(f"The size of mnist_train is {len(mnist_train)}")

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

In [None]:
# Create raw vectors
raw_vector_10 = torch.ones(10)*0.5
raw_vector_100 = torch.ones(100)*0.5

# pass each sample through a Bernoulli trial
rate_coded_vector_10 = torch.bernoulli(raw_vector_10)
rate_coded_vector_100 = torch.bernoulli(raw_vector_100)

print(f"The output is spiking {rate_coded_vector_10.sum()*100/len(rate_coded_vector_10):.2f}% of the time.")
print(f"The output is spiking {rate_coded_vector_100.sum()*100/len(rate_coded_vector_100):.2f}% of the time.")

fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].bar(range(10), rate_coded_vector_10)
ax[0].set_title("Rate coded vector with 10 elements")
ax[1].bar(range(100), rate_coded_vector_100)
ax[1].set_title("Rate coded vector with 100 elements")

### spikegen

In [None]:
num_steps = 100

# Iterate through minibatches
data = iter(train_loader)
data_it, targets_it = next(data)

# Spiking Data
spike_data = spikegen.rate(data_it, num_steps=num_steps)
print(spike_data.shape)

In [None]:
# Take sample and animate
spike_data_sample = spike_data[:, 0, 0]
print(spike_data_sample.size())

fig, ax = plt.subplots()
print(f"The corresponding target is: {targets_it[0]}")
anim = splt.animator(spike_data_sample, fig, ax)
video = HTML(anim.to_html5_video())
plt.close()
video

In [None]:
# We can reduce the spimking frequency by reducing the gain
spike_data = spikegen.rate(data_it, num_steps=num_steps, gain=0.25)

spike_data_sample2 = spike_data[:, 0, 0]
fig, ax = plt.subplots()
print(f"The corresponding target is: {targets_it[0]}")
anim = splt.animator(spike_data_sample2, fig, ax)
video = HTML(anim.to_html5_video())
plt.close()
video

In [None]:
# We can average the spikes out over time
plt.figure(facecolor="w")
plt.subplot(1,2,1)
plt.imshow(spike_data_sample.mean(axis=0).reshape((28,-1)).cpu(), cmap='binary')
plt.axis('off')
plt.title('Gain = 1')

plt.subplot(1,2,2)
plt.imshow(spike_data_sample2.mean(axis=0).reshape((28,-1)).cpu(), cmap='binary')
plt.axis('off')
plt.title('Gain = 0.25')

plt.show()

### Raster Plots

In [None]:
# Reshape
spike_data_sample2 = spike_data_sample2.reshape((num_steps, -1))

# raster plot
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data_sample2, ax, s=1.5, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
idx = 210  # index into 210th neuron

fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

splt.raster(spike_data_sample.reshape(num_steps, -1)[:, idx].unsqueeze(1), ax, s=100, c="black", marker="|")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.show()

## Latency Coding

In [None]:
def convert_to_time(data, tau=5, threshold=0.01):
    spike_time = tau * torch.log(data / (data - threshold))
    return spike_time

raw_input = torch.arange(0, 5, 0.05) # tensor from 0 to 5
spike_times = convert_to_time(raw_input)

plt.plot(raw_input, spike_times)
plt.xlabel('Input Value')
plt.ylabel('Spike Time (s)')
plt.show()

In [None]:
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01)

# Raster Plot
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
# Due to the lack of midtone/grayscale features there is significant clustering.
# Tau can be icnreased or the spike times can be linearized.
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
# All the firing occurs in the first 5 time steps, whereas the simulation is 100 time steps.
# We can increase Tau or normalize to the full span of num_steps
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01,
                              normalize=True, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
# All the neurons firing at the end is the black background of MNIST.
# It contains no useful information and can be removed with clipping.
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01,
                              clip=True, normalize=True, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

In [None]:
spike_data_sample = spike_data[:, 0, 0]
fig, ax = plt.subplots()
print(f"The corresponding target is: {targets_it[0]}")
anim = splt.animator(spike_data_sample, fig, ax)
video = HTML(anim.to_html5_video())
plt.close()
video

## Delta Modulation

In [None]:
# Create a tensor with some fake time-series data
data = torch.Tensor([0, 1, 0, 2, 8, -20, 20, -5, 0, 1, 0])

# Plot the tensor
plt.plot(data)

plt.title("Some fake time-series data")
plt.xlabel("Time step")
plt.ylabel("Voltage (mV)")
plt.show()

In [None]:
# Convert data
spike_data = spikegen.delta(data, threshold=4)

# Create fig, ax
fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

# Raster plot of delta converted data
splt.raster(spike_data, ax, c="black")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.xlim(0, len(data))
plt.show()

In [None]:
# Convert data
spike_data = spikegen.delta(data, threshold=4, off_spike=True)

# Create fig, ax
fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

# Raster plot of delta converted data
splt.raster(spike_data, ax, c="black")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.xlim(0, len(data))
plt.show()