In [None]:
%pip install pytorch-lightning
%pip install torchmetrics
%pip install wandb
%pip install einops

In [7]:
import torch
import torchvision

from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, LightningDataModule , Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.optim import Adam

from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.functional import accuracy
from einops import rearrange
from torch import nn

### Attention Module

The `Attention` module is the core of the vision transformer model. It implements the attention mechanism:

1) Multiply QKV by their weights
2) Perform dot product on Q and K. 
3) Normalize the result in 2) by sqrt of `head_dim`  
4) Softmax is applied to the result.
5) Perform dot product on the result of 4) and V and the result is the output.

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=3, qkv_bias=False):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x

### MLP Module

The MLP module is a made of two linear layers. A non-linear activation is applied to the output of the first layer.

In [4]:
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
      
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

### The Block Module

The `Block` module represents one encoder transformer block. It consists of two sub-modules:
1) The Attention module
2) The MLP module

Layer norm is applied before and after the Attention module.

In [5]:
class Block(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 
            act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias) 
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) 
   

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

### The Transformer Module

The feature encoder is made of several transformer blocks. The most important attributes are:
1) `depth` : representing the number of encoder blocks
2) `num_heads` : representing the number of attention heads

In [6]:
class Transformer(nn.Module):
    def __init__(self, dim, num_heads, num_blocks, mlp_ratio=4., qkv_bias=False,  
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.blocks = nn.ModuleList([Block(dim, num_heads, mlp_ratio, qkv_bias, 
                                     act_layer, norm_layer) for _ in range(num_blocks)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

#### The optional parameter initialization as adopted from `timm`

In [7]:
def init_weights_vit_timm(module: nn.Module):
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()

### PyTorch Lightning for Key Word Spotting

We use the `Transformer` module to build the feature encoder. Before the `Transformer` can be used, we convert the audio into patches. The patches are then embedded into a linear space. The output is then passed to the Transformer.



In [8]:
import torch
import torchaudio, torchvision
import os
import matplotlib.pyplot as plt 
import librosa
import argparse
import numpy as np
import wandb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torchvision.transforms import ToTensor
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import load_speechcommands_item

In [9]:
class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return waveform, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return waveform, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

In [10]:

class LitTransformer(LightningModule):
    def __init__(self, num_classes=37, lr=0.001, max_epochs=30, depth=12, embed_dim=512,
                 head=4, patch_dim=512, seqlen=16, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Transformer(dim=embed_dim, num_heads=head, num_blocks=depth, mlp_ratio=4.,
                                   qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
        self.embed = torch.nn.Linear(patch_dim, embed_dim)

        self.fc = nn.Linear(seqlen * embed_dim, num_classes)
        self.loss = torch.nn.CrossEntropyLoss()
        
        self.reset_parameters()


    def reset_parameters(self):
        init_weights_vit_timm(self)
    

    def forward(self, x):
        # Linear projection
        x = self.embed(x)
            
        # Encoder
        x = self.encoder(x)
        x = x.flatten(start_dim=1)

        # Classification head
        x = self.fc(x)
        return x
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss
    

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y)
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc*100., on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)


# a lightning data module for KWS
class LitKWS(LightningDataModule):
     def __init__(self, path='/content', patch_num=16,batch_size=128, num_workers=2, n_fft=512, 
                 n_mels=40, win_length=None, hop_length=256, class_dict={}, 
                 **kwargs):
        super().__init__(**kwargs)
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length
        self.class_dict = class_dict
        self.patch_num = patch_num

     def prepare_data(self):
            self.train_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                    download=True,
                                                                    subset='training')

            silence_dataset = SilenceDataset(self.path)
            unknown_dataset = UnknownDataset(self.path)
            self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset, silence_dataset, unknown_dataset])
                                                                    
            self.val_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                  download=True,
                                                                  subset='validation')
            self.test_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                  download=True,
                                                                  subset='testing')                                                    
            _, sample_rate, _, _, _ = self.train_dataset[0]
            self.sample_rate = sample_rate
            self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                                  n_fft=self.n_fft,
                                                                  win_length=self.win_length,
                                                                  hop_length=self.hop_length,
                                                                  n_mels=self.n_mels,
                                                                  power=2.0)

     def setup(self, stage=None):
        self.prepare_data()

     def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

     def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )
    
     def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

     def collate_fn(self, batch):
        mels = []
        labels = []
        wavs = []
        for sample in batch:
            waveform, sample_rate, label, speaker_id, utterance_number = sample
            # ensure that all waveforms are 1sec in length; if not pad with zeros
            if waveform.shape[-1] < sample_rate:
                waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)
            elif waveform.shape[-1] > sample_rate:
                waveform = waveform[:,:sample_rate]

            # mel from power to db
            data = ToTensor()(librosa.power_to_db(self.transform(waveform).squeeze().numpy(), ref=np.max))
            wave = torch.cat([data, torch.zeros(1,40,1)],dim=-1)
            mels.append(wave)
            labels.append(torch.tensor(self.class_dict[label]))
            #mels.append(data)
            #print("mels")
            #print(mels.shape)
            #print(mels)
            #mels2 = torch.zeros(1, 128, 64)
            #mels2 = torch.FloatTensor(mels2)
            #source = torch.ones(30, 35, 49)
            #mels2[:, :, :49] = mels

       
       # mels2 = torch.FloatTensor(mels2)
        labels = torch.stack(labels)
        #wavs = torch.stack(wavs)
        mels = torch.stack(mels)
        #print(mels.size)
        mels = rearrange(mels, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=1,p2=self.patch_num)
        return mels, labels


