<a href="https://colab.research.google.com/github/PrashubhAtri/heterogeneousSNNs/blob/main/Baselines/SNNs/SurrogateGradients/SNNwithVoltage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [2]:
!git clone https://github.com/fzenke/randman

Cloning into 'randman'...
remote: Enumerating objects: 104, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 104 (delta 16), reused 34 (delta 8), pack-reused 59 (from 1)[K
Receiving objects: 100% (104/104), 683.31 KiB | 8.65 MiB/s, done.
Resolving deltas: 100% (31/31), done.


In [3]:
import sys
sys.path.append('/content/randman')

In [4]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
import torch.optim.lr_scheduler as lr_scheduler
import pandas as pd
import randman
from randman import Randman
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
# Constants for data generation
NB_CLASSES = 2
NB_UNITS = 10 # number of input neurons / embedding dimensions
NB_STEPS = 50
NB_SAMPLES = 2000
SEED = 12

In [7]:
def standardize(x,eps=1e-7):
    # x's (which is actually y in the following code) shape will be [samples, units]
    # Therefore, 0-axis shows that the author standardize across all samples for each units
    mi,_ = x.min(0)
    ma,_ = x.max(0)
    return (x-mi)/(ma-mi+eps)

In [8]:
def make_spiking_dataset(nb_classes=10, nb_units=100, nb_steps=100, step_frac=1.0, dim_manifold=2, nb_spikes=1, nb_samples=1000, alpha=2.0, shuffle=True, classification=True, seed=None):
    """ Generates event-based generalized spiking randman classification/regression dataset.
    In this dataset each unit fires a fixed number of spikes. So ratebased or spike count based decoding won't work.
    All the information is stored in the relative timing between spikes.
    For regression datasets the intrinsic manifold coordinates are returned for each target.
    Args:
        nb_classes: The number of classes to generate
        nb_units: The number of units to assume
        nb_steps: The number of time steps to assume
        step_frac: Fraction of time steps from beginning of each to contain spikes (default 1.0)
        nb_spikes: The number of spikes per unit
        nb_samples: Number of samples from each manifold per class
        alpha: Randman smoothness parameter
        shuffe: Whether to shuffle the dataset
        classification: Whether to generate a classification (default) or regression dataset
        seed: The random seed (default: None)
    Returns:
        A tuple of data,labels. The data is structured as numpy array
        (sample x event x 2 ) where the last dimension contains
        the relative [0,1] (time,unit) coordinates and labels.
    """

    data = []
    labels = []
    targets = []

    if SEED is not None:
        np.random.seed(SEED)

    max_value = np.iinfo(int).max
    randman_seeds = np.random.randint(max_value, size=(nb_classes,nb_spikes) )

    for k in range(nb_classes):
        x = np.random.rand(nb_samples,dim_manifold)

        # The following code shows that if more than one spike, different spikes, even for the same unit, are generated by independent mappings
        submans = [ randman.Randman(nb_units, dim_manifold, alpha=alpha, seed=randman_seeds[k,i]) for i in range(nb_spikes) ]
        units = []
        times = []
        for i,rm in enumerate(submans):
            y = rm.eval_manifold(x)
            y = standardize(y)
            units.append(np.repeat(np.arange(nb_units).reshape(1,-1),nb_samples,axis=0))
            times.append(y.numpy())

        units = np.concatenate(units,axis=1)
        times = np.concatenate(times,axis=1)
        events = np.stack([times,units],axis=2)
        data.append(events)
        labels.append(k*np.ones(len(units)))
        targets.append(x)

    data = np.concatenate(data, axis=0)
    labels = np.array(np.concatenate(labels, axis=0), dtype=int)
    targets = np.concatenate(targets, axis=0)

    if shuffle:
        idx = np.arange(len(data))
        np.random.shuffle(idx)
        data = data[idx]
        labels = labels[idx]
        targets = targets[idx]

    data[:,:,0] *= nb_steps*step_frac
    # data = np.array(data, dtype=int)

    if classification:
        return data, labels
    else:
        return data, targets

In [9]:
def events_to_spike_train(data):
    """convert the data generated from manifold to spike train form

    Args:
        data (array): shape is [samples, nb_events, 2]

    Returns:
        spike_train: shape is [nb_samples, nb_time_steps, units]
    """

    # astyle() will discard the decimal to give integer timestep
    spike_steps = data[:, :, 0].astype(int)
    spike_units = data[:, :, 1].astype(int)
    # These will be the indices to entrices in the spike train to be set to 1

    # Use the index on spike train matrix [samples, steps, units]
    spike_train = np.zeros((data.shape[0], NB_STEPS, NB_UNITS))
    sample_indicies = np.expand_dims(np.arange(data.shape[0]), -1)
    spike_train[sample_indicies, spike_steps, spike_units] = 1

    return spike_train

In [10]:
def get_randman_dataset():
    """generate a TensorDataset encapsulated x and y, where x is spike trains

    Returns:
        TensorDataset: [nb_samples, time_steps, units] and [nb_samples]
    """
    data, label = make_spiking_dataset(NB_CLASSES, NB_UNITS, NB_STEPS, nb_spikes=1, nb_samples=NB_SAMPLES)
    spike_train = events_to_spike_train(data)

    spike_train = torch.Tensor(spike_train).to(device)
    label = torch.Tensor(label).to(device)

    # encapulate using Torch.Dataset
    dataset = TensorDataset(spike_train, label)

    return dataset

In [11]:
# Hyperparameters
NB_HIDDEN_UNITS = int(NB_UNITS * 1.5)
BETA = 0.85 # This can also be obtained using exp(-delta_t / tau)

In [12]:
spike_trains = get_randman_dataset()

In [13]:
sample, label = spike_trains[0]  # Access the first sample
print(f"Sample shape: {sample.shape}")
print(f"Label: {label}")

Sample shape: torch.Size([50, 10])
Label: 0.0


In [14]:
# Split into ttv sets
train, test, validation = 0.8, 0.1, 0.1

all_labels = [spike_trains[i][1] for i in range(len(spike_trains))]

# First split: train (80%) and temp (20%)
train_idx, temp_idx = train_test_split(
    np.arange(len(spike_trains)),
    test_size=test,
    stratify=all_labels,
    random_state=SEED
)

# Second split: val (10%) and test (10%) from temp
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=[all_labels[i] for i in temp_idx],
    random_state=SEED
)

