In [None]:
# !pip install tonic wandb snntorch
# !pip install weave
# !wandb login ''

Defaulting to user installation because normal site-packages is not writeable
Collecting tonic
  Using cached tonic-1.6.0-py3-none-any.whl.metadata (5.4 kB)
Collecting snntorch
  Using cached snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Collecting h5py (from tonic)
  Downloading h5py-3.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting importRosbag>=1.0.4 (from tonic)
  Using cached importRosbag-1.0.4-py3-none-any.whl.metadata (4.3 kB)
Collecting librosa (from tonic)
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting pbr (from tonic)
  Using cached pbr-6.1.1-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting expelliarmus (from tonic)
  Downloading expelliarmus-1.1.12.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting audioread>=2.1.9 (from librosa->tonic)
  Using cached audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa->tonic)
  Downloading numba-0.61.2

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import snntorch as snn
import wandb
import tonic
from sklearn.model_selection import train_test_split
import numpy as np


# load SHD
def convert_to_tensor(spike_times, spike_units, num_neurons=700, time_bins=100):

    spike_tensor = np.zeros((time_bins, num_neurons), dtype=np.float32)

    # Normalize time into the `time_bins`
    if len(spike_times) > 0:
        time_idx = (spike_times / np.max(spike_times) * (time_bins - 1)).astype(int)
        spike_tensor[time_idx, spike_units] = 1  # Mark neuron firing

    return torch.tensor(spike_tensor, dtype=torch.float32)

def get_SHD_dataset(SHD_raw, num_neurons=700, time_bins=100):
    spike_trains = []
    labels = []
    for i, (events, label) in enumerate(SHD_raw):
        # events has shape (nb_spikes,), each entry is array([t(spike time), x(unit), p])
        spike_times = events['t']
        spike_units = events['x']
        spike_trains.append(convert_to_tensor(spike_times, spike_units))
        labels.append(label)
    X = torch.tensor(torch.stack(spike_trains), dtype=torch.float32)  # [sample, time_step, unit]
    y = torch.tensor(labels, dtype=torch.long)
    return TensorDataset(X, y)

SHD_train_raw = tonic.datasets.SHD(save_to='../tonic_data', train=True)
SHD_test_raw = tonic.datasets.SHD(save_to='../tonic_data', train=False)

train_dataset = get_SHD_dataset(SHD_train_raw)
test_dataset = get_SHD_dataset(SHD_test_raw)

  X = torch.tensor(torch.stack(spike_trains), dtype=torch.float32)  # [sample, time_step, unit]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- SNN Architecture with Sparse Connectivity ----
class SNN(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, learn_beta=False, beta=0.95, sparsity=0.8):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden, bias=False)
        self.fc2 = nn.Linear(num_hidden, num_outputs, bias=False)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=learn_beta)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=learn_beta, reset_mechanism='none')

        # Apply sparsity masks to fc1 and fc2
        with torch.no_grad():
            mask1 = torch.rand_like(self.fc1.weight) > sparsity
            mask2 = torch.rand_like(self.fc2.weight) > sparsity
            self.fc1.weight.data *= mask1
            self.fc2.weight.data *= mask2
            self.register_buffer("mask1", mask1)
            self.register_buffer("mask2", mask2)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem2_rec = []
        for t in range(x.size(1)):
            cur1 = self.fc1(x[:, t])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            _, mem2 = self.lif2(cur2, mem2)
            mem2_rec.append(mem2)
        return torch.stack(mem2_rec, dim=1)  # (batch, time, outputs)

