# Conformer implemantation
Code inspiered by: https://github.com/sooftware/conformer and https://github.com/burchim/EfficientConformer

Author: Borghini Alessia

Check the notebook GPU

In [None]:
!nvidia-smi

In [None]:
# Mount Google Drive
from google.colab import drive 

drive.mount("/content/drive")

# Import prerequirements

Install libraries

In [None]:
!pip install -r drive/My\ Drive/Speech_Recognition/requirements.txt

Import libraries and set seed

In [None]:
import matplotlib.pyplot as plt

import torch 
from torch import Tensor
from torch.utils.data import DataLoader, Sampler
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

import torchaudio

import numpy as np
import librosa
from tqdm import tqdm
from typing import Tuple, Optional
import math

import jiwer
import sentencepiece

import warp_rnnt

import transformers

import os
import csv

PATH = "drive/MyDrive/Speech_Recognition/"

torch.manual_seed(42)

# Librispeech Dataset

In [None]:
os.mkdir("datasets")

In [None]:
# Download LibriSPeech train-clean-100 subset
!cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz

# Download LibriSPeech train-clean-360 subset
# !cd datasets && wget https://www.openslr.org/resources/12/train-clean-360.tar.gz && tar xzf train-clean-360.tar.gz

# Download LibriSPeech dev-clean subset
!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz

# Download LibriSPeech dev-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz

# Download LibriSPeech test-clean subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-clean.tar.gz && tar xzf test-clean.tar.gz

# Download LibriSPeech test-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-other.tar.gz && tar xzf test-other.tar.gz

In [None]:
import glob 

class LibriSpeechDataset():
    def __init__(self, split, tokenizer):

        print("Creating dataset..")
        self.data_names = glob.glob("datasets/LibriSpeech/" + split + "*/*/*/*.flac")
        self.vocab_type = "bpe"
        self.vocab_size = 256

        label_paths = []
        sentences = []

        for file_path in glob.glob("datasets/LibriSpeech/" + split + "/*/*/*.txt"):
            for line in open(file_path, "r").readlines():
                label_paths.append(file_path.replace(file_path.split("/")[-1], "") + line.split()[0] + "." + self.vocab_type + "_" + str(self.vocab_size))
                sentences.append(line[len(line.split()[0]) + 1:-1].lower())

        for (sentence, label_path) in tqdm(zip(sentences, label_paths)):
            # Tokenize and Save label
            label = torch.LongTensor(tokenizer.encode(sentence))
            torch.save(label, label_path)

            # Save Audio length
            audio_length = torchaudio.load(label_path.split(".")[0] + ".flac")[0].size(1)
            torch.save(audio_length, label_path.split(".")[0] + ".flac_len")

            # Save Label length
            label_length = label.size(0)
            torch.save(label_length, label_path + "_len")
            
        print("Done.")

    def __getitem__(self, i):
        return [torchaudio.load(self.data_names[i])[0], 
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size)),
                torch.load(self.data_names[i] + "_len"),
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size) + "_len")]
    
    def __len__(self):
        return len(self.data_names)

In [None]:
# corpus_path = "datasets/LibriSpeech/train-clean-100_corpus.txt"

# # Create Corpus File
# if not os.path.isfile(corpus_path):
#     print("Create Corpus File")
#     corpus_file = open(corpus_path, "w")
#     for file_path in tqdm(glob.glob("datasets/LibriSpeech/*/*/*/*.txt")):
#         for line in open(file_path, "r").readlines():
#             corpus_file.write(line[len(line.split()[0]) + 1:-1].lower() + "\n")

# # Train Tokenizer
# print("Training Tokenizer")
# sentencepiece.SentencePieceTrainer.train(input=corpus_path, 
#                                          model_prefix="LibriSpeech_bpe_1000", 
#                                          vocab_size=1000, 
#                                          character_coverage=1.0, 
#                                          model_type="bpe", 
#                                          bos_id=-1, 
#                                          eos_id=-1, 
#                                          unk_surface="")
# print("Training Done")

In [None]:
tokenizer_path = "drive/My Drive/Speech_Recognition/LibriSpeech_bpe_256.model"
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
train_dataset = LibriSpeechDataset("train-clean-100", tokenizer)
dev_dataset = LibriSpeechDataset("dev-clean", tokenizer)

# train_dataset.__getitem__(0)

# Dataloader

In [None]:
def collate_fn(batch):
    # Sorting sequences by lengths
    sorted_batch = sorted(batch, key=lambda x: x[0].shape[1], reverse=True)

    # Pad data sequences
    data = [item[0].squeeze() for item in sorted_batch]
    data_lengths = torch.tensor([len(d) for d in data],dtype=torch.long).cuda()
    data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0).cuda()

    # Pad labels
    target = [item[1] for item in sorted_batch]
    target_lengths = torch.tensor([t.size(0) for t in target],dtype=torch.long).cuda()
    target = torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=0).cuda()

    return data, target, data_lengths, target_lengths

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=4,
                                               shuffle=True,
                                               collate_fn=collate_fn,
                                               drop_last=True)

dev_dataloader = torch.utils.data.DataLoader(dev_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             collate_fn=collate_fn,
                                             drop_last=True)

In [None]:
os.mkdir("datasets")

Downloading LibriSpeech dataset

In [None]:
# Download LibriSPeech train-clean-100 subset
# !cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz

# Download LibriSPeech train-clean-360 subset
# !cd datasets && wget https://www.openslr.org/resources/12/train-clean-360.tar.gz && tar xzf train-clean-360.tar.gz

# Download LibriSPeech dev-clean subset
!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz

# Download LibriSPeech dev-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz

# Download LibriSPeech test-clean subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-clean.tar.gz && tar xzf test-clean.tar.gz

# Download LibriSPeech test-other subset
# !cd datasets && wget https://www.openslr.org/resources/12/test-other.tar.gz && tar xzf test-other.tar.gz

LibriSpeech dataset class

In [None]:
import glob 

class LibriSpeechDataset():
    def __init__(self, split, tokenizer):

        print("Creating dataset..")
        self.data_names = glob.glob("datasets/LibriSpeech/" + split + "*/*/*/*.flac")
        self.vocab_type = "bpe"
        self.vocab_size = 1000

        # if split == "train-clean-100":
        label_paths = []
        sentences = []

        for file_path in glob.glob("datasets/LibriSpeech/" + split + "/*/*/*.txt"):
            for line in open(file_path, "r").readlines():
                label_paths.append(file_path.replace(file_path.split("/")[-1], "") + line.split()[0] + "." + self.vocab_type + "_" + str(self.vocab_size))
                sentences.append(line[len(line.split()[0]) + 1:-1].lower())

        for (sentence, label_path) in tqdm(zip(sentences, label_paths)):
            # Tokenize and Save label
            label = torch.LongTensor(tokenizer.encode(sentence))
            torch.save(label, label_path)

            # Save Audio length
            audio_length = torchaudio.load(label_path.split(".")[0] + ".flac")[0].size(1)
            torch.save(audio_length, label_path.split(".")[0] + ".flac_len")

            # Save Label length
            label_length = label.size(0)
            torch.save(label_length, label_path + "_len")
                
        print("Done.")

    def __getitem__(self, i):
        return [torchaudio.load(self.data_names[i])[0], 
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size)),
                torch.load(self.data_names[i] + "_len"),
                torch.load(self.data_names[i].split(".flac")[0] + "." + self.vocab_type + "_" + str(self.vocab_size) + "_len")]
    
    def __len__(self):
        return len(self.data_names)

Training the tokenizer (uncomment to create a new tokenizer)

In [None]:
# corpus_path = "datasets/LibriSpeech/train-clean-100_corpus.txt"

# # Create Corpus File
# if not os.path.isfile(corpus_path):
#     print("Create Corpus File")
#     corpus_file = open(corpus_path, "w")
#     for file_path in glob.glob("datasets/LibriSpeech/*/*/*/*.txt"):
#         for line in open(file_path, "r").readlines():
#             corpus_file.write(line[len(line.split()[0]) + 1:-1].lower() + "\n")

