In [1]:

import wandb
from tqdm import tqdm
import torch
import os

from torch.utils.data import DataLoader, ConcatDataset, TensorDataset
import torchaudio

from torch import nn, optim, distributions as dist
import torch.nn.functional as F

import learn2learn as l2l
import copy

# from torch_mate.data.utils import IncrementalFewShot

import sys
sys.path.append("/home3/p306982/Simulations/fscil/algorithms_benchmarks/")

from neurobench.datasets import MSWC
from neurobench.preprocessing.speech2spikes import S2SProcessor
from neurobench.datasets.IncrementalFewShot import IncrementalFewShot
from neurobench.examples.model_data.M5 import M5

from neurobench.benchmarks import Benchmark
from cl_utils import *

import numpy as np
import argparse


In [2]:
ROOT = "//scratch/p306982/data/fscil/mswc/"
NUM_WORKERS = 8
BATCH_SIZE = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="training", incremental=True)
pre_train_loader = DataLoader(base_train_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
base_test_set = MSWC(root=ROOT, subset="base", procedure="testing", incremental=True)

In [10]:
len(base_train_set)

50000

In [11]:
len(base_test_set)

10000

In [12]:
base_train_set[0]

(tensor([[ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -3.0518e-05,
           0.0000e+00,  0.0000e+00]]),
 48,
 '//scratch/p306982/data/fscil/mswc/en/clips',
 'add/common_voice_en_100260.opus')

In [13]:
base_test_set[0]

(tensor([[ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -3.0518e-05,
           0.0000e+00,  0.0000e+00]]),
 48,
 '//scratch/p306982/data/fscil/mswc/en/clips',
 'add/common_voice_en_100260.opus')

In [15]:
from neurobench.datasets.MSWC import generate_mswc_fscil_splits
generate_mswc_fscil_splits(ROOT, ["en"])

['en']


({'see': 11526,
  'think': 9085,
  'work': 7944,
  'find': 6633,
  'home': 5877,
  'look': 5137,
  'music': 4657,
  'next': 4576,
  'heart': 4317,
  'play': 3850,
  'read': 3597,
  'show': 3574,
  'game': 3569,
  'set': 3291,
  'open': 3042,
  'sound': 2769,
  'story': 2702,
  'understand': 2561,
  'talk': 2503,
  'keep': 2466,
  'call': 2430,
  'stop': 2398,
  'change': 2172,
  'hear': 2161,
  'run': 2139,
  'start': 2132,
  'feel': 2115,
  'remember': 2079,
  'speak': 1978,
  'eat': 1944,
  'present': 1923,
  'center': 1900,
  'phone': 1845,
  'order': 1820,
  'learn': 1818,
  'close': 1809,
  'ask': 1782,
  'style': 1732,
  'season': 1720,
  'news': 1709,
  'research': 1660,
  'turn': 1659,
  'write': 1592,
  'drink': 1574,
  'search': 1567,
  'market': 1561,
  'track': 1493,
  'project': 1489,
  'add': 1482,
  'watch': 1474,
  'forget': 1464,
  'position': 1432,
  'listen': 1406,
  'sort': 1394,
  'design': 1334,
  'bank': 1291,
  'sit': 1271,
  'video': 1248,
  'save': 1241,
  'de

In [None]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="testing")

In [5]:
S2S = S2SProcessor(device)
S2S._default_spec_kwargs["sample_rate"] = 48000
pre_proc = S2S.transform

In [111]:
import torchaudio.transforms as T

n_fft = 2048
win_length = None
hop_length = 240
n_mels = 256
n_mfcc = 256

mfcc_transform = T.MFCC(
    sample_rate=48000,
    n_mfcc=n_mfcc,
    melkwargs={
        "n_fft": n_fft,
        "n_mels": n_mels,
        "hop_length": hop_length,
        "mel_scale": "htk",
    },
)



In [121]:
data = base_train_set[0][0]
mfcc = mfcc_transform(data)
mfcc.size()

torch.Size([1, 256, 201])

In [95]:
mfcc

tensor([[[3.2433e-09, 2.1147e-09, 7.0820e-10,  ..., 4.3519e-09,
          2.3478e-09, 5.4743e-10],
         [7.2043e-10, 3.6058e-10, 9.0458e-11,  ..., 4.3140e-09,
          3.5827e-09, 1.4136e-09],
         [1.2516e-09, 3.6034e-10, 4.4313e-10,  ..., 3.1022e-09,
          2.3836e-09, 6.7672e-10],
         ...,
         [1.5025e-09, 8.5147e-10, 5.5862e-10,  ..., 6.2881e-10,
          1.3803e-10, 4.6415e-10],
         [3.2382e-10, 1.4809e-09, 1.3067e-09,  ..., 3.3691e-10,
          2.7184e-10, 2.6055e-10],
         [3.2041e-09, 1.3078e-09, 1.9749e-09,  ..., 4.5405e-10,
          1.9087e-09, 9.1985e-10]]])

In [113]:
torch.sum(mfcc==0.0)

tensor(0)

In [114]:
zero_count = torch.zeros((1,))
for batch_idx, (data, target) in tqdm(enumerate(pre_train_loader), total=len(base_train_set)//BATCH_SIZE):
    mfcc = mfcc_transform(data)
    zero_count +=torch.sum(mfcc==0.0)

 68%|██████▊   | 133/195 [00:46<00:21,  2.87it/s]


KeyboardInterrupt: 

In [118]:
torch.sum(torch.isinf(mfcc))

tensor(0)

In [97]:
class M5_Feature(nn.Module):
    def __init__(self, n_input=1, stride=16, n_channel=32):

        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=8, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(2)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.act4 = nn.ReLU()
        self.pool4 = nn.MaxPool1d(2)
        # self.avg_pool = nn.AvgPool1d(64)



    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.act2(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.act3(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.act4(self.bn4(x))
        x = self.pool4(x)
        # x = self.avg_pool(x)
        # x = x.permute(0, 2, 1)

        return x

In [119]:
model = M5_Feature(256, 1, 128)

In [122]:
data = model(mfcc)

In [102]:
data

tensor([[[0.9530, 0.0000, 0.0000,  ..., 0.4128, 0.0000, 0.0000],
         [0.3260, 0.5191, 0.7440,  ..., 0.5860, 1.0364, 0.8143],
         [0.2114, 0.6727, 0.6269,  ..., 1.1413, 0.7937, 0.9362],
         ...,
         [1.9893, 0.6086, 0.4245,  ..., 0.4697, 0.4444, 0.0318],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0509, 0.5623,  ..., 0.6040, 0.6049, 1.2391]]],
       grad_fn=<SqueezeBackward1>)

In [123]:
data.size()

torch.Size([1, 256, 10])

In [124]:
avg_pool = nn.AvgPool1d(10)
feature = avg_pool(data)
feature.size()

torch.Size([1, 256, 1])