# ContextNet implementation
Code inspired by: https://github.com/upskyy/ContextNet

Author: Borghini Alessia

Check the notebook GPU

In [None]:
!nvidia-smi

In [None]:
# Mount Google Drive
from google.colab import drive 

drive.mount("/content/drive")

# Import prerequirements

Install libraries

In [None]:
!pip install -r drive/My\ Drive/Speech_Recognition/requirements.txt

Import libraries and set seed

In [None]:
import matplotlib.pyplot as plt

import torch 
from torch import Tensor
from torch.utils.data import DataLoader, Sampler
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

import torchaudio

import numpy as np
import librosa
from tqdm import tqdm
from typing import Tuple, Optional
import math

import warp_rnnt

import jiwer
import sentencepiece

import transformers

import os
import csv

PATH = "drive/MyDrive/Speech_Recognition/"
torch.manual_seed(42)

# Librispeech Dataset

In [None]:
os.mkdir("datasets")

Downloading LibriSpeech dataset

In [None]:
# Download LibriSPeech train-clean-100 subset
# !cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz

# Download LibriSPeech train-clean-360 subset
# !cd datasets && wget https://www.openslr.org/resources/12/train-clean-360.tar.gz && tar xzf train-clean-360.tar.gz

# Download LibriSPeech dev-clean subset
!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz

# Download LibriSPeech dev-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz

# Download LibriSPeech test-clean subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-clean.tar.gz && tar xzf test-clean.tar.gz

# Download LibriSPeech test-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-other.tar.gz && tar xzf test-other.tar.gz

LibriSpeech dataset class

In [None]:
import glob 

class LibriSpeechDataset():
    def __init__(self, split, tokenizer):

        print("Creating dataset..")
        self.data_names = glob.glob("datasets/LibriSpeech/" + split + "*/*/*/*.flac")
        self.vocab_type = "bpe"
        self.vocab_size = 1000

        # if split == "train-clean-100":
        label_paths = []
        sentences = []

        for file_path in glob.glob("datasets/LibriSpeech/" + split + "/*/*/*.txt"):
            for line in open(file_path, "r").readlines():
                label_paths.append(file_path.replace(file_path.split("/")[-1], "") + line.split()[0] + "." + self.vocab_type + "_" + str(self.vocab_size))
                sentences.append(line[len(line.split()[0]) + 1:-1].lower())

        for (sentence, label_path) in tqdm(zip(sentences, label_paths)):
            # Tokenize and Save label
            label = torch.LongTensor(tokenizer.encode(sentence))
            torch.save(label, label_path)

            # Save Audio length
            audio_length = torchaudio.load(label_path.split(".")[0] + ".flac")[0].size(1)
            torch.save(audio_length, label_path.split(".")[0] + ".flac_len")

            # Save Label length
            label_length = label.size(0)
            torch.save(label_length, label_path + "_len")
                
        print("Done.")

    def __getitem__(self, i):
        return [torchaudio.load(self.data_names[i])[0], 
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size)),
                torch.load(self.data_names[i] + "_len"),
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size) + "_len")]
    
    def __len__(self):
        return len(self.data_names)

Training the tokenizer (uncomment to create a new tokenizer)

In [None]:
# corpus_path = "datasets/LibriSpeech/train-clean-100_corpus.txt"

# # Create Corpus File
# if not os.path.isfile(corpus_path):
#     print("Create Corpus File")
#     corpus_file = open(corpus_path, "w")
#     for file_path in glob.glob("datasets/LibriSpeech/*/*/*/*.txt"):
#         for line in open(file_path, "r").readlines():
#             corpus_file.write(line[len(line.split()[0]) + 1:-1].lower() + "\n")

# # Train Tokenizer
# print("Training Tokenizer")
# sentencepiece.SentencePieceTrainer.train(input=corpus_path, 
#                                          model_prefix="LibriSpeech_bpe_256", 
#                                          vocab_size=256, 
#                                          character_coverage=1.0, 
#                                          model_type="bpe", 
#                                          bos_id=-1, 
#                                          eos_id=-1, 
#                                          unk_surface="")
# print("Training Done")

Build dataset

In [None]:
tokenizer_path = PATH + "LibriSpeech_bpe_256.model"
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)

# train_dataset = LibriSpeechDataset("train-clean-100", tokenizer)
dev_dataset = LibriSpeechDataset("dev-clean", tokenizer)

# train_dataset.__getitem__(0)

Build dataloader

In [None]:
def collate_fn(batch):
    # Sorting sequences by lengths
    sorted_batch = sorted(batch, key=lambda x: x[0].shape[1], reverse=True)

    # Pad data sequences
    data = [item[0].squeeze() for item in sorted_batch]
    data_lengths = torch.tensor([len(d) for d in data],dtype=torch.long).cuda()
    data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0).cuda()

    # Pad labels
    target = [item[1] for item in sorted_batch]
    target_lengths = torch.tensor([t.size(0) for t in target],dtype=torch.long).cuda()
    target = torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=0).cuda()

    return data, target, data_lengths, target_lengths

# train_dataloader = torch.utils.data.DataLoader(train_dataset,
#                                                batch_size=4,
#                                                shuffle=True,
#                                                collate_fn=collate_fn,
#                                                drop_last=True)

dev_dataloader = torch.utils.data.DataLoader(dev_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             collate_fn=collate_fn,
                                             drop_last=True)

# Trainer

