In [3]:
# !pip install -r requirements.txt

Initialize 

In [4]:
import torch
import torchaudio, torchvision
import os
import matplotlib.pyplot as plt 
import librosa
import argparse
import numpy as np
import wandb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torchvision.transforms import ToTensor
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import load_speechcommands_item

  from .autonotebook import tqdm as notebook_tqdm


Download Datasets

In [5]:
class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return waveform, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return waveform, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

Lightning Data module basic


In [6]:
class KWSDataModule(LightningDataModule):
    def __init__(self, path, batch_size=128, num_workers=0, n_fft=512, 
                 n_mels=128, win_length=None, hop_length=256, class_dict={}, 
                 **kwargs):
        print('_'*20, 'kws init')
        super().__init__(**kwargs)
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length
        self.class_dict = class_dict

    def prepare_data(self):
        print('_'*20, 'kws prep data')
        self.train_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                download=True,
                                                                subset='training')

        silence_dataset = SilenceDataset(self.path)
        unknown_dataset = UnknownDataset(self.path)
        self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset, silence_dataset, unknown_dataset])
                                                                
        self.val_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                              download=True,
                                                              subset='validation')
        self.test_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                               download=True,
                                                               subset='testing')                                                    
        _, sample_rate, _, _, _ = self.train_dataset[0]
        self.sample_rate = sample_rate
        self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                              n_fft=self.n_fft,
                                                              win_length=self.win_length,
                                                              hop_length=self.hop_length,
                                                              n_mels=self.n_mels,
                                                              power=2.0)

    def setup(self, stage=None):
        print('_'*20, 'setup ')
        self.prepare_data()

    def train_dataloader(self):
        print('_'*20, 'kws train dataloader')
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        print('_'*20, 'kws val dataloader')
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
            collate_fn=self.collate_fn
        )
    
    def test_dataloader(self):
        print('_'*20, 'kws test dataloader')
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def collate_fn(self, batch):
        print('_'*20, 'kws collate fn')
        mels = []
        labels = []
        wavs = []
        for sample in batch:
            # print('\nsample:', sample)
            waveform, sample_rate, label, speaker_id, utterance_number = sample
            # ensure that all waveforms are 1sec in length; if not pad with zeros
            if waveform.shape[-1] < sample_rate:
                waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)
            elif waveform.shape[-1] > sample_rate:
                waveform = waveform[:,:sample_rate]

            # mel from power to db
            mel1 = ToTensor()(librosa.power_to_db(self.transform(waveform).squeeze().numpy(), ref=np.max))
            # print('\nMEL 1 Shape:', mel1.shape)
            # print('\nMEL 1:', mel1)

            label1 = torch.tensor(self.class_dict[label])
            # print('\nlabel:', label1.shape)

            # print('\nwavform:', waveform.shape)
            mels.append(mel1)
            labels.append(label1)
            wavs.append(waveform)

        mels = torch.stack(mels)
        labels = torch.stack(labels)
        wavs = torch.stack(wavs)
    
        # print('mels', mels.shape)
        # print('labels', labels.shape)
        # print('wavs', wavs.shape)
        input()

        return mels, labels, wavs



PL LigtningMOdule MODEL