from torch.utils.data import Subset
train_dataset = Subset(spike_trains, train_idx)
val_dataset = Subset(spike_trains, val_idx)
test_dataset = Subset(spike_trains, test_idx)

#Batch_size = 32
batch_size = 32

train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

train_labels = [spike_trains[i][1] for i in train_dataset.indices]
test_labels = [spike_trains[i][1] for i in train_dataset.indices]
train_labels = [spike_trains[i][1] for i in train_dataset.indices]

In [15]:
# Example: Iterate through the training DataLoader
for batch_data, batch_labels in train_loader:
    print(f"Batch data shape: {batch_data.shape}")
    print(f"Batch labels: {batch_labels}")
    break

Batch data shape: torch.Size([32, 50, 10])
Batch labels: tensor([1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1.,
        1., 1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 1.])


In [18]:
class SNN(nn.Module):
    def __init__(self, num_inputs=NB_STEPS, num_hidden=100, num_outputs=10, beta=0.85):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=True, threshold=0.5,
                              reset_mechanism="subtract",
                              spike_grad=surrogate.fast_sigmoid(slope=5))

        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=True, threshold=0.5,
                              reset_mechanism="none",
                              spike_grad=surrogate.fast_sigmoid(slope=5))

        # Xavier Uniform Initialization for Stability
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        batch_size, num_neurons, time_steps = x.shape
        x = x.permute(2, 0, 1)  # (time, batch, neurons)

        mem1, mem2 = [torch.zeros(batch_size, layer.out_features, device=x.device)
                             for layer in [self.fc1, self.fc2]]
        # Store output membrane potentials over time
        mem2_rec = []

        for t in range(time_steps):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            mem2_rec.append(mem2)

        # Aggregate membrane potentials over time (sum)
        # mem2_rec = torch.stack(mem2_rec, dim=0).sum(dim=0)
        return mem2_rec[-1]

In [19]:
# Instantiate Model with voltage aggregation
model = SNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# Training Loop with Validation
num_epochs = 10
best_val_accuracy = 0.0
torch.autograd.set_detect_anomaly(True)

for epoch in range(num_epochs):
    # Training Phase
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        labels = labels.long()
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)

    # Validation Phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            labels = labels.long()
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = correct / total

    # Update learning rate
    scheduler.step()

    # Save best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy*100:.2f}%")

# Load best model for testing
model.load_state_dict(torch.load('best_model.pth'))

# Testing Loop
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        labels = labels.long()
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader)
test_accuracy = correct / total

print(f"\nFinal Test Results:")
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy*100:.2f}%")

Epoch [1/10]
  Train Loss: 0.9647 | Val Loss: 0.5637 | Val Acc: 66.00%
Epoch [2/10]
  Train Loss: 0.5463 | Val Loss: 0.5487 | Val Acc: 66.50%
Epoch [3/10]
  Train Loss: 0.5218 | Val Loss: 0.5322 | Val Acc: 69.50%
Epoch [4/10]
  Train Loss: 0.5120 | Val Loss: 0.5584 | Val Acc: 69.00%
Epoch [5/10]
  Train Loss: 0.4970 | Val Loss: 0.5810 | Val Acc: 65.00%
Epoch [6/10]
  Train Loss: 0.4866 | Val Loss: 0.5241 | Val Acc: 71.00%
Epoch [7/10]
  Train Loss: 0.4710 | Val Loss: 0.5327 | Val Acc: 71.00%
Epoch [8/10]
  Train Loss: 0.4556 | Val Loss: 0.4939 | Val Acc: 71.50%
Epoch [9/10]
  Train Loss: 0.4343 | Val Loss: 0.4956 | Val Acc: 73.00%
Epoch [10/10]
  Train Loss: 0.4172 | Val Loss: 0.4760 | Val Acc: 72.50%

Final Test Results:
Test Loss: 0.4245 | Test Accuracy: 77.00%