def get_args():
    parser = ArgumentParser(description='PyTorch Transformer')
    parser.add_argument('--depth', type=int, default=12, help='depth')
    parser.add_argument('--embed_dim', type=int, default=80, help='embedding dimension')
    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')

    parser.add_argument('--patch_num', type=int, default=32, help='patch_num')
    parser.add_argument('--kernel_size', type=int, default=3, help='kernel size')
    parser.add_argument('--batch_size', type=int, default=512, metavar='N',
                        help='input batch size for training (default: )')
    parser.add_argument('--max-epochs', type=int, default=50, metavar='N',
                        help='number of epochs to train (default: 0)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.0)')

    parser.add_argument('--accelerator', default='gpu', type=str, metavar='N')
    parser.add_argument('--devices', default=1, type=int, metavar='N')
    parser.add_argument('--dataset', default='cifar10', type=str, metavar='N')
    parser.add_argument('--num_workers', default=2, type=int, metavar='N')

    parser.add_argument("--no-wandb", default=False, action='store_true')

    args = parser.parse_args("")
    return args


In [11]:
mels2 = torch.zeros(1, 128, 64)
mels2 = torch.FloatTensor(mels2)
mels2.shape

torch.Size([1, 128, 64])

In [12]:
mels2 = rearrange(mels2, 'c h (p1 w) -> p1 (c h w)', p1=16)


In [13]:
mels2.shape

torch.Size([16, 512])

### Training