In [7]:
class KWSModel(LightningModule):
    def __init__(self, num_classes=37, epochs=30, lr=0.001, **kwargs):
        print('_'*20, 'init')
        super().__init__()
        self.save_hyperparameters()
        self.model = torchvision.models.resnet18(num_classes=num_classes)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        print('_'*20, 'fwd')
        return self.model(x)

    def training_step(self, batch, batch_idx):
        print('_'*20, 'steps')
        mels, labels, _ = batch

        print(mels.size())

        preds = self.model(mels)
        loss = self.hparams.criterion(preds, labels)
        return {'loss': loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        # print('_'*20, 'train ep end')
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    def validation_step(self, batch, batch_idx):
        # print('_'*20, 'valid step')
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        # print('_'*20, 'vald epoc end')
        return self.test_epoch_end(outputs)

    def test_step(self, batch, batch_idx):
        print('_'*20, 'test step')
        mels, labels, wavs = batch
        preds = self.model(mels)
        loss = self.hparams.criterion(preds, labels)
        acc = accuracy(preds, labels) * 100.
        return {"preds": preds, 'test_loss': loss, 'test_acc': acc}

    def test_epoch_end(self, outputs):
        # print('_'*20, 'test ep end')
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        print('_'*20, 'optimizer')
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.hparams.lr)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.hparams.epochs)
        return [optimizer], [lr_scheduler]

    def setup(self, stage=None):
        print('_'*20, 'setup')
        self.hparams.criterion = torch.nn.CrossEntropyLoss()
        


In [8]:
# def get_args():
#     parser = argparse.ArgumentParser()
#     # model training hyperparameters
#     parser.add_argument('--batch-size', type=int, default=128, metavar='N',
#                         help='input batch size for training (default: 64)')
#     parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
#                         help='number of epochs to train (default: 30)')
#     parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
#                         help='learning rate (default: 0.001)')

#     # where dataset will be stored
#     parser.add_argument("--path", type=str, default="data/speech_commands/")

#     # 35 keywords + silence + unknown
#     parser.add_argument("--num-classes", type=int, default=37)
   
#     # mel spectrogram parameters
#     parser.add_argument("--n-fft", type=int, default=1024)
#     parser.add_argument("--n-mels", type=int, default=128)
#     parser.add_argument("--win-length", type=int, default=None)
#     parser.add_argument("--hop-length", type=int, default=512)

#     # 16-bit fp model to reduce the size
#     parser.add_argument("--precision", default=16)
#     parser.add_argument("--accelerator", default='gpu')
#     parser.add_argument("--devices", default=1)
#     parser.add_argument("--num-workers", type=int, default=96)

#     parser.add_argument("--no-wandb", default=True, action='store_true')

#     args = parser.parse_args("")
#     return args
# args = get_args()


to ewdit

In [9]:
# import torch
# import torchvision

# from argparse import ArgumentParser
# from pytorch_lightning import LightningModule, Trainer, LightningDataModule
# from torch.optim import Adam
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from torchmetrics.functional import accuracy
# from einops import rearrange
# from torch import nn
# # from torchvision.datasets.cifar import CIFAR10

In [10]:

# class LitTransformer(LightningModule):
#     def __init__(self, num_classes=10, lr=0.001, max_epochs=30, depth=12, embed_dim=64,
#                  head=4, patch_dim=192, seqlen=16, **kwargs):
#         super().__init__()
#         self.save_hyperparameters()
#         self.encoder = Transformer(dim=embed_dim, num_heads=head, num_blocks=depth, mlp_ratio=4.,
#                                    qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
#         self.embed = torch.nn.Linear(patch_dim, embed_dim)

#         self.fc = nn.Linear(seqlen * embed_dim, num_classes)
#         self.loss = torch.nn.CrossEntropyLoss()
        
#         self.reset_parameters()


#     def reset_parameters(self):
#         init_weights_vit_timm(self)
    

#     def forward(self, x):
#         # Linear projection
#         x = self.embed(x)
            
#         # Encoder
#         x = self.encoder(x)
#         x = x.flatten(start_dim=1)

#         # Classification head
#         x = self.fc(x)
#         return x
    
#     def configure_optimizers(self):
#         optimizer = Adam(self.parameters(), lr=self.hparams.lr)
#         # this decays the learning rate to 0 after max_epochs using cosine annealing
#         scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
#         return [optimizer], [scheduler]

#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         y_hat = self(x)
#         loss = self.loss(y_hat, y)
#         return loss
    

