In [1]:
!pip install tonic wandb snntorch
!pip install weave
!wandb login '624747d405596915090d3160f109335907281de4'

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
import randman
from randman import Randman
import numpy as np
import torch
from torch.utils.data import TensorDataset

SEED = 42

def standardize(x,eps=1e-7):
    # x's (which is actually y in the following code) shape will be [samples, units]
    # Therefore, 0-axis shows that the author standardize across all samples for each units
    mi,_ = x.min(0)
    ma,_ = x.max(0)
    return (x-mi)/(ma-mi+eps)

def make_spiking_dataset(nb_classes=10, nb_units=100, nb_steps=100, step_frac=1.0, dim_manifold=2, nb_spikes=1, nb_samples=1000, alpha=2.0, shuffle=True, classification=True, seed=None):
    """ Generates event-based generalized spiking randman classification/regression dataset.
    In this dataset each unit fires a fixed number of spikes. So ratebased or spike count based decoding won't work.
    All the information is stored in the relative timing between spikes.
    For regression datasets the intrinsic manifold coordinates are returned for each target.
    Args:
        nb_classes: The number of classes to generate
        nb_units: The number of units to assume
        nb_steps: The number of time steps to assume
        step_frac: Fraction of time steps from beginning of each to contain spikes (default 1.0)
        nb_spikes: The number of spikes per unit
        nb_samples: Number of samples from each manifold per class
        alpha: Randman smoothness parameter
        shuffe: Whether to shuffle the dataset
        classification: Whether to generate a classification (default) or regression dataset
        seed: The random seed (default: None)
    Returns:
        A tuple of data,labels. The data is structured as numpy array
        (sample x event x 2 ) where the last dimension contains
        the relative [0,1] (time,unit) coordinates and labels.
    """

    data = []
    labels = []
    targets = []

    if SEED is not None:
        np.random.seed(SEED)

    max_value = np.iinfo(int).max
    randman_seeds = np.random.randint(max_value, size=(nb_classes,nb_spikes) )

    for k in range(nb_classes):
        x = np.random.rand(nb_samples,dim_manifold)

        # The following code shows that if more than one spike, different spikes, even for the same unit, are generated by independent mappings
        submans = [ randman.Randman(nb_units, dim_manifold, alpha=alpha, seed=randman_seeds[k,i]) for i in range(nb_spikes) ]
        units = []
        times = []
        for i,rm in enumerate(submans):
            y = rm.eval_manifold(x)
            y = standardize(y)
            units.append(np.repeat(np.arange(nb_units).reshape(1,-1),nb_samples,axis=0))
            times.append(y.numpy())

        units = np.concatenate(units,axis=1)
        times = np.concatenate(times,axis=1)
        events = np.stack([times,units],axis=2)
        data.append(events)
        labels.append(k*np.ones(len(units)))
        targets.append(x)

    data = np.concatenate(data, axis=0)
    labels = np.array(np.concatenate(labels, axis=0), dtype=int)
    targets = np.concatenate(targets, axis=0)

    if shuffle:
        idx = np.arange(len(data))
        np.random.shuffle(idx)
        data = data[idx]
        labels = labels[idx]
        targets = targets[idx]

    data[:,:,0] *= nb_steps*step_frac
    # data = np.array(data, dtype=int)

    if classification:
        return data, labels
    else:
        return data, targets

def events_to_spike_train(data, nb_steps, nb_units):
    """convert the data generated from manifold to spike train form

    Args:
        data (array): shape is [samples, nb_events, 2]

    Returns:
        spike_train: shape is [nb_samples, nb_time_steps, units]
    """

    # astyle() will discard the decimal to give integer timestep
    spike_steps = data[:, :, 0].astype(int)
    spike_units = data[:, :, 1].astype(int)
    # These will be the indices to entrices in the spike train to be set to 1

    # Use the index on spike train matrix [samples, steps, units]
    spike_train = np.zeros((data.shape[0], nb_steps, nb_units))
    sample_indicies = np.expand_dims(np.arange(data.shape[0]), -1)
    spike_train[sample_indicies, spike_steps, spike_units] = 1

    return spike_train