# # Train Tokenizer
# print("Training Tokenizer")
# sentencepiece.SentencePieceTrainer.train(input=corpus_path, 
#                                          model_prefix="LibriSpeech_bpe_256", 
#                                          vocab_size=256, 
#                                          character_coverage=1.0, 
#                                          model_type="bpe", 
#                                          bos_id=-1, 
#                                          eos_id=-1, 
#                                          unk_surface="")
# print("Training Done")

Build dataset

In [None]:
tokenizer_path = PATH + "LibriSpeech_bpe_256.model"
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)

# train_dataset = LibriSpeechDataset("train-clean-100", tokenizer)
dev_dataset = LibriSpeechDataset("dev-clean", tokenizer)

# train_dataset.__getitem__(0)

Build dataloader

In [None]:
def collate_fn(batch):
    # Sorting sequences by lengths
    sorted_batch = sorted(batch, key=lambda x: x[0].shape[1], reverse=True)

    # Pad data sequences
    data = [item[0].squeeze() for item in sorted_batch]
    data_lengths = torch.tensor([len(d) for d in data],dtype=torch.long).cuda()
    data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0).cuda()

    # Pad labels
    target = [item[1] for item in sorted_batch]
    target_lengths = torch.tensor([t.size(0) for t in target],dtype=torch.long).cuda()
    target = torch.nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=0).cuda()

    return data, target, data_lengths, target_lengths

# train_dataloader = torch.utils.data.DataLoader(train_dataset,
#                                                batch_size=4,
#                                                shuffle=True,
#                                                collate_fn=collate_fn,
#                                                drop_last=True)

dev_dataloader = torch.utils.data.DataLoader(dev_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             collate_fn=collate_fn,
                                             drop_last=True)

# Trainer

In [None]:
class Trainer():
    def __init__(self,
                 model: nn.Module,
                 optimizer,
                 scheduler,
                 scaler,
                 path,
                 save_best_model = True,
                 logging = True,
                 accumulated_steps = 16):
      
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scaler = scaler
        self.accumulated_steps = accumulated_steps

        self.best_wer = 2
        self.path = path
        self.save_best_model = save_best_model
        self.logging = logging

    def train(self, 
              train_dataset,
              dev_dataset,
              epochs:int=20):
        
        print("Training...")

        train_loss = 0.0
        total_loss_train = []
        total_loss_dev = []

        self.optimizer.zero_grad()

        for epoch in range(epochs):
            print(" Epoch {:03d}".format(epoch + 1))

            epoch_loss = 0.0

            # train mode on
            self.model.train()

            for step, batch in enumerate(tqdm(train_dataset)):        
                inputs, targets, input_len, target_len = batch

                # Automatic Mixed Precision Casting (model prediction + loss computing)
                with torch.cuda.amp.autocast():
                    pred, pred_len = self.model.forward(inputs, input_len, targets, target_len, training=True)
                    loss_mini  = warp_rnnt.rnnt_loss(
                                    log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                    labels=targets.int(),
                                    frames_lengths=pred_len.int(),
                                    labels_lengths=target_len.int(),
                                    average_frames=False,
                                    reduction='mean',
                                    blank=0,
                                    gather=True)
                    # loss_mini = lossfn(pred, y.int(), pred_len.int(), y_len.int())
                    loss = loss_mini / self.accumulated_steps

                # Accumulate gradients
                scaler.scale(loss).backward()

                # Update Epoch Variables
                epoch_loss += loss_mini.detach()

                if step % self.accumulated_steps == 0:
                    # Update Parameters, Zero Gradients and Update Learning Rate
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
                    self.scheduler.step()

                    # Step Print
                    print("mean loss {:.4f} - batch loss: {:.4f} - learning rate: {:.6f}".format(epoch_loss / (step + 1), loss_mini, optimizer.param_groups[0]['lr']))

            avg_epoch_loss = epoch_loss / len(dev_dataset)
            print('\t[E: {:2d}] train loss = {:0.4f}'.format(epoch, avg_epoch_loss))

            wer, speech_true, speech_pred, val_loss = self.evaluate(dev_dataset)
            print('\t[E: {:2d}] valid loss = {:0.4f}, wer = {:0.4f}, loss = {:0.4f}'.format(epoch, val_loss, wer, avg_epoch_loss))

            if self.save_best_model:
                self.save_model(wer, avg_epoch_loss, optimizer.param_groups[0]['lr'], self.path)

        print("...Done!")
        return avg_epoch_loss

    def evaluate(self,
              dev_dataset):
        
        self.model.eval()

        speech_true = []
        speech_pred = []
        total_wer = 0.0
        total_loss = 0.0

        # Evaluation Loop
        for step, batch in enumerate(tqdm(dev_dataset)):

            inputs, targets, input_len, target_len = batch

            # Sequence Prediction
            with torch.no_grad():

                outputs_pred = greedy_search_decoding(inputs, input_len)

                # out = self.model.recognize(inputs, input_len)
                # outputs_pred = tokenizer.decode(out.tolist())

            # Sequence Truth
            outputs_true = tokenizer.decode(targets.tolist())

            # Compute Batch wer and Update total wer
            batch_wer = jiwer.wer(outputs_true, outputs_pred, standardize=True)
            total_wer += batch_wer

            # Update String lists
            speech_true += outputs_true
            speech_pred += outputs_pred

            # Prediction Verbose
            print("Groundtruths :\n", outputs_true)
            print("Predictions :\n", outputs_pred)

            # Eval Loss
            with torch.no_grad():
                pred, pred_len = self.model.forward(inputs, input_len, targets, target_len)
                batch_loss = warp_rnnt.rnnt_loss(
                                    log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                    labels=targets.int(),
                                    frames_lengths=pred_len.int(),
                                    labels_lengths=target_len.int(),
                                    average_frames=False,
                                    reduction='mean',
                                    blank=0,
                                    gather=True)
                # batch_loss = self.loss(pred, targets, pred_len, target_len)
                total_loss += batch_loss

            # Step print
            print("mean batch wer {:.2f}% - batch wer: {:.2f}% - mean loss {:.4f} - batch loss: {:.4f}".format(100 * total_wer / (step + 1), 100 * batch_wer, total_loss / (step + 1), batch_loss))

        # Compute wer
        if total_wer / dev_dataset.__len__() > 1:
            wer = 1
        else:
            wer = jiwer.wer(speech_true, speech_pred, standardize=True)

        # Compute loss
        loss = total_loss / dev_dataset.__len__()

        return wer, speech_true, speech_pred, loss

    def save_model(self, wer, loss, lr, path):
        if wer < self.best_wer:
            torch.save(self.model.state_dict(), f"{path}.pth")
            torch.save(self.scheduler.state_dict(), f"{path}_scheduler.pth")
            torch.save(self.optimizer.state_dict(), f"{path}_optimizer.pth")

        if self.logging:
            with open(f"{path}.tsv", "a") as log:
                log.write(f"{lr}\t{wer}\t{loss}\n")
                

Greedy search decoding strategy

In [None]:
def greedy_search_decoding(x, x_len):
    # Predictions String List
    preds = []

    # Forward Encoder (B, Taud) -> (B, T, Denc)
    f, f_len = model.encoder(x, x_len)

    # Batch loop
    for b in range(x.size(0)): # One sample at a time for now, not batch optimized
        # Init y and hidden state
        y = x.new_zeros(1, 1, dtype=torch.long)
        hidden = None

        enc_step = 0
        consec_dec_step = 0

        # Decoder loop
        while enc_step < f_len[b]:

            # Forward Decoder (1, 1) -> (1, 1, Ddec)
            g, hidden = model.decoder(y[:, -1:], hidden_states=hidden)
            
            # Joint Network loop
            while enc_step < f_len[b]:

                # Forward Joint Network (1, 1, Denc) and (1, 1, Ddec) -> (1, V)
                logits = model.joint(f[b:b+1, enc_step], g[:, 0])

                # Token Prediction
                pred = logits.softmax(dim=-1).log().argmax(dim=-1) # (1)

                # Null token or max_consec_dec_step
                if pred == 0 or consec_dec_step == 5:
                    consec_dec_step = 0
                    enc_step += 1
                # Token
                else:
                    consec_dec_step += 1
                    y = torch.cat([y, pred.unsqueeze(0)], axis=-1)
                    break

        # Decode Label Sequence
        pred = tokenizer.decode(y[:, 1:].tolist())
        preds += pred

    return preds