#     def test_step(self, batch, batch_idx):
#         x, y = batch
#         y_hat = self(x)
#         loss = self.loss(y_hat, y)
#         acc = accuracy(y_hat, y)
#         return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

#     def test_epoch_end(self, outputs):
#         avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
#         avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
#         self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
#         self.log("test_acc", avg_acc*100., on_epoch=True, prog_bar=True)

#     def validation_step(self, batch, batch_idx):
#         return self.test_step(batch, batch_idx)

#     def validation_epoch_end(self, outputs):
#         return self.test_epoch_end(outputs)


# # a lightning data module for cifar 10 dataset
# class LitCifar10(LightningDataModule):
#     def __init__(self, batch_size=32, num_workers=32, patch_num=4, **kwargs):
#         super().__init__()
#         self.batch_size = batch_size
#         self.patch_num = patch_num
#         self.num_workers = num_workers

#     def prepare_data(self):
#         self.train_set = CIFAR10(root='./data', train=True,
#                                  download=True, transform=torchvision.transforms.ToTensor())
#         self.test_set = CIFAR10(root='./data', train=False,
#                                 download=True, transform=torchvision.transforms.ToTensor())

#     def collate_fn(self, batch):
#         x, y = zip(*batch)
#         x = torch.stack(x, dim=0)
#         y = torch.LongTensor(y)
#         x = rearrange(x, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)
#         return x, y

#     def train_dataloader(self):
#         return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, 
#                                         shuffle=True, collate_fn=self.collate_fn,
#                                         num_workers=self.num_workers)

#     def test_dataloader(self):
#         return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, 
#                                         shuffle=False, collate_fn=self.collate_fn,
#                                         num_workers=self.num_workers)

#     def val_dataloader(self):
#         return self.test_dataloader()




In [11]:

def get_args():
    parser = argparse.ArgumentParser(description='PyTorch Transformer')
    parser.add_argument('--depth', type=int, default=12, help='depth')
    parser.add_argument('--embed_dim', type=int, default=64, help='embedding dimension')
    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')

    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 30)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--patch_num', type=int, default=8, help='patch_num')
    parser.add_argument('--kernel_size', type=int, default=3, help='kernel size')

    parser.add_argument('--accelerator', default='gpu', type=str, metavar='N')
    parser.add_argument('--devices', default=1, type=int, metavar='N')
    parser.add_argument('--dataset', default='cifar10', type=str, metavar='N')
    parser.add_argument('--num_workers', default=4, type=int, metavar='N')


    # where dataset will be stored
    parser.add_argument("--path", type=str, default="data/speech_commands/")
    # 35 keywords + silence + unknown
    parser.add_argument("--num-classes", type=int, default=37)

    # mel spectrogram parameters
    parser.add_argument("--n-fft", type=int, default=1024)
    parser.add_argument("--n-mels", type=int, default=128)
    parser.add_argument("--win-length", type=int, default=None)
    parser.add_argument("--hop-length", type=int, default=512)


    # 16-bit fp model to reduce the size
    parser.add_argument("--precision", default=16)
    # parser.add_argument("--accelerator", default='gpu')
    # parser.add_argument("--devices", default=1)
    # parser.add_argument("--num-workers", type=int, default=96)

    parser.add_argument("--no-wandb", default=True, action='store_true')

    
    args = parser.parse_args("")
    return args
args = get_args()


CALLBACK

In [12]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # log 10 sample audio predictions from the first batch
        if batch_idx == 0:
            n = 10
            mels, labels, wavs = batch
            preds = outputs["preds"]
            preds = torch.argmax(preds, dim=1)

            labels = labels.cpu().numpy()
            preds = preds.cpu().numpy()
            
            wavs = torch.squeeze(wavs, dim=1)
            wavs = [ (wav.cpu().numpy()*32768.0).astype("int16") for wav in wavs]
            
            sample_rate = pl_module.hparams.sample_rate
            idx_to_class = pl_module.hparams.idx_to_class
            
            # log audio samples and predictions as a W&B Table
            columns = ['audio', 'mel', 'ground truth', 'prediction']
            data = [[wandb.Audio(wav, sample_rate=sample_rate), wandb.Image(mel), idx_to_class[label], idx_to_class[pred]] for wav, mel, label, pred in list(
                zip(wavs[:n], mels[:n], labels[:n], preds[:n]))]
            wandb_logger.log_table(
                key='ResNet18 on KWS using PyTorch Lightning',
                columns=columns,
                data=data)