def get_randman_dataset(nb_classes = 2, nb_units = 10, nb_steps = 50, nb_samples = 1000):
    """generate a TensorDataset encapsulated x and y, where x is spike trains

    Returns:
        TensorDataset: [nb_samples, time_steps, units] and [nb_samples]
    """
    data, label = make_spiking_dataset(nb_classes, nb_units, nb_steps, nb_spikes=1, nb_samples = nb_samples)
    spike_train = events_to_spike_train(data, nb_steps, nb_units)

    spike_train = torch.Tensor(spike_train)
    label = torch.Tensor(label)

    # encapulate using Torch.Dataset
    dataset = TensorDataset(spike_train, label)

    return dataset

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import snntorch as snn
import wandb
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- SNN Architecture with Sparse Connectivity ----
class SNN(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, learn_beta=False, beta=0.95, sparsity=0.8):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden, bias=False)
        self.fc2 = nn.Linear(num_hidden, num_outputs, bias=False)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=learn_beta)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=learn_beta, reset_mechanism='none')

        # Apply sparsity masks to fc1 and fc2
        with torch.no_grad():
            mask1 = torch.rand_like(self.fc1.weight) > sparsity
            mask2 = torch.rand_like(self.fc2.weight) > sparsity
            self.fc1.weight.data *= mask1
            self.fc2.weight.data *= mask2
            self.register_buffer("mask1", mask1)
            self.register_buffer("mask2", mask2)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem2_rec = []
        for t in range(x.size(1)):
            cur1 = self.fc1(x[:, t])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            _, mem2 = self.lif2(cur2, mem2)
            mem2_rec.append(mem2)
        return torch.stack(mem2_rec, dim=1)  # (batch, time, outputs)

