### Transformer for CIFAR10

A configurable transformer model will be used for CIFAR10 image classification. 

The vision transformer model is a modified version of ViT. The changes are:
1) No position embedding.
2) No dropout is used.
3) All encoder features are used for class prediction.

The code below is a simplified version of Timm modules.

Let us import the required packages.

In [1]:
import torch
import torchvision

from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.functional import accuracy
from einops import rearrange
from torch import nn
from torchvision.datasets.cifar import CIFAR10

  from .autonotebook import tqdm as notebook_tqdm


### Attention Module

The `Attention` module is the core of the vision transformer model. It implements the attention mechanism:

1) Multiply QKV by their weights
2) Perform dot product on Q and K. 
3) Normalize the result in 2) by sqrt of `head_dim`  
4) Softmax is applied to the result.
5) Perform dot product on the result of 4) and V and the result is the output.

In [2]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x

### MLP Module

The MLP module is a made of two linear layers. A non-linear activation is applied to the output of the first layer.

In [3]:
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
      
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

### The Block Module

The `Block` module represents one encoder transformer block. It consists of two sub-modules:
1) The Attention module
2) The MLP module

Layer norm is applied before and after the Attention module.

In [4]:
class Block(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 
            act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias) 
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) 
   

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

### The Transformer Module

The feature encoder is made of several transformer blocks. The most important attributes are:
1) `depth` : representing the number of encoder blocks
2) `num_heads` : representing the number of attention heads

