In [1]:
pip install snntorch

Collecting snntorch
  Downloading snntorch-0.7.0-py2.py3-none-any.whl (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting nir (from snntorch)
  Downloading nir-0.2.0-py3-none-any.whl (21 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-0.2.1-py3-none-any.whl (10 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-0.2.0 nirtorch-0.2.1 snntorch-0.7.0


In [2]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import utils

num_steps = 25 # number of time steps
batch_size = 1
beta = 0.5  # neuron decay rate
spike_grad = surrogate.fast_sigmoid() # surrogate gradient

net = nn.Sequential(
      nn.Conv2d(1, 8, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Conv2d(8, 16, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Flatten(),
      nn.Linear(16 * 4 * 4, 10),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True)
      )

data_in = torch.rand(num_steps, batch_size, 1, 28, 28) # random input data
spike_recording = [] # record spikes over time
utils.reset(net) # reset/initialize hidden states for all neurons

for step in range(num_steps): # loop over time
    spike, state = net(data_in[step]) # one time step of forward-pass
    spike_recording.append(spike) # record spikes in list

In [7]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import utils
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define your SNN architecture and parameters as provided in your code.
net = nn.Sequential(
      nn.Conv2d(1, 16, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Flatten(),
      nn.Linear(16 * 4 * 4, 10),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True),
      nn.Flatten(),
nn.Linear(256, 10),  # Ensure the number of output units is 10
snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True)

      )
# Load and preprocess the Fashion MNIST dataset using PyTorch DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST('data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# Training loop
num_epochs = 5  # Adjust the number of epochs as needed
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        # Reset SNN states
        utils.reset(net)

        # Perform forward pass
        spike_recording = []  # Record spikes over time
        for step in range(num_steps):
            spike, state = net(inputs[step])
            spike_recording.append(spike)

        # Calculate loss
        output = spike_recording[-1]  # Use the output spike recording
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")

print("Finished Training")

# Evaluation on test dataset
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        utils.reset(net)

        spike_recording = []  # Record spikes over time
        for step in range(num_steps):
            spike, state = net(inputs[step])
            spike_recording.append(spike)

        output = spike_recording[-1]
        _, predicted = torch.max(output, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test dataset: {(100 * correct / total):.2f}%")


RuntimeError: ignored