# HW4P2: Attention-based Speech Recognition

<img src="https://cdn.shopify.com/s/files/1/0272/2080/3722/products/SmileBumperSticker_5400x.jpg" alt="A cute cat" width="600">


Welcome to the final assignment in 11785. In this HW, you will work on building a speech recognition system with <i>attention</i>. <br> <br>

<center>
<img src="https://popmn.org/wp-content/uploads/2020/03/pay-attention.jpg" alt="A cute cat" height="100">
</center>

HW Writeup: [TODO] <br>
Kaggle Competition Link: https://www.kaggle.com/competitions/attention-based-speech-recognition <br>
Kaggle Dataset Link: https://www.kaggle.com/competitions/attention-based-speech-recognition/data
<br>
LAS Paper: https://arxiv.org/pdf/1508.01211.pdf <br>
Attention is all you need:https://arxiv.org/pdf/1706.03762.pdf

# Read this section importantly!

1. By now, we believe that you are already a great deep learning practitioner, Congratulations. 🎉

2. You are allowed to use code from your previous homeworks for this homework. We will only provide, aspects that are necessary and new with this homework.

3. There are a lot of resources provided in this notebook, that will help you check if you are running your implementations correctly.

In [1]:
!nvidia-smi

Thu Sep  4 10:17:46 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.153                Driver Version: 573.26         CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5060 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   46C    P8             12W /   65W |     286MiB /   8151MiB |     22%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
# Install some required libraries
# Feel free to add more if you want
!pip install -q python-levenshtein torchsummaryX wandb kaggle pytorch-nlp

In [31]:
import torch
import torchaudio
from torch import nn, Tensor
# import torchsummary

import numpy as np
import os

import gc
import time

import pandas as pd
from tqdm.notebook import tqdm as blue_tqdm
import matplotlib.pyplot as plt
import seaborn
import json
from tqdm import tqdm

import math
import random
from typing import Optional, List

#imports for decoding and distance calculation
try:
    import wandb
    import torchsummaryX
    import Levenshtein
except:
    print("Didnt install some/all imports")

import warnings
warnings.filterwarnings('ignore')

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cuda


# Config

In [2]:
config = dict (
    train_dataset       = 'train-clean-360', # train-clean-100, train-clean-360, train-clean-460
    batch_size          = 96,
    epochs              = 100,
    learning_rate       = 2e-4,
    weight_decay        = 5e-3,
    cepstral_norm            = True, # Whether to use MFCC features or Spectrogram
)

# Kaggle Dataset Download

In [None]:
!pip install --upgrade --force-reinstall --no-deps kaggle==1.5.8
!mkdir /root/.kaggle

with open("/root/.kaggle/kaggle.json", "w+") as f:
    f.write('') # Put your kaggle username & key here

!chmod 600 /root/.kaggle/kaggle.json

In [2]:
# # to download the dataset
!kaggle competitions download -c attention-based-speech-recognition

# # to unzip data quickly and quietly
!unzip -q attention-based-speech-recognition.zip -d ./data

^C
unzip:  cannot find or open attention-based-speech-recognition.zip, attention-based-speech-recognition.zip.zip or attention-based-speech-recognition.zip.ZIP.


# Character-based LibriSpeech (HW4P2)

In terms of the dataset, the dataset structure for HW3P2 and HW4P2 dataset are very similar. Can you spot out the differences? What all will be required??

Hints:

- Check how big is the dataset (do you require memory efficient loading techniques??)
- How do we load mfccs? Do we need to normalise them?
- Does the data have \<SOS> and \<EOS> tokens in each sequences? Do we remove them or do we not remove them? (Read writeup)
- Would we want a collating function? Ask yourself: Why did we need a collate function last time?
- Observe the VOCAB, is the dataset same as HW3P2?
- Should you add augmentations, if yes which augmentations? When should you add augmentations? (Check bootcamp for answer)


In [3]:
VOCAB = [
    '<pad>', '<sos>', '<eos>',
    'A',   'B',    'C',    'D',
    'E',   'F',    'G',    'H',
    'I',   'J',    'K',    'L',
    'M',   'N',    'O',    'P',
    'Q',   'R',    'S',    'T',
    'U',   'V',    'W',    'X',
    'Y',   'Z',    "'",    ' ',
]

VOCAB_MAP = {VOCAB[i]:i for i in range(0, len(VOCAB))}

PAD_TOKEN = VOCAB_MAP["<pad>"]
SOS_TOKEN = VOCAB_MAP["<sos>"]
EOS_TOKEN = VOCAB_MAP["<eos>"]

print(f"Length of vocab : {len(VOCAB)}")
print(f"Vocab           : {VOCAB}")
print(f"PAD_TOKEN       : {PAD_TOKEN}")
print(f"SOS_TOKEN       : {SOS_TOKEN}")
print(f"EOS_TOKEN       : {EOS_TOKEN}")