# Model

## Utils

### Modules

In [None]:
class Linear(nn.Linear):

    def __init__(self, in_features, out_features, bias = True):
        super(Linear, self).__init__(
            in_features=in_features, 
            out_features=out_features, 
            bias=bias)

        # Variational Noise
        self.noise = None
        self.vn_std = None

    def init_vn(self, vn_std):

        # Variational Noise
        self.vn_std = vn_std

    def sample_synaptic_noise(self, distributed):

        # Sample Noise
        self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)

        # Broadcast Noise
        if distributed:
            torch.distributed.broadcast(self.noise, 0)

    def forward(self, input):

        # Weight
        weight = self.weight

        # Add Noise
        if self.noise is not None and self.training:
            weight = weight + self.vn_std * self.noise
            
        # Apply Weight
        return F.linear(input, weight, self.bias)


class View(nn.Module):
    """ Wrapper class of torch.view() for Sequential module. """
    def __init__(self, shape: tuple, contiguous: bool = False):
        super(View, self).__init__()
        self.shape = shape
        self.contiguous = contiguous

    def forward(self, x: Tensor) -> Tensor:
        if self.contiguous:
            x = x.contiguous()

        return x.view(*self.shape)

class Transpose(nn.Module):

    def __init__(self, dim0, dim1):
        super(Transpose, self).__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        return x.transpose(self.dim0, self.dim1)

In [None]:
class Conv2dSubampling(nn.Module):
    """
    Convolutional 2D subsampling (to 1/4 length)
    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
    Inputs: inputs
        - **inputs** (batch, time, dim): Tensor containing sequence of inputs
    Returns: outputs, output_lengths
        - **outputs** (batch, time, dim): Tensor produced by the convolution
        - **output_lengths** (batch): list of sequence output lengths
    """
    def __init__(self, in_channels, out_channels, num_conv_layers, kernel_size=3) -> None:
        super(Conv2dSubampling, self).__init__()

        self.sequential = nn.ModuleList([nn.Sequential(
            nn.Conv2d(in_channels if id == 0 else out_channels, 
                      out_channels, 
                      kernel_size, 
                      stride=2, 
                      padding=(kernel_size - 1) // 2), 
            nn.BatchNorm2d(out_channels),
            Swish()
        ) for id in range(num_conv_layers)])


    def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
        outputs = inputs.unsqueeze(1)

        for layer in self.sequential:
            outputs = layer(outputs)

        output_lengths = input_lengths
        batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()

        outputs = outputs.permute(0, 2, 1, 3)
        outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)

        output_lengths = torch.div(input_lengths - 1, 2, rounding_mode='floor') + 1
        output_lengths = torch.div(output_lengths - 1, 2, rounding_mode='floor') + 1

        return outputs, output_lengths