# ---- Hybrid Parameter Update (PSO-inspired + adaptive Pool with Offspring) ----
def hybrid_update(mean, velocity, personal_best, global_best, loss_fn, std, samples, x, y, hidden_dim, num_classes, lr=0.1, acc_threshold=0.95):
    device = mean.device
    sample_batch = mean + std * torch.randn(samples, *mean.shape).to(device)
    losses = []
    accs = []

    # Evaluate each sampled particle
    for i in range(samples):
        model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
        with torch.no_grad():
            flat_params = sample_batch[i]
            offset = 0
            for p in model.parameters():
                numel = p.numel()
                p.data.copy_(flat_params[offset:offset+numel].view_as(p))
                offset += numel
            output = model(x)
            loss = loss_fn(output.mean(1), y)
            pred = output.mean(1).argmax(1)
            acc = (pred == y).float().mean().item()
            losses.append(loss.item())
            accs.append(acc)

    losses = torch.tensor(losses, device=device)
    accs = torch.tensor(accs, device=device)

    # Update global (group) best
    best_idx = torch.argmin(losses)
    model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
    with torch.no_grad():
        offset = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(mean[offset:offset+numel].view_as(p))
            offset += numel
        output = model(x)
        curr_loss = loss_fn(output.mean(1), y)
    if losses[best_idx] < curr_loss:
        global_best = sample_batch[best_idx].clone()

    # PSO update
    r1, r2 = torch.rand(2)
    velocity = 0.5 * velocity + 1.5 * r1 * (personal_best - mean) + 1.5 * r2 * (global_best - mean)
    mean = mean + lr * velocity

    # Evaluate new mean
    model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
    with torch.no_grad():
        offset = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(mean[offset:offset+numel].view_as(p))
            offset += numel
        output_now = model(x)
        loss_now = loss_fn(output_now.mean(1), y)
        acc_now = (output_now.mean(1).argmax(1) == y).float().mean().item()

    # Maintain personal best
    model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
    with torch.no_grad():
        offset = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(personal_best[offset:offset+numel].view_as(p))
            offset += numel
        output_pbest = model(x)
        loss_pbest = loss_fn(output_pbest.mean(1), y)
    if loss_now < loss_pbest:
        personal_best = mean.clone()

    # Adaptive Pooling with Offspring
    if accs[best_idx] < acc_threshold:
        print(f"Adaptive Pooling with Offspring | ")
        topk = sample_batch[torch.argsort(losses)[:samples//4]]
        offspring = topk + std * torch.randn_like(topk)

        offspring_losses = []
        for i in range(offspring.size(0)):
            model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
            with torch.no_grad():
                flat_params = offspring[i]
                offset = 0
                for p in model.parameters():
                    numel = p.numel()
                    p.data.copy_(flat_params[offset:offset+numel].view_as(p))
                    offset += numel
                output = model(x)
                loss = loss_fn(output.mean(1), y)
                offspring_losses.append(loss.item())

        offspring_losses = torch.tensor(offspring_losses, device=device)
        best_offspring_idx = offspring_losses.argsort()[:max(1, topk.size(0)//2)]
        mean = offspring[best_offspring_idx].mean(0).detach()

        # After pooling, re-evaluate mean and personal best
        model = SNN(x.shape[2], hidden_dim, num_classes).to(device)
        with torch.no_grad():
            offset = 0
            for p in model.parameters():
                numel = p.numel()
                p.data.copy_(mean[offset:offset+numel].view_as(p))
                offset += numel
            output_now = model(x)
            loss_now = loss_fn(output_now.mean(1), y)
        if loss_now < loss_pbest:
            personal_best = mean.clone()

    # Log batch best performance
    print(f"    Batch Accuracy: {acc_now * 100:.2f}%")
    wandb.log({"train_acc": acc_now})

    return mean, velocity, personal_best, global_best

# ---- Annealing Schedulers ----
def get_annealed_param(init, final, current_epoch, total_epochs, mode='exp'):
    if mode == 'linear':
        return final + (init - final) * (1 - current_epoch / total_epochs)
    elif mode == 'exp':
        return final + (init - final) * (0.995 ** current_epoch)
    else:
        return init

# # ---- Train Function ----
# def train_snn():
#     run_name = 'EA_10_classes_randman_offspring'
#     config = {
#         'nb_input': 100, 'nb_output': 10, 'nb_steps': 50, 'nb_data_samples': 1000,
#         'nb_hidden': 20, 'learn_beta': False, 'nb_model_samples': 100,
#         'std': 0.05, 'epochs': 20, 'batch_size': 256,
#         'loss': 'cross-entropy', 'optimizer': 'Adam', 'lr': 0.01, 'regularization': 'none'
#     }

#     wandb.init(entity='DarwinNeuron', project='EA-Randman', name=run_name, config=config)

#     with torch.no_grad():
#         dataset = get_randman_dataset(config['nb_output'], config['nb_input'], config['nb_steps'], config['nb_data_samples'])
#         train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, shuffle=False)
#         train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
#         val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

#         sample_model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output'])
#         param_vector = torch.cat([p.flatten() for p in sample_model.parameters()]).detach()
#         mean = param_vector.clone()
#         velocity = torch.zeros_like(mean)
#         personal_best = mean.clone()
#         global_best = mean.clone()

#         for epoch in range(config['epochs']):
#             print(f"Epoch {epoch}")
#             current_std = get_annealed_param(init=0.1, final=0.01, current_epoch=epoch, total_epochs=config['epochs'])
#             current_samples = int(get_annealed_param(init=1000, final=100, current_epoch=epoch, total_epochs=config['epochs']))
#             acc_thresh = get_annealed_param(init=0.90, final=0.98, current_epoch=epoch, total_epochs=config['epochs'])

#             for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
#                 x_batch, y_batch = x_batch.float(), y_batch.long()
#                 mean, velocity, personal_best, global_best = hybrid_update(
#                     mean, velocity, personal_best, global_best,
#                     nn.CrossEntropyLoss(), current_std, current_samples,
#                     x_batch, y_batch, 20, 10, lr=config['lr'], acc_threshold=acc_thresh
#                 )

#             # ---- Evaluate validation set ----
#             val_accs = []
#             with torch.no_grad():
#                 model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output']).to(mean.device)
#                 offset = 0
#                 for p in model.parameters():
#                     numel = p.numel()
#                     p.data.copy_(mean[offset:offset+numel].view_as(p))
#                     offset += numel

#                 for x_val, y_val in val_loader:
#                     x_val, y_val = x_val.float().to(mean.device), y_val.long().to(mean.device)
#                     output = model(x_val)
#                     pred = output.mean(1).argmax(1)
#                     acc = (pred == y_val).float().mean().item()
#                     val_accs.append(acc)

#                 val_accuracy = sum(val_accs) / len(val_accs)
#                 print(f"Validation Accuracy: {val_accuracy:.4f} | std: {current_std:.4f} | samples: {current_samples} | acc_thresh: {acc_thresh:.4f}")
#                 wandb.log({
#                     "val_accuracy": val_accuracy,
#                     "epoch": epoch,
#                     "std": current_std,
#                     "samples": current_samples,
#                     "acc_threshold": acc_thresh
#                 })

#             torch.cuda.empty_cache()

# train_snn()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# ---- Train Function ----
def train_snn():
    run_name = 'EA_2_classes_randman_offspring_save'
    config = {
        'nb_input': 100, 'nb_output': 2, 'nb_steps': 50, 'nb_data_samples': 1000,
        'nb_hidden': 10, 'learn_beta': False, 'nb_model_samples': 100,
        'std': 0.05, 'epochs': 20, 'batch_size': 256,
        'loss': 'cross-entropy', 'optimizer': 'Adam', 'lr': 0.01, 'regularization': 'none'
    }

    wandb.init(entity='DarwinNeuron', project='EA-Randman', name=run_name, config=config)

    with torch.no_grad():
        dataset = get_randman_dataset(config['nb_output'], config['nb_input'], config['nb_steps'], config['nb_data_samples'])
        train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, shuffle=False)
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

        sample_model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output'])
        param_vector = torch.cat([p.flatten() for p in sample_model.parameters()]).detach()
        mean = param_vector.clone().to(device)
        velocity = torch.zeros_like(mean)
        personal_best = mean.clone()
        global_best = mean.clone()

        train_losses = []
        val_losses = []

        for epoch in range(config['epochs']):
            print(f"Epoch {epoch}")
            current_std = get_annealed_param(init=0.1, final=0.01, current_epoch=epoch, total_epochs=config['epochs'])
            current_samples = int(get_annealed_param(init=1000, final=100, current_epoch=epoch, total_epochs=config['epochs']))
            acc_thresh = get_annealed_param(init=0.90, final=0.98, current_epoch=epoch, total_epochs=config['epochs'])

            neuron_firing = torch.zeros(config['nb_hidden'], device=device)
            running_loss = 0.0
            total_batches = 0

            for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
                x_batch, y_batch = x_batch.float().to(device), y_batch.long().to(device)
                mean, velocity, personal_best, global_best = hybrid_update(
                    mean, velocity, personal_best, global_best,
                    nn.CrossEntropyLoss(), current_std, current_samples,
                    x_batch, y_batch, config['nb_hidden'], config['nb_output'],
                    lr=config['lr'], acc_threshold=acc_thresh
                )

                # After update, load mean into model and accumulate training loss
                model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output']).to(device)
                offset = 0
                for p in model.parameters():
                    numel = p.numel()
                    p.data.copy_(mean[offset:offset+numel].view_as(p))
                    offset += numel

                out = model(x_batch)
                loss = nn.CrossEntropyLoss()(out.mean(1), y_batch)
                running_loss += loss.item()
                total_batches += 1

                # Count spikes during training
                mem1 = model.lif1.init_leaky().to(device)
                for t in range(x_batch.size(1)):
                    cur1 = model.fc1(x_batch[:, t])
                    spk1, mem1 = model.lif1(cur1, mem1)
                    neuron_firing += spk1.sum(0)

            avg_train_loss = running_loss / total_batches
            train_losses.append(avg_train_loss)

            # Save model checkpoint
            checkpoint_path = f"checkpoint_epoch_{epoch}.pth"
            torch.save(mean.cpu(), checkpoint_path)
            artifact = wandb.Artifact(f'checkpoint-epoch-{epoch}', type='model')
            artifact.add_file(checkpoint_path)
            wandb.log_artifact(artifact)

            # ---- Evaluate validation set and firing counts ----
            val_accs = []
            val_losses = []
            hidden_firing_counts = torch.zeros(config['nb_hidden'], device=mean.device)
            with torch.no_grad():
                model = SNN(config['nb_input'], config['nb_hidden'], config['nb_output']).to(mean.device)
                offset = 0
                for p in model.parameters():
                    numel = p.numel()
                    p.data.copy_(mean[offset:offset+numel].view_as(p))
                    offset += numel

                for x_val, y_val in val_loader:
                    x_val, y_val = x_val.float().to(mean.device), y_val.long().to(mean.device)
                    mem1 = model.lif1.init_leaky().to(mean.device)
                    mem2 = model.lif2.init_leaky().to(mean.device)
                    for t in range(x_val.size(1)):
                        cur1 = model.fc1(x_val[:, t])
                        spk1, mem1 = model.lif1(cur1, mem1)
                        cur2 = model.fc2(spk1)
                        _, mem2 = model.lif2(cur2, mem2)
                        hidden_firing_counts += spk1.sum(0)

                    output = model(x_val)
                    pred = output.mean(1).argmax(1)
                    acc = (pred == y_val).float().mean().item()
                    loss = nn.CrossEntropyLoss()(output.mean(1), y_val)
                    val_losses.append(loss.item())
                    val_accs.append(acc)

                val_accuracy = sum(val_accs) / len(val_accs)
                val_loss = sum(val_losses) / len(val_losses)
                hidden_firing_counts = hidden_firing_counts.cpu().numpy()

                # Log heatmap
                import matplotlib.pyplot as plt
                import seaborn as sns
                fig, ax = plt.subplots(figsize=(10, 3))
                sns.heatmap(hidden_firing_counts[np.newaxis, :], cmap='viridis', cbar=True, xticklabels=False, yticklabels=False)
                ax.set_title(f"Hidden Neuron Firing Counts (Epoch {epoch})")
                wandb.log({"hidden_firing_heatmap": wandb.Image(fig)})
                plt.close(fig)

                # Log scalar values
                wandb.log({
                    "val_accuracy": val_accuracy,
                    "val_loss": val_loss,
                    "epoch": epoch,
                    "std": current_std,
                    "samples": current_samples,
                    "acc_threshold": acc_thresh,
                    "hidden_firing_mean": hidden_firing_counts.mean()
                })

            torch.cuda.empty_cache()

train_snn()


0,1
train_acc,▁

0,1
train_acc,0.08594


Epoch 0
Adaptive Pooling with Offspring | 
    Batch Accuracy: 8.59%
Adaptive Pooling with Offspring | 
    Batch Accuracy: 11.72%