In [None]:
class Trainer():
    def __init__(self,
                 model: nn.Module,
                 optimizer,
                 scheduler,
                 scaler,
                 path,
                 save_best_model = True,
                 logging = True,
                 accumulated_steps = 16):
      
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scaler = scaler
        self.accumulated_steps = accumulated_steps

        self.best_wer = 1
        self.path = path
        self.save_best_model = save_best_model
        self.logging = logging

    def train(self, 
              train_dataset,
              dev_dataset,
              epochs:int=20):
        
        print("Training...")

        train_loss = 0.0
        total_loss_train = []
        total_loss_dev = []

        self.optimizer.zero_grad()

        for epoch in range(epochs):
            print(" Epoch {:03d}".format(epoch + 1))

            epoch_loss = 0.0

            # train mode on
            self.model.train()

            for step, batch in enumerate(tqdm(train_dataset)):        
                inputs, targets, input_len, target_len = batch

                # Automatic Mixed Precision Casting (model prediction + loss computing)
                with torch.cuda.amp.autocast():
                    pred, pred_len = self.model.forward(inputs, input_len, targets, target_len, train=True)
                    loss_mini  = warp_rnnt.rnnt_loss(
                                    log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                    labels=targets.int(),
                                    frames_lengths=pred_len.int(),
                                    labels_lengths=target_len.int(),
                                    average_frames=False,
                                    reduction='mean',
                                    blank=0,
                                    gather=True)
                    loss = loss_mini / self.accumulated_steps

                # Accumulate gradients
                scaler.scale(loss).backward()

                # Update Epoch Variables
                epoch_loss += loss_mini.detach()

                if step % self.accumulated_steps == 0:
                    # Update Parameters, Zero Gradients and Update Learning Rate
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
                    self.scheduler.step()

                    # Step Print
                    print("mean loss {:.4f} - batch loss: {:.4f} - learning rate: {:.6f}".format(epoch_loss / (step + 1), loss_mini, optimizer.param_groups[0]['lr']))

            avg_epoch_loss = epoch_loss / len(train_dataset)
            print('\t[E: {:2d}] train loss = {:0.4f}'.format(epoch, avg_epoch_loss))

            wer, speech_true, speech_pred, val_loss = self.evaluate(dev_dataset)
            print('\t[E: {:2d}] valid loss = {:0.4f}, wer = {:0.4f}, loss = {:0.4f}'.format(epoch, val_loss, wer, avg_epoch_loss))

            if self.save_best_model:
                self.save_model(wer, avg_epoch_loss, optimizer.param_groups[0]['lr'], self.path)

        print("...Done!")
        return avg_epoch_loss

    def evaluate(self,
              dev_dataset):
        
        self.model.eval()

        speech_true = []
        speech_pred = []
        total_wer = 0.0
        total_loss = 0.0

        # Evaluation Loop
        for step, batch in enumerate(tqdm(dev_dataset)):

            inputs, targets, input_len, target_len = batch

            # Sequence Prediction
            with torch.no_grad():

                outputs_pred = greedy_search_decoding(inputs, input_len)

            # Sequence Truth
            outputs_true = tokenizer.decode(targets.tolist())

            # Compute Batch wer and Update total wer
            batch_wer = jiwer.wer(outputs_true, outputs_pred, standardize=True)
            total_wer += batch_wer

            # Update String lists
            speech_true += outputs_true
            speech_pred += outputs_pred

            # Prediction Verbose
            print("Groundtruths :\n", outputs_true)
            print("Predictions :\n", outputs_pred)

            # Eval Loss
            with torch.no_grad():
                pred, pred_len = self.model.forward(inputs, input_len, targets, target_len, train=False)
                batch_loss = warp_rnnt.rnnt_loss(
                                    log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                    labels=targets.int(),
                                    frames_lengths=pred_len.int(),
                                    labels_lengths=target_len.int(),
                                    average_frames=False,
                                    reduction='mean',
                                    blank=0,
                                    gather=True)
                # batch_loss = self.loss(pred, targets, pred_len, target_len)
                total_loss += batch_loss

            # Step print
            print("mean batch wer {:.2f}% - batch wer: {:.2f}% - mean loss {:.4f} - batch loss: {:.4f}".format(100 * total_wer / (step + 1), 100 * batch_wer, total_loss / (step + 1), batch_loss))

        # Compute wer
        if total_wer / dev_dataset.__len__() > 1:
            wer = 1
        else:
            wer = jiwer.wer(speech_true, speech_pred, standardize=True)

        # Compute loss
        loss = total_loss / dev_dataset.__len__()

        return wer, speech_true, speech_pred, loss

    def save_model(self, wer, loss, lr, path):
        if wer < self.best_wer:
            torch.save(self.model.state_dict(), f"{path}.pth")
            torch.save(self.scheduler.state_dict(), f"{path}_scheduler.pth")
            torch.save(self.optimizer.state_dict(), f"{path}_optimizer.pth")

        if self.logging:
            with open(f"{path}.tsv", "a") as log:
                log.write(f"{lr}\t{wer}\t{loss}\n")
               

Greedy search decoding strategy

In [None]:
def greedy_search_decoding(x, x_len):

    # Predictions String List
    preds = []

    # Forward Encoder (B, Taud) -> (B, T, Denc)
    f, f_len = model.encoder(x, x_len, train=False)

    # Batch loop
    for b in range(x.size(0)): # One sample at a time for now, not batch optimized

        # Init y and hidden state
        y = x.new_zeros(1, 1, dtype=torch.long)
        hidden = None

        enc_step = 0
        consec_dec_step = 0

        # Decoder loop
        while enc_step < f_len[b]:

            # Forward Decoder (1, 1) -> (1, 1, Ddec)
            g, hidden = model.decoder(y[:, -1:], hidden_states=hidden)
            
            # Joint Network loop
            while enc_step < f_len[b]:

                # Forward Joint Network (1, 1, Denc) and (1, 1, Ddec) -> (1, V)
                logits = model.joint(f[b:b+1, enc_step], g[:, 0])

                # Token Prediction
                pred = logits.softmax(dim=-1).log().argmax(dim=-1) # (1)

                # Null token or max_consec_dec_step
                if pred == 0 or consec_dec_step == 5:
                    consec_dec_step = 0
                    enc_step += 1
                # Token
                else:
                    consec_dec_step += 1
                    y = torch.cat([y, pred.unsqueeze(0)], axis=-1)
                    break

        # Decode Label Sequence
        pred = tokenizer.decode(y[:, 1:].tolist())
        preds += pred

    return preds