Arguments and Other Stuff

In [14]:

CLASSES = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
            'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
            'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
            'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']


# make a dictionary from CLASSES to integers
CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}

if not os.path.exists(args.path):
    os.makedirs(args.path, exist_ok=True)

In [15]:
model = KWSModel(num_classes=args.num_classes, epochs=args.max_epochs, lr=args.lr)
print(model)


____________________ init
KWSModel(
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

In [16]:

datamodule = KWSDataModule(batch_size=args.batch_size, num_workers=args.num_workers,
                            path=args.path, n_fft=args.n_fft, n_mels=args.n_mels,
                            win_length=args.win_length, hop_length=args.hop_length,
                            class_dict=CLASS_TO_IDX)
datamodule.setup()


____________________ kws init
____________________ setup 
____________________ kws prep data


In [17]:
# datamodule = LitCifar10(batch_size=args.batch_size,
#                         patch_num=args.patch_num, 
#                         num_workers=args.num_workers * args.devices)
datamodule.prepare_data()

____________________ kws prep data


In [18]:

data = iter(datamodule.train_dataloader()).next()
# patch_dim = data[0].shape[-1]
# seqlen = data[0].shape[-2]
# print("Embed dim:", args.embed_dim)
# print("Patch size:", 32 // args.patch_num)
# print("Sequence length:", seqlen)

____________________ kws train dataloader
____________________ kws collate fn
________________________________________ ____________________ kws collate fnkws collate fn 

kws collate fn


: 

: 

In [1]:
data

NameError: name 'data' is not defined

In [14]:
t = datamodule.train_dataloader()

____________________ kws train dataloader


In [15]:
t.dataset[100]

(tensor([[-1.5259e-04,  3.0518e-05, -9.1553e-05,  ..., -6.1340e-03,
          -2.8992e-03, -1.2207e-04]]),
 16000,
 'backward',
 '14c7b073',
 4)

In [16]:
t.dataset[1][0].size()


torch.Size([1, 16000])

main

In [17]:

if __name__ == "__main__":

    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="pl-kws")

    model_checkpoint = ModelCheckpoint(
        dirpath=os.path.join(args.path, "checkpoints"),
        filename="resnet18-kws-best-acc",
        save_top_k=1,
        verbose=True,
        monitor='test_acc',
        mode='max',
    )
    idx_to_class = {v: k for k, v in CLASS_TO_IDX.items()}
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      precision=args.precision,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger if not args.no_wandb else None,
                      callbacks=[model_checkpoint])
    model.hparams.sample_rate = datamodule.sample_rate
    model.hparams.idx_to_class = idx_to_class
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    wandb.finish()
    #trainer.save_checkpoint('../mnist/checkpoint.ckpt')


