<a href="https://colab.research.google.com/github/Medulus/CollegeCode/blob/main/SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install snntorch --quiet
! pip install tonic --quiet

In [None]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import librosa

import matplotlib.pyplot as plt
import numpy as np
import itertools


In [None]:
batch_size = 128

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [None]:
# Define the ResizeAudio transformation class
class ResizeAudio(object):
    def __init__(self, target_duration):
        self.target_duration = target_duration

    def __call__(self, audio):
        # Calculate the current duration of the audio
        current_duration = librosa.get_duration(filename=train_ds)

        # Calculate the ratio to resize the audio
        resize_ratio = self.target_duration / current_duration

        # Resample the audio
        audio_resized = audio.resample(resize_ratio)

        return audio_resized

Downloading the data from tonic

In [None]:
DATADIR = "/tmp/data"

sensor_size = tonic.datasets.SSC.sensor_size

batch_size = 32
dt = 4000  # All time units in Tonic are in microseconds (us)
encoding_dim = 100
target_duration = 5


# Define the transformation pipeline
dense_transform = transforms.Compose([
    transforms.Downsample(spatial_factor=encoding_dim / 700),
    transforms.CropTime(max=1e6),
    transforms.ToFrame(sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True),
    ResizeAudio(target_duration)
])

train_ds = tonic.datasets.SSC(save_to=DATADIR, transform=dense_transform, split="train")
test_ds = tonic.datasets.SSC(save_to=DATADIR, transform=dense_transform, split='test')

In [None]:
train_loader = DataLoader(train_ds, batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_ds, batch_size, shuffle=True, drop_last=True)

In [None]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)


Accuracy

In [None]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [None]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [None]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

	This alias will be removed in version 1.0.
  current_duration = librosa.get_duration(filename=train_ds)


TypeError: Invalid file: SSC