# Meta-Training SNNs using MAML

In [2]:
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import datasets, transforms

from yingyang.dataset import YinYangDataset

import random
from tqdm.notebook import tqdm
import pandas as pd
import seaborn as sns
import argparse

In [3]:
%load_ext autoreload
%autoreload 2

## Data and Config

In [4]:
from eventprop.config import get_flat_dict_from_nested

In [5]:
data_config = {
    "seed": 42,
    "dataset": "ying_yang",
    "deterministic": True,
    "batch_size": 128,
    "encoding": "latency",
    "T": 50,
    "dt": 1e-3,
    "t_min": 2,
    "data_folder": "../../../data/",
}

In [6]:
torch.manual_seed(data_config["seed"])
np.random.seed(data_config["seed"])
random.seed(data_config["seed"])

data_config["dataset"] = data_config["dataset"]
if data_config["deterministic"]:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if data_config["dataset"] == "mnist":
    train_dataset = datasets.MNIST(
        data_config["data_folder"],
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    )
    test_dataset = datasets.MNIST(
        data_config["data_folder"],
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    )
    
elif data_config["dataset"] == "ying_yang":
    train_dataset = YinYangDataset(size=1000, seed=data_config["seed"])
    test_dataset = YinYangDataset(size=1000, seed=data_config["seed"] + 2)

elif data_config["dataset"] == "synthetic":
    in_dim = 2
    input_spike_times = {
        k: (
            np.random.choice(np.arange(int(data_config["T"]) - 1), size=5)
            * (1e-3 / data_config["dt"])
        ).astype(int)
        for k in range(in_dim)
    }
    input_spikes = np.zeros((data_config["T"], 1, in_dim))
    for n, times in input_spike_times.items():
        input_spikes[times, 0, n] = 1

else:
    raise ValueError("Invalid dataset name")


## Rotation Ying Yang data for Meta Learning

In [7]:
train_rotations = np.linspace(0, np.pi/2, 45)
test_rotations = np.linspace(np.pi/2, np.pi, 45)

In [5]:
from torchmeta.utils.data import BatchMetaDataLoader


## Models

In [6]:
from eventprop.models import SNN

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = {
    "model_type": None,  # will be set later
    "snn": {
        "T": data_config["T"],
        "dt": data_config["dt"],
        "tau_m": 20e-3,
        "tau_s": 5e-3,
    },
    "weights": {
        "init_mode": "kaiming_both",
        "scale": 5.5,
        "seed": data_config["seed"],
        # "mu": paper_params[data_config["dataset"]]["mu"],
        # "sigma": paper_params[data_config["dataset"]]["sigma"],
        "n_hid": 30,
        "resolve_silent": False,
        "dropout": 0.0,
    },
    "device": torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu"),
    # "device": "cpu",
}

training_config = {
    "n_epochs": 15,
    "loss": "ce_temporal",
    "alpha": 0.0,
    "xi": 1,
    "beta": 6.4,
    "n_tests": 20,
}

optim_config = {
    "lr": 1e-2,
    "weight_decay": 1e-6,
    "optimizer": "adam",
    "gamma": 0.9,
}

config = {
    "data": data_config,
    "model": model_config,
    "training": training_config,
    "optim": optim_config,
}

flat_config = get_flat_dict_from_nested(config)

n_ins = {
    "mnist": 784,
    "ying_yang": 4 if data_config["encoding"] == "latency" else 4,
    "synthetic": 2,
}
n_outs = {"mnist": 10, "ying_yang": 3, "synthetic": 2}

dims = [n_ins[data_config["dataset"]]]
if model_config["weights"]["n_hid"] is not None and isinstance(
    model_config["weights"]["n_hid"], list
):
    dims.extend(model_config["weights"]["n_hid"])
elif isinstance(model_config["weights"]["n_hid"], int):
    dims.append(model_config["weights"]["n_hid"])
dims.append(n_outs[data_config["dataset"]])