Length of vocab : 31
Vocab           : ['<pad>', '<sos>', '<eos>', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', "'", ' ']
PAD_TOKEN       : 0
SOS_TOKEN       : 1
EOS_TOKEN       : 2


In [4]:
class SpeechDataset(torch.utils.data.Dataset): # Memory efficient
    # Loades the data in get item to save RAM

    def __init__(self, root, partition= "train-clean-360", transforms = None, cepstral=True):

        self.VOCAB      = VOCAB
        self.cepstral   = cepstral

        if partition == "train-clean-100" or partition == "train-clean-360":
            mfcc_dir       = os.path.join(root, partition, "mfcc")
            transcript_dir = os.path.join(root, partition, "transcripts")

            mfcc_files          = [os.path.join(mfcc_dir, f) for f in os.listdir(mfcc_dir)]
            transcript_files    = [os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir)]

        else:
            mfcc_dir       = os.path.join(root, "train-clean-100", "mfcc")
            transcript_dir = os.path.join(root, "train-clean-100", "transcripts")

            mfcc_files          = [os.path.join(mfcc_dir, f) for f in os.listdir(mfcc_dir)]
            transcript_files    = [os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir)]

            mfcc_dir       = os.path.join(root, "train-clean-360", "mfcc")
            transcript_dir = os.path.join(root, "train-clean-360", "transcripts")

            # add the list of mfcc and transcript paths from train-clean-360 to the list of paths  from train-clean-100
            mfcc_files.extend([os.path.join(mfcc_dir, f) for f in os.listdir(mfcc_dir)])
            transcript_files.extend([os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir)])

        assert len(mfcc_files) == len(transcript_files)
        length = len(mfcc_files)

        self.mfcc_files         = mfcc_files
        self.transcript_files   = transcript_files
        self.length             = len(transcript_files)
        print("Loaded file paths ME: ", partition)


    def __len__(self):
        return self.length

    def __getitem__(self, ind):

        # Load the mfcc and transcripts from the mfcc and transcript paths created earlier
        mfcc        = np.load(self.mfcc_files[ind])
        transcript  = np.load(self.transcript_files[ind])
        # Normalize the mfccs and map the transcripts to integers
        mfcc                = (mfcc - np.mean(mfcc)) / np.std(mfcc)
        transcript_mapped   = [VOCAB_MAP[c] for c in transcript]

        return torch.FloatTensor(mfcc), torch.LongTensor(transcript_mapped)

    def collate_fn(self,batch):

        batch_x, batch_y, lengths_x, lengths_y = [], [], [], []

        for x, y in batch:
            batch_x.append(x)
            batch_y.append(y)

            # Add the mfcc, transcripts and their lengths to the lists created above
            lengths_x.append(x.size(0))
            lengths_y.append(y.size(0))

        # pack the mfccs and transcripts using the pad_sequence function from pytorch
        batch_x_pad = nn.utils.rnn.pad_sequence(batch_x, batch_first=True)
        batch_y_pad = nn.utils.rnn.pad_sequence(batch_y, batch_first=True)

        return batch_x_pad, batch_y_pad, torch.tensor(lengths_x), torch.tensor(lengths_y)


In [5]:
class SpeechDatasetTest(torch.utils.data.Dataset):

    def __init__(self, root, partition, cepstral=False):
        from tqdm import tqdm
        self.mfcc_dir   = os.path.join(root, partition, "mfcc") # path to the test-clean mfccs
        self.mfcc_files = [os.path.join(self.mfcc_dir, f) for f in os.listdir(self.mfcc_dir)]

        self.mfccs = []
        for i, filename in enumerate(tqdm(self.mfcc_files)):
            mfcc = np.load(filename)
            if cepstral:
                # Normalize the mfccs
                mfcc = (mfcc - np.mean(mfcc)) / np.std(mfcc)
            # append the mfcc to the mfcc list created earlier
            self.mfccs.append(mfcc)


        print("Loaded: ", partition)

    def __len__(self):
        return len(self.mfccs)

    def __getitem__(self, ind):
        mfcc = self.mfccs[ind]
        return torch.FloatTensor(mfcc)

    def collate_fn(self,batch):

        batch_x, lengths_x = [], []
        for x in batch:
            # Append the mfccs and their lengths to the lists created above
            batch_x.append(x)
            lengths_x.append(x.size(0))
        # pack the mfccs using the pad_sequence function from pytorch
        batch_x_pad = nn.utils.rnn.pad_sequence(batch_x, batch_first=True)

        return batch_x_pad, torch.tensor(lengths_x)

In [6]:
DATA_DIR        = 'data'
PARTITION       = config['train_dataset']
CEPSTRAL        = config['cepstral_norm']

train_dataset   = SpeechDataset( # Or AudioDatasetME
    root        = DATA_DIR,
    partition   = PARTITION,
    cepstral    = CEPSTRAL
)
valid_dataset   = SpeechDataset(
    root        = DATA_DIR,
    partition   = 'dev-clean',
    cepstral    = CEPSTRAL
)
test_dataset    = SpeechDatasetTest(
    root        = DATA_DIR,
    partition   = 'test-clean',
    cepstral    = CEPSTRAL,
)

gc.collect()

Loaded file paths ME:  train-clean-360
Loaded file paths ME:  dev-clean


100%|██████████| 2620/2620 [00:01<00:00, 2298.68it/s]

Loaded:  test-clean





253

In [7]:
train_loader    = torch.utils.data.DataLoader(
    dataset     = train_dataset,
    batch_size  = config['batch_size'],
    shuffle     = True,
    num_workers = 0,
    pin_memory  = True,
    collate_fn  = train_dataset.collate_fn
)