In [15]:
if __name__ == "__main__":
    args = get_args()
    CLASSES = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
               'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
               'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
             'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']
   
    # make a dictionary from CLASSES to integers
    CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}


    model_checkpoint = ModelCheckpoint(
    dirpath=os.path.join("./checkpoints"),
    filename="kws_best_acc",
    save_top_k=1,
    verbose=True,
    monitor='test_acc',
    mode='max',
    )

    if args.no_wandb == False :
     wandb_logger = WandbLogger(project="KWS")

    datamodule = LitKWS(
        class_dict=CLASS_TO_IDX , 
        batch_size=args.batch_size,
                        patch_num=args.patch_num, 
                        num_workers=args.num_workers * args.devices
          )
    datamodule.prepare_data()

    data = iter(datamodule.train_dataloader()).next()
    print(data[0].size())
    patch_dim = data[0].shape[-1]
    print('patch_dim: ',patch_dim)
    seqlen = data[0].shape[-2]
    print("Embed dim:", args.embed_dim)
    print("Patch size:", 64 // args.patch_num)
    print("Sequence length:", seqlen)


    model = LitTransformer(num_classes=37, lr=args.lr, epochs=args.max_epochs, 
                           depth=args.depth, embed_dim=args.embed_dim, head=args.num_heads,
                           patch_dim=patch_dim, seqlen=seqlen,)

    trainer = Trainer(accelerator=args.accelerator, devices=args.devices,
                      max_epochs=args.max_epochs, precision=16 if args.accelerator == 'gpu' else 32,
                       logger=wandb_logger if not args.no_wandb else None,
                       callbacks=[model_checkpoint])
    trainer.fit(model, datamodule=datamodule)

    wandb.finish()

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"


  0%|          | 0.00/2.26G [00:00<?, ?B/s]

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


torch.Size([512, 32, 80])
patch_dim:  80
Embed dim: 80
Patch size: 2
Sequence length: 32


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Transformer      | 931 K 
1 | embed   | Linear           | 6.5 K 
2 | fc      | Linear           | 94.8 K
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
2.065     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 176: 'test_acc' reached 53.51361 (best 53.51361), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 352: 'test_acc' reached 75.70899 (best 75.70899), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 528: 'test_acc' reached 82.58701 (best 82.58701), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 704: 'test_acc' reached 86.13425 (best 86.13425), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 880: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 1056: 'test_acc' reached 87.25030 (best 87.25030), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 1232: 'test_acc' reached 88.22524 (best 88.22524), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 1408: 'test_acc' reached 89.27156 (best 89.27156), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 1584: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 1760: 'test_acc' reached 90.33648 (best 90.33648), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 1936: 'test_acc' reached 90.35646 (best 90.35646), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 2112: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 2288: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 2464: 'test_acc' reached 90.95032 (best 90.95032), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 2640: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 2816: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 2992: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 3168: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 3344: 'test_acc' reached 91.06843 (best 91.06843), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 3520: 'test_acc' reached 91.09097 (best 91.09097), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 3696: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 21, global step 3872: 'test_acc' reached 91.40116 (best 91.40116), saving model to '/content/checkpoints/kws_best_acc.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 4048: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 4224: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 4400: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 4576: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 4752: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 4928: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 5104: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 5280: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 5456: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 31, global step 5632: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 32, global step 5808: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 33, global step 5984: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 34, global step 6160: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 6336: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 6512: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 6688: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 6864: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 7040: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 7216: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 41, global step 7392: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 42, global step 7568: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 43, global step 7744: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 44, global step 7920: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 45, global step 8096: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 46, global step 8272: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 47, global step 8448: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 48, global step 8624: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 49, global step 8800: 'test_acc' was not in top 1


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test_acc,▁▅▆▇▇▇█▇████████████████████████████████
test_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
epoch,49.0
test_acc,90.3746
test_loss,0.41277
trainer/global_step,8799.0


# Test Model

In [None]:
trainer.save_checkpoint("example_1by4.ckpt")

In [None]:
# https://pytorch-lightning.readthedocs.io/en/stable/common/production_inference.html
model = model.load_from_checkpoint(os.path.join(
    '/content', "example.ckpt"))
model.eval()
script = model.to_torchscript()

# save for use in production environment
model_path = os.path.join('/content',
                          "example.pt")
torch.jit.save(script, model_path)


In [None]:
 idx_to_class = {v: k for k, v in CLASS_TO_IDX.items()}

In [None]:

# list wav files given a folder
label = CLASSES[2:]
label = np.random.choice(label)
path = os.path.join('/content', "SpeechCommands/speech_commands_v0.02/")
path = os.path.join(path, label)
wav_files = [os.path.join(path, f)
             for f in os.listdir(path) if f.endswith('.wav')]
# select random wav file
wav_file = np.random.choice(wav_files)
waveform, sample_rate = torchaudio.load(wav_file)
transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                 n_fft=512,
                                                 win_length=None,
                                                 hop_length=256,
                                                 n_mels=128,
                                                 power=2.0)

mel = ToTensor()(librosa.power_to_db(
    transform(waveform).squeeze().numpy(), ref=np.max))
mel = torch.cat([mel, torch.zeros(1,128,1)],dim=-1)
mel = mel.unsqueeze(0)

NameError: ignored

In [None]:
mel.shape

torch.Size([1, 1, 128, 64])

In [None]:
mel = rearrange(mel, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=4, p2=4)
mel.shape

torch.Size([1, 16, 512])

In [None]:
scripted_module = torch.jit.load(model_path)
pred = torch.argmax(scripted_module(mel), dim=1)


In [None]:
pred.item()

14

In [None]:
label

'happy'

In [None]:
print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")

Ground Truth: happy, Prediction: happy