class Conv1d(nn.Conv1d):

    def __init__(
        self, 
        in_channels, 
        out_channels, 
        kernel_size, 
        stride = 1, 
        padding = "same", 
        dilation = 1, 
        groups = 1, 
        bias = True
    ):
        super(Conv1d, self).__init__(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=0, 
            dilation=dilation, 
            groups=groups, 
            bias=bias, 
            padding_mode="zeros")

        # Assert
        assert padding in ["valid", "same", "causal"]

        # Padding
        if padding == "valid":
            self.pre_padding = None
        elif padding == "same":
            self.pre_padding = nn.ConstantPad1d(padding=((kernel_size - 1) // 2, (kernel_size - 1) // 2), value=0)
        elif padding == "causal":
            self.pre_padding = nn.ConstantPad1d(padding=(kernel_size - 1, 0), value=0)

    def forward(self, input):

        # Weight
        weight = self.weight

        # Padding
        if self.pre_padding is not None:
            input = self.pre_padding(input)

        # Apply Weight
        return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


### Feed Forward

In [None]:
class FeedForwardModule(nn.Module):

    """Transformer Feed Forward Module
    Args:
        dim_model: model feature dimension
        dim_ffn: expanded feature dimension
        Pdrop: dropout probability
        act: inner activation function
        inner_dropout: whether to apply dropout after the inner activation function
    Input: (batch size, length, dim_model)
    Output: (batch size, length, dim_model)
    
    """

    def __init__(self, dim_model, dim_ffn, Pdrop, act, inner_dropout):
        super(FeedForwardModule, self).__init__()

        # Assert
        assert act in ["relu", "swish"]

        # Layers
        self.layers = nn.Sequential(
            nn.LayerNorm(dim_model, eps=1e-6),
            Linear(dim_model, dim_ffn),
            Swish() if act=="swish" else nn.ReLU(),
            nn.Dropout(p=Pdrop) if inner_dropout else nn.Identity(),
            Linear(dim_ffn, dim_model),
            nn.Dropout(p=Pdrop)
        )

    def forward(self, x):
        return self.layers(x)

### Activation Function

In [None]:
class Swish(nn.Module):
    """
    Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
    to a variety of challenging domains such as Image classification and Machine translation.
    """
    def __init__(self):
        super(Swish, self).__init__()
    
    def forward(self, inputs: Tensor) -> Tensor:
        return inputs * inputs.sigmoid()


class GLU(nn.Module):
    """
    The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
    in the paper “Language Modeling with Gated Convolutional Networks”
    """
    def __init__(self, dim: int) -> None:
        super(GLU, self).__init__()
        self.dim = dim

    def forward(self, inputs: Tensor) -> Tensor:
        outputs, gate = inputs.chunk(2, dim=self.dim)
        return outputs * gate.sigmoid()

### Convolution Module

In [None]:
class ConvolutionModule(nn.Module):

    """Conformer Convolution Module
    Args:
        dim_model: input feature dimension
        dim_expand: output feature dimension
        kernel_size: 1D depthwise convolution kernel size
        Pdrop: residual dropout probability
        stride: 1D depthwise convolution stride
        padding: "valid", "same" or "causal"
    Input: (batch size, input length, dim_model)
    Output: (batch size, output length, dim_expand)
    
    """

    def __init__(self, dim_model, dim_expand, kernel_size, Pdrop, stride, padding):
        super(ConvolutionModule, self).__init__()

        # Layers
        self.layers = nn.Sequential(
            nn.LayerNorm(dim_model, eps=1e-6),
            Transpose(1, 2),
            Conv1d(dim_model, 2 * dim_expand, kernel_size=1),
            GLU(dim=1),
            Conv1d(dim_expand, dim_expand, kernel_size, stride=stride, padding=padding, groups=dim_expand),
            nn.BatchNorm1d(dim_expand),
            Swish(),
            Conv1d(dim_expand, dim_expand, kernel_size=1),
            Transpose(1, 2),
            nn.Dropout(p=Pdrop)
        )

    def forward(self, x):
        return self.layers(x)

### Attention 

In [None]:
class RelativeSinusoidalPositionalEncoding(nn.Module):
    
    """
        Relative Sinusoidal Positional Encoding
        Positional encoding for left context (sin) and right context (cos)
        Total context = 2 * max_len - 1
    """

    def __init__(self, max_len, dim_model):
        super(RelativeSinusoidalPositionalEncoding, self).__init__()

        # PE
        pos_encoding = torch.zeros(2 * max_len - 1, dim_model)

        # Positions (max_len - 1, ..., max_len - 1)
        pos_left = torch.arange(start=max_len-1, end=0, step=-1, dtype=torch.float)
        pos_right = torch.arange(start=0, end=-max_len, step=-1, dtype=torch.float)
        pos = torch.cat([pos_left, pos_right], dim=0).unsqueeze(1)

        # Angles
        angles = pos / 10000**(2 * torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0) / dim_model)

        # Rel Sinusoidal PE
        pos_encoding[:, 0::2] = angles.sin()
        pos_encoding[:, 1::2] = angles.cos()

        pos_encoding = pos_encoding.unsqueeze(0)

        self.register_buffer('pos_encoding', pos_encoding, persistent=False)
        self.max_len = max_len

    def forward(self, batch_size=1, seq_len=None, hidden_len=0):

        # (B, Th + 2*T-1, D)
        if seq_len is not None:
            R = self.pos_encoding[:, self.max_len - seq_len - hidden_len : self.max_len - 1  + seq_len]
        
        # (B, 2*Tmax-1, D)
        else:
            R = self.pos_encoding

        return R.repeat(batch_size, 1, 1)

In [None]:
class MultiHeadAttention(nn.Module):

    """Mutli-Head Attention Layer
    Args:
        dim_model: model feature dimension
        num_heads: number of attention heads
    References: 
        Attention Is All You Need, Vaswani et al.
        https://arxiv.org/abs/1706.03762
    """

    def __init__(self, dim_model, num_heads):
        super(MultiHeadAttention, self).__init__()

        # Attention Params
        self.num_heads = num_heads # H
        self.dim_model = dim_model # D
        self.dim_head = dim_model // num_heads # d

        # Linear Layers
        self.query_layer = Linear(self.dim_model, self.dim_model)
        self.key_layer = Linear(self.dim_model, self.dim_model)
        self.value_layer = Linear(self.dim_model, self.dim_model)
        self.output_layer = Linear(self.dim_model, self.dim_model)

    def forward(self, Q, K, V, mask=None):

        """Scaled Dot-Product Multi-Head Attention
        Args:
            Q: Query of shape (B, T, D)
            K: Key of shape (B, T, D)
            V: Value of shape (B, T, D)
            mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
        
        Return:
            O: Attention output of shape (B, T, D)
            att_w: Attention weights of shape (B, H, T, T)
        """

        # Batch size B
        batch_size = Q.size(0)

        # Linear Layers
        Q = self.query_layer(Q)
        K = self.key_layer(K)
        V = self.value_layer(V)

        # Reshape and Transpose (B, T, D) -> (B, H, T, d)
        Q = Q.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)

        # Att scores (B, H, T, T)
        att_scores = Q.matmul(K.transpose(2, 3)) / K.shape[-1]**0.5

        # Apply mask
        if mask is not None:
            att_scores += (mask * -1e9)

        # Att weights (B, H, T, T)
        att_w = att_scores.softmax(dim=-1)

        # Att output (B, H, T, d)
        O = att_w.matmul(V)

        # Transpose and Reshape (B, H, T, d) -> (B, T, D)
        O = O.transpose(1, 2).reshape(batch_size, -1,  self.dim_model)

        # Output linear layer
        O = self.output_layer(O)

        return O, att_w.detach()

    def pad(self, Q, K, V, mask, chunk_size):

        # Compute Overflows
        overflow_Q = Q.size(1) % chunk_size
        overflow_KV = K.size(1) % chunk_size
        
        padding_Q = chunk_size - overflow_Q if overflow_Q else 0
        padding_KV = chunk_size - overflow_KV if overflow_KV else 0

        batch_size, seq_len_KV, _ = K.size()

        # Input Padding (B, T, D) -> (B, T + P, D)
        Q = F.pad(Q, (0, 0, 0, padding_Q), value=0)
        K = F.pad(K, (0, 0, 0, padding_KV), value=0)
        V = F.pad(V, (0, 0, 0, padding_KV), value=0)

        # Update Padding Mask
        if mask is not None:

            # (B, 1, 1, T) -> (B, 1, 1, T + P) 
            if mask.size(2) == 1:
                mask = F.pad(mask, pad=(0, padding_KV), value=1)
            # (B, 1, T, T) -> (B, 1, T + P, T + P)
            else:
                mask = F.pad(mask, pad=(0, padding_Q, 0, padding_KV), value=1)

        elif padding_KV:

            # None -> (B, 1, 1, T + P) 
            mask = F.pad(Q.new_zeros(batch_size, 1, 1, seq_len_KV), pad=(0, padding_KV), value=1)

        return Q, K, V, mask, padding_Q

In [None]:
class RelPosMultiHeadSelfAttention(MultiHeadAttention):

    """Multi-Head Self-Attention Layer with Relative Sinusoidal Positional Encodings
    Args:
        dim_model: model feature dimension
        num_heads: number of attention heads
        max_pos_encoding: maximum relative distance between elements
    References: 
        Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Dai et al.
        https://arxiv.org/abs/1901.02860
    """

    def __init__(self, dim_model, num_heads, max_pos_encoding):
        super(RelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads)

        # Position Embedding Layer
        self.pos_layer = nn.Linear(self.dim_model, self.dim_model)

        # Global content and positional bias
        self.u = nn.Parameter(torch.Tensor(self.dim_model)) # Content bias
        self.v = nn.Parameter(torch.Tensor(self.dim_model)) # Pos bias
        torch.nn.init.xavier_uniform_(self.u.reshape(self.num_heads, self.dim_head)) # glorot uniform
        torch.nn.init.xavier_uniform_(self.v.reshape(self.num_heads, self.dim_head)) # glorot uniform

        # Relative Sinusoidal Positional Encodings
        self.rel_pos_enc = RelativeSinusoidalPositionalEncoding(max_pos_encoding, self.dim_model)

    def rel_to_abs(self, att_scores):

        """Relative to absolute position indexing
        Args:
            att_scores: absolute-by-relative indexed attention scores of shape 
            (B, H, T, Th + 2*T-1) for full context
        Return:
            att_scores: absolute-by-absolute indexed attention scores of shape (B, H, T, Th + T)
        References: 
            full context:
            Attention Augmented Convolutional Networks, Bello et al.
            https://arxiv.org/abs/1904.09925
        """

        # Att Scores (B, H, T, Th + 2*T-1)
        batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()

        # Column Padding (B, H, T, Th + 2*T)
        att_scores = F.pad(att_scores, pad=(0, 1), value=0)

        # Flatten (B, H, TTh + 2*TT)
        att_scores = att_scores.reshape(batch_size, num_heads, -1)

        # End Padding (B, H, TTh + 2*TT + Th + T - 1)
        att_scores = F.pad(att_scores, pad=(0, seq_length2 - seq_length1), value=0)

        # Reshape (B, H, T + 1, Th + 2*T-1)
        att_scores = att_scores.reshape(batch_size, num_heads, 1 + seq_length1, seq_length2)

        # Slice (B, H, T, Th + T)
        att_scores = att_scores[:, :, :seq_length1, seq_length1-1:]

        return att_scores

    def forward(self, Q, K, V, mask=None, hidden=None):

        """Scaled Dot-Product Self-Attention with relative sinusoidal position encodings
        Args:
            Q: Query of shape (B, T, D)
            K: Key of shape (B, T, D)
            V: Value of shape (B, T, D)
            mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
            hidden: Optional Key and Value hidden states for decoding
        
        Return:
            O: Attention output of shape (B, T, D)
            att_w: Attention weights of shape (B, H, T, Th + T)
            hidden: Key and value hidden states
        """

        # Batch size B
        batch_size = Q.size(0)

        # Linear Layers
        Q = self.query_layer(Q)
        K = self.key_layer(K)
        V = self.value_layer(V)

        # Hidden State Provided
        if hidden:
            K = torch.cat([hidden["K"], K], dim=1)
            V = torch.cat([hidden["V"], V], dim=1)

        # Update Hidden State
        hidden = {"K": K.detach(), "V": V.detach()}

        # Add Bias
        Qu = Q + self.u
        Qv = Q + self.v

        # Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
        E = self.pos_layer(self.rel_pos_enc(batch_size, Q.size(1), K.size(1) - Q.size(1)))

        # Reshape and Transpose (B, T, D) -> (B, H, T, d)
        Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        # Reshape and Transpose (B, Th + T, D) -> (B, H, Th + T, d)
        K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        # Reshape and Transpose (B, Th + 2*T-1, D) -> (B, H, Th + 2*T-1, d) / (B, Th + T, D) -> (B, H, Th + T, d)
        E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)

        # att_scores (B, H, T, Th + T)
        att_scores_K = Qu.matmul(K.transpose(2, 3))
        att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
        att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5

        # Apply mask
        if mask is not None:
            att_scores += (mask * -1e9)

        # Att weights (B, H, T, Th + T)
        att_w = att_scores.softmax(dim=-1)

        # Att output (B, H, T, d)
        O = att_w.matmul(V)

        # Transpose and Reshape (B, H, T, d) -> (B, T, D)
        O = O.transpose(1, 2).reshape(batch_size, -1,  self.dim_model)

        # Output linear layer
        O = self.output_layer(O)

        return O, att_w.detach(), hidden

