In [1]:
import dataloader
trainloader= dataloader.load_filtered_shd_dataloader()

Downloading https://zenkelab.org/datasets/shd_train.h5.zip to ./data/SHD/shd_train.h5.zip


  0%|          | 0/130863613 [00:00<?, ?it/s]

Extracting ./data/SHD/shd_train.h5.zip to ./data/SHD


In [2]:
events, target = next(iter(trainloader))

In [3]:
events[4].shape

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [4]:
events[4].shape

torch.Size([250, 1, 700])

In [5]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [6]:
# Network Architecture
num_inputs = 700
num_hidden = 2000
num_outputs = 10

# Temporal Dynamics
num_steps = 1000
beta = 0.95

In [7]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):  # x: [batch, time, input]
        batch_size, time_steps, _ = x.shape

        # Initialize membrane potentials as zeros with the correct shape
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)

        spk2_rec = []
        mem2_rec = []

        for step in range(time_steps):
            xt = x[:, step, :]  # [batch, 700]
            cur1 = self.fc1(xt)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

In [8]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data)
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(
    data, targets, epoch,
    counter, iter_counter,
        loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [9]:
loss = nn.CrossEntropyLoss()

In [10]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [11]:
data, targets = next(iter(trainloader))
data = data.to(device).squeeze(2).squeeze(2)  # [128, 1000, 700]
targets = targets.to(device)
spk_rec, mem_rec = net(data)
print(mem_rec.size())

torch.Size([250, 32, 10])


In [12]:
import torch
torch.cuda.empty_cache()


In [13]:
# initialize the total loss value
dtype = torch.float

loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")

IndexError: index 250 is out of bounds for dimension 0 with size 250

In [29]:
print_batch_accuracy(data, targets, train=True)

Train set accuracy for a single minibatch: 18.75%


In [30]:
# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

In [31]:
# calculate new network outputs using the same data
spk_rec, mem_rec = net(data)

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")
print_batch_accuracy(data, targets, train=True)

Training loss: 2231.914
Train set accuracy for a single minibatch: 25.00%


In [33]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(trainloader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.squeeze(2).squeeze(2)
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print_batch_accuracy(data, targets, train=True)


Train set accuracy for a single minibatch: 21.88%
Train set accuracy for a single minibatch: 12.50%
Train set accuracy for a single minibatch: 12.50%
Train set accuracy for a single minibatch: 9.38%
Train set accuracy for a single minibatch: 9.38%
Train set accuracy for a single minibatch: 21.88%
Train set accuracy for a single minibatch: 15.62%
Train set accuracy for a single minibatch: 12.50%
Train set accuracy for a single minibatch: 28.12%
Train set accuracy for a single minibatch: 18.75%
Train set accuracy for a single minibatch: 9.38%
Train set accuracy for a single minibatch: 25.00%
Train set accuracy for a single minibatch: 9.38%
Train set accuracy for a single minibatch: 21.88%


KeyboardInterrupt: 