In [5]:
class Transformer(nn.Module):
    def __init__(self, dim, num_heads, num_blocks, mlp_ratio=4., qkv_bias=False,  
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.blocks = nn.ModuleList([Block(dim, num_heads, mlp_ratio, qkv_bias, 
                                     act_layer, norm_layer) for _ in range(num_blocks)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

#### The optional parameter initialization as adopted from `timm`

In [6]:
def init_weights_vit_timm(module: nn.Module):
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()

### PyTorch Lightning for CIFAR10 Image Classification

We use the `Transformer` module to build the feature encoder. Before the `Transformer` can be used, we convert the input image into patches. The patches are then embedded into a linear space. The output is then passed to the Transformer.

Another difference between this model is we use all output features for the final classification. In the ViT, only the first feature is used.



In [7]:
class LitTransformer(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, max_epochs=30, depth=12, embed_dim=64,
                 head=4, patch_dim=192, seqlen=16, **kwargs):
        super().__init__()
        print('_'*20, 'init')
        self.save_hyperparameters()
        self.encoder = Transformer(dim=embed_dim, num_heads=head, num_blocks=depth, mlp_ratio=4.,
                                   qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
        self.embed = torch.nn.Linear(patch_dim, embed_dim)

        self.fc = nn.Linear(seqlen * embed_dim, num_classes)
        self.loss = torch.nn.CrossEntropyLoss()
        
        self.reset_parameters()


    def reset_parameters(self):
        print('_'*20, 'reset param')
        init_weights_vit_timm(self)
    

    def forward(self, x):
        # print('_'*20, 'fwd')
        # Linear projection
        x = self.embed(x)
            
        # Encoder
        x = self.encoder(x)
        x = x.flatten(start_dim=1)

        # Classification head
        x = self.fc(x)
        return x
    
    def configure_optimizers(self):
        print('_'*20, 'config optimizers')
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # print('_'*20, 'steps')
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss
    

    def test_step(self, batch, batch_idx):
        # print('_'*20, 'test step')
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y)
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        # print('_'*20, 'test ep end')
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc*100., on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        # print('_'*20, 'valid step')
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        # print('_'*20, 'vald epoc end')
        return self.test_epoch_end(outputs)



In [8]:

import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import load_speechcommands_item
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback


from torchvision.transforms import ToTensor
import librosa
import numpy as np

import matplotlib.pyplot as plt 

In [9]:

import torch
import torchaudio, torchvision
import os
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy

In [10]:

class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return waveform, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return waveform, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

In [11]:

# a lightning data module for cifar 10 dataset
class LitCifar10(LightningDataModule):
    def __init__(self, path, batch_size=32, num_workers=32, patch_num=4
                , n_fft=512, n_mels=128, win_length=None, hop_length=256
                , class_dict={}
                , **kwargs):
    # def __init__(self, batch_size=32, num_workers=32, patch_num=4, **kwargs):
        
        super().__init__()
        print('_'*20, 'datamodule init')
        self.path = path

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.patch_num = patch_num

        # Window
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length
        self.class_dict = class_dict

    def prepare_data(self):
        print('_'*20, 'datamodule prepare data')
        #________________________________
        silence_dataset = SilenceDataset(self.path)
        unknown_dataset = UnknownDataset(self.path)
        self.train_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path, download=True,subset='training')
        # self.train_set = CIFAR10(root='./data', train=True,download=True, transform=torchvision.transforms.ToTensor())
        self.train_set = torch.utils.data.ConcatDataset([self.train_dataset, silence_dataset, unknown_dataset])

        #________________________________
        self.val_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,download=True,subset='validation')
        # self.val_set = torchaudio.datasets.SPEECHCOMMANDS(self.path,download=True,subset='validation')
        
        
        #________________________________
        self.test_set = torchaudio.datasets.SPEECHCOMMANDS(self.path, download=True, subset='testing')      
        # self.test_set = CIFAR10(root='./data', train=False,download=True, transform=torchvision.transforms.ToTensor())




        _, sample_rate, _, _, _ = self.train_dataset[0]
        self.sample_rate = sample_rate
        self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                              n_fft=self.n_fft,
                                                              win_length=self.win_length,
                                                              hop_length=self.hop_length,
                                                              n_mels=self.n_mels,
                                                              power=2.0)

    def setup(self, stage=None):
        print('_'*20, 'setup ')
        self.prepare_data()


    def train_dataloader(self):
        print('_'*20, 'datamodule train dataloader')
        return torch.utils.data.DataLoader( self.train_set, 
                                            batch_size=self.batch_size,
                                            num_workers=self.num_workers,
                                            shuffle=True, 
                                            # pin_memory=True,
                                            collate_fn=self.collate_fn)

    def test_dataloader(self):
        print('_'*20, 'datamodule test dataload')
        return torch.utils.data.DataLoader( self.test_set, 
                                            batch_size=self.batch_size,
                                            num_workers=self.num_workers, 
                                            shuffle=False, 
                                            # pin_memory=True,
                                            collate_fn=self.collate_fn)

    def val_dataloader(self):
        print('_'*20, 'kws val dataloader')
        return torch.utils.data.DataLoader( self.val_dataset,
                                            batch_size=self.batch_size,
                                            num_workers=self.num_workers,
                                            shuffle=True,
                                            pin_memory=True,
                                            collate_fn=self.collate_fn
        )



    def collate_fn(self, batch):
        mels = []
        xmels = []
        labels = []
        wavs = []
        for sample in batch:
            waveform, sample_rate, label, speaker_id, utterance_number = sample
            # ensure that all waveforms are 1sec in length; if not pad with zeros
            if waveform.shape[-1] < sample_rate:
                waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)
            elif waveform.shape[-1] > sample_rate:
                waveform = waveform[:,:sample_rate]

            # mel from power to db
            mel1 = ToTensor()(librosa.power_to_db(self.transform(waveform).squeeze().numpy(), ref=np.max))

            # print('mel1 shapre', mel1.shape)
            # mel1 = rearrange(mel1, 'c (p1 h) (p2 w) -> 1 (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)
            # print('mel1 shapre', mel1.shape)

            xmels.append(xmels)
            mels.append(mel1)
            labels.append(torch.tensor(self.class_dict[label]))
            wavs.append(waveform)

        mels = torch.stack(mels)
        labels = torch.stack(labels)
        wavs = torch.stack(wavs)

        # print('mels sh:', mels.shape)
        x = rearrange(mels, 'b 1 h w -> b h w')
   
        # print('x sh:', x.shape)
        # return x, labels, wavs, sample_rate

        return x, labels

In [103]:
DEFAULT_BATCH_SIZE = 128
# DEFAULT_BATCH_SIZE = 4 #for debugging

from argparse import ArgumentParser
def get_args():
    parser = ArgumentParser(description='PyTorch Transformer')
    parser.add_argument('--depth', type=int, default=12, help='depth')
    parser.add_argument('--embed_dim', type=int, default=64, help='embedding dimension')
    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')


    parser.add_argument('--patch_num', type=int, default=8, help='patch_num')
    parser.add_argument('--kernel_size', type=int, default=3, help='kernel size')
    parser.add_argument('--batch-size', type=int, default=DEFAULT_BATCH_SIZE, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 30)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--accelerator', default='gpu', type=str, metavar='N')
    parser.add_argument('--devices', default=1, type=int, metavar='N')
    # parser.add_argument('--dataset', default='cifar10', type=str, metavar='N')
    parser.add_argument('--num_workers', default=12, type=int, metavar='N')


    # where dataset will be stored
    parser.add_argument("--path", type=str, default="data/speech_commands/")

    # 35 keywords + silence + unknown
    parser.add_argument("--num-classes", type=int, default=37)

    # mel spectrogram parameters
    parser.add_argument("--n-fft", type=int, default=1024)
    parser.add_argument("--n-mels", type=int, default=128)
    parser.add_argument("--win-length", type=int, default=None)
    parser.add_argument("--hop-length", type=int, default=512)
    # parser.add_argument("--hop-length", type=int, default=334)


    # 16-bit fp model to reduce the size
    parser.add_argument("--precision", default=16)
    # parser.add_argument("--accelerator", default='gpu')
    # parser.add_argument("--devices", default=1)
    # parser.add_argument("--num-workers", type=int, default=48)

    parser.add_argument("--no-wandb", default=True, action='store_true')

    
    args = parser.parse_args("")
    return args