In [None]:
class MultiHeadSelfAttentionModule(nn.Module):

    """Multi-Head Self-Attention Module
    Args:
        dim_model: model feature dimension
        num_heads: number of attention heads
        Pdrop: residual dropout probability
        max_pos_encoding: maximum position
        relative_pos_enc: whether to use relative postion embedding
        group_size: Attention group size
        kernel_size: Attention kernel size
        stride: Query stride
        linear_att: whether to use multi-head linear self-attention
    """

    def __init__(self, 
                 dim_model, 
                 num_heads, 
                 Pdrop,
                 max_pos_encoding, 
                 relative_pos_enc, 
                 group_size, 
                 kernel_size, 
                 stride, 
                 linear_att):
        super(MultiHeadSelfAttentionModule, self).__init__()

        # Pre Norm
        self.norm = nn.LayerNorm(dim_model, eps=1e-6)

        # if relative_pos_enc:
        self.mhsa = RelPosMultiHeadSelfAttention(dim_model, num_heads, max_pos_encoding)

        # Dropout
        self.dropout = nn.Dropout(Pdrop)

        # Module Params
        self.rel_pos_enc = relative_pos_enc
        self.linear_att = linear_att

    def forward(self, x, mask=None, hidden=None):

        # Pre Norm
        x = self.norm(x)

        # Multi-Head Self-Attention
        x, attention, hidden = self.mhsa(x, x, x, mask, hidden)

        # Dropout
        x = self.dropout(x)

        return x, attention, hidden

## Conformer

#### Conformer Block

In [None]:
class ConformerBlock(nn.Module):

    def __init__(
        self, 
        dim_model, 
        dim_expand, 
        ff_ratio, 
        num_heads, 
        kernel_size, 
        att_group_size, 
        att_kernel_size,
        linear_att,
        Pdrop, 
        relative_pos_enc, 
        max_pos_encoding, 
        conv_stride,
        att_stride,
    ):
        super(ConformerBlock, self).__init__()

        # Feed Forward Module 1
        self.feed_forward_module1 = FeedForwardModule(
            dim_model=dim_model, 
            dim_ffn=dim_model * ff_ratio,
            Pdrop=Pdrop, 
            act="swish",
            inner_dropout=True
        )

        # Multi-Head Self-Attention Module
        self.multi_head_self_attention_module = MultiHeadSelfAttentionModule(
            dim_model=dim_model, 
            num_heads=num_heads,  
            Pdrop=Pdrop, 
            max_pos_encoding=max_pos_encoding,
            relative_pos_enc=relative_pos_enc, 
            group_size=att_group_size,
            kernel_size=att_kernel_size,
            stride=att_stride,
            linear_att=linear_att
        )

        # Convolution Module
        self.convolution_module = ConvolutionModule(
            dim_model=dim_model,
            dim_expand=dim_expand,
            kernel_size=kernel_size, 
            Pdrop=Pdrop, 
            stride=conv_stride,
            padding="same"
        )

        # Feed Forward Module 2
        self.feed_forward_module2 = FeedForwardModule(
            dim_model=dim_expand, 
            dim_ffn=dim_expand * ff_ratio,
            Pdrop=Pdrop, 
            act="swish",
            inner_dropout=True
        )

        # Block Norm
        self.norm = nn.LayerNorm(dim_expand, eps=1e-6)

        # Attention Residual
        self.att_res = nn.Sequential(
            Transpose(1, 2),
            nn.MaxPool1d(kernel_size=1, stride=att_stride),
            Transpose(1, 2)
        ) if att_stride > 1 else nn.Identity()

        # Convolution Residual
        self.conv_res = nn.Sequential(
            Transpose(1, 2),
            Conv1d(dim_model, dim_expand, kernel_size=1, stride=conv_stride),
            Transpose(1, 2)
        ) if dim_model != dim_expand else nn.Sequential(
            Transpose(1, 2),
            nn.MaxPool1d(kernel_size=1, stride=conv_stride),
            Transpose(1, 2)
        ) if conv_stride > 1 else nn.Identity()

        # Bloc Stride
        self.stride = conv_stride * att_stride

    def forward(self, x, mask=None, hidden=None):

        # FFN Module 1
        x = x + 1/2 * self.feed_forward_module1(x)

        # MHSA Module
        x_att, attention, hidden = self.multi_head_self_attention_module(x, mask, hidden)
        x = self.att_res(x) + x_att

        # Conv Module
        x = self.conv_res(x) + self.convolution_module(x)

        # FFN Module 2
        x = x + 1/2 * self.feed_forward_module2(x)

        # Block Norm
        x = self.norm(x)

        return x, attention, hidden

#### Preprocessing

In [None]:
class AudioPreprocessing(nn.Module):

    """Audio Preprocessing
    Computes mel-scale log filter banks spectrogram
    Args:
        sample_rate: Audio sample rate
        n_fft: FFT frame size, creates n_fft // 2 + 1 frequency bins.
        win_length_ms: FFT window length in ms, must be <= n_fft
        hop_length_ms: length of hop between FFT windows in ms
        n_mels: number of mel filter banks
        normalize: whether to normalize mel spectrograms outputs
        mean: training mean
        std: training std
    Shape:
        Input: (batch_size, audio_len)
        Output: (batch_size, n_mels, audio_len // hop_length + 1)
    
    """

    def __init__(self, 
                 sample_rate, 
                 n_fft, 
                 win_length_ms, 
                 hop_length_ms, 
                 n_mels, 
                 normalize, 
                 mean, 
                 std):
        super(AudioPreprocessing, self).__init__()
        self.win_length = int(sample_rate * win_length_ms) // 1000
        self.hop_length = int(sample_rate * hop_length_ms) // 1000
        self.Spectrogram = torchaudio.transforms.Spectrogram(n_fft, self.win_length, self.hop_length)
        self.MelScale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min=0, f_max=8000, n_stft=n_fft // 2 + 1)
        self.normalize = normalize
        self.mean = mean
        self.std = std

    def forward(self, x, x_len):

        # Short Time Fourier Transform (B, T) -> (B, n_fft // 2 + 1, T // hop_length + 1)
        x = self.Spectrogram(x)

        # Mel Scale (B, n_fft // 2 + 1, T // hop_length + 1) -> (B, n_mels, T // hop_length + 1)
        x = self.MelScale(x)
        
        # Energy log, autocast disabled to prevent float16 overflow
        x = (x.float() + 1e-9).log().type(x.dtype)

        # Compute Sequence lengths 
        if x_len is not None:
            x_len = torch.div(x_len, self.hop_length, rounding_mode='floor') + 1

        # Normalize
        if self.normalize:
            x = (x - self.mean) / self.std

        x = x.transpose(1,2)
        
        return x, x_len

