In [None]:
# Imports 
import os
# Important: insert the path to synaptic_plasticity/src
os.chdir('/Users/path_to_synaptic_plasticity/synaptic_plasticity/src')
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random
import spikingjelly
import utils
import network
from spikingjelly.activation_based import functional

## Loading a random MNIST image

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
mnist = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(mnist, batch_size=1, shuffle=True)

In [None]:
# Select a random sample
sample_idx = random.randint(0, len(mnist) - 1)
sample, label = mnist[sample_idx]

# Visualize the sample
plt.imshow(sample.squeeze(), cmap='gray')
plt.axis('off')
plt.title(f"True Label: {label}")
plt.show()

## Encoding static image into spike trains

In [None]:
_, testing_loader = utils.load_MNIST()
encoded, encoded_label = testing_loader.dataset[sample_idx]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch
from IPython.display import HTML

# Sample tensor data (replace this with your actual tensor)
data = encoded  # Shape: [15, 6, 28, 28]

# Convert to numpy array for plotting
data_np = data.numpy()

# Create a figure and subplots
fig, axs = plt.subplots(1, 6, figsize=(18, 3))
fig.patch.set_facecolor('lightsalmon')

# Hide the axes for a cleaner display
for ax in axs:
    ax.axis('off')

# Animation update function
def update(t):
    for i, ax in enumerate(axs):
        ax.imshow(data_np[t, i], cmap='grey')
        ax.set_title(f'Channel = {i}', color='white', fontsize=10, fontweight='bold')
    return axs

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=15, interval=200, blit=False)

# Close the figure to prevent the initial plot from showing
plt.close(fig)

# Display the animation
HTML(ani.to_jshtml())


## Insert the encoded MNIST sample into the model

In [None]:
net = network.Network(number_of_classes=10)

In [None]:
plt.imshow(net.conv1.weight.detach().numpy()[0,0])
plt.axis('off')

In [None]:
latest_checkpoint_path = utils.get_latest_checkpoint('../checkpoints/experiment_1')
if latest_checkpoint_path:
    # Load the checkpoint if found
    start_epoch, training_layer = utils.load_checkpoint(net,
                                                    latest_checkpoint_path)

In [None]:
plt.imshow(net.conv1.weight.detach().numpy()[0,0])
plt.axis('off')

In [None]:
encoded_label

In [None]:
# Use multistep mode for faster training
functional.set_step_mode(net, step_mode='m')
frame = encoded.to('cpu:0')
frame = frame.unsqueeze(dim=0)
frame = frame.transpose(0, 1)
frame = frame.float()
_ = net(frame)

In [None]:
# Extract the decision value as an integer
decision = net.get_decision()
decision_value = int(decision.numpy()[0])

# Determine if the decision is correct
result = "Correct 😀" if decision_value == label else "Incorrect 😭"

# Print the result
print(f"The network decided: {decision_value}\nThe true label: {label}\n{result}")