Importing needed dependencies

In [None]:
import re
import os
import string

import numpy as np
import librosa
import torch
import IPython.display as ipd

from tqdm import tqdm
from torchsummary import summary
from matplotlib import pyplot as plt

from datasets import load_dataset

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.amp import autocast, GradScaler


Selecting CUDA device if available and changing default fork strategy to spawn as it works better on UNIX systems(if you have Windows, switching back to fork is recommended)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.multiprocessing.set_start_method('spawn', force=True)

Preprocessing audio to mel-spectrograms and text to be tokenized according to token_vocabulary

In [None]:
class ASRDataset(Dataset):
    def __init__(self, dataset_name : str, token_vocabulary : list, dataset_split : str = 'train'):
        super().__init__()
        self.dataset = load_dataset(dataset_name, split=dataset_split)
        self.token_vocabulary = token_vocabulary

    def play_audio(self, idx : int):
        item = self.dataset[idx]
        audio_array = item['audio']['array']
        sampling_rate = item['audio']['sampling_rate']
        return ipd.Audio(audio_array, rate=sampling_rate, autoplay=True)

    def audio_to_mel(self, audio_array : np.ndarray, sample_rate : int, n_mels : int = 80, target_sample_rate : int = 16000) -> np.array:
        if sample_rate != target_sample_rate:
            audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate)
        mel_spectr = librosa.feature.melspectrogram(y=audio_array, sr=target_sample_rate, n_mels=n_mels)
        log_mel_spectr = librosa.power_to_db(mel_spectr, ref=np.max)
        log_mel_spectr_norm = librosa.util.normalize(log_mel_spectr, axis=0)
        return log_mel_spectr_norm

    def text_to_tokens(self, text : str) -> list[int]:
        text_clean = re.sub(r"[^a-z\s]", '', text.lower())
        tokens = [self.token_vocabulary.index(letter) for letter in text_clean]
        return tokens

    def __getitem__(self, index):
        item = self.dataset[index]
        audio_array = item['audio']['array']
        sampling_rate = item['audio']['sampling_rate']
        mel_spectrogram = self.audio_to_mel(audio_array, sampling_rate)

        text = item['text']
        text_tokens = self.text_to_tokens(text)

        output = {
            'input_values' : mel_spectrogram,
            'text_tokens' : torch.Tensor(text_tokens)
        }
        return output

    def __len__(self):
        return len(self.dataset)

Using english alphabet and whitespace as model vocabulary.

In [None]:
VOCABULARY = list(' ' + string.ascii_lowercase)
VOCAB_SIZE = len(VOCABULARY)

Downloading it takes few hours as it consists of 399 Gb of data, after that it takes few minutes to load from cache

In [None]:
train_ds = ASRDataset("MLCommons/peoples_speech", VOCABULARY, 'train')
test_ds = ASRDataset("MLCommons/peoples_speech", VOCABULARY, 'test')
val_ds = ASRDataset("MLCommons/peoples_speech", VOCABULARY, 'validation')

Encode and decode functions to test the solution

In [None]:
def encode_text(text):
    text_clean = re.sub(r"[^a-z\s]", '', text.lower())
    tokens = [VOCABULARY.index(letter) for letter in text_clean]
    return tokens

def decode_text(text_tokens):
    decoded_text = "".join(list(map(lambda x : VOCABULARY[int(x)], text_tokens)))
    text_formatted = decoded_text.rstrip().lstrip()
    return text_formatted

Dataset sample

In [None]:
print(decode_text(train_ds[0]['text_tokens']))
train_ds.play_audio(0)

In [None]:
plt.imshow(train_ds[0]['input_values'])

Preparing dataloader for training process. Input values would be padded to same length batch-wise. Lengths are needed for CTC loss.

In [None]:
def collate_fn(batch):
    features = []
    labels = []
    features_lengths = []
    labels_lengths = []
    for sample in batch:
        features.append(torch.Tensor(sample['input_values']).permute(1, 0))
        features_lengths.append(features[-1].shape[0])

        labels.append(sample['text_tokens'])
        labels_lengths.append((sample['text_tokens'] != 0).sum().item())
    features = pad_sequence(features, batch_first=True).permute(0, 2, 1)
    labels = pad_sequence(labels, batch_first=True)
    return features, labels, features_lengths, labels_lengths

In [None]:
BATCH_SIZE = 8

dataloader_train = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True)
dataloader_val = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True)
dataloader_test = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True)

Data after preprocessing