args = get_args()

import os
CLASSES = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
            'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
            'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
            'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

# make a dictionary from CLASSES to integers
CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}

if not os.path.exists(args.path):
    os.makedirs(args.path, exist_ok=True)


In [104]:
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f'Channel {c+1}')
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

### Performance on different settings

The following table shows different performances on different settings. Generally, Transformer is better than MLP in terms of accuracy and parameter count. However, the performance is worse compared to CNN models. 

However, the most important thing to note is that Transformers are more general purpose models than CNNs. They can process different types of data. They can process multiple types of data at the same time. This is why they are considered to be the backbone of many high-performing models like BERT, GPT3, PalM and Gato.

| **Depth** | **Head** | **Embed dim** | **Patch size** | **Seq len** | **Params** | **Accuracy** | 
| -: | -: | -: | -: | -: | -: | -: |
| 12 | 4 | 32 | 4x4 | 64 | 173k | 68.2% |
| 12 | 4 | 64 | 4x4 | 64 | 641k | 71.1% |
| 12 | 4 | 128 | 4x4 | 64 | 2.5M | 71.5% |

In [105]:

args = get_args()

# datamodule = LitCifar10(batch_size=args.batch_size,
#                         patch_num=args.patch_num, 
#                         num_workers=args.num_workers * args.devices)
datamodule = LitCifar10(batch_size=args.batch_size,
                        num_workers=args.num_workers * args.devices,
                        path=args.path,
                        patch_num=args.patch_num, 

                        n_fft=args.n_fft, 
                        n_mels=args.n_mels,
                        win_length=args.win_length, 
                        hop_length=args.hop_length,
                        class_dict=CLASS_TO_IDX
                        )
datamodule.prepare_data()

data = iter(datamodule.train_dataloader()).next()
patch_dim = data[0].shape[-1]
seqlen = data[0].shape[-2]
print("Embed dim:", args.embed_dim)
print("Patch size:", 32 // args.patch_num)
print("Sequence length:", seqlen)


____________________ datamodule init
____________________ datamodule prepare data
____________________ datamodule train dataloader
Embed dim: 64
Patch size: 4
Sequence length: 128


In [106]:

model_checkpoint = ModelCheckpoint(
    dirpath=os.path.join(args.path, "checkpoints"),
    filename="trans-kws-best-acc",
    save_top_k=1,
    verbose=True,
    monitor='test_acc',
    mode='max',
)
idx_to_class = {v: k for k, v in CLASS_TO_IDX.items()}

In [107]:


model = LitTransformer(num_classes=args.num_classes, lr=args.lr, epochs=args.max_epochs, 
                        depth=args.depth, embed_dim=args.embed_dim, head=args.num_heads,
                        patch_dim=patch_dim, seqlen=seqlen)



____________________ init
____________________ reset param


In [108]:
model = model.load_from_checkpoint(os.path.join(
    args.path, "checkpoints", "trans-kws-best-acc.ckpt"))

____________________ init
____________________ reset param


In [109]:

trainer = Trainer(  accelerator=args.accelerator, 
                    devices=args.devices,
                    precision=16 if args.accelerator == 'gpu' else 32,
                    max_epochs=args.max_epochs, 
                    callbacks=[model_checkpoint],
                    )

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [102]:
trainer.fit(model, datamodule=datamodule)

____________________ datamodule prepare data
____________________ setup 
____________________ datamodule prepare data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Transformer      | 597 K 
1 | embed   | Linear           | 2.1 K 
2 | fc      | Linear           | 303 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
902 K     Trainable params
0         Non-trainable params
902 K     Total params
1.806     Total estimated model params size (MB)


____________________ config optimizers
Sanity Checking: 0it [00:00, ?it/s]____________________ kws val dataloader
____________________ datamodule train dataloader                           
Epoch 0: 100%|██████████| 779/779 [00:42<00:00, 18.17it/s, loss=0.697, v_num=31, test_loss=0.517, test_acc=85.50]

Epoch 0, global step 701: 'test_acc' reached 85.48149 (best 85.48149), saving model to '/home/dl/Desktop/dl/object_detection_model_hw2/hw3/data/speech_commands/checkpoints/trans-kws-best-acc-v6.ckpt' as top 1


Epoch 1:  61%|██████    | 476/779 [00:27<00:17, 17.11it/s, loss=0.55, v_num=31, test_loss=0.517, test_acc=85.50] 

In [110]:
trainer.fit(model, datamodule=datamodule)

____________________ datamodule prepare data
____________________ setup 
____________________ datamodule prepare data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Transformer      | 597 K 
1 | embed   | Linear           | 2.1 K 
2 | fc      | Linear           | 303 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
902 K     Trainable params
0         Non-trainable params
902 K     Total params
1.806     Total estimated model params size (MB)


____________________ config optimizers
Sanity Checking: 0it [00:00, ?it/s]____________________ kws val dataloader
____________________ datamodule train dataloader                           
Epoch 0: 100%|██████████| 779/779 [00:42<00:00, 18.29it/s, loss=0.6, v_num=32, test_loss=0.538, test_acc=84.50]]