In [None]:
/content/SpeechCommands/speech_commands_v0.02/bed/00176480_nohash_0.wav

In [None]:
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.

# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------

import os

import librosa
import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display


_SAMPLE_DIR = "_assets"

SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"  # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")

os.makedirs(_SAMPLE_DIR, exist_ok=True)


def _fetch_data():
    uri = [
        (SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
    ]
    for url, path in uri:
        with open(path, "wb") as file_:
            file_.write(requests.get(url).content)


_fetch_data()


def _get_sample(path, resample=None):
    effects = [["remix", "1"]]
    if resample:
        effects.extend(
            [
                ["lowpass", f"{resample // 2}"],
                ["rate", f"{resample}"],
            ]
        )
    return torchaudio.sox_effects.apply_effects_file(path, effects=effects)


def get_speech_sample(*, resample=None):
    return _get_sample("/content/SpeechCommands/speech_commands_v0.02/bed/00176480_nohash_0.wav", resample=resample)


def print_stats(waveform, sample_rate=None, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
    if sample_rate:
        print("Sample Rate:", sample_rate)
    print("Shape:", tuple(waveform.shape))
    print("Dtype:", waveform.dtype)
    print(f" - Max:     {waveform.max().item():6.3f}")
    print(f" - Min:     {waveform.min().item():6.3f}")
    print(f" - Mean:    {waveform.mean().item():6.3f}")
    print(f" - Std Dev: {waveform.std().item():6.3f}")
    print()
    print(waveform)
    print()


def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
    if xmax:
        axs.set_xlim((0, xmax))
    fig.colorbar(im, ax=axs)
    plt.show(block=False)


def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)


def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")


def plot_mel_fbank(fbank, title=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Filter bank")
    axs.imshow(fbank, aspect="auto")
    axs.set_ylabel("frequency bin")
    axs.set_xlabel("mel bin")
    plt.show(block=False)


def plot_pitch(waveform, sample_rate, pitch):
    figure, axis = plt.subplots(1, 1)
    axis.set_title("Pitch Feature")
    axis.grid(True)

    end_time = waveform.shape[1] / sample_rate
    time_axis = torch.linspace(0, end_time, waveform.shape[1])
    axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)

    axis2 = axis.twinx()
    time_axis = torch.linspace(0, end_time, pitch.shape[1])
    axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")

    axis2.legend(loc=0)
    plt.show(block=False)


def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
    figure, axis = plt.subplots(1, 1)
    axis.set_title("Kaldi Pitch Feature")
    axis.grid(True)

    end_time = waveform.shape[1] / sample_rate
    time_axis = torch.linspace(0, end_time, waveform.shape[1])
    axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)

    time_axis = torch.linspace(0, end_time, pitch.shape[1])
    ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
    axis.set_ylim((-1.3, 1.3))

    axis2 = axis.twinx()
    time_axis = torch.linspace(0, end_time, nfcc.shape[1])
    ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")

    lns = ln1 + ln2
    labels = [l.get_label() for l in lns]
    axis.legend(lns, labels, loc=0)
    plt.show(block=False)

In [None]:
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

1.11.0+cu113
0.11.0+cu113


In [None]:
waveform, sample_rate = get_speech_sample()

n_fft = 1024
win_length = None
hop_length = 512

# define transformation
spectrogram = T.Spectrogram(
    n_fft=512,
    #win_length=win_length,
    hop_length=256,
    center=True,
    pad_mode="reflect",
    power=2.0,
)
# Perform transformation
spec = spectrogram(waveform)
spec =ToTensor()(librosa.power_to_db(spec.squeeze().numpy(), ref=np.max))
waveform = torch.cat([spec, torch.zeros(1,257,1)],dim=-1)
print_stats(waveform)
#plot_spectrogram(spec[0], title="torchaudio")

 #waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)

Shape: (1, 257, 64)
Dtype: torch.float32
 - Max:      0.000
 - Min:     -80.000
 - Mean:    -70.195
 - Std Dev: 17.314

tensor([[[-80.0000, -72.4388, -76.4144,  ..., -53.8912, -65.9771,   0.0000],
         [-80.0000, -76.9620, -76.2810,  ..., -55.4343, -68.3305,   0.0000],
         [-80.0000, -80.0000, -76.5729,  ..., -59.9448, -72.0771,   0.0000],
         ...,
         [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000,   0.0000],
         [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000,   0.0000],
         [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000,   0.0000]]])