[34m[1mwandb[0m: Currently logged in as: [33malessandrosantiago[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


____________________ kws prep data
____________________ setup 
____________________ kws prep data


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


____________________ setup


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.378    Total estimated model params size (MB)


____________________ optimizer
Sanity Checking: 0it [00:00, ?it/s]____________________ kws val dataloader
________________________________________  kws collate fnkws collate fn


sample:
sample:  (tensor([[ 0.0304,  0.0394,  0.0523,  ..., -0.0290, -0.0356, -0.0411]]), 16000, 'right', 'c4e1f6e0', 0)(tensor([[ 0.0376,  0.0128, -0.0697,  ..., -0.1370, -0.1884, -0.2496]]), 16000, 'right', '8a90cf67', 2)


MEL 1 Shape:
MEL 1 Shape:  torch.Size([1, 128, 32])torch.Size([1, 128, 32])


MEL 1:
MEL 1: ________________________________________   tensor([[[-41.7963, -31.3605, -29.9010,  ..., -34.0468, -30.0449, -43.3809],
         [-35.4294, -24.8841, -24.0556,  ..., -27.1521, -25.7376, -37.2367],
         [-40.2689, -24.0828, -23.7613,  ..., -25.3162, -26.1984, -32.8227],
         ...,
         [-80.0000, -80.0000, -79.4283,  ..., -80.0000, -80.0000, -77.8260],
         [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000],
         [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




torch.Size([1, 128, 32]) (tensor([[-0.0003, -0.0012, -0.0023,  ..., -0.0162, -0.0173, -0.0176]]), 16000, 'right', '56eb74ae', 1)
torch.Size([1, 128, 32])

MEL 1:
MEL 1 Shape:
  
MEL 1:
MEL 1 Shape:torch.Size([1, 128, 32])  
tensor([[[-28.3609, -26.8729, -32.0230,  ..., -29.5404, -32.8370, -22.5717],
         [-17.9930, -29.7993, -31.7823,  ..., -35.4416, -32.3854, -20.3885],
         [-14.7052, -23.8554, -21.5584,  ..., -24.9777, -22.2919, -20.9768],
         ...,
         [-51.4630, -53.5354, -54.0043,  ..., -55.0779, -53.3340, -56.2342],
         [-53.9522, -62.0991, -60.1022,  ..., -60.1272, -60.8192, -58.0092],
         [-57.3937, -67.3234, -69.0707,  ..., -67.4035, -70.1807, -65.7911]]])torch.Size([1, 128, 32])


label:
MEL 1:tensor([[[-17.1624, -11.6806, -15.4780,  ..., -80.0000, -80.0000, -80.0000],
         [-20.2020, -14.4230, -15.5756,  ..., -80.0000, -80.0000, -80.0000],
         [-20.6458, -20.6438, -20.2085,  ..., -80.0000, -80.0000, -80.0000],
         ...,
         [-64

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



: 

: 

#end

In [None]:
# https://pytorch-lightning.readthedocs.io/en/stable/common/production_inference.html
model = model.load_from_checkpoint(os.path.join(
    args.path, "checkpoints", "resnet18-kws-best-acc.ckpt"))
model.eval()
script = model.to_torchscript()

# save for use in production environment
model_path = os.path.join(args.path, "checkpoints",
                          "resnet18-kws-best-acc.pt")
torch.jit.save(script, model_path)


NameError: name 'model' is not defined

In [None]:

# list wav files given a folder
label = CLASSES[2:]
label = np.random.choice(label)
path = os.path.join(args.path, "SpeechCommands/speech_commands_v0.02/")
path = os.path.join(path, label)
wav_files = [os.path.join(path, f)
             for f in os.listdir(path) if f.endswith('.wav')]
# select random wav file
wav_file = np.random.choice(wav_files)
waveform, sample_rate = torchaudio.load(wav_file)
transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                 n_fft=args.n_fft,
                                                 win_length=args.win_length,
                                                 hop_length=args.hop_length,
                                                 n_mels=args.n_mels,
                                                 power=2.0)

mel = ToTensor()(librosa.power_to_db(
    transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)
# scripted_module = torch.jit.load(model_path)
# pred = torch.argmax(scripted_module(mel), dim=1)
# print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")


In [None]:
mel = ToTensor()(librosa.power_to_db(
    transform(waveform).squeeze().numpy(), ref=np.max))

In [None]:
mel.shape

torch.Size([1, 128, 32])

In [36]:
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f'Channel {c+1}')
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

plot_waveform()