# ---- Hybrid Parameter Update (PSO-inspired + adaptive Pool) ----
def hybrid_update(mean, velocity, personal_best, global_best, loss_fn, std, samples, x, y, lr=0.1, acc_threshold=0.95):
    sample_batch = mean + std * torch.randn(samples, *mean.shape).to(mean.device)
    losses = []
    accs = []

    for i in range(samples):
        model = SNN(*x.shape[2:], 100, 20).to(mean.device) # input, hidden=100, output=20
        with torch.no_grad():
            flat_params = sample_batch[i]
            offset = 0
            for p in model.parameters():
                numel = p.numel()
                p.data.copy_(flat_params[offset:offset+numel].view_as(p))
                offset += numel
            output = model(x)
            loss = loss_fn(output.mean(1), y)
            pred = output.mean(1).argmax(1)
            acc = (pred == y).float().mean().item()
            losses.append(loss.item())
            accs.append(acc)

    losses = torch.tensor(losses, device=mean.device)
    accs = torch.tensor(accs, device=mean.device)

    best_idx = torch.argmin(losses)
    if losses[best_idx] < loss_fn(model(x).mean(1), y):
        global_best = sample_batch[best_idx].clone()

    # PSO update
    r1, r2 = torch.rand(2)
    velocity = 0.5 * velocity + 1.5 * r1 * (personal_best - mean) + 1.5 * r2 * (global_best - mean)
    mean = mean + lr * velocity

    # Adaptive Pooling
    if accs[best_idx] < acc_threshold:
        print(f"Adaptive Pooling |", end = ' ')
        topk = sample_batch[torch.argsort(losses)[:samples//4]]
        mean = topk.mean(dim=0)

    # Log batch best performance (optional)
    with torch.no_grad():
        model = SNN(*x.shape[2:], 10, 2).to(mean.device)
        offset = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(mean[offset:offset+numel].view_as(p))
            offset += numel
        output = model(x)
        pred = output.mean(1).argmax(1)
        acc = (pred == y).float().mean().item()
        print(f"    Batch Accuracy: {acc * 100:.2f}%")
        wandb.log({"train_acc": acc})

    personal_best = mean.clone()
    return mean, velocity, personal_best, global_best

# ---- Annealing Schedulers ----
def get_annealed_param(init, final, current_epoch, total_epochs, mode='exp'):
    if mode == 'linear':
        return final + (init - final) * (1 - current_epoch / total_epochs)
    elif mode == 'exp':
        return final + (init - final) * (0.97 ** current_epoch)
    else:
        return init

# ---- Train Function ----
def train_snn():
    run_name = 'EA_hybrid_SHD'
    config = {
        'nb_input': 700, 'nb_output': 20, 'nb_steps': 100, 'nb_data_samples': 8155,
        'nb_hidden': 100, 'learn_beta': False, 'nb_model_samples': 1000,
        'std': 0.05, 'epochs': 50, 'batch_size': 256,
        'loss': 'cross-entropy', 'optimizer': 'Adam', 'lr': 0.01, 'regularization': 'none'
    }

    wandb.init(entity='DarwinNeuron', project='EA-SHD', name=run_name, config=config)

    with torch.no_grad():
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
        val_loader = DataLoader(test_dataset, batch_size=256, shuffle=True)

        sample_model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output'])
        param_vector = torch.cat([p.flatten() for p in sample_model.parameters()]).detach()
        mean = param_vector.clone()
        velocity = torch.zeros_like(mean)
        personal_best = mean.clone()
        global_best = mean.clone()

        for epoch in range(config['epochs']):
            print(f"Epoch {epoch}")
            current_std = get_annealed_param(init=0.1, final=0.01, current_epoch=epoch, total_epochs=config['epochs'])
            current_samples = int(get_annealed_param(init=1000, final=100, current_epoch=epoch, total_epochs=config['epochs']))
            acc_thresh = get_annealed_param(init=0.90, final=0.98, current_epoch=epoch, total_epochs=config['epochs'])

            for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
                x_batch, y_batch = x_batch.float(), y_batch.long()
                mean, velocity, personal_best, global_best = hybrid_update(
                    mean, velocity, personal_best, global_best,
                    nn.CrossEntropyLoss(), current_std, current_samples,
                    x_batch, y_batch, lr=config['lr'], acc_threshold=acc_thresh
                )

            # ---- Evaluate validation set ----
            val_accs = []
            with torch.no_grad():
                model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output']).to(mean.device)
                offset = 0
                for p in model.parameters():
                    numel = p.numel()
                    p.data.copy_(mean[offset:offset+numel].view_as(p))
                    offset += numel

                for x_val, y_val in val_loader:
                    x_val, y_val = x_val.float().to(mean.device), y_val.long().to(mean.device)
                    output = model(x_val)
                    pred = output.mean(1).argmax(1)
                    acc = (pred == y_val).float().mean().item()
                    val_accs.append(acc)

                val_accuracy = sum(val_accs) / len(val_accs)
                print(f"Validation Accuracy: {val_accuracy:.4f} | std: {current_std:.4f} | samples: {current_samples} | acc_thresh: {acc_thresh:.4f}")
                wandb.log({
                    "val_accuracy": val_accuracy,
                    "epoch": epoch,
                    "std": current_std,
                    "samples": current_samples,
                    "acc_threshold": acc_thresh
                })

            torch.cuda.empty_cache()

train_snn()


Epoch 0
Adaptive Pooling | Epoch 0
    Batch Accuracy: 3.52%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 7.42%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 6.25%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 4.30%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 2.73%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.47%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 3.52%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.08%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 4.69%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.08%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 4.30%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 2.73%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 6.64%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.08%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 6.64%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.47%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 5.47%
Adaptive Pooling | Epoch 0
    Batch Accuracy: 4.30%
Adaptive Pooling | Epoch 0
    Batch A