# Surrogate (snntorch) vs Eventprop Comparison

In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import datasets, transforms

from yingyang.dataset import YinYangDataset

import random
from tqdm.notebook import tqdm
import pandas as pd
import seaborn as sns
import argparse

In [2]:
%load_ext autoreload
%autoreload 2

## Data

In [3]:
data_config = {
    "seed": 42,
    "dataset": "mnist",
    "deterministic": True,
    "batch_size": 128,
    "encoding": "latency",
    "T": 30,
    "dt": 1e-3,
    "t_min": 2,
    'data_folder' : '../data'
}

In [4]:
torch.manual_seed(data_config["seed"])
np.random.seed(data_config["seed"])
random.seed(data_config["seed"])

data_config["dataset"] = data_config["dataset"]
if data_config["deterministic"]:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if data_config["dataset"] == "mnist":
    train_dataset = datasets.MNIST(
        data_config["data_folder"],
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    )
    test_dataset = datasets.MNIST(
        data_config["data_folder"],
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    )
elif data_config["dataset"] == "ying_yang":
    train_dataset = YinYangDataset(size=60000, seed=data_config["seed"])
    test_dataset = YinYangDataset(size=10000, seed=data_config["seed"] + 2)

else:
    raise ValueError("Invalid dataset name")

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=data_config["batch_size"], shuffle=True, drop_last=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=data_config["batch_size"], shuffle=False, drop_last=True
)

## Models

In [5]:
from eventprop.models import SNN, SpikingLinear_ev, SpikingLinear_su

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = {
    "T": data_config["T"],
    "dt": data_config["dt"],
    "tau_m": 20e-3,
    "tau_s": 5e-3,
    "mu": 1,
    "resolve_silent": False,
    "n_hid": 30,
    "device": device,
    "get_first_spikes": False,
}

n_ins = {"mnist": 784, "ying_yang": 5 if data_config["encoding"] == "latency" else 4}
n_outs = {"mnist": 10, "ying_yang": 3}

dims = [n_ins[data_config["dataset"]]]
if model_config["n_hid"] is not None and isinstance(model_config["n_hid"], list):
    dims.extend(model_config["n_hid"])
elif isinstance(model_config["n_hid"], int):
    dims.append(model_config["n_hid"])
dims.append(n_outs[data_config["dataset"]])

In [7]:
snntorch_model = SNN(dims, **dict(model_config, model_type='snntorch')).to(device)
eventprop_model = SNN(dims, **dict(model_config, model_type='eventprop')).to(device)
eventprop_model.layers[0].weight.data = snntorch_model.layers[0].weight.data
models = {"snntorch": snntorch_model, "eventprop": eventprop_model}

### Initialization

In [8]:
from eventprop.initalization import FluctuationDrivenCenteredNormalInitializer

In [9]:
paper_params = {
    "mnist": {
        "mu": [0.078, 0.2],
        "sigma": [0.045, 0.37],
    },
    "ying_yang": {"mu": [1.5, 0.78], "sigma": [0.93, 0.1]},
}
k_aiming_params = {"mu": [0, 0], "sigma": [1 / np.sqrt(d) for d in dims[:-1]]}

dt, T = data_config["dt"], data_config["T"]

xi = 3
sigma_nu, nu = 1/xi, 15

initializer = FluctuationDrivenCenteredNormalInitializer(
    sigma_u=sigma_nu, nu=nu, timestep=dt
)

fluctuation_params = {
    name: {
        k: v
        for k, v in zip(
            ["mu", "sigma"],
            list(
                zip(
                    *[
                        initializer._get_weight_parameters_con(layer)
                        for layer in model.cpu().layers
                    ]
                )
            ),
        )
    }
    for name, model in models.items()
}
paper_params[data_config["dataset"]], fluctuation_params['eventprop'], k_aiming_params    

({'mu': [0.078, 0.2], 'sigma': [0.045, 0.37]},
 {'mu': (0.0, 0.0), 'sigma': (0.12448720624794865, 0.6363882091727956)},
 {'mu': [0, 0], 'sigma': [0.03571428571428571, 0.18257418583505536]})

## Behavior Comparison

In [10]:
from eventprop.training import encode_data

### Voltage plot check

In [11]:
if True : 
        
    data, targets = next(iter(test_loader))
    data = data.to(device)
    spikes = encode_data(data, argparse.Namespace(**data_config))
    outs = {n: model(spikes) for n, model in models.items()}
    fig, axs = plt.subplots(
        2, 5, figsize=(20, 4), sharex=True, sharey=True, constrained_layout=True
    )
    for t, (ax_model, target) in enumerate(zip(axs.T, targets)):
        for ax, (name, out) in zip(ax_model, outs.items()):
            voltages = out[1][-1][1][:, t].cpu().detach().numpy()
            spikes = out[1][-1][0][:, t].cpu().detach().numpy()
            sns.lineplot(voltages, palette="viridis", ax=ax)
            sns.scatterplot(
                x=np.where(spikes)[0],
                y=np.ones_like(np.where(spikes)[0]),
                ax=ax,
                color="black",
            )
            ax.set_title(f"{name} : {target}")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