Epoch 0, global step 701: 'test_acc' reached 84.52805 (best 84.52805), saving model to '/home/dl/Desktop/dl/object_detection_model_hw2/hw3/data/speech_commands/checkpoints/trans-kws-best-acc-v7.ckpt' as top 1


Epoch 1: 100%|██████████| 779/779 [00:41<00:00, 18.74it/s, loss=0.474, v_num=32, test_loss=0.460, test_acc=86.90]

Epoch 1, global step 1402: 'test_acc' reached 86.94576 (best 86.94576), saving model to '/home/dl/Desktop/dl/object_detection_model_hw2/hw3/data/speech_commands/checkpoints/trans-kws-best-acc-v7.ckpt' as top 1


Epoch 2: 100%|██████████| 779/779 [00:41<00:00, 18.71it/s, loss=0.461, v_num=32, test_loss=0.488, test_acc=86.40]

Epoch 2, global step 2103: 'test_acc' was not in top 1


Epoch 3: 100%|██████████| 779/779 [00:41<00:00, 18.92it/s, loss=0.375, v_num=32, test_loss=0.482, test_acc=87.10]

Epoch 3, global step 2804: 'test_acc' reached 87.08549 (best 87.08549), saving model to '/home/dl/Desktop/dl/object_detection_model_hw2/hw3/data/speech_commands/checkpoints/trans-kws-best-acc-v7.ckpt' as top 1


Epoch 4: 100%|██████████| 779/779 [00:41<00:00, 18.91it/s, loss=0.43, v_num=32, test_loss=0.571, test_acc=85.40] 

Epoch 4, global step 3505: 'test_acc' was not in top 1


Epoch 5:  85%|████████▍ | 662/779 [00:36<00:06, 18.07it/s, loss=0.454, v_num=32, test_loss=0.571, test_acc=85.40]

In [18]:

trainer.test(model, datamodule=datamodule)

____________________ datamodule prepare data
____________________ setup 
____________________ datamodule prepare data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


____________________ datamodule test dataload
Testing DataLoader 0: 100%|██████████| 172/172 [00:02<00:00, 64.81it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             83.48890686035156
        test_loss           0.6210438013076782
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.6210438013076782, 'test_acc': 83.48890686035156}]

In [21]:
model = model.load_from_checkpoint(os.path.join(
    args.path, "checkpoints", "trans-kws-best-acc.ckpt"))
model.eval()

____________________ init
____________________ reset param


LitTransformer(
  (encoder): Transformer(
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=64, out_features=192, bias=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
      (1): Block(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=64, out_features=192, bias=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_

In [22]:

script = model.to_torchscript()

In [23]:

# Load most accurate checkpoint
model_path = os.path.join(args.path, "checkpoints",
                          "trans-kws-best-acc.pt")

# save for use in production environment
torch.jit.save(script, model_path)

In [56]:

# list wav files given a folder
label = CLASSES[2:]
label = np.random.choice(label)
path = os.path.join(args.path, "SpeechCommands/speech_commands_v0.02/")
path = os.path.join(path, label)
wav_files = [os.path.join(path, f)
             for f in os.listdir(path) if f.endswith('.wav')]
# select random wav file
wav_file = np.random.choice(wav_files)
waveform, sample_rate = torchaudio.load(wav_file)
transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                 n_fft=args.n_fft,
                                                 win_length=args.win_length,
                                                 hop_length=args.hop_length,
                                                 n_mels=args.n_mels,
                                                 power=2.0)

mel = ToTensor()(librosa.power_to_db(
    transform(waveform).squeeze().numpy(), ref=np.max))
    
scripted_module = torch.jit.load(model_path)
pred = torch.argmax(scripted_module(mel), dim=1)
print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")

Ground Truth: eight, Prediction: eight