valid_loader    = torch.utils.data.DataLoader(
    dataset     = valid_dataset,
    batch_size  = config['batch_size'],
    shuffle     = False,
    num_workers = 0,
    pin_memory  = True,
    collate_fn  = valid_dataset.collate_fn
)

test_loader     = torch.utils.data.DataLoader(
    dataset     = test_dataset,
    batch_size  = config['batch_size'],
    shuffle     = False,
    num_workers = 0,
    pin_memory  = True,
    collate_fn  = test_dataset.collate_fn
)

print("No. of train mfccs   : ", train_dataset.__len__())
print("Batch size           : ", config['batch_size'])
print("Train batches        : ", train_loader.__len__())
print("Valid batches        : ", valid_loader.__len__())
print("Test batches         : ", test_loader.__len__())

No. of train mfccs   :  104013
Batch size           :  96
Train batches        :  1084
Valid batches        :  1381
Test batches         :  28


In [8]:
print("\nChecking the shapes of the data...")
for batch in train_loader:
    x, y, x_len, y_len = batch
    print(x.shape, y.shape, x_len.shape, y_len.shape)
    print(y)
    break


Checking the shapes of the data...
torch.Size([96, 1635, 28]) torch.Size([96, 292]) torch.Size([96]) torch.Size([96])
tensor([[ 1, 16,  7,  ...,  0,  0,  0],
        [ 1, 21,  7,  ...,  0,  0,  0],
        [ 1, 21, 11,  ...,  0,  0,  0],
        ...,
        [ 1, 25,  3,  ...,  0,  0,  0],
        [ 1,  3, 16,  ...,  0,  0,  0],
        [ 1, 21, 22,  ...,  0,  0,  0]])


In [None]:
print("\nChecking the shapes of the data...")
for batch in train_loader:
    x, y, x_len, y_len = batch
    print(x.shape, y.shape, x_len.shape, y_len.shape)
    print(y)
    break


Checking the shapes of the data...
torch.Size([96, 1616, 28]) torch.Size([96, 264]) torch.Size([96]) torch.Size([96])
tensor([[ 1,  3, 16,  ...,  0,  0,  0],
        [ 1, 25, 10,  ...,  0,  0,  0],
        [ 1, 22, 10,  ...,  0,  0,  0],
        ...,
        [ 1, 22, 10,  ...,  0,  0,  0],
        [ 1, 10, 11,  ...,  0,  0,  0],
        [ 1,  3, 30,  ...,  0,  0,  0]])


In [60]:
def verify_dataset(dataset, partition= 'train-clean-100'):
    print("\nPartition loaded     : ", partition)
    if partition != 'test-clean':
        print("Max mfcc length          : ", np.max([data[0].shape[0] for data in dataset]))
        print("Avg mfcc length          : ", np.mean([data[0].shape[0] for data in dataset]))
        print("Max transcript length    : ", np.max([data[1].shape[0] for data in dataset]))
        print("Max transcript length    : ", np.mean([data[1].shape[0] for data in dataset]))
    else:
        print("Max mfcc length          : ", np.max([data.shape[0] for data in dataset]))
        print("Avg mfcc length          : ", np.mean([data.shape[0] for data in dataset]))

verify_dataset(train_dataset, partition= 'train-clean-100')
verify_dataset(valid_dataset, partition= 'dev-clean')
verify_dataset(test_dataset, partition= 'test-clean')
dataset_max_len  = max(
    np.max([data[0].shape[0] for data in train_dataset]),
    np.max([data[0].shape[0] for data in valid_dataset]),
    np.max([data.shape[0] for data in test_dataset])
)
print("\nMax Length: ", dataset_max_len)


Partition loaded     :  train-clean-100


KeyboardInterrupt: 

Check if you are loading the data correctly with the following:

- Train Dataset
```
Partition loaded:  train-clean-100
Max mfcc length:  2448
Average mfcc length:  1264.6258453344547
Max transcript:  400
Average transcript length:  186.65321139493324
```

- Dev Dataset
```
Partition loaded:  dev-clean
Max mfcc length:  3260
Average mfcc length:  713.3570107288198
Max transcript:  518
Average transcript length:  108.71698113207547
```

- Test Dataset
```
Partition loaded:  test-clean
Max mfcc length:  3491
Average mfcc length:  738.2206106870229
```

If your values is not matching, read hints, think what could have gone wrong. Then approach TAs.

# THE MODEL

### Listen, Attend and Spell
Listen, Attend and Spell (LAS) is a neural network model used for speech recognition and synthesis tasks.

- LAS is designed to handle long input sequences and is robust to noisy speech signals.
- LAS is known for its high accuracy and ability to improve over time with additional training data.
- It consists of an <b>listener, an attender and a speller</b>, which work together to convert an input speech signal into a corresponding output text.

#### The Dataflow:
<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/DataFlow.png" alt="data flow" height="100">
</center>

#### The Listener:
- converts the input speech signal into a sequence of hidden states.

#### The Attender:
- Decides how the sequence of Encoder hidden state is propogated to decoder.

#### The Speller:
- A language model, that incorporates the "context of attender"(output of attender) to predict sequence of words.






## Utils


In [9]:
class PermuteBlock(torch.nn.Module):
    def forward(self, x):
        return x.transpose(1, 2)