class SpecAugment(nn.Module):

    """Spectrogram Augmentation
    Args:
        spec_augment: whether to apply spec augment
        mF: number of frequency masks
        F: maximum frequency mask size
        mT: number of time masks
        pS: adaptive maximum time mask size in %
    References:
        SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition, Park et al.
        https://arxiv.org/abs/1904.08779
        SpecAugment on Large Scale Datasets, Park et al.
        https://arxiv.org/abs/1912.05533
    """

    def __init__(self, spec_augment, mF, F, mT, pS):
        super(SpecAugment, self).__init__()
        self.spec_augment = spec_augment
        self.mF = mF
        self.F = F
        self.mT = mT
        self.pS = pS

    def forward(self, x, x_len):

        # Spec Augment
        if self.spec_augment:
        
            # Frequency Masking
            for _ in range(self.mF):
                x = torchaudio.transforms.FrequencyMasking(freq_mask_param=self.F, iid_masks=False).forward(x)

            # Time Masking
            for b in range(x.size(0)):
                T = int(self.pS * x_len[b])
                for _ in range(self.mT):
                    x[b, :, :x_len[b]] = torchaudio.transforms.TimeMasking(time_mask_param=T).forward(x[b, :, :x_len[b]])

        return x

#### Encoder

In [None]:
class ConformerEncoder(nn.Module):
    """
    Conformer encoder first processes the input with a convolution subsampling layer and then
    with a number of conformer blocks.
    Args:
        input_dim (int, optional): Dimension of input vector
        encoder_dim (int, optional): Dimension of conformer encoder
        num_layers (int, optional): Number of conformer blocks
        num_attention_heads (int, optional): Number of attention heads
        feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
        conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
        feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
        attention_dropout_p (float, optional): Probability of attention module dropout
        conv_dropout_p (float, optional): Probability of conformer convolution module dropout
        conv_kernel_size (int or tuple, optional): Size of the convolving kernel
        half_step_residual (bool): Flag indication whether to use half step residual or not
    Inputs: inputs, input_lengths
        - **inputs** (batch, time, dim): Tensor containing input vector
        - **input_lengths** (batch): list of sequence input lengths
    Returns: outputs, output_lengths
        - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
        - **output_lengths** (batch): list of sequence output lengths
    """
    def __init__(
            self,
            preprocessing_params,
            input_dim: int,
            encoder_dim: int,
            num_layers: int = 17,
            num_attention_heads: int = 8,
            feed_forward_expansion_factor: int = 4,
            conv_expansion_factor: int = 2,
            input_dropout_p: float = 0.1,
            feed_forward_dropout_p: float = 0.1,
            attention_dropout_p: float = 0.1,
            conv_dropout_p: float = 0.1,
            conv_kernel_size: int = 31,
            half_step_residual: bool = True,
            decoder_dim:int = 320,
    ):
        super(ConformerEncoder, self).__init__()

        # Audio Preprocessing
        self.preprocessing = AudioPreprocessing(preprocessing_params["sample_rate"], preprocessing_params["n_fft"], preprocessing_params["win_length_ms"], preprocessing_params["hop_length_ms"], preprocessing_params["n_mels"], preprocessing_params["normalize"], preprocessing_params["mean"], preprocessing_params["std"])
        
        # Spec Augment
        self.augment = SpecAugment(preprocessing_params["spec_augment"], preprocessing_params["mF"], preprocessing_params["F"], preprocessing_params["mT"], preprocessing_params["pS"])

        self.conv_subsample = Conv2dSubampling(in_channels=1, 
                                               out_channels=encoder_dim,
                                               num_conv_layers=2)
        
        self.input_projection = nn.Sequential(
            Linear(encoder_dim * (preprocessing_params["n_mels"] // 2**2), encoder_dim),
            nn.Dropout(p=input_dropout_p),
        )

        self.layers = nn.ModuleList([ConformerBlock(
            dim_model=encoder_dim,
            dim_expand=encoder_dim,
            ff_ratio=feed_forward_expansion_factor,
            num_heads=num_attention_heads, 
            kernel_size=conv_kernel_size, 
            att_group_size=1,
            att_kernel_size=None,
            linear_att=False,
            Pdrop=0.1, 
            relative_pos_enc=True, 
            max_pos_encoding=10000,
            conv_stride=1,
            att_stride=1,
        ) for block_id in range(num_layers)])

        # self.linear = nn.Linear(encoder_dim, decoder_dim)

    def count_parameters(self) -> int:
        """ Count parameters of encoder """
        return sum([p.numel for p in self.parameters()])

    def update_dropout(self, dropout_p: float) -> None:
        """ Update dropout probability of encoder """
        for name, child in self.named_children():
            if isinstance(child, nn.Dropout):
                child.p = dropout_p

    def forward(self, inputs: Tensor, input_lengths: Tensor, training=False) -> Tuple[Tensor, Tensor]:
        """
        Forward propagate a `inputs` for  encoder training.
        Args:
            inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
                `FloatTensor` of size ``(batch, seq_length, dimension)``.
            input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
        Returns:
            (Tensor, Tensor)
            * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
        """

        # Audio Preprocessing
        inputs, input_lengths = self.preprocessing(inputs, input_lengths)

        # Spec Augment
        if training:
            inputs = self.augment(inputs, input_lengths)
        outputs, output_lengths = self.conv_subsample(inputs, input_lengths)
        outputs = self.input_projection(outputs)

        for layer in self.layers:
            outputs, _, _ = layer(outputs)
        
        # outputs = self.linear(outputs)
        return outputs, output_lengths

#### Decoder

In [None]:
class DecoderRNNT(nn.Module):
    """
    Decoder of RNN-Transducer
    Args:
        num_classes (int): number of classification
        hidden_state_dim (int, optional): hidden state dimension of decoder (default: 512)
        output_dim (int, optional): output dimension of encoder and decoder (default: 512)
        num_layers (int, optional): number of decoder layers (default: 1)
        rnn_type (str, optional): type of rnn cell (default: lstm)
        sos_id (int, optional): start of sentence identification
        eos_id (int, optional): end of sentence identification
        dropout_p (float, optional): dropout probability of decoder
    Inputs: inputs, input_lengths
        inputs (torch.LongTensor): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
        input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
        hidden_states (torch.FloatTensor): A previous hidden state of decoder. `FloatTensor` of size
            ``(batch, seq_length, dimension)``
    Returns:
        (Tensor, Tensor):
        * decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size
            ``(batch, seq_length, dimension)``
        * hidden_states (torch.FloatTensor): A hidden state of decoder. `FloatTensor` of size
            ``(batch, seq_length, dimension)``
    """
    supported_rnns = {
        'lstm': nn.LSTM,
        'gru': nn.GRU,
        'rnn': nn.RNN,
    }

    def __init__(
            self,
            num_classes: int,
            hidden_state_dim: int,
            output_dim: int,
            num_layers: int,
            rnn_type: str = 'lstm',
            sos_id: int = 1,
            eos_id: int = 2,
            dropout_p: float = 0.2,
    ):
        super(DecoderRNNT, self).__init__()
        # self.hidden_state_dim = hidden_state_dim
        # self.sos_id = sos_id
        # self.eos_id = eos_id
        # self.embedding = nn.Embedding(num_classes, hidden_state_dim, padding_idx=0)
        # rnn_cell = self.supported_rnns[rnn_type.lower()]
        # self.rnn = rnn_cell(
        #     input_size=hidden_state_dim,
        #     hidden_size=hidden_state_dim,
        #     num_layers=num_layers,
        #     bias=True,
        #     batch_first=True,
        #     dropout=dropout_p,
        #     bidirectional=False,
        # )
        # self.out_proj = Linear(hidden_state_dim, output_dim)

        #self.embedding = nn.Embedding(num_classes, hidden_state_dim, padding_idx=0)
        self.embedding = nn.Embedding(num_classes, hidden_state_dim, padding_idx=0)

        rnn_cell = self.supported_rnns[rnn_type.lower()]
        self.rnn = rnn_cell(
            input_size=hidden_state_dim,
            hidden_size=hidden_state_dim,
            num_layers=num_layers,
            #bias=True,
            batch_first=True,
            #dropout=dropout_p,
            bidirectional=False,
        )
        self.out_proj = Linear(hidden_state_dim, output_dim)

    def count_parameters(self) -> int:
        """ Count parameters of encoder """
        return sum([p.numel for p in self.parameters()])

    def update_dropout(self, dropout_p: float) -> None:
        """ Update dropout probability of encoder """
        for name, child in self.named_children():
            if isinstance(child, nn.Dropout):
                child.p = dropout_p

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor = None,
            hidden_states: Tensor = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Forward propage a `inputs` (targets) for training.
        Args:
            inputs (torch.LongTensor): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
            input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
            hidden_states (torch.FloatTensor): A previous hidden state of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        Returns:
            (Tensor, Tensor):
            * decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            * hidden_states (torch.FloatTensor): A hidden state of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        """

        # Sequence Embedding (B, U + 1) -> (B, U + 1, D)
        embedded = self.embedding(inputs)

        if input_lengths is not None:
            embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
            outputs, hidden_states = self.rnn(embedded, hidden_states)
            outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
            outputs = self.out_proj(outputs)

        else:
            outputs, hidden_states = self.rnn(embedded, hidden_states)
            outputs = self.out_proj(outputs)

        return outputs, hidden_states

#### Conformer Model

In [None]:
class Conformer(nn.Module):
    """
    Conformer: Convolution-augmented Transformer for Speech Recognition
    The paper used a one-lstm Transducer decoder, currently still only implemented
    the conformer encoder shown in the paper.
    Args:
        hparams: HParams
    Inputs: inputs
        - **inputs** (batch, time, dim): Tensor containing input vector
        - **input_lengths** (batch): list of sequence input lengths
    Returns: outputs, output_lengths
        - **outputs** (batch, out_channels, time): Tensor produces by conformer.
        - **output_lengths** (batch): list of sequence output lengths
    """
    def __init__(
            self,
            hparams
    ) -> None:
        super(Conformer, self).__init__()
        self.encoder = ConformerEncoder(
            preprocessing_params=hparams.preprocessing_params,
            input_dim=hparams.input_dim,
            encoder_dim=hparams.encoder_dim,
            num_layers=hparams.num_encoder_layers,
            num_attention_heads=hparams.num_attention_heads,
            feed_forward_expansion_factor=hparams.feed_forward_expansion_factor,
            conv_expansion_factor=hparams.conv_expansion_factor,
            input_dropout_p=hparams.input_dropout_p,
            feed_forward_dropout_p=hparams.feed_forward_dropout_p,
            attention_dropout_p=hparams.attention_dropout_p,
            conv_dropout_p=hparams.conv_dropout_p,
            conv_kernel_size=hparams.conv_kernel_size,
            half_step_residual=hparams.half_step_residual,
        )
        self.decoder = DecoderRNNT(
            num_classes=hparams.num_classes,
            hidden_state_dim=hparams.decoder_dim,
            output_dim=hparams.encoder_dim,
            # output_dim=hparams.decoder_dim,
            num_layers=hparams.num_decoder_layers,
            rnn_type=hparams.decoder_rnn_type,
            dropout_p=hparams.decoder_dropout_p,
        )

        # se concat hparams.encoder_dim * 2, se sum hparams.encoder_dim
        self.act_fn = nn.Tanh()
        # self.fc = Linear(hparams.encoder_dim * 2, hparams.num_classes, bias=False)
        # self.fc = Linear(hparams.decoder_dim, hparams.num_classes, bias=False)
        self.fc = Linear(hparams.encoder_dim, hparams.num_classes, bias=False)

    def set_encoder(self, encoder):
        """ Setter for encoder """
        self.encoder = encoder

    def set_decoder(self, decoder):
        """ Setter for decoder """
        self.decoder = decoder

    def count_parameters(self) -> int:
        """ Count parameters of encoder """
        num_encoder_parameters = self.encoder.count_parameters()
        num_decoder_parameters = self.decoder.count_parameters()
        return num_encoder_parameters + num_decoder_parameters

    def update_dropout(self, dropout_p) -> None:
        """ Update dropout probability of model """
        self.encoder.update_dropout(dropout_p)
        self.decoder.update_dropout(dropout_p)

    def joint(self, encoder_outputs: Tensor, decoder_outputs: Tensor) -> Tensor:
        """
        Joint `encoder_outputs` and `decoder_outputs`.
        Args:
            encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
            decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size
                ``(batch, seq_length, dimension)``
        Returns:
            * outputs (torch.FloatTensor): outputs of joint `encoder_outputs` and `decoder_outputs`..
        """
        if encoder_outputs.dim() == 3 and decoder_outputs.dim() == 3:
            input_length = encoder_outputs.size(1)
            target_length = decoder_outputs.size(1)

            encoder_outputs = encoder_outputs.unsqueeze(2)
            decoder_outputs = decoder_outputs.unsqueeze(1) 

            encoder_outputs = encoder_outputs.repeat([1, 1, target_length, 1])
            decoder_outputs = decoder_outputs.repeat([1, input_length, 1, 1])

        # Loro fanno sum e poi usano tanh activation function  
        # outputs = torch.cat((encoder_outputs, decoder_outputs), dim=-1)
        outputs = encoder_outputs + decoder_outputs
        outputs = self.act_fn(outputs)
        outputs = self.fc(outputs)

        return outputs

    def forward(
            self,
            inputs: Tensor,
            input_lengths: Tensor,
            targets: Tensor,
            target_lengths: Tensor,
            training = False
    ) -> Tensor:
        """
        Forward propagate a `inputs` and `targets` pair for training.
        Args:
            inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
                `FloatTensor` of size ``(batch, seq_length, dimension)``.
            input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
            targets (torch.LongTensr): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
            target_lengths (torch.LongTensor): The length of target tensor. ``(batch)``
        Returns:
            * predictions (torch.FloatTensor): Result of model predictions.
        """
        encoder_outputs, encoder_output_len = self.encoder(inputs, input_lengths, training=training)
        
        targets = torch.nn.functional.pad(targets, pad=(1, 0, 0, 0), value=0)
        target_lengths = target_lengths + 1
        
        decoder_outputs, _ = self.decoder(targets, target_lengths)
        
        outputs = self.joint(encoder_outputs, decoder_outputs)
        # linear for encoder and decoder ALREADY MADE IN ENC AND DEC

        return outputs, encoder_output_len

    # @torch.no_grad()
    # def decode(self, encoder_output: Tensor, max_length: int) -> Tensor:
    #     """
    #     Decode `encoder_outputs`.
    #     Args:
    #         encoder_output (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
    #             ``(seq_length, dimension)``
    #         max_length (int): max decoding time step
    #     Returns:
    #         * predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
    #     """
    #     pred_tokens, hidden_state = list(), None
    #     decoder_input = encoder_output.new_tensor([[self.decoder.sos_id]], dtype=torch.long)

    #     for t in range(max_length):
    #         decoder_output, hidden_state = self.decoder(decoder_input, hidden_states=hidden_state)
    #         step_output = self.joint(encoder_output[t].view(-1), decoder_output.view(-1))
    #         step_output = step_output.softmax(dim=0)
    #         pred_token = step_output.argmax(dim=0)
    #         pred_token = int(pred_token.item())
    #         pred_tokens.append(pred_token)
    #         decoder_input = step_output.new_tensor([[pred_token]], dtype=torch.long)

    #     return torch.LongTensor(pred_tokens)

    # @torch.no_grad()
    # def recognize(self, inputs: Tensor, input_lengths: Tensor):
    #     """
    #     Recognize input speech. This method consists of the forward of the encoder and the decode() of the decoder.
    #     Args:
    #         inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
    #             `FloatTensor` of size ``(batch, seq_length, dimension)``.
    #         input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
    #     Returns:
    #         * predictions (torch.FloatTensor): Result of model predictions.
    #     """
    #     outputs = list()

    #     encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)
    #     max_length = encoder_outputs.size(1)

    #     for encoder_output in encoder_outputs:
    #         decoded_seq = self.decode(encoder_output, max_length)
    #         outputs.append(decoded_seq)

    #     outputs = torch.stack(outputs, dim=1).transpose(0, 1)

    #     return outputs

# Hyperparameters and Model initialization

In [None]:
class HParams():
    """       
    HParams: Hyperparameters needed to initialize the Conformer model     
    Args:
        num_classes (int): Number of classification classes
        input_dim (int, optional): Dimension of input vector
        encoder_dim (int, optional): Dimension of conformer encoder
        decoder_dim (int, optional): Dimension of conformer decoder
        num_encoder_layers (int, optional): Number of conformer blocks
        num_decoder_layers (int, optional): Number of decoder layers
        decoder_rnn_type (str, optional): type of RNN cell
        num_attention_heads (int, optional): Number of attention heads
        feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
        conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
        feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
        attention_dropout_p (float, optional): Probability of attention module dropout
        conv_dropout_p (float, optional): Probability of conformer convolution module dropout
        decoder_dropout_p (float, optional): Probability of conformer decoder dropout
        conv_kernel_size (int or tuple, optional): Size of the convolving kernel
        half_step_residual (bool): Flag indication whether to use half step residual or not
    """
    
    num_classes: int = 256
    input_dim: int = 80
    encoder_dim = 144
    decoder_dim = 320
    num_encoder_layers = 16
    num_decoder_layers = 1
    num_attention_heads = 4
    feed_forward_expansion_factor: int = 4
    conv_expansion_factor: int = 2
    input_dropout_p: float = 0.1
    feed_forward_dropout_p: float = 0.1
    attention_dropout_p: float = 0.1
    conv_dropout_p: float = 0.1
    decoder_dropout_p: float = 0.1
    conv_kernel_size: int = 31
    half_step_residual: bool = True
    decoder_rnn_type: str = "lstm"

    preprocessing_params = {
        "sample_rate": 16000,
        "win_length_ms": 25,
        "hop_length_ms": 10,
        "n_fft": 512,
        "n_mels": 80,
        "normalize": False,
        "mean": -5.6501,
        "std": 4.2280,

        "spec_augment": True,
        "mF": 2,
        "F": 27,
        "mT": 5,
        "pS": 0.05
    }

    training_params = {
        "epochs": 20,
        # "batch_size": 16,
        "accumulated_steps": 4,

        "beta1": 0.9,
        "beta2": 0.98,
        "eps": 1e-9,
        "weight_decay": 1e-6,
        "lr": 0.05/math.sqrt(encoder_dim),

        "schedule_dim": 144,
        "warmup_steps": 10000,
        "K": 2
    }

params = HParams()

In [None]:
model = Conformer(params).to("cuda")

In [None]:
model.load_state_dict(torch.load(PATH + "Conformer.pth"))

# Training

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), 
                             lr=params.training_params["lr"], 
                             betas=(params.training_params["beta1"], params.training_params["beta2"]), 
                             eps=params.training_params["eps"], 
                             weight_decay=params.training_params["weight_decay"])