In [None]:
voltages.shape

(30, 3)

### Firing rates check

In [None]:
if False : 
    counts = [0, 0]
    frs = [0, 0]
    n_batch = len(train_loader)
    for (data, target), _ in zip(tqdm(train_loader, total=n_batch), range(n_batch)):
        data, target = data.to(device), target.to(device)
        spikes_data = encode_data(data, data_config)
        snn_out = models["snntorch"](spikes_data)
        event_out = models["eventprop"](spikes_data)
        counts = [
            c + out[0].unique(return_counts=True)[1]
            for c, out in zip(counts, [snn_out, event_out])
        ]
        for f, s in enumerate([snn_out, event_out]):
            frs[f] += s[0].float().sum(0).mean(0)
    frs = torch.stack(frs) / n_batch
    frs, counts

In [None]:
raise KeyboardInterrupt

KeyboardInterrupt: 

## Training

In [None]:
from eventprop.training import train_single_model, test
from snntorch.functional.loss import (
    ce_temporal_loss,
    SpikeTime,
    ce_rate_loss,
    ce_count_loss,
)


In [None]:
first_spike_fn = SpikeTime().first_spike_fn
training_config = {
    "n_epochs": 2,
    "loss": "ce_temporal",
    "first_spike_fn": first_spike_fn,
    'alpha' : 0.
}

optim_config = {"lr": 1e-3, "weight_decay": 0, "optimizer": "adam"}

optimizers_type = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD}
optimizers = {
    n: optimizers_type[optim_config["optimizer"]](
        model.parameters(),
        lr=optim_config["lr"],
        weight_decay=optim_config["weight_decay"],
    )
    for n, model in models.items()
}

In [None]:
def get_flat_dict_from_nested(config):
    flat_dict = {}
    for key, value in config.items():
        if isinstance(value, dict):
            flat_dict.update(get_flat_dict_from_nested(value))
        else:
            flat_dict[key] = value
    return flat_dict

In [None]:
config = {
    "data": data_config,
    "model": model_config,
    "training": training_config,
    "optim": optim_config,
}
flat_config = get_flat_dict_from_nested(config)


In [None]:
args = argparse.Namespace(**flat_config)

In [None]:
vars(args)

{'seed': 42,
 'dataset': 'ying_yang',
 'deterministic': True,
 'batch_size': 128,
 'encoding': 'latency',
 'T': 30,
 'dt': 0.001,
 't_min': 2,
 'tau_m': 0.02,
 'tau_s': 0.005,
 'mu': 10,
 'resolve_silent': False,
 'n_hid': 30,
 'device': device(type='cpu'),
 'get_first_spikes': False,
 'n_epochs': 2,
 'loss': 'ce_temporal',
 'first_spike_fn': <bound method Function.apply of <class 'snntorch.functional.loss.SpikeTime.FirstSpike'>>,
 'alpha': 0.0,
 'lr': 0.001,
 'weight_decay': 0,
 'optimizer': 'adam'}

In [None]:
loaders = {"train": train_loader, "test": test_loader}

if args.loss == "ce_temporal":
    criterion = ce_temporal_loss()
elif args.loss == "ce_rate":
    criterion = ce_rate_loss()
elif args.loss == "ce_count":
    criterion = ce_count_loss()
else:
    raise ValueError("Invalid loss type")

In [None]:
train_results = {
    n: train_single_model(
        model, criterion, optimizers[n], loaders, args, first_spike_fn=first_spike_fn, use_wandb=True
    )
    for n, model in models.items()
}

 | :   0%|          | 0/2 [00:00<?, ?it/s]

 | :   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
import wandb

In [None]:
wandb.run.id

'u3rvbc03'

In [None]:
from itertools import product

In [None]:
plot_data = {}
for name, results in train_results.items():
    for trial, metric in product(["train", "test"], ["loss", "acc"]):
        plot_data.setdefault(f"{trial}_{metric}", [])
        plot_data[f"{trial}_{metric}"].extend(results[f"{trial}_{metric}"])
    plot_data.setdefault("epoch", [])
    plot_data["epoch"].extend(np.arange(len(results[f"{trial}_{metric}"])))
    plot_data.setdefault("model", [])
    plot_data["model"].extend([name] * len(results[f"{trial}_{metric}"]))

In [None]:
plot_data

{'train_loss': [739.4349976238022,
  648.0840718848074,
  466.6816376906175,
  343.872798691448],
 'train_acc': [0.3933961004273504,
  0.41855301816239315,
  0.47025240384615385,
  0.5035389957264957],
 'test_loss': [tensor(0.0459),
  tensor(0.0412),
  tensor(0.0325),
  tensor(0.0325),
  tensor(0.0236),
  tensor(0.0168)],
 'test_acc': [0.3968349358974359,
  0.4276842948717949,
  0.4602363782051282,
  0.4602363782051282,
  0.5003004807692307,
  0.5596955128205128],
 'epoch': [0, 1, 2, 0, 1, 2],
 'model': ['snntorch',
  'snntorch',
  'snntorch',
  'eventprop',
  'eventprop',
  'eventprop']}