# Packages

In [1]:
import os

import lightning
import pandas as pd
import torch
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from load_data import AudioTrainDataset, TargetEncoder, RemoveSampleRate, PaddingZeros

# Simple loading

In [2]:
DATA_PATH = os.path.join("tensorflow-speech-recognition-challenge", "train", "audio")
dataset = AudioTrainDataset(DATA_PATH)

labels_list, labels_dict = dataset.find_classes(DATA_PATH)
labels_dict = {idx: name for name, idx in labels_dict.items()}
labels_dict

{1: 'bed',
 2: 'bird',
 3: 'cat',
 4: 'dog',
 5: 'down',
 6: 'eight',
 7: 'five',
 8: 'four',
 9: 'go',
 10: 'happy',
 11: 'house',
 12: 'left',
 13: 'marvin',
 14: 'nine',
 15: 'no',
 16: 'off',
 17: 'on',
 18: 'one',
 19: 'right',
 20: 'seven',
 21: 'sheila',
 22: 'silence',
 23: 'six',
 24: 'stop',
 25: 'three',
 26: 'tree',
 27: 'two',
 28: 'up',
 29: 'wow',
 30: 'yes',
 31: 'zero'}

In [3]:
NUM_WORKERS = 6
BATCH_SIZE = 256

# Simple Model

In [4]:
import torchaudio

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
extractor = bundle.get_model().to(device=torch.device('cuda'))
# freeze weights
for param in extractor.parameters():
    param.requires_grad = False

In [5]:
transforms = Compose([
    PaddingZeros(16000),
    RemoveSampleRate()
])
raw_dataset = AudioTrainDataset(DATA_PATH, target_transform=TargetEncoder(class_dict=labels_dict), transform=transforms)

In [6]:
gen = torch.Generator().manual_seed(42)
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(raw_dataset, [0.7, 0.1, 0.2],
                                                                           generator=gen)
len(train_dataset), len(valid_dataset), len(test_dataset)

(45587, 6512, 13024)

In [7]:
train_dataset_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True,
                                  generator=torch.random.manual_seed(123))
valid_dataset_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
len(train_dataset_loader), len(valid_dataset_loader), len(test_dataset_loader)

(179, 26, 51)

In [8]:
class MyLSTM(lightning.LightningModule):
    def __init__(self,
                 extractor: torchaudio.models.Wav2Vec2Model,
                 hidden_size,
                 num_layers,
                 target_size):
        super().__init__()
        self.extractor = extractor
        lstm_input_size = 768 * 12
        self.num_layers= num_layers
        self.hidden_size = hidden_size
        self.lstm = torch.nn.LSTM(lstm_input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.hidden2label = torch.nn.Linear(hidden_size, target_size)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=target_size)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=target_size)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=target_size)
        self.test_conf_mat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=target_size)

    def forward(self, x):
        x = x.squeeze()
        x, _ = self.extractor.extract_features(x)
        x = torch.cat(x, dim=-1)
        lstm_out, _ = self.lstm(x)
        label_space = self.hidden2label(lstm_out[:, -1])
        return self.softmax(label_space)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.train_acc(y_hat, torch.argmax(y, dim=-1))
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc_step", self.train_acc)
        return loss

    def on_train_epoch_end(self):
        self.log('train_acc', self.train_acc)

    def predict_step(self, batch, batch_idx, dataloader_idx = 0):
        x, y = batch
        return self(x)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.valid_acc(y_hat, torch.argmax(y, dim=-1))
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', self.valid_acc, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        y_class = torch.argmax(y, dim=-1)
        self.test_acc(y_hat, y_class)
        self.test_conf_mat(y_hat, y_class)
        self.log('test_loss', loss, on_epoch=True)
        self.log('test_acc', self.test_acc, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.3, patience=3),
                "monitor": "val_loss",
            }
        }

In [9]:
for batch_x, batch_y in train_dataset_loader:
    print(batch_x.shape)
    break

torch.Size([256, 16000])


In [10]:
model = MyLSTM(extractor, 32, 1, 12)
for batch_x,  batch_y in train_dataset_loader:
    y_hat = model(batch_x.to(device=torch.device('cuda')))
    break

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu

In [11]:
model = MyLSTM(extractor, 32, 1, 12)
trainer = lightning.Trainer(max_epochs=2, logger=True)
torch.set_float32_matmul_precision('medium')
trainer.fit(model, train_dataloaders=train_dataset_loader, val_dataloaders=valid_dataset_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | extractor     | Wav2Vec2Model             | 94.4 M
1 | lstm          | LSTM                      | 1.2 M 
2 | hidden2label  | Linear                    | 396   
3 | softmax       | Softmax                   | 0     
4 | train_acc     | MulticlassAccuracy        | 0     
5 | valid_acc     | MulticlassAccuracy        | 0     
6 | test_acc      | MulticlassAccuracy        | 0     
7 | test_conf_mat | MulticlassConfusionMatrix | 0     
------------------------------------------------------------
1.2 M     Trainable params
94.4 M    Non-trainable params
95.6 M    Total params
382.311   Total estimated model params size (MB)


Epoch 0: 100%|██████████| 179/179 [00:56<00:00,  3.16it/s, v_num=28, train_loss_step=1.890]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|▍         | 1/26 [00:00<00:10,  2.33it/s][A
Validation DataLoader 0:   8%|▊         | 2/26 [00:00<00:08,  2.92it/s][A
Validation DataLoader 0:  12%|█▏        | 3/26 [00:00<00:07,  3.16it/s][A
Validation DataLoader 0:  15%|█▌        | 4/26 [00:01<00:06,  3.31it/s][A
Validation DataLoader 0:  19%|█▉        | 5/26 [00:01<00:06,  3.48it/s][A
Validation DataLoader 0:  23%|██▎       | 6/26 [00:01<00:05,  3.58it/s][A
Validation DataLoader 0:  27%|██▋       | 7/26 [00:01<00:05,  3.67it/s][A
Validation DataLoader 0:  31%|███       | 8/26 [00:02<00:04,  3.71it/s][A
Validation DataLoader 0:  35%|███▍      | 9/26 [00:02<00:04,  3.77it/s][A
Validation DataLoader 0:  38%|███▊      | 10/26 [00:02<00:04,  3.80it/s][A
Va

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 179/179 [01:19<00:00,  2.25it/s, v_num=28, train_loss_step=2.040, train_loss_epoch=1.990]


In [None]:
results = []
predictions = []
for i in range(5):
    lightning.pytorch.seed_everything(i)
    model = MyLSTM(extractor, 32, 1, 12)
    early_stopping = lightning.pytorch.callbacks.EarlyStopping('val_loss', verbose=True)
    logger = lightning.pytorch.loggers.tensorboard.TensorBoardLogger(save_dir="cnn_lstm", version=i)
    trainer = lightning.Trainer(max_epochs=200, callbacks=[early_stopping], logger=logger)
    trainer.fit(model, train_dataloaders=train_dataset_loader, val_dataloaders=valid_dataset_loader)
    res = trainer.test(dataloaders=test_dataset_loader, ckpt_path='best')
    test_pred_tensor = torch.cat(trainer.predict(dataloaders=test_dataset_loader, ckpt_path='best'))
    results.append(res[0])
    predictions.append(test_pred_tensor)
torch.save(torch.stack(predictions), "wav2vec_lstm.ts")
pd.DataFrame(results).to_csv("wav2vec_lstm.csv")

Global seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | extractor     | Wav2Vec2Model             | 94.4 M
1 | lstm          | LSTM                      | 1.2 M 
2 | hidden2label  | Linear                    | 396   
3 | softmax       | Softmax                   | 0     
4 | train_acc     | MulticlassAccuracy        | 0     
5 | valid_acc     | MulticlassAccuracy        | 0     
6 | test_acc      | MulticlassAccuracy        | 0     
7 | test_conf_mat | MulticlassConfusionMatrix | 0     
------------------------------------------------------------
1.2 M     Trainable params
94.4 M    Non-trainable params
95.6 M    Total params
382.311  

Epoch 0: 100%|██████████| 179/179 [00:57<00:00,  3.12it/s, v_num=0, train_loss_step=2.040]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|▍         | 1/26 [00:00<00:13,  1.87it/s][A
Validation DataLoader 0:   8%|▊         | 2/26 [00:00<00:09,  2.58it/s][A
Validation DataLoader 0:  12%|█▏        | 3/26 [00:01<00:07,  2.96it/s][A
Validation DataLoader 0:  15%|█▌        | 4/26 [00:01<00:07,  3.14it/s][A
Validation DataLoader 0:  19%|█▉        | 5/26 [00:01<00:06,  3.34it/s][A
Validation DataLoader 0:  23%|██▎       | 6/26 [00:01<00:05,  3.48it/s][A
Validation DataLoader 0:  27%|██▋       | 7/26 [00:01<00:05,  3.60it/s][A
Validation DataLoader 0:  31%|███       | 8/26 [00:02<00:04,  3.69it/s][A
Validation DataLoader 0:  35%|███▍      | 9/26 [00:02<00:04,  3.77it/s][A
Validation DataLoader 0:  38%|███▊      | 10/26 [00:02<00:04,  3.83it/s][A
Val

Metric val_loss improved. New best score: 1.980


Epoch 1: 100%|██████████| 179/179 [00:57<00:00,  3.12it/s, v_num=0, train_loss_step=1.830, train_loss_epoch=2.040]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|▍         | 1/26 [00:00<00:11,  2.19it/s][A
Validation DataLoader 0:   8%|▊         | 2/26 [00:00<00:08,  2.91it/s][A
Validation DataLoader 0:  12%|█▏        | 3/26 [00:00<00:07,  3.22it/s][A
Validation DataLoader 0:  15%|█▌        | 4/26 [00:01<00:06,  3.49it/s][A
Validation DataLoader 0:  19%|█▉        | 5/26 [00:01<00:05,  3.68it/s][A
Validation DataLoader 0:  23%|██▎       | 6/26 [00:01<00:05,  3.81it/s][A
Validation DataLoader 0:  27%|██▋       | 7/26 [00:01<00:04,  3.92it/s][A
Validation DataLoader 0:  31%|███       | 8/26 [00:01<00:04,  4.00it/s][A
Validation DataLoader 0:  35%|███▍      | 9/26 [00:02<00:04,  4.07it/s][A
Validation DataLoader 0:  38%|███▊      | 10/26 [00:02<

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.978


Epoch 2: 100%|██████████| 179/179 [00:57<00:00,  3.12it/s, v_num=0, train_loss_step=2.200, train_loss_epoch=1.990]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|▍         | 1/26 [00:00<00:11,  2.27it/s][A
Validation DataLoader 0:   8%|▊         | 2/26 [00:00<00:07,  3.03it/s][A
Validation DataLoader 0:  12%|█▏        | 3/26 [00:00<00:06,  3.40it/s][A
Validation DataLoader 0:  15%|█▌        | 4/26 [00:01<00:06,  3.62it/s][A
Validation DataLoader 0:  19%|█▉        | 5/26 [00:01<00:05,  3.78it/s][A
Validation DataLoader 0:  23%|██▎       | 6/26 [00:01<00:05,  3.84it/s][A
Validation DataLoader 0:  27%|██▋       | 7/26 [00:01<00:04,  3.94it/s][A
Validation DataLoader 0:  31%|███       | 8/26 [00:01<00:04,  4.02it/s][A
Validation DataLoader 0:  35%|███▍      | 9/26 [00:02<00:04,  4.09it/s][A
Validation DataLoader 0:  38%|███▊      | 10/26 [00:02<

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 1.978


Epoch 3: 100%|██████████| 179/179 [00:56<00:00,  3.19it/s, v_num=0, train_loss_step=1.880, train_loss_epoch=1.990]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/26 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|▍         | 1/26 [00:00<00:11,  2.12it/s][A
Validation DataLoader 0:   8%|▊         | 2/26 [00:00<00:08,  2.87it/s][A
Validation DataLoader 0:  12%|█▏        | 3/26 [00:00<00:07,  3.17it/s][A
Validation DataLoader 0:  15%|█▌        | 4/26 [00:01<00:06,  3.41it/s][A
Validation DataLoader 0:  19%|█▉        | 5/26 [00:01<00:05,  3.53it/s][A
Validation DataLoader 0:  23%|██▎       | 6/26 [00:01<00:05,  3.59it/s][A
Validation DataLoader 0:  27%|██▋       | 7/26 [00:01<00:05,  3.68it/s][A
Validation DataLoader 0:  31%|███       | 8/26 [00:02<00:04,  3.67it/s][A
Validation DataLoader 0:  35%|███▍      | 9/26 [00:02<00:04,  3.74it/s][A
Validation DataLoader 0:  38%|███▊      | 10/26 [00:02<

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 1.978


Epoch 4:  19%|█▉        | 34/179 [00:23<01:39,  1.46it/s, v_num=0, train_loss_step=2.000, train_loss_epoch=1.990] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
Restoring states from the checkpoint path at cnn_lstm\lightning_logs\version_0\checkpoints\epoch=3-step=716.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at cnn_lstm\lightning_logs\version_0\checkpoints\epoch=3-step=716.ckpt