scheduler = transformers.get_linear_schedule_with_warmup(optimizer=optimizer, 
                                                         num_warmup_steps=params.training_params["warmup_steps"],
                                                         num_training_steps=500, 
                                                         last_epoch = -1) 

#optimizer.load_state_dict(torch.load(PATH + "Contextnet_optimizer.pth"))
#scheduler.load_state_dict(torch.load(PATH + "Contextnet_scheduler.pth"))

In [None]:
scaler = torch.cuda.amp.GradScaler()

path = PATH + "Conformer"

trainer = Trainer(model, optimizer, scheduler, scaler, path)

#trainer.train(train_dataloader, dev_dataloader, epochs=15)
trainer.evaluate(dev_dataloader)

# Test

In [None]:
def test(test_dataset, path):
    
    model.eval()

    speech_true = []
    speech_pred = []
    total_wer = 0.0
    total_loss = 0.0

    with open(path, "w") as f:
      writer = csv.writer(f)
      # Evaluation Loop
      for step, batch in enumerate(tqdm(test_dataset)):

          inputs, targets, input_len, target_len = batch

          # Sequence Prediction
          with torch.no_grad():

              outputs_pred = greedy_search_decoding(inputs, input_len)

          # Sequence Truth
          outputs_true = tokenizer.decode(targets.tolist())

          # Compute Batch wer and Update total wer
          batch_wer = jiwer.wer(outputs_true, outputs_pred, standardize=True)
          total_wer += batch_wer

          # Update String lists
          speech_true += outputs_true
          speech_pred += outputs_pred

          # Prediction Verbose
          print("Groundtruths :\n", outputs_true)
          print("Predictions :\n", outputs_pred)

          # Eval Loss
          with torch.no_grad():
              pred, pred_len = model.forward(inputs, input_len, targets, target_len)
              batch_loss = warp_rnnt.rnnt_loss(
                                  log_probs=torch.nn.functional.log_softmax(pred, dim=-1),
                                  labels=targets.int(),
                                  frames_lengths=pred_len.int(),
                                  labels_lengths=target_len.int(),
                                  average_frames=False,
                                  reduction='mean',
                                  blank=0,
                                  gather=True)
              total_loss += batch_loss

          # Step print
          print("\nmean batch wer {:.2f}% - batch wer: {:.2f}% - mean loss {:.4f} - batch loss: {:.4f}".format(100 * total_wer / (step + 1), 100 * batch_wer, total_loss / (step + 1), batch_loss))

          writer.writerow([outputs_true,outputs_pred,100 * batch_wer,batch_loss])

    # Compute wer
    if total_wer / test_dataset.__len__() > 1:
        wer = 1
    else:
        wer = jiwer.wer(speech_true, speech_pred, standardize=True)

    # Compute loss
    loss = total_loss / test_dataset.__len__()

    return wer, speech_true, speech_pred, loss