def plot_attention(attention):
    # Function for plotting attention
    # You need to get a diagonal plot
    plt.clf()
    seaborn.heatmap(attention, cmap='GnBu')
    plt.show()

def save_model(model, optimizer, scheduler, tf_scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         'tf_scheduler'             : tf_scheduler,
         metric[0]                  : metric[1],
         'epoch'                    : epoch},
         path
    )

def load_model(best_path, epoch_path, model, mode= 'best', metric= 'valid_acc', optimizer= None, scheduler= None, tf_scheduler= None):


    if mode == 'best':
        checkpoint  = torch.load(best_path)
        print("Loading best checkpoint: ", checkpoint[metric])
    else:
        checkpoint  = torch.load(epoch_path)
        print("Loading epoch checkpoint: ", checkpoint[metric])

    model.load_state_dict(checkpoint['model_state_dict'], strict= False)

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        #optimizer.param_groups[0]['lr'] = 1.5e-3
        optimizer.param_groups[0]['weight_decay'] = 1e-5
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    if tf_scheduler != None:
        tf_scheduler    = checkpoint['tf_scheduler']

    epoch   = checkpoint['epoch']
    metric  = torch.load(best_path)[metric]

    return [model, optimizer, scheduler, tf_scheduler, epoch, metric]

class TimeElapsed():
    def __init__(self):
        self.start  = -1

    def time_elapsed(self):
        if self.start == -1:
            self.start = time.time()
        else:
            end = time.time() - self.start
            hrs, rem    = divmod(end, 3600)
            min, sec    = divmod(rem, 60)
            min         = min + 60*hrs
            print("Time Elapsed: {:0>2}:{:02}".format(int(min),int(sec)))
            self.start  = -1

## Modules

# Transformer Encoder

In [10]:
import math