# Model

### Activation function

In [None]:
class Swish(nn.Module):
    r"""
    In the ContextNet paper, the swish activation function works consistently better than ReLU.
    """
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, inputs: Tensor) -> Tensor:
        return inputs * inputs.sigmoid()

### Squeeze-and-excitation

In [None]:
class SELayer(nn.Module):
    r"""
    Squeeze-and-excitation module.
    Args:
        dim (int): Dimension to be used for two fully connected (FC) layers
    Inputs: inputs, input_lengths
        - **inputs**: The output of the last convolution layer. `FloatTensor` of size
            ``(batch, dimension, seq_length)``
        - **input_lengths**: The length of input tensor. ``(batch)``
    Returns: output
        - **output**: Output of SELayer `FloatTensor` of size
            ``(batch, dimension, seq_length)``
    """
    def __init__(self, dim: int) -> None:
        super(SELayer, self).__init__()
        assert dim % 8 == 0, 'Dimension should be divisible by 8.'

        self.dim = dim
        self.sequential = nn.Sequential(
            nn.Linear(dim, dim // 8),
            Swish(),
            nn.Linear(dim // 8, dim),
        )

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Forward propagate a `inputs` for SE Layer.
        Args:
            **inputs** (torch.FloatTensor): The output of the last convolution layer. `FloatTensor` of size
                ``(batch, dimension, seq_length)``
            **input_lengths** (torch.LongTensor): The length of input tensor. ``(batch)``
        Returns:
            **output** (torch.FloatTensor): Output of SELayer `FloatTensor` of size
                ``(batch, dimension, seq_length)``
        """
        residual = inputs
        seq_lengths = inputs.size(2)

        inputs = inputs.sum(dim=2) / input_lengths.unsqueeze(1)
        output = self.sequential(inputs)

        output = output.sigmoid().unsqueeze(2)
        output = output.repeat(1, 1, seq_lengths)

        return output * residual

### Convolutional layer

In [None]:
class ConvLayer(nn.Module):
    """
    When the stride is 1, it pads the input so the output has the shape as the input.
    And when the stride is 2, it does not pad the input.
    Args:
        in_channels (int): Input channel in convolutional layer
        out_channels (int): Output channel in convolutional layer
        kernel_size (int, optional): Value of convolution kernel size (default : 5)
        stride(int, optional): Value of stride (default : 1)
        padding (int, optional): Value of padding (default: 0)
        activation (bool, optional): Flag indication use activation function or not (default : True)
        groups(int, optional): Value of groups (default : 1)
        bias (bool, optional): Flag indication use bias or not (default : True)
    Inputs: inputs, input_lengths
        - **inputs**: Input of convolution layer `FloatTensor` of size ``(batch, dimension, seq_length)``
        - **input_lengths**: The length of input tensor. ``(batch)``
    Returns: output, output_lengths
        - **output**: Output of convolution layer `FloatTensor` of size
                ``(batch, dimension, seq_length)``
        - **output_lengths**: The length of output tensor. ``(batch)``
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 5,
            stride: int = 1,
            padding: int = 0,
            activation: bool = True,
            groups: int = 1,
            bias: bool = True,
    ):
        super(ConvLayer, self).__init__()
        assert kernel_size == 5, "The convolution layer in the ContextNet model has 5 kernels."

        if stride == 1:
            self.conv = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                dilation=1,
                padding=(kernel_size - 1) // 2,
                groups=groups,
                bias=bias,
            )
        elif stride == 2:
            self.conv = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                dilation=1,
                padding=padding,
                groups=groups,
                bias=bias,
            )

        self.batch_norm = nn.BatchNorm1d(num_features=out_channels)
        self.activation = activation

        if self.activation:
            self.swish = Swish()

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Forward propagate a `inputs` for convolution layer.
        Args:
            **inputs** (torch.FloatTensor): Input of convolution layer `FloatTensor` of size
                ``(batch, dimension, seq_length)``
            **input_lengths** (torch.LongTensor): The length of input tensor. ``(batch)``
        Returns:
            **output** (torch.FloatTensor): Output of convolution layer `FloatTensor` of size
                ``(batch, dimension, seq_length)``
            **output_lengths** (torch.LongTensor): The length of output tensor. ``(batch)``
        """
        outputs, output_lengths = self.conv(inputs), self._get_sequence_lengths(input_lengths)
        outputs = self.batch_norm(outputs)

        if self.activation:
            outputs = self.swish(outputs)

        return outputs, output_lengths

    def _get_sequence_lengths(self, seq_lengths):
        # return (
        #         (seq_lengths + 2 * self.conv.padding[0]
        #          - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1) // self.conv.stride[0] + 1
        # )
        a = seq_lengths + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
        b = self.conv.stride[0]
        return torch.div(a, b, rounding_mode='floor') + 1

### Preprocessing

In [None]:
class AudioPreprocessing(nn.Module):

    """Audio Preprocessing
    Computes mel-scale log filter banks spectrogram
    Args:
        sample_rate: Audio sample rate
        n_fft: FFT frame size, creates n_fft // 2 + 1 frequency bins.
        win_length_ms: FFT window length in ms, must be <= n_fft
        hop_length_ms: length of hop between FFT windows in ms
        n_mels: number of mel filter banks
        normalize: whether to normalize mel spectrograms outputs
        mean: training mean
        std: training std
    Shape:
        Input: (batch_size, audio_len)
        Output: (batch_size, n_mels, audio_len // hop_length + 1)
    
    """

    def __init__(self, 
                 sample_rate, 
                 n_fft, 
                 win_length_ms, 
                 hop_length_ms, 
                 n_mels, 
                 normalize, 
                 mean, 
                 std):
        super(AudioPreprocessing, self).__init__()
        self.win_length = int(sample_rate * win_length_ms) // 1000
        self.hop_length = int(sample_rate * hop_length_ms) // 1000
        self.Spectrogram = torchaudio.transforms.Spectrogram(n_fft, self.win_length, self.hop_length)
        self.MelScale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min=0, f_max=8000, n_stft=n_fft // 2 + 1)
        self.normalize = normalize
        self.mean = mean
        self.std = std

    def forward(self, x, x_len):

        # Short Time Fourier Transform (B, T) -> (B, n_fft // 2 + 1, T // hop_length + 1)
        x = self.Spectrogram(x)

        # Mel Scale (B, n_fft // 2 + 1, T // hop_length + 1) -> (B, n_mels, T // hop_length + 1)
        x = self.MelScale(x)
        
        # Energy log, autocast disabled to prevent float16 overflow
        x = (x.float() + 1e-9).log().type(x.dtype)

        # Compute Sequence lengths 
        if x_len is not None:
            x_len = torch.div(x_len, self.hop_length, rounding_mode='floor') + 1

        # Normalize
        if self.normalize:
            x = (x - self.mean) / self.std

        x = x.transpose(1,2)
        
        return x, x_len

class SpecAugment(nn.Module):

    """Spectrogram Augmentation
    Args:
        spec_augment: whether to apply spec augment
        mF: number of frequency masks
        F: maximum frequency mask size
        mT: number of time masks
        pS: adaptive maximum time mask size in %
    References:
        SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition, Park et al.
        https://arxiv.org/abs/1904.08779
        SpecAugment on Large Scale Datasets, Park et al.
        https://arxiv.org/abs/1912.05533
    """

    def __init__(self, spec_augment, mF, F, mT, pS):
        super(SpecAugment, self).__init__()
        self.spec_augment = spec_augment
        self.mF = mF
        self.F = F
        self.mT = mT
        self.pS = pS

    def forward(self, x, x_len):

        # Spec Augment
        if self.spec_augment:
        
            # Frequency Masking
            for _ in range(self.mF):
                x = torchaudio.transforms.FrequencyMasking(freq_mask_param=self.F, iid_masks=False).forward(x)

            # Time Masking
            for b in range(x.size(0)):
                T = int(self.pS * x_len[b])
                for _ in range(self.mT):
                    x[b, :, :x_len[b]] = torchaudio.transforms.TimeMasking(time_mask_param=T).forward(x[b, :, :x_len[b]])

        return x

### Convolution Block

In [None]:
class ConvBlock(nn.Module):
    """
    Convolution block contains a number of convolutions, each followed by batch normalization and activation.
    Squeeze-and-excitation (SE) block operates on the output of the last convolution layer.
    Skip connection with projection is applied on the output of the squeeze-and-excitation block.
    Args:
        in_channels (int): Input channel in convolutional layer
        out_channels (int): Output channel in convolutional layer
        num_layers (int, optional): The number of convolutional layers (default : 5)
        kernel_size (int, optional): Value of convolution kernel size (default : 5)
        stride(int, optional): Value of stride (default : 1)
        padding (int, optional): Value of padding (default: 0)
        residual (bool, optional): Flag indication residual or not (default : True)
    Inputs: inputs, input_lengths
        - **inputs**: Input of convolution block `FloatTensor` of size ``(batch, dimension, seq_length)``
        - **input_lengths**: The length of input tensor. ``(batch)``
    Returns: output, output_lengths
        - **output**: Output of convolution block `FloatTensor` of size
                ``(batch, dimension, seq_length)``
        - **output_lengths**: The length of output tensor. ``(batch)``
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            num_layers: int = 5,
            kernel_size: int = 5,
            stride: int = 1,
            padding: int = 0,
            residual: bool = True,
    ) -> None:
        super(ConvBlock, self).__init__()
        self.num_layers = num_layers
        self.swish = Swish()
        self.se_layer = SELayer(out_channels)
        self.residual = None

        if residual:
            self.residual = ConvLayer(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                activation=False,
            )

        if self.num_layers == 1:
            self.conv_layers = ConvLayer(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                    )

        else:
            stride_list = [1 for _ in range(num_layers - 1)] + [stride]
            in_channel_list = [in_channels] + [out_channels for _ in range(num_layers - 1)]

            self.conv_layers = nn.ModuleList(list())
            for in_channels, stride in zip(in_channel_list, stride_list):
                self.conv_layers.append(
                    ConvLayer(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                    )
                )

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Forward propagate a `inputs` for convolution block.
        Args:
            **inputs** (torch.FloatTensor): Input of convolution block `FloatTensor` of size
                ``(batch, dimension, seq_length)``
            **input_lengths** (torch.LongTensor): The length of input tensor. ``(batch)``
        Returns:
            **output** (torch.FloatTensor): Output of convolution block `FloatTensor` of size
                ``(batch, dimension, seq_length)``
            **output_lengths** (torch.LongTensor): The length of output tensor. ``(batch)``
        """
        output = inputs
        output_lengths = input_lengths

        if self.num_layers == 1:
            output, output_lengths = self.conv_layers(output, output_lengths)
        else:
            for conv_layer in self.conv_layers:
                output, output_lengths = conv_layer(output, output_lengths)

        output = self.se_layer(output, output_lengths)

        if self.residual is not None:
            residual, _ = self.residual(inputs, input_lengths)
            output += residual

        return self.swish(output), output_lengths

    @staticmethod
    def make_conv_blocks(
            input_dim: int = 80,
            num_layers: int = 5,
            kernel_size: int = 5,
            num_channels: int = 256,
            output_dim: int = 640,
    ) -> nn.ModuleList:
        r"""
        Create 23 convolution blocks.
        Args:
            input_dim (int, optional): Dimension of input vector (default : 80)
            num_layers (int, optional): The number of convolutional layers (default : 5)
            kernel_size (int, optional): Value of convolution kernel size (default : 5)
            num_channels (int, optional): The number of channels in the convolution filter (default: 256)
            output_dim (int, optional): Dimension of encoder output vector (default: 640)
        Returns:
            **conv_blocks** (nn.ModuleList): ModuleList with 23 convolution blocks
        """
        conv_blocks = nn.ModuleList()

        # C0 : 1 conv layer, init_dim output channels, stride 1, no residual
        conv_blocks.append(ConvBlock(input_dim, 
                                     num_channels, 
                                     1, 
                                     kernel_size, 
                                     1, 
                                     0, 
                                     False))

        # C1-2 : 5 conv layers, init_dim output channels, stride 1
        for _ in range(1, 2 + 1):
            conv_blocks.append(ConvBlock(num_channels, 
                                         num_channels, 
                                         num_layers, 
                                         kernel_size, 
                                         1, 
                                         0, 
                                         True))

        # C3 : 5 conv layer, init_dim output channels, stride 2
        conv_blocks.append(ConvBlock(num_channels, 
                                     num_channels, 
                                     num_layers, 
                                     kernel_size, 
                                     2, 
                                     0, 
                                     True))

        # C4-6 : 5 conv layers, init_dim output channels, stride 1
        for _ in range(4, 6 + 1):
            conv_blocks.append(ConvBlock(num_channels, 
                                         num_channels, 
                                         num_layers, 
                                         kernel_size, 
                                         1, 
                                         0, 
                                         True))

        # C7 : 5 conv layers, init_dim output channels, stride 2
        conv_blocks.append(ConvBlock(num_channels, 
                                     num_channels, 
                                     num_layers, 
                                     kernel_size, 
                                     2, 
                                     0, 
                                     True))

        # C8-10 : 5 conv layers, init_dim output channels, stride 1
        for _ in range(8, 10 + 1):
            conv_blocks.append(ConvBlock(num_channels, 
                                         num_channels, 
                                         num_layers, 
                                         kernel_size, 
                                         1, 
                                         0, 
                                         True))

        # C11-13 : 5 conv layers, middle_dim output channels, stride 1
        conv_blocks.append(ConvBlock(num_channels, 
                                     num_channels << 1, 
                                     num_layers, 
                                     kernel_size, 
                                     1, 
                                     0, 
                                     True))
        for _ in range(12, 13 + 1):
            conv_blocks.append(ConvBlock(num_channels << 1, 
                                         num_channels << 1, 
                                         num_layers, 
                                         kernel_size, 
                                         1, 0, True))

        # C14 : 5 conv layers, middle_dim output channels, stride 2
        conv_blocks.append(ConvBlock(num_channels << 1, 
                                     num_channels << 1, 
                                     num_layers, 
                                     kernel_size, 
                                     2, 0, True))

        # C15-21 : 5 conv layers, middle_dim output channels, stride 1
        for i in range(15, 21 + 1):
            conv_blocks.append(ConvBlock(num_channels << 1, 
                                         num_channels << 1, 
                                         num_layers, 
                                         kernel_size, 
                                         1, 0, True))

        # C22 : 1 conv layer, final_dim output channels, stride 1, no residual
        conv_blocks.append(ConvBlock(num_channels << 1, 
                                     output_dim, 
                                     1, 
                                     kernel_size, 
                                     1, 0, False))

        return conv_blocks

## ContextNet

### Audio Encoder

In [None]:
class AudioEncoder(nn.Module):
    r"""
    Audio encoder goes through 23 convolution blocks to convert to higher feature values.
    Args:
        input_dim (int, optional): Dimension of input vector (default : 80)
        num_layers (int, optional): The number of convolution layers (default : 5)
        kernel_size (int, optional): Value of convolution kernel size (default : 5)
        num_channels (int, optional): The number of channels in the convolution filter (default: 256)
        output_dim (int, optional): Dimension of encoder output vector (default: 640)
    Inputs: inputs, input_lengths
        - **inputs**: Parsed audio of batch size number `FloatTensor` of size ``(batch, seq_length, dimension)``
        - **input_lengths**: Tensor representing the sequence length of the input ``(batch)``
    Returns: output, output_lengths
        - **output**: Tensor of encoder output `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        - **output_lengths**: Tensor representing the length of the encoder output ``(batch)``
    """
    def __init__(
            self,
            preprocessing_params,
            input_dim: int = 80,
            num_layers: int = 5,
            kernel_size: int = 5,
            num_channels: int = 256,
            output_dim: int = 640,
    ) -> None:
        super(AudioEncoder, self).__init__()
        self.preprocessing = AudioPreprocessing(preprocessing_params["sample_rate"], 
                                                preprocessing_params["n_fft"],
                                                preprocessing_params["win_length_ms"], 
                                                preprocessing_params["hop_length_ms"],
                                                preprocessing_params["n_mels"], 
                                                preprocessing_params["normalize"], 
                                                preprocessing_params["mean"], 
                                                preprocessing_params["std"])
        
        self.augment = SpecAugment(preprocessing_params["spec_augment"], 
                                   preprocessing_params["mF"], 
                                   preprocessing_params["F"], 
                                   preprocessing_params["mT"], 
                                   preprocessing_params["pS"])
        
        self.blocks = ConvBlock.make_conv_blocks(input_dim, 
                                                num_layers, 
                                                kernel_size, 
                                                num_channels, 
                                                output_dim)

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
            train=True
    ) -> Tuple[Tensor, Tensor]:
        r"""
        Forward propagate a `inputs` for audio encoder.
        Args:
            **inputs** (torch.FloatTensor): Parsed audio of batch size number `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            **input_lengths** (torch.LongTensor): Tensor representing the sequence length of the input
                `LongTensor` of size ``(batch)``
        Returns:
            **output** (torch.FloatTensor): Tensor of encoder output `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            **output_lengths** (torch.LongTensor): Tensor representing the length of the encoder output
                `LongTensor` of size ``(batch)``
        """
        inputs, input_lengths = self.preprocessing(inputs, input_lengths)
        if train:
            inputs = self.augment(inputs, input_lengths)
        output = inputs.transpose(1, 2)
        output_lengths = input_lengths

        for block in self.blocks:
            output, output_lengths = block(output, output_lengths)

        return output.transpose(1, 2), output_lengths

### Label Encoder

In [None]:
class LabelEncoder(nn.Module):
    r"""
    Label encoder goes through a one-layered lstm model to convert to higher feature values.
    Args:
        num_vocabs (int): The number of vocabulary
        output_dim (int, optional): Dimension of decoder output vector (default: 640)
        hidden_dim (int, optional): The number of features in the decoder hidden state (default : 2048)
        num_layers (int, optional): The number of rnn layers (default : 1)
        dropout (float, optional): Dropout probability of decoder (default: 0.3)
        rnn_type (str, optional): Type of RNN cell (default: lstm)
        sos_id (int, optional): Index of the start of sentence (default: 1)
    Inputs: inputs, input_lengths, hidden_states
        - **inputs**: Tensor representing the target `LongTensor` of size ``(batch, seq_length)``
        - **input_lengths**: Tensor representing the target length `LongTensor` of size ``(batch)``
        - **hidden_states**: A previous hidden state of decoder `FloatTensor` of size
            ``(batch, seq_length, dimension)``
    Returns: outputs, hidden_states
        - **outputs**: A output sequence of decoder `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        - **hidden_states**: A hidden state of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
    """
    supported_rnns = {
        'rnn': nn.RNN,
        'lstm': nn.LSTM,
        'gru': nn.GRU
    }

    def __init__(
            self,
            num_vocabs: int = 256,
            output_dim: int = 640,
            hidden_dim: int = 2048,
            num_layers: int = 1,
            dropout: float = 0.3,
            rnn_type: str = 'lstm',
            sos_id: int = 1,
    ) -> None:
        super(LabelEncoder, self).__init__()
        self.sos_id = sos_id
        self.embedding = nn.Embedding(num_vocabs, hidden_dim)
        self.rnn = self.supported_rnns[rnn_type](hidden_dim, hidden_dim, num_layers, True, True, dropout, False)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor = None,
            hidden_states: Tensor = None,
    ) -> Tuple[Tensor, Tensor]:
        r"""
        Forward propagate a `inputs` for label encoder.
        Args:
            **inputs** (torch.LongTensor): Tensor representing the target `LongTensor` of size
                ``(batch, seq_length)``
            **input_lengths** (torch.LongTensor): Tensor representing the target length `LongTensor` of size
                ``(batch)``
            **hidden_states** (torch.FloatTensor): A previous hidden state of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        Returns:
            **outputs** (torch.FloatTensor): A output sequence of decoder `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            **hidden_states** (torch.FloatTensor): A hidden state of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        """

        embedded = self.embedding(inputs)

        if input_lengths is not None:
            embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
            rnn_output, hidden = self.rnn(embedded, hidden_states)
            rnn_output, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True)

        else:
            rnn_output, hidden_states = self.rnn(embedded, hidden_states)

        output = self.fc(rnn_output)

        return output, hidden_states

### ContextNet

In [None]:
class ContextNet(nn.Module):
    """
    ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates
    global context information into convolution layers by adding squeeze-and-excitation modules.
    Also, ContextNet supports three size models: small, medium, and large.
    ContextNet uses the global parameter alpha to control the scaling of the model
    by changing the number of channels in the convolution filter.
    Args:
        num_vocabs (int): The number of vocabulary
        model_size (str, optional): Size of the model['small', 'medium', 'large'] (default : 'medium')
        input_dim (int, optional): Dimension of input vector (default : 80)
        encoder_num_layers (int, optional): The number of convolutional layers (default : 5)
        decoder_num_layers (int, optional): The number of rnn layers (default : 1)
        kernel_size (int, optional): Value of convolution kernel size (default : 5)
        num_channels (int, optional): The number of channels in the convolution filter (default: 256)
        hidden_dim (int, optional): The number of features in the decoder hidden state (default : 2048)
        encoder_output_dim (int, optional): Dimension of encoder output vector (default: 640)
        decoder_output_dim (int, optional): Dimension of decoder output vector (default: 640)
        dropout (float, optional): Dropout probability of decoder (default: 0.3)
        rnn_type (str, optional): Type of RNN cell (default: lstm)
        sos_id (int, optional): Index of the start of sentence (default: 1)
    Inputs: inputs, input_lengths, targets, target_lengths
        - **inputs** (torch.FloatTensor): Parsed audio of batch size number `FloatTensor` of size
            ``(batch, seq_length, dimension)``
        - **input_lengths** (torch.LongTensor): Tensor representing the sequence length of the input `LongTensor` of size
            ``(batch)``
        - **targets** (torch.LongTensor): Tensor representing the target `LongTensor` of size
            ``(batch, seq_length)``
        - **target_lengths** (torch.LongTensor): Tensor representing the target length `LongTensor` of size
            ``(batch)``
    Returns: output
        - **output** (torch.FloatTensor): Result of model predictions
    """
    supported_models = {
        'small': 0.5,
        'medium': 1,
        'large': 2,
    }

    def __init__(
            self,
            preprocessing_params,
            num_vocabs: int = 1000,
            model_size: str = 'small',
            input_dim: int = 80,
            encoder_num_layers: int = 5,
            decoder_num_layers: int = 1,
            kernel_size: int = 5,
            num_channels: int = 256,
            hidden_dim: int = 2048,
            encoder_output_dim: int = 640,
            decoder_output_dim: int = 640,
            dropout: float = 0.3,
            rnn_type: str = 'lstm',
            sos_id: int = 1
        ) -> None:
        super(ContextNet, self).__init__()
        assert model_size in ('small', 'medium', 'large'), f'{model_size} is not supported.'

        alpha = self.supported_models[model_size]

        num_channels = int(num_channels * alpha)
        encoder_output_dim = int(encoder_output_dim * alpha)

        self.encoder = AudioEncoder(
            input_dim=input_dim,
            num_layers=encoder_num_layers,
            kernel_size=kernel_size,
            num_channels=num_channels,
            output_dim=encoder_output_dim,
            preprocessing_params=preprocessing_params,
        )
        self.decoder = LabelEncoder(
            num_vocabs=num_vocabs,
            output_dim=decoder_output_dim,
            hidden_dim=hidden_dim,
            num_layers=decoder_num_layers,
            dropout=dropout,
            rnn_type=rnn_type,
            sos_id=sos_id,
        )
        self.joint = JointNet(num_vocabs, encoder_output_dim + decoder_output_dim)

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
            targets: Tensor,
            target_lengths: Tensor,
            train=True
    ) -> Tensor:
        r"""
        Forward propagate a `inputs` for label encoder.
        Args:
            **inputs** (torch.FloatTensor): Parsed audio of batch size number `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            **input_lengths** (torch.LongTensor): Tensor representing the sequence length of the input
                `LongTensor` of size ``(batch)``
            **targets** (torch.LongTensor): Tensor representing the target `LongTensor` of size
                ``(batch, seq_length)``
            **target_lengths** (torch.LongTensor): Tensor representing the target length `LongTensor` of size
                ``(batch)``
        Returns:
            **output** (torch.FloatTensor): Result of model predictions
        """
        encoder_output, encoder_output_lengths = self.encoder(inputs, input_lengths, train=train)

        self.decoder.rnn.flatten_parameters()
        targets = torch.nn.functional.pad(targets, pad=(1, 0, 0, 0), value=0)
        target_lengths = target_lengths + 1
        decoder_output, _ = self.decoder(targets, target_lengths)

        output = self.joint(encoder_output, decoder_output)

        return output, encoder_output_lengths

class JointNet(nn.Module):
    """
    Joint `encoder_output` and `decoder_output`.
    Args:
        num_vocabs (int): The number of vocabulary
        output_dim (int): Encoder output dimension plus Decoder output dimension
    Inputs: encoder_output, decoder_output
        - **encoder_output** (torch.FloatTensor): A output sequence of encoder `FloatTensor` of size
            ``(batch, seq_length, dimension)``
        - **decoder_output** (torch.FloatTensor): A output sequence of decoder `FloatTensor` of size
            ``(batch, seq_length, dimension)``
    Returns: output
        - **output** (torch.FloatTensor): Result of joint `encoder_output` and `decoder_output`
    """
    def __init__(
            self,
            num_vocabs: int,
            output_dim: int,
    ) -> None:
        super(JointNet, self).__init__()
        self.fc = nn.Linear(output_dim, num_vocabs)

    def forward(
            self,
            encoder_output: Tensor,
            decoder_output: Tensor,
    ) -> Tensor:
        assert encoder_output.dim() == decoder_output.dim()

        if encoder_output.dim() == 3 and decoder_output.dim() == 3:  # Train
            seq_lengths = encoder_output.size(1)
            target_lengths = decoder_output.size(1)

            encoder_output = encoder_output.unsqueeze(2)
            decoder_output = decoder_output.unsqueeze(1)

            encoder_output = encoder_output.repeat(1, 1, target_lengths, 1)
            decoder_output = decoder_output.repeat(1, seq_lengths, 1, 1)

        output = torch.cat((encoder_output, decoder_output), dim=-1)
        output = self.fc(output).log_softmax(dim=-1)

        return output

# Hyperparameters and Model initialization

In [None]:
class HParams():
    """       
    HParams: Hyperparameters needed to initialize the Conformer model     
    Args:
        num_vocabs (int): The number of vocabulary
        model_size (str, optional): Size of the model['small', 'medium', 'large'] (default : 'medium')
        input_dim (int, optional): Dimension of input vector (default : 80)
        encoder_num_layers (int, optional): The number of convolutional layers (default : 5)
        decoder_num_layers (int, optional): The number of rnn layers (default : 1)
        kernel_size (int, optional): Value of convolution kernel size (default : 5)
        num_channels (int, optional): The number of channels in the convolution filter (default: 256)
        hidden_dim (int, optional): The number of features in the decoder hidden state (default : 2048)
        encoder_output_dim (int, optional): Dimension of encoder output vector (default: 640)
        decoder_output_dim (int, optional): Dimension of decoder output vector (default: 640)
        dropout (float, optional): Dropout probability of decoder (default: 0.3)
        rnn_type (str, optional): Type of RNN cell (default: lstm)
        sos_id (int, optional): Index of the start of sentence (default: 1)

    """
    num_vocabs = 256
    model_size = "small" # ['small', 'medium', 'large']
    input_dim = 80
    encoder_num_layers =  5
    decoder_num_layers = 1
    kernel_size = 5
    num_channels = 256
    hidden_dim = 2048
    encoder_output_dim = 640
    decoder_output_dim = 640
    dropout = 0.3
    rnn_type = "lstm"
    sos_id = 1
    preprocessing_params = {
        "sample_rate": 16000,
        "win_length_ms": 25,
        "hop_length_ms": 10,
        "n_fft": 512,
        "n_mels": 80,
        "normalize": False,
        "mean": -5.6501,
        "std": 4.2280,

        "spec_augment": True,
        "mF": 2,
        "F": 27,
        "mT": 5,
        "pS": 0.05
    }

    training_params = {
    # "epochs": 250,
    # "batch_size": 16,
    # "accumulated_steps": 4,
    # "mixed_precision": True,

    "beta1": 0.9,
    "beta2": 0.98,
    "eps": 1e-9,
    "weight_decay": 1e-6,
    "lr": 0.0025,

    "schedule_dim": 144,
    "warmup_steps": 10000,
    "K": 2
    }

params = HParams()

In [None]:
model = ContextNet(params.preprocessing_params).cuda()

In [None]:
model.load_state_dict(torch.load(PATH + "Contextnet.pth"))

# Training

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), 
                             lr=params.training_params["lr"], 
                             betas=(params.training_params["beta1"], params.training_params["beta2"]), 
                             eps=params.training_params["eps"], 
                             weight_decay=params.training_params["weight_decay"])

scheduler = transformers.get_linear_schedule_with_warmup(optimizer=optimizer, 
                                                         num_warmup_steps=params.training_params["warmup_steps"],
                                                         num_training_steps=85, 
                                                         last_epoch = -1) 

# optimizer.load_state_dict(torch.load(PATH + "Contextnet_256_optimizer.pth"))
# scheduler.load_state_dict(torch.load(PATH + "Contextnet_256_scheduler.pth"))

In [None]:
scaler = torch.cuda.amp.GradScaler()
path = PATH + "Contextnet"

trainer = Trainer(model, optimizer, scheduler, scaler, path)
#trainer.train(train_dataloader, dev_dataloader)
trainer.evaluate(dev_dataloader)

# Test

In [None]:
def test(test_dataset, path):
    
    model.eval()

    speech_true = []
    speech_pred = []
    total_wer = 0.0
    total_loss = 0.0

    with open(path, "w") as f:
      writer = csv.writer(f)
      # Evaluation Loop
      for step, batch in enumerate(tqdm(test_dataset)):

          inputs, targets, input_len, target_len = batch

          # Sequence Prediction
          with torch.no_grad():

              outputs_pred = greedy_search_decoding(inputs, input_len)

          # Sequence Truth
          outputs_true = tokenizer.decode(targets.tolist())

          # Compute Batch wer and Update total wer
          batch_wer = jiwer.wer(outputs_true, outputs_pred, standardize=True)
          total_wer += batch_wer

          # Update String lists
          speech_true += outputs_true
          speech_pred += outputs_pred

          # Prediction Verbose
          print("Groundtruths :\n", outputs_true)
          print("Predictions :\n", outputs_pred)

          # Eval Loss
          with torch.no_grad():
              pred, pred_len = model.forward(inputs, input_len, targets, target_len)
              batch_loss = warp_rnnt.rnnt_loss(
                                  log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                  labels=targets.int(),
                                  frames_lengths=pred_len.int(),
                                  labels_lengths=target_len.int(),
                                  average_frames=False,
                                  reduction='mean',
                                  blank=0,
                                  gather=True)
              # batch_loss = self.loss(pred, targets, pred_len, target_len)
              total_loss += batch_loss

          # Step print
          print("\nmean batch wer {:.2f}% - batch wer: {:.2f}% - mean loss {:.4f} - batch loss: {:.4f}".format(100 * total_wer / (step + 1), 100 * batch_wer, total_loss / (step + 1), batch_loss))

          writer.writerow([outputs_true,outputs_pred,100 * batch_wer,batch_loss])

    # Compute wer
    if total_wer / test_dataset.__len__() > 1:
        wer = 1
    else:
        wer = jiwer.wer(speech_true, speech_pred, standardize=True)

    # Compute loss
    loss = total_loss / test_dataset.__len__()

    return wer, speech_true, speech_pred, loss

In [None]:
test_clean_dataset = LibriSpeechDataset("test-clean", tokenizer)
test_other_dataset = LibriSpeechDataset("test-other", tokenizer)

test_clean_dataloader = torch.utils.data.DataLoader(test_clean_dataset,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    collate_fn=collate_fn,
                                                    drop_last=True)

test_other_dataloader = torch.utils.data.DataLoader(test_other_dataset,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    collate_fn=collate_fn,
                                                    drop_last=True)

In [None]:
path_test_clean = PATH + "Contextnet_test_clean.csv"
wer_clean, _, _, loss_clean = test(test_clean_dataloader, path_test_clean)
print(wer_clean, loss_clean)

In [None]:
path_test_other = PATH + "Contextnet_test_other.csv"
wer_clean, _, _, loss_clean = test(test_other_dataloader, path_test_other)
print(wer_clean,loss_clean)

In [None]:
path2 = PATH + "Conformer"
path = PATH + "Contextnet"

fig = plt.figure()

with open(path2+".tsv", "r") as f:
    reader = csv.reader(f, delimiter="\t")

    p = []
    for i, row in enumerate(reader):
        p.append(float(row[1]))

    print(p)
    p1, = plt.plot(p, label="Conformer")

with open(path+".tsv", "r") as f:
    reader = csv.reader(f, delimiter="\t")
    p = []
    for row in reader:
        p.append(float(row[1]))

    print(p)
    p2, = plt.plot(p, label="ContextNet")

plt.xlabel("Epochs")
plt.ylabel("WER")
plt.legend(handles=[p1,p2], loc=1, fontsize='small', fancybox=True)
plt.show()
fig.savefig(PATH + "wer_plot.png")