In [None]:
for sample_feature, sample_target, sample_lengths, target_lengths in dataloader_train:
    plt.imshow(sample_feature[0])
    print(decode_text(sample_target[1]))
    print(sample_lengths)
    print(target_lengths)
    break

Specifying speech-to-text Conformer inspired model

In [None]:
class FeedForwardModule(nn.Module):
    def __init__(self, input_dim : int, output_dim : int):
        super(FeedForwardModule, self).__init__()
        self.linear_relu = nn.Sequential(
            nn.Linear(input_dim, output_dim*4),
            nn.ReLU()
        )
        self.linear = nn.Linear(output_dim*4, output_dim)

    def forward(self, x):
        x = self.linear_relu(x)
        x = self.linear(x)
        return x

class ConformerModule(nn.Module):
    def __init__(self, embedding_dim : int, output_dim : int, num_heads : int, conv_kernel_size : int = 5, dropout_rate : float = 0.1):
        super(ConformerModule, self).__init__()
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads=num_heads)
        self.feed_forward1 = FeedForwardModule(embedding_dim, embedding_dim)
        self.conv_block = nn.Sequential(
            nn.Conv1d(embedding_dim, embedding_dim*2, 1),
            nn.ReLU(),
            nn.Conv1d(embedding_dim*2, embedding_dim, conv_kernel_size, padding=(conv_kernel_size-1)//2),
            nn.ReLU()
        )
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.feed_forward2 = FeedForwardModule(embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        attention_output, _ = self.attention(x, x, x)
        x = x + self.dropout(attention_output)

        ff_output1 = self.feed_forward1(x)
        x = x + self.dropout(ff_output1)

        x = x.permute(1, 2, 0)
        conv_output = self.conv_block(x)
        x = x + self.dropout(conv_output)
        x = x.permute(2, 0, 1)
        x = self.layer_norm(x)

        ff_output2 = self.feed_forward2(x)
        x = x + self.dropout(ff_output2)
        return x

class ASRModel(nn.Module):
    def __init__(self, n_mels : int, hidden_dim : int, vocab_size : int, dropout_rate : float = 0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, hidden_dim, 1)
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.conformer_block = nn.Sequential(
            ConformerModule(hidden_dim, hidden_dim, 8, dropout_rate=dropout_rate),
            ConformerModule(hidden_dim, hidden_dim, 8, dropout_rate=dropout_rate),
            ConformerModule(hidden_dim, hidden_dim, 8, dropout_rate=dropout_rate),
            ConformerModule(hidden_dim, hidden_dim, 8, dropout_rate=dropout_rate),
            ConformerModule(hidden_dim, hidden_dim, 8, dropout_rate=dropout_rate)
        )
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.conv1(x)
        x = x.permute(2, 0, 1)
        x = self.layer_norm1(x)
        x = self.conformer_block(x)
        x = self.layer_norm2(x)
        x = self.linear(x)
        return x

In [None]:
model = ASRModel(80, 512, VOCAB_SIZE).to(device)

summary(model, (80, 1), batch_size=4)

Declaring CTC loss

In [None]:
ctc_loss = nn.CTCLoss(blank=0)

def ctc_loss_fn(y_true, y_pred, target_lengths, pred_lengths):
    y_pred = torch.clamp(y_pred, min=1e-7)
    y_pred_log_softmax = nn.LogSoftmax(dim=2)(y_pred)
    loss = ctc_loss(y_pred_log_softmax, y_true, pred_lengths, target_lengths)
    return loss

In [None]:
history = {
    "loss" : [float('inf')],
    "val_loss" : [float('inf')],
    "test_loss" : [float('inf')]
}

As dataset and model are considerably large, we may encounter situations when VRAM is going to be overfitted. To avoid interrupting training process we'll ignore batches which causes overfitting, also we are going to use scaler and autocast in order to reduce computational cost.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scaler = GradScaler('cuda')

def train(model : nn.Module, dataloader : DataLoader, display_each : int = 20):
    global history
    out_of_mem = not_lost = 0
    accumulation_steps = 4
    loss_avg = 0
    model.train()
    pbar = tqdm(dataloader, desc=f"loss: {history['loss'][-1]:.2f}, out_of_mem: {out_of_mem/(not_lost + out_of_mem + 1e-6):.2f}")
    for batch_idx, (features_batch, targets_batch, feature_lengths, target_lengths) in enumerate(pbar):
        try:
            features_batch, targets_batch = features_batch.to(device, non_blocking=True), targets_batch.to(device, non_blocking=True)

            with autocast('cuda'):
                pred = model(features_batch)
                loss = ctc_loss_fn(targets_batch, pred, target_lengths, feature_lengths)
                loss_avg += loss.item()
            scaler.scale(loss).backward()

            if (batch_idx+1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            not_lost += 1

        except RuntimeError as e:
            if 'out of memory' in str(e):
                out_of_mem += 1
                torch.cuda.empty_cache()
            else:
                raise e

        if (batch_idx+1) % display_each == 0:
            history['loss'].append(loss_avg/display_each)
            pbar.set_description(f"loss: {loss_avg/display_each:.2f}, out_of_mem_ratio: {out_of_mem/(not_lost + out_of_mem):.4f}")
            loss_avg = 0

        del features_batch, targets_batch, feature_lengths, target_lengths

def validate(model : nn.Module, dataloader : DataLoader, display_each : int = 20):
    global history
    out_of_mem = not_lost = 0
    model.eval()
    loss_avg = 0
    pbar = tqdm(dataloader, desc=f"val_loss: {history['val_loss'][-1]}")
    with torch.no_grad():
        for batch_idx, (features_batch, targets_batch, feature_lengths, target_lengths) in enumerate(pbar):
            try:
                features_batch, targets_batch = features_batch.to(device, non_blocking=True), targets_batch.to(device, non_blocking=True)

                with autocast('cuda'):
                    pred = model(features_batch)
                    loss = ctc_loss_fn(targets_batch, pred, target_lengths, feature_lengths)
                    loss_avg += loss.item()

                not_lost += 1
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    out_of_mem += 1
                    torch.cuda.empty_cache()
                else:
                    raise e

            if (batch_idx+1) % display_each == 0:
                history['val_loss'].append(loss_avg/display_each)
                pbar.set_description(f"val_loss: {loss_avg/display_each:.2f}, out_of_mem_ratio: {out_of_mem/(not_lost + out_of_mem):.4f}")
                loss_avg = 0

            del features_batch, targets_batch, feature_lengths, target_lengths

def test(model : nn.Module, dataloader : DataLoader, display_each : int = 20):
    global history
    model.eval()
    loss_avg = 0
    pbar = tqdm(dataloader, desc=f"test_loss: {history['test_loss'][-1]:.2f}")
    with torch.no_grad():
        for batch_idx, (features_batch, targets_batch, feature_lengths, target_lengths) in enumerate(pbar):
            features_batch, targets_batch = features_batch.to(device, non_blocking=True), targets_batch.to(device, non_blocking=True)

            with autocast('cuda'):
                pred = model(features_batch)
                loss = ctc_loss_fn(targets_batch, pred, target_lengths, feature_lengths)
                loss_avg += loss.item()

            if (batch_idx+1) % display_each == 0:
                pbar.set_description(f"test_loss: {loss_avg/display_each:.4f}")
                history['test_loss'].append(loss_avg/display_each)
                loss_avg = 0

            del features_batch, targets_batch, feature_lengths, target_lengths


Starting the training process. After each epoch of training we'll save model if it outperforms itself on validation set.

In [None]:
checkpoint_path = './checkpoint'
EPOCHS = 10
best_val_loss = float('inf')

if not os.path.exists(checkpoint_path):
    os.mkdir(checkpoint_path)

for i in range(EPOCHS):
    print(f"Epoch : {i+1}/{EPOCHS}")
    train(model, dataloader_train)
    validate(model, dataloader_val)
    current_val_loss = history['val_loss'][-1]
    if current_val_loss < best_val_loss:
        best_val_loss = current_val_loss
        checkpoint_file = os.path.join(checkpoint_path, f'asr_conformer_best.pth')
        torch.save(model.state_dict(), checkpoint_file)
        print(f"New best model saved with val_loss: {best_val_loss:.4f} at {checkpoint_file}")
    else:
        print(f"No improvement in validation loss ({current_val_loss:.4f})")

Training loss history

In [None]:
plt.plot(history['loss'])

Testing model

In [None]:
test(model, dataloader_test)

Some experiments

In [None]:
def transcribe(mel_spec):
    model.eval()
    with torch.no_grad():
        pred = model(torch.Tensor([mel_spec]).to(device))
        pred = nn.Softmax(dim=2)(pred)
    chosen_tokens = torch.argmax(pred, dim=2)
    return chosen_tokens

In [None]:
i = 200
print(decode_text(transcribe(test_ds[i]['input_values'])))
plt.imshow(test_ds[i]['input_values'])


In [None]:
test_ds.play_audio(i)

CUDA memory summary

In [None]:
print(torch.cuda.memory_summary())