In [None]:
import torch

In [None]:
# Hyperparameters
input_size = 700        # Number of input neurons (same as num_units)
num_classes = 20        # Adjust based on the dataset's number of classes
num_epochs = 20
learning_rate = 0.001
time_window = 2         # Time steps for temporal processing
batch_size = 4          # Define batch size here
beta1 = 0.8             # Decay factor for RLeaky layer 1
beta2 = 0.89            # Decay factor for RLeaky layer 2
num_units = 700         # Number of input units (same as input_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os
import urllib.request
import gzip, shutil

cache_dir = os.path.expanduser("~/data")
os.makedirs(cache_dir, exist_ok=True)

# Download and extract files
def download_and_extract_shd():
    url = "https://zenkelab.org/datasets/shd_train.h5.gz"
    gz_file_path = os.path.join(cache_dir, "shd_train.h5.gz")
    hdf5_file_path = gz_file_path[:-3]

    if not os.path.exists(hdf5_file_path):
        print("Downloading SHD dataset...")
        urllib.request.urlretrieve(url, gz_file_path)
        
        # Decompress .gz file
        print(f"Decompressing {gz_file_path}...")
        with gzip.open(gz_file_path, 'rb') as f_in, open(hdf5_file_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
        print(f"Decompressed to {hdf5_file_path}")
    else:
        print(f"{hdf5_file_path} already exists.")
    
    return hdf5_file_path

# Download and extract the training set
shd_train_file = download_and_extract_shd()
shd_train_file

In [None]:
import h5py

def explore_h5_file(file_path):
    with h5py.File(file_path, 'r') as file:
        # Recursive function to explore groups and datasets
        def print_group(name, obj):
            print(f"Group: {name}")
            if isinstance(obj, h5py.Group):
                for key in obj.keys():
                    print_group(f"{name}/{key}", obj[key])
            else:
                print(f"  Dataset: {name}, Shape: {obj.shape}, Dtype: {obj.dtype}")

        print_group('/', file)
        
explore_h5_file(shd_train_file)


In [None]:
import h5py
import numpy as np

# Function to load spikes and labels from the SHD dataset
def load_spikes_and_labels(hdf5_file_path):
    with h5py.File(hdf5_file_path, 'r') as f:
        # Extract the labels
        labels = np.array(f['/labels'])
        
        # Extract the spike units (neurons that fired)
        spike_units = np.array(f['/spikes/units'], dtype=object)  # Shape: (8156,)
        
        # Extract the spike times (optional if you want to analyze timings as well)
        spike_times = np.array(f['/spikes/times'], dtype=object)
        
    return spike_units, spike_times, labels

In [None]:
import torch
import numpy as np
import torch.nn.utils.rnn as rnn_utils

spike_units, spike_times, labels = load_spikes_and_labels(shd_train_file)


spike_units_list = [torch.tensor(units, dtype=torch.float32) for units in spike_units]
spike_times_list = [torch.tensor(times, dtype=torch.float32) for times in spike_times]


# Limit the number of samples to process (e.g., 500 samples)
num_samples = 500
spike_units_list = spike_units_list[:num_samples]
spike_times_list = spike_times_list[:num_samples]
labels = labels[:num_samples]

In [None]:
import snntorch.spikegen as spikegen

# Function to convert spike times and units into tensor format for SLSTM input
def generate_spike_tensor(spike_units, spike_times, num_units, time_window):
    """
    Convert spike units and spike times to spike train tensors.
    
    Parameters:
    - spike_units: tensor of spike units (neuron IDs)
    - spike_times: tensor of spike times (in ms)
    - num_units: number of possible units (e.g., 700 in SHD)
    - time_window: total time of the spike train (in ms)
    
    Returns:
    - A spike tensor of shape [time_window, num_units]
    """
    spike_tensor = torch.zeros((time_window, num_units))

    for unit, time in zip(spike_units, spike_times):
        # Time is rounded to the nearest ms"
        spike_time = int(time.item())
        if spike_time < time_window:
            spike_tensor[spike_time, int(unit.item())] = 1

    return spike_tensor

In [None]:
spike_list = [generate_spike_tensor(spike_units_list[i], spike_times_list[i], num_units, time_window) for i in range(num_samples)]
spike_list

In [None]:
labels_tensor = [torch.tensor(i, dtype=torch.long).unsqueeze(0) for i in labels]
labels_tensor

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
from torch.utils.data import DataLoader, TensorDataset


# Define the Spiking Recurrent Neural Network (SRNN) Model
class SRNN(nn.Module):
    def __init__(self, input_size, num_classes, time_window):
        super(SRNN, self).__init__()

        # Define layers
        self.fc1 = nn.Linear(input_size, 512)
        self.rlif1 = snn.RLeaky(beta=beta1, all_to_all=True, linear_features=512)

        self.fc2 = nn.Linear(512, num_classes)
        self.rlif2 = snn.RLeaky(beta=beta2, all_to_all=True, linear_features=num_classes)

        self.time_window = time_window

    def forward(self, x):
        # Initialize spikes and membrane potential for both layers
        spk1, mem1 = torch.ones(512, device=x.device) * 0.2, torch.ones(512, device=x.device) * 0.1
        spk2, mem2 = torch.ones(num_classes, device=x.device) * 0.2, torch.ones(num_classes, device=x.device) * 0.1

        spk2_list, mem2_list = [], []

        # Iterate over the time window
        for step in range(self.time_window):
            ip = x[:, step, :]  # Access time step for entire batch; Shape: (batch_size, num_units)

            # Layer 1: Fully connected + RLeaky
            out = self.fc1(ip)
            spk1, mem1 = self.rlif1(out, spk1, mem1)

            # Layer 2: Fully connected + RLeaky
            out = self.fc2(spk1)
            spk2, mem2 = self.rlif2(out, spk2, mem2)

            # Append results for the time step
            spk2_list.append(spk2)

        # Stack results from each time step to form tensors
        spk2_tensor = torch.stack(spk2_list, dim=1)  # Shape: (batch_size, time_window, num_classes)

        return spk2_tensor


# Convert to torch tensors and use DataLoader for batching
spike_tensor = torch.stack(spike_list)  # Shape: (100, time_window, num_units)
labels_tensor = torch.cat(labels_tensor)        # Shape: (100,)

print(spike_tensor.size())
print(labels_tensor.size())

# Create a TensorDataset and DataLoader for batching
dataset = TensorDataset(spike_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the model, loss function, and optimizer
model = SRNN(input_size, num_classes, time_window).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for i, (spike_batch, label_batch) in enumerate(data_loader):
        # Move batch to device (GPU or CPU)
        spike_batch = spike_batch.to(device)  # Shape: (batch_size, time_window, num_units)
        label_batch = label_batch.to(device)

        # Forward pass
        outputs = model(spike_batch)  # Output shape: (batch_size, time_window, num_classes)
        
        # To compute loss, flatten the outputs and labels
        outputs = outputs.mean(dim=1)  # Aggregating over the time dimension
        loss = criterion(outputs, label_batch)

        # Backward pass and optimization

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print progress for every batch
        print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{len(data_loader)}], Loss: {loss.item():.4f}')
