In [1]:

import wandb
from tqdm import tqdm
import torch
import os

from torch.utils.data import DataLoader, ConcatDataset, TensorDataset
import torchaudio

from torch import nn, optim, distributions as dist
import torch.nn.functional as F

import learn2learn as l2l
import copy

# from torch_mate.data.utils import IncrementalFewShot

import sys
sys.path.append("/home3/p306982/Simulations/fscil/algorithms_benchmarks/")

from neurobench.datasets import MSWC
from neurobench.preprocessing.speech2spikes import S2SProcessor
from neurobench.datasets.IncrementalFewShot import IncrementalFewShot
from neurobench.examples.model_data.M5 import M5

from neurobench.benchmarks import Benchmark
from cl_utils import *

import numpy as np
import argparse

import snntorch


In [2]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt

  from snntorch import backprop


In [3]:
ROOT = "//scratch/p306982/data/fscil/mswc/"
NUM_WORKERS = 8
BATCH_SIZE = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [5]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="training")
pre_train_loader = DataLoader(base_train_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
base_test_set = MSWC(root=ROOT, subset="base", procedure="testing")

In [4]:
len(base_train_set)

50000

In [5]:
len(base_test_set)

10000

In [6]:
base_train_set[0]

(tensor([[ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -3.0518e-05,
           0.0000e+00,  0.0000e+00]]),
 48,
 '//scratch/p306982/data/fscil/mswc/en/clips',
 'add/common_voice_en_100260.opus')

In [7]:
base_test_set[0]

(tensor([[ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -3.0518e-05,
           0.0000e+00,  0.0000e+00]]),
 48,
 '//scratch/p306982/data/fscil/mswc/en/clips',
 'add/common_voice_en_100260.opus')

In [8]:
from neurobench.datasets.MSWC import generate_mswc_fscil_splits
generate_mswc_fscil_splits(ROOT, ["en"])

['en']


({'see': 11526,
  'think': 9085,
  'work': 7944,
  'find': 6633,
  'home': 5877,
  'look': 5137,
  'music': 4657,
  'next': 4576,
  'heart': 4317,
  'play': 3850,
  'read': 3597,
  'show': 3574,
  'game': 3569,
  'set': 3291,
  'open': 3042,
  'sound': 2769,
  'story': 2702,
  'understand': 2561,
  'talk': 2503,
  'keep': 2466,
  'call': 2430,
  'stop': 2398,
  'change': 2172,
  'hear': 2161,
  'run': 2139,
  'start': 2132,
  'feel': 2115,
  'remember': 2079,
  'speak': 1978,
  'eat': 1944,
  'present': 1923,
  'center': 1900,
  'phone': 1845,
  'order': 1820,
  'learn': 1818,
  'close': 1809,
  'ask': 1782,
  'style': 1732,
  'season': 1720,
  'news': 1709,
  'research': 1660,
  'turn': 1659,
  'write': 1592,
  'drink': 1574,
  'search': 1567,
  'market': 1561,
  'track': 1493,
  'project': 1489,
  'add': 1482,
  'watch': 1474,
  'forget': 1464,
  'position': 1432,
  'listen': 1406,
  'sort': 1394,
  'design': 1334,
  'bank': 1291,
  'sit': 1271,
  'video': 1248,
  'save': 1241,
  'de

In [None]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="testing")

In [6]:
S2S = S2SProcessor(device, transpose=False)
S2S._default_spec_kwargs["sample_rate"] = 48000
S2S._default_spec_kwargs["hop_length"] = 240
pre_proc = S2S.transform

In [14]:
S2S._default_spec_kwargs

{'sample_rate': 48000,
 'n_mels': 20,
 'n_fft': 512,
 'f_min': 20,
 'f_max': 4000,
 'hop_length': 240}

In [9]:
import torchaudio.transforms as T

n_fft = 2048
win_length = None
hop_length = 240
n_mels = 256
n_mfcc = 256

mfcc_transform = T.MFCC(
    sample_rate=48000,
    n_mfcc=n_mfcc,
    melkwargs={
        "n_fft": n_fft,
        "n_mels": n_mels,
        "hop_length": hop_length,
        "mel_scale": "htk",
    },
)



In [7]:
data = next(iter(pre_train_loader))

In [8]:
transform = S2S.transform.to(device)

In [9]:
transform(data[0].to(device))

tensor([[[[1.1222e-09, 8.1323e-10, 8.3698e-10,  ..., 5.2586e-09,
           1.8456e-09, 1.1892e-09],
          [9.6105e-10, 1.2303e-09, 2.3467e-09,  ..., 3.0912e-09,
           2.3288e-09, 1.6148e-09],
          [2.9100e-09, 5.1502e-09, 7.8932e-09,  ..., 8.1274e-09,
           1.1089e-08, 1.3257e-08],
          ...,
          [8.6771e-08, 6.4785e-08, 3.7651e-08,  ..., 5.8987e-08,
           7.7212e-08, 1.0462e-07],
          [1.1844e-07, 9.7498e-08, 7.5531e-08,  ..., 9.2827e-08,
           1.0752e-07, 1.3343e-07],
          [1.1271e-07, 1.1175e-07, 1.1668e-07,  ..., 1.4001e-07,
           1.8513e-07, 2.2397e-07]]],


        [[[2.0776e-08, 2.2499e-08, 2.9034e-08,  ..., 1.2042e-08,
           4.4046e-09, 2.5336e-09],
          [8.9449e-09, 5.8346e-09, 4.5939e-09,  ..., 4.7325e-09,
           2.4665e-09, 1.7616e-09],
          [3.9666e-09, 3.6041e-09, 3.3655e-09,  ..., 6.1593e-09,
           1.6837e-08, 2.5383e-08],
          ...,
          [1.0682e-07, 1.2322e-07, 1.1922e-07,  ..., 5.44

In [10]:
spike, target = S2S((data[0].to(device), data[1].to(device)))

In [11]:
spike.size()

torch.Size([256, 601, 20])

In [5]:
data = base_train_set[0][0]
mfcc = mfcc_transform(data)
mfcc.size()

torch.Size([1, 256, 201])

In [6]:
mfcc

tensor([[[-5.9695e+02, -5.9704e+02, -5.9712e+02,  ..., -5.9683e+02,
          -5.9659e+02, -5.9648e+02],
         [-7.4492e-01, -6.2521e-01, -5.1344e-01,  ..., -9.3668e-01,
          -1.2713e+00, -1.4302e+00],
         [ 7.0271e-01,  5.8958e-01,  4.8761e-01,  ...,  9.2732e-01,
           1.2569e+00,  1.4129e+00],
         ...,
         [-1.2441e-01, -9.6221e-02, -4.0581e-02,  ..., -1.3851e-02,
          -1.2847e-02, -3.8621e-03],
         [ 8.5055e-02,  6.5515e-02,  2.7330e-02,  ...,  8.0706e-03,
           7.6974e-03,  1.6639e-03],
         [-4.3595e-02, -3.3596e-02, -1.4172e-02,  ..., -4.0992e-03,
          -4.0109e-03, -9.8432e-04]]])

In [7]:
torch.sum(mfcc==0.0)

tensor(0)

In [114]:
zero_count = torch.zeros((1,))
for batch_idx, (data, target) in tqdm(enumerate(pre_train_loader), total=len(base_train_set)//BATCH_SIZE):
    mfcc = mfcc_transform(data)
    zero_count +=torch.sum(mfcc==0.0)

 68%|██████▊   | 133/195 [00:46<00:21,  2.87it/s]


KeyboardInterrupt: 

In [118]:
torch.sum(torch.isinf(mfcc))

tensor(0)

In [45]:
class M5_Feature(nn.Module):
    def __init__(self, n_input=1, stride=16, n_channel=32, drop =False):

        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=8, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.act1 = nn.ReLU()
        self.drop1 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.act2 = nn.ReLU()
        self.drop2 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.act3 = nn.ReLU()
        self.drop3 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool3 = nn.MaxPool1d(2)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.act4 = nn.ReLU()
        self.drop4 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool4 = nn.MaxPool1d(2)

        # self.avg_pool = nn.AvgPool1d(64)



    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(self.bn1(x))
        x = self.drop1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.act2(self.bn2(x))
        x = self.drop2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.act3(self.bn3(x))
        x = self.drop3(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.act4(self.bn4(x))
        x = self.drop4(x)
        x = self.pool4(x)
        # x = self.avg_pool(x)
        # x = x.permute(0, 2, 1)

        return x

In [18]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.9
num_steps = 50



import torch.nn as nn
import torch.nn.functional as F
from torch import cat, mul
import collections

def remove_sequential(network, all_layers):

    for name, layer in network.named_children():
        if isinstance(layer, nn.Sequential): # if sequential layer, apply recursively to layers in sequential layer
            #print(layer)
            remove_sequential(layer, all_layers)
        else: # if leaf node, add it to list
            # print(layer)
            all_layers.append((name,layer))

In [29]:
class SCNN_features(nn.Module):
    def __init__(self, n_input=1, stride=16, n_channel=32, input_kernel=80, pool_kernel=4, drop=False):

        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=input_kernel, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.drop1 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool1 = nn.MaxPool1d(pool_kernel)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.drop2 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool2 = nn.MaxPool1d(pool_kernel)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.drop3 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool3 = nn.MaxPool1d(pool_kernel)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.drop4 = nn.Dropout(p=0.2) if drop else nn.Identity()
        self.pool4 = nn.MaxPool1d(2)

        # self.avg_pool = nn.AvgPool1d(64)



    def forward(self, x):



        x = self.conv1(x)
        x, mem1 = self.lif1(self.bn1(x), mem1)
        x = self.drop1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x, mem2 = self.lif2(self.bn2(x), mem2)
        x = self.drop2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x, mem3 = self.lif3(self.bn3(x), mem3)
        x = self.drop3(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x, mem4 = self.lif4(self.bn4(x), mem4)
        x = self.drop4(x)
        x = self.pool4(x)
        # x = self.avg_pool(x)
        # x = x.permute(0, 2, 1)

        return x
    

class SCNN(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, 
                 n_channel=32, load_model=None, output= None, input_kernel=80, pool_kernel=4, drop=False, latent_layer_num=100):
        """Taken from: https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html"""
        super().__init__()

        if load_model:
            all_layers = []
            # all_seq = nn.Sequential(load_model.lat_features, load_model.end_features)
            remove_sequential(load_model.lat_features, all_layers)

        else:
            features = SCNN_features(n_input, stride, n_channel, input_kernel, pool_kernel, drop)
            
            all_layers = []
            remove_sequential(features, all_layers)

        lat_list = []
        end_list = []

        for i, layer in enumerate(all_layers):
            if i <= latent_layer_num:
                lat_list.append(layer)
            else:
                end_list.append(layer)

        self.lat_features = nn.Sequential(collections.OrderedDict(lat_list))
        self.end_features = nn.Sequential(collections.OrderedDict(end_list))

        if load_model:
            self.output = load_model.output
        else:
            self.output = nn.Linear(2 * n_channel, n_output, bias=False)

    def forward(self, x, latent_input=None, return_lat_acts=False):

        utils.reset(self.lat_features)
        utils.reset(self.end_features)
        # mem1 = self.lif1.init_leaky()
        # mem2 = self.lif2.init_leaky()
        # mem3 = self.lif3.init_leaky()
        # mem4 = self.lif4.init_leaky()

        num_steps = x.shape[1]

        for step in range(num_steps):

            orig_acts = self.lat_features(x[step])
            if latent_input is not None:
                lat_acts = cat((orig_acts, latent_input), 0)
            else:
                lat_acts = orig_acts

            logits = self.end_features(lat_acts)
            logits = F.avg_pool1d(logits, logits.shape[-1])
            logits = logits.permute(0, 2, 1)

            outputs = self.output(logits)

        # masked_outputs = torch.mul(self.mask, outputs)

        if return_lat_acts:
            return outputs, orig_acts
        else:
            return outputs

In [31]:
class SCNN(nn.Module):
    def __init__(self,load_model=None, drop=False, latent_layer_num=100):
        """Taken from: https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html"""
        super().__init__()

        if load_model:
            all_layers = []
            # all_seq = nn.Sequential(load_model.lat_features, load_model.end_features)
            remove_sequential(load_model.lat_features, all_layers)

        else:
            net = nn.Sequential(
                # nn.Flatten(),
                nn.Linear(20, 256),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                nn.Linear(256, 256),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                nn.Linear(256, 256),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                nn.Linear(256, 256),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )
            all_layers = []
            remove_sequential(net, all_layers)

        lat_list = []
        end_list = []

        for i, layer in enumerate(all_layers):
            if i <= latent_layer_num:
                lat_list.append(layer)
            else:
                end_list.append(layer)

        self.lat_features = nn.Sequential(collections.OrderedDict(lat_list))
        self.end_features = nn.Sequential(collections.OrderedDict(end_list))

        self.output = nn.Linear(256, 200, bias=False)
        self.output_neurons = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x, latent_input=None, return_lat_acts=False):

        utils.reset(self.lat_features)
        utils.reset(self.end_features)
        utils.reset(self.output_neurons)
        softmax_output = torch.zeros(x.size(0), 200).to(device)
        # mem1 = self.lif1.init_leaky()
        # mem2 = self.lif2.init_leaky()
        # mem3 = self.lif3.init_leaky()
        # mem4 = self.lif4.init_leaky()
        # mem = []

        num_steps = x.shape[1]

        for step in range(num_steps):

            orig_acts = self.lat_features(x[:,step])
            if latent_input is not None:
                lat_acts = cat((orig_acts, latent_input), 0)
            else:
                lat_acts = orig_acts

            logits = self.end_features(lat_acts)
            # logits = F.avg_pool1d(logits, logits.shape[-1])
            # logits = logits.permute(0, 2, 1)

            outputs = self.output(logits)
            _, mem_out = self.output_neurons(outputs)

            softmax_output += F.softmax(mem_out, dim=1)

        if return_lat_acts:
            return softmax_output, orig_acts
        else:
            return softmax_output

In [32]:
model = SCNN().to(device)

In [33]:
out = model(spike)
out.size()

torch.Size([256, 200])

In [34]:
out

tensor([[3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050],
        [3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050],
        [3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050],
        ...,
        [3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050],
        [3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050],
        [3.0050, 3.0050, 3.0050,  ..., 3.0050, 3.0050, 3.0050]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [50]:
model = M5_Feature(256, 2, 128, drop=False)

In [51]:
model = model.train()

In [13]:
load_model = torch.load(os.path.join(ROOT, "model_ref"), map_location=device)
model = M5(n_input=1, n_output=200, load_model= load_model, latent_layer_num=9).to(device)

In [15]:
sum(p.numel() for p in load_model.parameters())

1745408

In [9]:
model

M5(
  (lat_features): Sequential(
    (conv1): Conv1d(256, 256, kernel_size=(4,), stride=(2,))
    (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU()
    (drop1): Dropout(p=0.2, inplace=False)
    (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
    (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU()
    (drop2): Dropout(p=0.2, inplace=False)
    (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (end_features): Sequential(
    (conv3): Conv1d(256, 512, kernel_size=(3,), stride=(1,))
    (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act3): ReLU()
    (drop3): Dropout(p=0.2, inplace=False)
    (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv1d(512

In [59]:
load_model

M5(
  (lat_features): Sequential(
    (conv1): Conv1d(256, 256, kernel_size=(8,), stride=(1,))
    (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU()
    (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
    (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU()
    (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv1d(256, 512, kernel_size=(3,), stride=(1,))
    (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act3): ReLU()
    (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv1d(512, 512, kernel_size=(3,), stride=(1,))
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act4): ReLU()
    (pool4): M

In [56]:
for name, param in model.named_parameters():
    print(name)

conv1.weight
conv1.bias
bn1.weight
bn1.bias
conv2.weight
conv2.bias
bn2.weight
bn2.bias
conv3.weight
conv3.bias
bn3.weight
bn3.bias
conv4.weight
conv4.bias
bn4.weight
bn4.bias


In [26]:
data = model(mfcc)

In [102]:
data

tensor([[[0.9530, 0.0000, 0.0000,  ..., 0.4128, 0.0000, 0.0000],
         [0.3260, 0.5191, 0.7440,  ..., 0.5860, 1.0364, 0.8143],
         [0.2114, 0.6727, 0.6269,  ..., 1.1413, 0.7937, 0.9362],
         ...,
         [1.9893, 0.6086, 0.4245,  ..., 0.4697, 0.4444, 0.0318],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0509, 0.5623,  ..., 0.6040, 0.6049, 1.2391]]],
       grad_fn=<SqueezeBackward1>)

In [27]:
data.size()

torch.Size([1, 256, 4])

In [28]:
avg_pool = nn.AvgPool1d(10)
feature = avg_pool(data)
feature.size()

RuntimeError: Given input size: (256x1x4). Calculated output size: (256x1x0). Output size is too small