In [3]:
import torch

# from speech2spikes import S2S

import torchvision
from torch.utils.data import DataLoader

from torch_mate.data.utils import FewShot
from torch_mate.utils import get_device

from neurobench.data.datasets.MSWC import MSWC
from neurobench.examples.few_shot_learning.utils import train_using_MAML
from neurobench.models import M5
from neurobench.utils import Dict2Class

In [2]:
ROOT = "//scratch/p306982/data/fscil/mswc/"

In [3]:
cfg = {
    "criterion": {"name": "CrossEntropyLoss"},
    "meta_learning": {
        "fast_lr": 0.4,
        "adaptation_steps": 1,
        "test_adaptation_steps": 1,
        "meta_batch_size": 32,
        "num_iterations": 30000,
        "name": "MAML",
        "first_order": False,
        "ways": 5,
        "shots": 1,
        "query_shots": 100,
    },
    "continual_learning": {
        "max_classes_to_learn": 200
    },
    "model": {
        "name": "M5",
        "cfg": {
            "stride": 16,
            "n_channel": 32
        },
    },
    "optimizer": {"name": "Adam", "cfg": {"lr": 0.001, "betas": (0.9, 0.999)}},
    "seed": 4223747124,
    "task": {
        "name": "MWSC",
        "cfg": {
            "representation": {
                "name": "MFCC",
                "cfg": {
                    "center": True,
                    "hop_length": 160,
                    "n_fft": 400,
                    "n_mels": 96,
                    "n_mfcc": 48

                }
            }
        }
    },
}

cfg = Dict2Class(cfg)

In [4]:
model = M5(n_input=cfg.task.cfg.representation.cfg.n_mfcc,
           stride=cfg.model.cfg.stride,
           n_channel=cfg.model.cfg.n_channel,
           n_output=cfg.continual_learning.max_classes_to_learn)

In [11]:
eval_dataset = MSWC(ROOT, subset='evaluation')

In [12]:
fscil_set = FewShot(eval_dataset, 10, 5, 100, None, (100, 100), True, True, None, 200, torch.nn.Identity(), None)

In [13]:
eval_data_loader = DataLoader(fscil_set, 1, num_workers=8)

In [14]:
for session, (X, y) in enumerate(eval_data_loader):
    print("Session: {}".format(session))
    print(X, y)