In [None]:
/content/SpeechCommands/speech_commands_v0.02/bed/00176480_nohash_0.wav

In [None]:
torch.zeros(257,64)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:

!pip install PySimpleGUI

Collecting PySimpleGUI
  Downloading PySimpleGUI-4.60.0-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 5.1 MB/s 
[?25hInstalling collected packages: PySimpleGUI
Successfully installed PySimpleGUI-4.60.0


In [None]:
import PySimpleGUI as sg


In [None]:
form = sg.FlexForm('mygui')
layout = [sg.Text('enter')]
form.LayoutAndRead(layout)

TclError: ignored

In [None]:
# all imports
from IPython.display import Javascript
from google.colab import output
from base64 import b64decode
from io import BytesIO
!pip -q install pydub
from pydub import AudioSegment

RECORD = """
const sleep  = time => new Promise(resolve => setTimeout(resolve, time))
const b2text = blob => new Promise(resolve => {
  const reader = new FileReader()
  reader.onloadend = e => resolve(e.srcElement.result)
  reader.readAsDataURL(blob)
})
var record = time => new Promise(async resolve => {
  stream = await navigator.mediaDevices.getUserMedia({ audio: true })
  recorder = new MediaRecorder(stream)
  chunks = []
  recorder.ondataavailable = e => chunks.push(e.data)
  recorder.start()
  await sleep(time)
  recorder.onstop = async ()=>{
    blob = new Blob(chunks)
    text = await b2text(blob)
    resolve(text)
  }
  recorder.stop()
})
"""

def record(sec=3):
  display(Javascript(RECORD))
  s = output.eval_js('record(%d)' % (sec*1000))
  b = b64decode(s.split(',')[1])
  audio = AudioSegment.from_file(BytesIO(b))
  return audio
record_save.py
# all imports
from IPython.display import Javascript
from google.colab import output
from base64 import b64decode

RECORD = """
const sleep  = time => new Promise(resolve => setTimeout(resolve, time))
const b2text = blob => new Promise(resolve => {
  const reader = new FileReader()
  reader.onloadend = e => resolve(e.srcElement.result)
  reader.readAsDataURL(blob)
})
var record = time => new Promise(async resolve => {
  stream = await navigator.mediaDevices.getUserMedia({ audio: true })
  recorder = new MediaRecorder(stream)
  chunks = []
  recorder.ondataavailable = e => chunks.push(e.data)
  recorder.start()
  await sleep(time)
  recorder.onstop = async ()=>{
    blob = new Blob(chunks)
    text = await b2text(blob)
    resolve(text)
  }
  recorder.stop()
})
"""

def record(sec=3):
  display(Javascript(RECORD))
  s = output.eval_js('record(%d)' % (sec*1000))
  b = b64decode(s.split(',')[1])
  with open('audio.wav','wb') as f:
    f.write(b)
  return 'audio.wav'  # or webm ?

NameError: ignored

# New Section

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
cd /content/drive/MyDrive/Colab Notebooks/DL/Building_Blocks/KWS

/content/drive/MyDrive/Colab Notebooks/DL/Building_Blocks/KWS


In [9]:
!python3 train.py

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice: 3
[34m[1mwandb[0m: You chose 'Don't visualize my results'
[34m[1mwandb[0m: Tracking run with wandb version 0.12.16
[34m[1mwandb[0m: W&B syncing is set to [1m`offline`[0m in this directory.  
[34m[1mwandb[0m: Run [1m`wandb online`[0m or set [1mWANDB_MODE=online[0m to enable cloud syncing.
100% 2.26G/2.26G [00:28<00:00, 84.0MB/s]
Traceback (most recent call last):
  File "train.py", line 89, in <module>
    datamodule.prepare_data()
  File "/content/drive/MyDrive/Colab Notebooks/DL/Building_Blocks/KWS/dataloader_module.py", line 106, in prepare_data
    subset='training')
  File "/usr/local/lib/python3.7/dist-packages/torchaudio/datasets/speechcommands.py", line 113, in __init__
    extract_archive(archive, self._path)
  File "/usr/local/lib/python3.7/dist-packages/torchaudio/datasets/u

In [None]:
3