class PositionalEncoding(torch.nn.Module):

    # def __init__(self, projection_size, max_seq_len= 176):
    def __init__(self, projection_size, max_seq_len= 1760):
        super().__init__()
        # Read the Attention Is All You Need paper to learn how to code code the positional encoding
        pe = torch.zeros(max_seq_len, projection_size)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, projection_size, 2).float() * (-math.log(10000.0) / projection_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x shape [96, 108, 128]
        x = x + self.pe[:x.size(1), :]
        return x

class TransformerEncoder(torch.nn.Module):
    def __init__(self, projection_size, num_heads=8, dropout= 0.0):
        super().__init__()

        # create the key, query and value weights
        self.KW         = nn.Linear(projection_size, projection_size)
        self.VW         = nn.Linear(projection_size, projection_size)
        self.QW         = nn.Linear(projection_size, projection_size)

        self.permute    = PermuteBlock()

        # Compute multihead attention. You are free to use the version provided by pytorch
        self.attention  = nn.MultiheadAttention(embed_dim=projection_size, num_heads=num_heads, dropout=dropout, batch_first=True)

        self.bn1        = nn.LayerNorm(projection_size)

        self.bn2        = nn.LayerNorm(projection_size)

        # Feed forward neural network
        self.MLP        = nn.Sequential(
            nn.Linear(projection_size, projection_size * 4),
            nn.ReLU(),
            nn.Linear(projection_size * 4, projection_size)
        )

    def forward(self, x):
        # compute the key, query and value
        key     = self.KW(x)
        value   = self.VW(x)
        query   = self.QW(x)

        # compute the output of the attention module
        out1    = self.attention(query, key, value, need_weights=False)
        # Create a residual connection between the input and the output of the attention module
        out1    = out1[0] + x
        # Apply batch norm to out1
        out1    = self.bn1(out1)

        # Apply the output of the feed forward network
        out2    = self.MLP(out1)
        # Apply a residual connection between the input and output of the  FFN
        out2    = out2 + out1 # could be x
        # Apply batch norm to the output
        out2    = self.bn2(out2)

        return out2

model   = TransformerEncoder(
    projection_size  = 128
).to(DEVICE)


print(model)

x_sample    = torch.rand(32, 176, 128)
output      = model(x_sample.to(DEVICE))
# torchsummaryX.summary(model, x_sample.to(DEVICE))
del x_sample

TransformerEncoder(
  (KW): Linear(in_features=128, out_features=128, bias=True)
  (VW): Linear(in_features=128, out_features=128, bias=True)
  (QW): Linear(in_features=128, out_features=128, bias=True)
  (permute): PermuteBlock()
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (bn1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (bn2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (MLP): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
  )
)


In [11]:
class TransformerListener(torch.nn.Module):

    def __init__(self,
                 input_size,
                 base_lstm_layers        = 1,
                 pblstm_layers           = 1,
                 listener_hidden_size    = 256,
                 n_heads                 = 8,
                 tf_blocks               = 1):
        super().__init__()

        # create an lstm layer
        self.base_lstm      = nn.LSTM(input_size, listener_hidden_size//2, num_layers=base_lstm_layers, batch_first=True, bidirectional=True)

        # create a sequence of Conv1d layers
        # mind the paddings, mind the paddings, mind the paddings!!!
        self.embedding      = nn.Sequential(
            PermuteBlock(),
            nn.Conv1d(listener_hidden_size, listener_hidden_size, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(listener_hidden_size, listener_hidden_size, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            PermuteBlock()
        )

        # compute the postion encoding
        self.positional_encoding    = PositionalEncoding(listener_hidden_size)

        # create a sequence of transformer blocks
        self.transformer_encoder    = torch.nn.Sequential()
        for i in range(tf_blocks):
            self.transformer_encoder.append(TransformerEncoder(listener_hidden_size, num_heads=n_heads))
            

    def forward(self, x, x_len):
        # pack the inputs before passing them to the LSTm
        x_packed                = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
        # Pass the packed sequence through the lstm
        lstm_out, _             = self.base_lstm(x_packed)
        # Unpack the output of the lstm
        output, output_lengths  = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        # Pass the output through the embedding
        output                  = self.embedding(output)
        # calculate the new output length
        output_lengths          = (output_lengths + 1) // 2

        # calculate the position encoding
        output  = self.positional_encoding(output)
        # Pass the output of the positional encoding through the transformer encoder
        output  = self.transformer_encoder(output)


        return output, output_lengths

In [12]:
listener_hidden_size = 32
layers = nn.Sequential(
            nn.Conv1d(listener_hidden_size, listener_hidden_size, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(listener_hidden_size, listener_hidden_size, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
        )

layers = layers.to(DEVICE)
x = torch.randn(4, 32, 21).to(DEVICE)
layers(x).shape

torch.Size([4, 32, 11])

In [13]:
listener = TransformerListener(
    input_size = 28,
    base_lstm_layers=1,
    listener_hidden_size=256,
    n_heads = 8,
    tf_blocks = 2,
).to(DEVICE)

for batch in train_loader:
    x, y, x_len, y_len = batch
    output = listener(x.to(DEVICE), x_len)
    break


# Attention

### Different ways to compute Attention

1. Dot-product attention
    * raw_weights = bmm(key, query)
    * Optional: Scaled dot-product by normalizing with sqrt key dimension
    * Check "Attention is All You Need" Section 3.2.1
    * 1st way is what most TAs are comfortable with, but if you want to explore, check out other methods below


2. Cosine attention
    * raw_weights = cosine(query, key) # almost the same as dot-product xD

3. Bi-linear attention
    * W = Linear transformation (learnable parameter): d_k -> d_q
    * raw_weights = bmm(key @ W, query)

4. Multi-layer perceptron
    * Check "Neural Machine Translation and Sequence-to-sequence Models: A Tutorial" Section 8.4

5. Multi-Head Attention
    * Check "Attention is All You Need" Section 3.2.2
    * h = Number of heads
    * W_Q, W_K, W_V: Weight matrix for Q, K, V (h of them in total)
    * W_O: d_v -> d_v
    * Reshape K: (B, T, d_k) to (B, T, h, d_k // h) and transpose to (B, h, T, d_k // h)
    * Reshape V: (B, T, d_v) to (B, T, h, d_v // h) and transpose to (B, h, T, d_v // h)
    * Reshape Q: (B, d_q) to (B, h, d_q // h) `
    * raw_weights = Q @ K^T
    * masked_raw_weights = mask(raw_weights)
    * attention = softmax(masked_raw_weights)
    * multi_head = attention @ V
    * multi_head = multi_head reshaped to (B, d_v)
    * context = multi_head @ W_O

Pseudocode:

```python
class Attention:
    '''
    Attention is calculated using the key, value (from encoder embeddings) and query from decoder.

    After obtaining the raw weights, compute and return attention weights and context as follows.:

    attention_weights   = softmax(raw_weights)
    attention_context   = einsum("thinkwhatwouldbetheequationhere",attention, value) #take hint from raw_weights calculation

    At the end, you can pass context through a linear layer too.
    '''

    def init(listener_hidden_size,
              speller_hidden_size,
              projection_size):

        VW = Linear(listener_hidden_size,projection_size)
        KW = Linear(listener_hidden_size,projection_size)
        QW = Linear(speller_hidden_size,projection_size)

    def set_key_value(encoder_outputs):
        '''
        In this function we take the encoder embeddings and make key and values from it.
        key.shape   = (batch_size, timesteps, projection_size)
        value.shape = (batch_size, timesteps, projection_size)
        '''
        key = KW(encoder_outputs)
        value = VW(encoder_outputs)
      
    def compute_context(decoder_context):
        '''
        In this function from decoder context, we make the query, and then we
         multiply the queries with the keys to find the attention logits,
         finally we take a softmax to calculate attention energy which gets
         multiplied to the generted values and then gets summed.

        key.shape   = (batch_size, timesteps, projection_size)
        value.shape = (batch_size, timesteps, projection_size)
        query.shape = (batch_size, projection_size)

        You are also recomended to check out Abu's Lecture 19 to understand Attention better.
        '''
        query = QW(decoder_context) #(batch_size, projection_size)

        raw_weights = #using bmm or einsum. We need to perform batch matrix multiplication. It is important you do this step correctly.
        #What will be the shape of raw_weights?

        attention_weights = #What makes raw_weights -> attention_weights

        attention_context = #Multiply attention weights to values

        return attention_context, attention_weights
```

In [14]:
from einops import einsum, rearrange

class Attention(torch.nn.Module):
  def __init__(self, listener_hidden_size, speller_hidden_size, projection_size, num_heads=8):
    super().__init__()
    self.QW = nn.Linear(speller_hidden_size, projection_size)
    self.KW = nn.Linear(listener_hidden_size, projection_size)
    self.VW = nn.Linear(listener_hidden_size, projection_size)
    self.n_head = num_heads
    
  def set_key_value(self, encoder_outputs):
    self.key = self.KW(encoder_outputs)
    self.value = self.VW(encoder_outputs)

  def compute_context(self, decoder_context):
    query = self.QW(decoder_context)
    query = rearrange(query, 'b (h d) -> b h d', h=self.n_head)
    key = rearrange(self.key, 'b t (h d) -> b h t d', h=self.n_head)
    value = rearrange(self.value, 'b t (h d) -> b h t d', h=self.n_head)
    scale = math.sqrt(key.size(-1))
    att = einsum(query, key, 'b h d, b h k d -> b h k') / scale
    att = nn.functional.softmax(att, dim=-1)
    context = einsum(att, value, 'b h t, b h t d -> b h d')
    context = rearrange(context, 'b h d -> b (h d)') # re-assemble all head outputs side by side

    return context, att

# The Speller

Similar to the language model that you coded up for HW4P1, you have to code a language model for HW4P2 as well. This time, we will also call the attention context step, within the decoder to get the attended-encoder-embeddings.


What you have coded till now:

<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/EncoderAttention.png" alt="data flow" height="400">
</center>

For the Speller, what we have to code:


<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/Decoder.png" alt="data flow" height="400">
</center>

In [45]:
class Speller(torch.nn.Module):

  # Refer to your HW4P1 implementation for help with setting up the language model.
  # The only thing you need to implement on top of your HW4P1 model is the attention module and teacher forcing.

  def __init__(self, attender:Attention, embedding_dim=128, hidden_dim=256, num_lstm_layers=2, context_dim=128):
    super(). __init__()
    assert hidden_dim == embedding_dim + context_dim, "Hidden dim should be equal to embedding + context dim"
    # config
    vocab_size          = len(VOCAB)
    self.embedding_dim   = embedding_dim
    self.hidden_dim      = hidden_dim
    self.context_dim     = context_dim
    self.num_lstm_layers = num_lstm_layers

    self.attend = attender # Attention object in speller
    self.max_timesteps = 176 # Max timesteps

    self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
    self.lstm_cells = nn.ModuleList([
      nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim) 
      for _ in range(num_lstm_layers)
    ])

    # For CDN (Feel free to change)
    self.output_to_char = nn.Linear(hidden_dim + context_dim, embedding_dim) # Linear module to convert outputs to correct hidden size (Optional: TO make dimensions match)
    self.activation = nn.ReLU() # Check which activation is suggested
    self.char_prob = nn.Linear(embedding_dim, vocab_size)# Linear layer to convert hidden space back to logits for token classification
    self.char_prob.weight = self.embedding.weight # Weight tying (From embedding layer)



  def lstm_step(self, input_word, hidden_state):
    lstm_input = input_word
    for i in range(len(self.lstm_cells)):
      hidden_state[i][0], hidden_state[i][1] = self.lstm_cells[i](lstm_input, (hidden_state[i][0], hidden_state[i][1])) # Feed the input through each LSTM Cell
      lstm_input = hidden_state[i][0]
    return hidden_state # What information does forward() need?

  def CDN(self, x):
    # Make the CDN here, you can add the output-to-char
    x = self.output_to_char(x)
    x = self.activation(x)
    x = self.char_prob(x)
    return x

  def forward (self, y=None, teacher_forcing_ratio=1):

    attn_context = torch.zeros((config["batch_size"], self.context_dim,)).to(DEVICE) # initial context tensor for time t = 0
    output_symbol = torch.LongTensor([SOS_TOKEN]*config["batch_size"]).to(DEVICE) # Set it to SOS for time t = 0
    raw_outputs = []
    attention_plot = []

    if y is None:
      timesteps = self.max_timesteps
      teacher_forcing_ratio = 0 #Why does it become zero?

    else:
      timesteps = y.size(1) # How many timesteps are we predicting for?

    hidden_states_list = [[torch.zeros((config["batch_size"], self.hidden_dim)).to(DEVICE), 
                           torch.zeros((config["batch_size"], self.hidden_dim)).to(DEVICE)
                          ] for _ in range(len(self.lstm_cells))] # Initialize your hidden_states list here similar to HW4P1

    for t in tqdm(range(timesteps)):
      p = random.uniform(0, 1)# generate a probability p between 0 and 1

      if p < teacher_forcing_ratio and t > 0: # Why do we consider cases only when t > 0? What is considered when t == 0? Think.
        output_symbol = y[:, t-1]# Take from y, else draw from probability distribution

      char_embed = self.embedding(output_symbol) # Embed the character symbol

      # Concatenate the character embedding and context from attention, as shown in the diagram
      lstm_input = torch.cat([char_embed, attn_context], dim=-1)

      hidden_states_list = self.lstm_step(lstm_input, hidden_states_list) # Feed the input through LSTM Cells and attention.
      # What should we retrieve from forward_step to prepare for the next timestep?

      attn_context, attn_weights = self.attend.compute_context(hidden_states_list[-1][0]) # Feed the resulting hidden state into attention

      cdn_input = torch.cat([attn_context, hidden_states_list[-1][0]], dim=-1) # Concatenate context and LSTM output

      raw_pred = self.CDN(cdn_input)

      # Generate a prediction for this timestep and collect it in output_symbols
      output_symbol = torch.argmax(raw_pred, dim=-1) # Draw correctly from raw_pred

      raw_outputs.append(raw_pred) # for loss calculation
      attention_plot.append(attn_weights) # for plotting attention plot

      print(t)

    attention_plot = torch.stack(attention_plot, dim=1)
    raw_outputs = torch.stack(raw_outputs, dim=1)

    return raw_outputs, attention_plot

In [46]:
class ASRModel(torch.nn.Module):
  def __init__(self,): # add parameters
    super().__init__()

    # Pass the right parameters here
    self.listener = TransformerListener(
        input_size = 28,
        base_lstm_layers=1,
        listener_hidden_size=128,
        n_heads = 8,
        tf_blocks = 2,
    )
    self.attend = Attention(
      listener_hidden_size=128, 
      speller_hidden_size=128, 
      projection_size=64
    )
    self.speller = Speller(self.attend, 
      embedding_dim=64, 
      hidden_dim=128, 
      num_lstm_layers=2, 
      context_dim=64
    )

  def forward(self, x,lx,y=None,teacher_forcing_ratio=1):
    # Encode speech features
    encoder_outputs, _ = self.listener(x,lx)

    # We want to compute keys and values ahead of the decoding step, as they are constant for all timesteps
    # Set keys and values using the encoder outputs
    self.attend.set_key_value(encoder_outputs)

    # Decode text with the speller using context from the attention
    raw_outputs, attention_plots = self.speller(y=y,teacher_forcing_ratio=teacher_forcing_ratio)

    return raw_outputs, attention_plots

In [47]:
model = ASRModel()
model = model.to(DEVICE)
for batch in train_loader:
    x, y, x_len, y_len = batch
    output = model(x.to(DEVICE), x_len, None)
    break

  0%|          | 0/176 [00:00<?, ?it/s]

0


  1%|          | 2/176 [06:47<8:36:09, 177.99s/it] 

1


  2%|▏         | 3/176 [06:48<4:40:50, 97.40s/it] 

2


  2%|▏         | 4/176 [06:49<2:50:06, 59.34s/it]

3


  3%|▎         | 5/176 [06:52<1:50:16, 38.69s/it]

4


  3%|▎         | 6/176 [19:28<13:21:05, 282.74s/it]

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


 20%|█▉        | 35/176 [19:34<43:17, 18.42s/it]   

32
33
34
35
36
37
38


 27%|██▋       | 47/176 [19:34<20:14,  9.41s/it]

39
40
41
42
43
44
45
46
47
48
49
50
51
52


 34%|███▍      | 60/176 [19:34<08:56,  4.63s/it]

53
54
55
56
57
58
59
60
61
62
63
64


 41%|████      | 72/176 [19:34<04:05,  2.36s/it]

65
66
67
68
69
70
71
72
73
74
75


 47%|████▋     | 83/176 [19:35<01:55,  1.25s/it]

76
77
78
79
80
81
82
83
84


 53%|█████▎    | 94/176 [19:35<00:51,  1.58it/s]

85
86
87
88
89
90
91
92
93
94
95
96


 61%|██████    | 107/176 [19:35<00:20,  3.31it/s]

97
98
99
100
101
102
103
104
105
106


 64%|██████▎   | 112/176 [19:35<00:14,  4.35it/s]

107
108
109
110
111
112
113
114
115


 76%|███████▌  | 133/176 [19:36<00:03, 12.34it/s]

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140


100%|██████████| 176/176 [19:36<00:00,  6.68s/it]

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175





# Model Setup

In [37]:
torch.cuda.empty_cache()
gc.collect()

3302

In [None]:
model = ASRModel()

model = model.to(DEVICE)
print(model)

ASRModel(
  (listener): TransformerListener(
    (base_lstm): LSTM(28, 128, batch_first=True, bidirectional=True)
    (embedding): Sequential(
      (0): PermuteBlock()
      (1): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,))
      (2): ReLU()
      (3): Conv1d(256, 256, kernel_size=(5,), stride=(2,), padding=(2,))
      (4): ReLU()
      (5): PermuteBlock()
    )
    (positional_encoding): PositionalEncoding()
    (transformer_encoder): Sequential(
      (0): TransformerEncoder(
        (KW): Linear(in_features=256, out_features=256, bias=True)
        (VW): Linear(in_features=256, out_features=256, bias=True)
        (QW): Linear(in_features=256, out_features=256, bias=True)
        (permute): PermuteBlock()
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (bn1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (bn2): LayerNorm((256,), eps=1e-05, 

# Loss Function, Optimizers, Scheduler

In [None]:
optimizer   = # TODO

criterion   = # TODO

scaler      = # TODO

scheduler   = # TODO

# Levenshtein Distance

In [None]:
# We have given you this utility function which takes a sequence of indices and converts them to a list of characters
def indices_to_chars(indices, vocab):
    tokens = []
    for i in indices: # This loops through all the indices
        if int(i) == SOS_TOKEN: # If SOS is encountered, dont add it to the final list
            continue
        elif int(i) == EOS_TOKEN: # If EOS is encountered, stop the decoding process
            break
        else:
            tokens.append(vocab[i])
    return tokens

# To make your life more easier, we have given the Levenshtein distantce / Edit distance calculation code
def calc_edit_distance(predictions, y, y_len, vocab= VOCAB, print_example= False):

    dist                = 0
    batch_size, seq_len = predictions.shape

    for batch_idx in range(batch_size):

        y_sliced    = indices_to_chars(y[batch_idx,0:y_len[batch_idx]], vocab)
        pred_sliced = indices_to_chars(predictions[batch_idx], vocab)

        # Strings - When you are using characters from the AudioDataset
        y_string    = ''.join(y_sliced)
        pred_string = ''.join(pred_sliced)

        #dist        += Levenshtein.distance(pred_string, y_string)
        # Comment the above abd uncomment below for toy dataset
        dist      += Levenshtein.distance(y_sliced, pred_sliced)

    if print_example:
        # Print y_sliced and pred_sliced if you are using the toy dataset
        print("\nGround Truth : ", y_string)
        print("Prediction   : ", pred_string)

    dist    /= batch_size
    return dist

# Train and Validation functions


In [None]:
def train(model, dataloader, criterion, optimizer, teacher_forcing_rate):

    model.train()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    running_loss        = 0.0
    running_perplexity  = 0.0

    for i, (x, y, lx, ly) in enumerate(dataloader):

        optimizer.zero_grad()

        x, y, lx, ly = x.to(DEVICE), y.to(DEVICE), lx, ly

        with torch.cuda.amp.autocast():

            raw_predictions, attention_plot = model(x, lx, y= y, tf_rate= teacher_forcing_rate)

            # Predictions are of Shape (batch_size, timesteps, vocab_size).
            # Transcripts are of shape (batch_size, timesteps) Which means that you have batch_size amount of batches with timestep number of tokens.
            # So in total, you have batch_size*timesteps amount of characters.
            # Similarly, in predictions, you have batch_size*timesteps amount of probability distributions.
            # How do you need to modify transcipts and predictions so that you can calculate the CrossEntropyLoss? Hint: Use Reshape/View and read the docs
            # Also we recommend you plot the attention weights, you should get convergence in around 10 epochs, if not, there could be something wrong with
            # your implementation
            loss        =  # TODO: Cross Entropy Loss

            perplexity  = torch.exp(loss) # Perplexity is defined the exponential of the loss

            running_loss        += loss.item()
            running_perplexity  += perplexity.item()

        # Backward on the masked loss
        scaler.scale(loss).backward()

        # Optional: Use torch.nn.utils.clip_grad_norm to clip gradients to prevent them from exploding, if necessary
        # If using with mixed precision, unscale the Optimizer First before doing gradient clipping

        scaler.step(optimizer)
        scaler.update()


        batch_bar.set_postfix(
            loss="{:.04f}".format(running_loss/(i+1)),
            perplexity="{:.04f}".format(running_perplexity/(i+1)),
            lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])),
            tf_rate='{:.02f}'.format(teacher_forcing_rate))
        batch_bar.update()

        del x, y, lx, ly
        torch.cuda.empty_cache()

    running_loss /= len(dataloader)
    running_perplexity /= len(dataloader)
    batch_bar.close()

    return running_loss, running_perplexity, attention_plot

In [None]:
def validate(model, dataloader):

    model.eval()

    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc="Val")

    running_lev_dist = 0.0

    for i, (x, y, lx, ly) in enumerate(dataloader):

        x, y, lx, ly = x.to(DEVICE), y.to(DEVICE), lx, ly

        with torch.inference_mode():
            raw_predictions, attentions = model(x, lx, y = None)

        # Greedy Decoding
        greedy_predictions   =  # TODO: How do you get the most likely character from each distribution in the batch?

        # Calculate Levenshtein Distance
        running_lev_dist    += calc_edit_distance(greedy_predictions, y, ly, VOCAB, print_example = False) # You can use print_example = True for one specific index i in your batches if you want

        batch_bar.set_postfix(
            dist="{:.04f}".format(running_lev_dist/(i+1)))
        batch_bar.update()

        del x, y, lx, ly
        torch.cuda.empty_cache()

    batch_bar.close()
    running_lev_dist /= len(dataloader)

    return running_lev_dist

# Wandb


In [None]:
# Login to Wandb
# Initialize your Wandb Run Here
# Save your model architecture in a txt file, and save the file to Wandb

In [None]:
def plot_attention(attention):
    # Function for plotting attention
    # You need to get a diagonal plot
    plt.clf()
    sns.heatmap(attention, cmap='GnBu')
    plt.show()

# Experiment

In [None]:
best_lev_dist = float("inf")
tf_rate = 1.0

for epoch in range(0, config['epochs']):

    print("\nEpoch: {}/{}".format(epoch+1, config['epochs']))

    # Call train and validate, get attention weights from training

    # Print your metrics

    # Plot Attention for a single item in the batch
    plot_attention(attention_plot[0].cpu().detach().numpy())

    # Log metrics to Wandb

    # Optional: Scheduler Step / Teacher Force Schedule Step


    if valid_dist <= best_lev_dist:
        best_lev_dist = valid_dist
        # Save your model checkpoint here

# Testing

In [None]:
# Optional: Load your best model Checkpoint here

# TODO: Create a testing function similar to validation
# TODO: Create a file with all predictions
# TODO: Submit to Kaggle