In [None]:
test_clean_dataset = LibriSpeechDataset("test-clean", tokenizer)
test_other_dataset = LibriSpeechDataset("test-other", tokenizer)

test_clean_dataloader = torch.utils.data.DataLoader(test_clean_dataset,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    collate_fn=collate_fn,
                                                    drop_last=True)

test_other_dataloader = torch.utils.data.DataLoader(test_other_dataset,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    collate_fn=collate_fn,
                                                    drop_last=True)

In [None]:
path_test_clean = PATH + "Conformer_test_clean.csv"
wer_clean, _, _, loss_clean = test(test_clean_dataloader, path_test_clean)
print(wer_clean, loss_clean)

In [None]:
path_test_other = PATH + "Conformer_test_other.csv"
wer_clean, _, _, loss_clean = test(test_other_dataloader, path_test_other)
print(wer_clean,loss_clean)

In [None]:
import csv

with open(path+".tsv", "r") as f:
    reader = csv.reader(f, delimiter="\t")

    p = []
    for i, row in enumerate(reader):
        if i < 4:
            continue
        p.append(float(row[1]))

    print(p)
    plt.plot(p)

path2 = PATH + "Contextnet"

with open(path2+".tsv", "r") as f:
    reader = csv.reader(f, delimiter="\t")

    p = []
    for row in reader:
        p.append(float(row[1]))

    print(p)
    plt.plot(p)

plt.xlabel("Epochs")
plt.ylabel("WER")
plt.show()