# HW4P2: Attention-based Speech Recognition

<img src="https://cdn.shopify.com/s/files/1/0272/2080/3722/products/SmileBumperSticker_5400x.jpg" alt="A cute cat" width="600">


Welcome to the final assignment in 11785. In this HW, you will work on building a speech recognition system with <i>attention</i>. <br> <br>

<center>
<img src="https://popmn.org/wp-content/uploads/2020/03/pay-attention.jpg" alt="A cute cat" height="100">
</center>

HW Writeup: On Piazza/Course Website <br>
Kaggle Competition Link: https://www.kaggle.com/competitions/11-785-s23-hw4p2/ <br>
Kaggle Dataset Link: https://www.kaggle.com/datasets/varunjain3/11-785-s23-hw4p2-dataset
<br>
LAS Paper: https://arxiv.org/pdf/1508.01211.pdf <br>
Attention is all you need:https://arxiv.org/pdf/1706.03762.pdf

# Read this section importantly!

1. By now, we believe that you are already a great deep learning practitioner, Congratulations. 🎉

2. You are allowed to use code from your previous homeworks for this homework. We will only provide, aspects that are necessary and new with this homework. 

3. There are a lot of resources provided in this notebook, that will help you check if you are running your implementations correctly.

In [None]:
!nvidia-smi

In [None]:
# Install some required libraries
# Feel free to add more if you want
!pip install -q python-levenshtein torchsummaryX wandb kaggle pytorch-nlp 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import Necessary Modules you require for this HW here
!pip install wandb --quiet
!pip install python-Levenshtein -q
!git clone --recursive https://github.com/parlance/ctcdecode.git
!pip install wget -q
%cd ctcdecode
!pip install . -q
%cd ..

!pip install torchsummaryX -q

# Imports

In [None]:
# Import Necessary Modules you require for this HW here
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchsummaryX import summary
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torchaudio.transforms as tat
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
import gc

import zipfile
import pandas as pd
from tqdm import tqdm
import os
import datetime

# imports for decoding and distance calculation
import ctcdecode
import Levenshtein
from ctcdecode import CTCBeamDecoder

import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

# Kaggle Dataset Download

In [None]:
api_token = '{"username":"jasongao0917","key":"9aaf0189605045a47217ffc7b87d06fc"}'

# set up kaggle.json
# TODO: Use the same Kaggle code from HW1P2, HW2P2, HW3P2
!mkdir /root/.kaggle/

with open("/root/.kaggle/kaggle.json", "w+") as f:
    f.write(api_token) # Put your kaggle username & key here

!chmod 600 /root/.kaggle/kaggle.json

In [None]:
# To download the dataset
!kaggle datasets download -d varunjain3/11-785-s23-hw4p2-dataset

In [None]:
# To unzip data quickly and quietly
!unzip -q 11-785-s23-hw4p2-dataset.zip -d ./data

# Dataset and Dataloaders

We have given you 2 datasets. One is a toy dataset, and the other is the standard LibriSpeech dataset. The toy dataset is to help you get your code implemented and tested and debugged easily, to verify that your attention diagonal is produced correctly. Note however that it's task (phonetic transcription) is drawn from HW3P2, it is meant to be familiar and help you understand how to transition from phonetic transcription to alphabet transcription, with a working attention module.

Please make sure you use the right constants in your code implementation for future modules, (SOS_TOKEN vs SOS_TOKEN_TOY) when working with either dataset. We have defined the constants accordingly below. Before you come to OH or post on piazza, make sure you aren't misuing the constants for either dataset in your code. 

## LibriSpeech

In terms of the dataset, the dataset structure for HW3P2 and HW4P2 dataset are very similar. Can you spot out the differences? What all will be required?? 

Hints:

- Check how big is the dataset (do you require memory efficient loading techniques??)
- How do we load mfccs? Do we need to normalise them? 
- Does the data have \<SOS> and \<EOS> tokens in each sequences? Do we remove them or do we not remove them? (Read writeup)
- Would we want a collating function? Ask yourself: Why did we need a collate function last time?
- Observe the VOCAB, is the dataset same as HW3P2? 
- Should you add augmentations, if yes which augmentations? When should you add augmentations? (Check bootcamp for answer)


In [None]:
config = {
  'batch_size': 256,
  'lr':1e-3,
  'epochs': 120,
}

VOCAB = ['<pad>', '<sos>', '<eos>', 
         'A',   'B',    'C',    'D',    
         'E',   'F',    'G',    'H',    
         'I',   'J',    'K',    'L',       
         'M',   'N',    'O',    'P',    
         'Q',   'R',    'S',    'T', 
         'U',   'V',    'W',    'X', 
         'Y',   'Z',    "'",    ' ', 
         ]

VOCAB_MAP = {VOCAB[i]:i for i in range(0, len(VOCAB))}

PAD_TOKEN = VOCAB_MAP["<pad>"]
SOS_TOKEN = VOCAB_MAP["<sos>"]
EOS_TOKEN = VOCAB_MAP["<eos>"]

print(f"Length of vocab: {len(VOCAB)}")
print(f"Vocab: {VOCAB}")
print(f"PAD_TOKEN: {PAD_TOKEN}")
print(f"SOS_TOKEN: {SOS_TOKEN}")
print(f"EOS_TOKEN: {EOS_TOKEN}")

Length of vocab: 31
Vocab: ['<pad>', '<sos>', '<eos>', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', "'", ' ']
PAD_TOKEN: 0
SOS_TOKEN: 1
EOS_TOKEN: 2


In [None]:
class SpeechDataset(torch.utils.data.Dataset):
    def __init__(self, root, partition= "train-clean-100"): 
      # Load the directory and all files in them
        self.mfcc_dir = root + partition + '/mfcc' #TODO
        self.transcript_dir = root + partition + '/transcripts' #TODO
        self.mfcc_files = sorted(os.listdir(self.mfcc_dir)) #TODO
        self.transcript_files = sorted(os.listdir(self.transcript_dir)) #TODO
        self.length = len(self.mfcc_files)
        self.mfccs = []
        self.transcripts = []
        for i in range(len(self.mfcc_files)):
            mfcc = np.load(self.mfcc_dir + '/' + self.mfcc_files[i])
            transcript = np.load(self.transcript_dir + '/' + self.transcript_files[i])
            # Cepstral Mean Normalization
            mfcc_mean = np.mean(mfcc, axis = 0)
            mfcc_std = np.std(mfcc, axis = 0)
            mfcc = (mfcc - mfcc_mean)/mfcc_std
            transcript = [VOCAB_MAP[label] for label in transcript]
            self.mfccs.append(mfcc)
            self.transcripts.append(transcript)

    def __len__(self):
        return self.length

    def __getitem__(self, ind):
        mfcc = self.mfccs[ind] # TODO
        mfcc = torch.FloatTensor(mfcc)
        transcript = self.transcripts[ind] # TODO
        transcript = torch.tensor(transcript)
        return mfcc, transcript


    def collate_fn(self, batch):
        batch_mfcc = [mfcc[0] for mfcc in batch]  # TODO
        batch_transcript = [mfcc[1] for mfcc in batch]  # TODO
        batch_mfcc_pad = pad_sequence(batch_mfcc, batch_first=True, padding_value=0) # TODO
        lengths_mfcc = [len(mfcc) for mfcc in batch_mfcc] # TODO 
        batch_transcript_pad = pad_sequence(batch_transcript, batch_first=True, padding_value=0) # TODO
        lengths_transcript = [len(transcript) for transcript in batch_transcript] # TODO
        return batch_mfcc_pad, batch_transcript_pad, torch.tensor(lengths_mfcc), torch.tensor(lengths_transcript)    

In [None]:
# Test Dataloader
class SpeechDatasetTest(torch.utils.data.Dataset):
  def __init__(self, root, partition = 'test-clean'):
        self.mfcc_dir = root + partition + '/mfcc' 
        self.mfcc_files = sorted(os.listdir(self.mfcc_dir))
        self.length = len(self.mfcc_files)
        self.mfccs = []
        for i in range(len(self.mfcc_files)):
            mfcc = np.load(self.mfcc_dir + '/' + self.mfcc_files[i])
            # Cepstral Mean Normalization
            mfcc_mean = np.mean(mfcc, axis = 0)
            mfcc_std = np.std(mfcc, axis = 0)
            mfcc = (mfcc - mfcc_mean)/mfcc_std
            self.mfccs.append(mfcc)
    
  def __len__(self):
        return self.length

  def __getitem__(self, ind):
        mfcc = self.mfccs[ind] # TODO
        mfcc = torch.FloatTensor(mfcc)
        return mfcc

  def collate_fn(self, batch):
    batch_mfcc_pad = pad_sequence(batch, batch_first=True, padding_value=0)
    lengths_mfcc = [len(mfcc) for mfcc in batch]
    return batch_mfcc_pad, torch.tensor(lengths_mfcc)

In [None]:
root = '/content/data/'
train_dataset = SpeechDataset(root, partition="train-clean-100")
dev_dataset = SpeechDataset(root, partition="dev-clean")
test_dataset = SpeechDatasetTest(root, partition="test-clean")

dev_loader = torch.utils.data.DataLoader(
    dataset     = dev_dataset, 
    num_workers = 2,
    batch_size  = config["batch_size"], 
    pin_memory  = True,
    shuffle     = False,
    collate_fn = dev_dataset.collate_fn
)

train_loader = torch.utils.data.DataLoader(
    dataset     = train_dataset, 
    num_workers = 4,
    batch_size  = config["batch_size"], 
    pin_memory  = True,
    shuffle     = True,
    collate_fn = train_dataset.collate_fn
)

test_loader = torch.utils.data.DataLoader(
    dataset     = test_dataset, 
    num_workers = 2,
    batch_size  = config["batch_size"], 
    pin_memory  = True,
    shuffle     = False,
    collate_fn = test_dataset.collate_fn
)

print("\nChecking the shapes of the data...")
for batch in dev_loader:
    x, y, x_len, y_len = batch
    print(x.shape, y.shape, x_len.shape, y_len.shape)
    break


Checking the shapes of the data...
torch.Size([256, 2936, 27]) torch.Size([256, 364]) torch.Size([256]) torch.Size([256])


Check if you are loading the data correctly with the following:

(Note: These are outputs from loading your data in the dataset class, not your dataloader which will have padded sequences)

- Train Dataset
```
Partition loaded:  train-clean-100
Max mfcc length:  2448
Average mfcc length:  1264.6258453344547
Max transcript:  400
Average transcript length:  186.65321139493324
```

- Dev Dataset
```
Partition loaded:  dev-clean
Max mfcc length:  3260
Average mfcc length:  713.3570107288198
Max transcript:  518
Average transcript length:  108.71698113207547
```

- Test Dataset
```
Partition loaded:  test-clean
Max mfcc length:  3491
Average mfcc length:  738.2206106870229
```

If your values is not matching, read hints, think what could have gone wrong. Then approach TAs.

# THE MODEL 

### Listen, Attend and Spell
Listen, Attend and Spell (LAS) is a neural network model used for speech recognition and synthesis tasks.

- LAS is designed to handle long input sequences and is robust to noisy speech signals.
- LAS is known for its high accuracy and ability to improve over time with additional training data.
- It consists of an <b>listener, an attender and a speller</b>, which work together to convert an input speech signal into a corresponding output text.

#### The Dataflow:
<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/DataFlow.png" alt="data flow" height="100">
</center>

#### The Listener: 
- converts the input speech signal into a sequence of hidden states.

#### The Attender:
- Decides how the sequence of Encoder hidden state is propogated to decoder.

#### The Speller:
- A language model, that incorporates the "context of attender"(output of attender) to predict sequence of words.






## The Listener:

Psuedocode:
```python
class Listner:
  def init():
    feature_embedder = #Few layers of 1DConv-batchnorm-activation (Don't overdo)
    pblstm_encoder = #Cascaded pblstm layers (Take pblstm from #HW3P2), 
    #can add more sequential lstms 
    dropout = #As per your liking

  def forward(x,lx):
    embedding = feature_embedder(x) #optional
    encoding, encoding_len = pblstm_encoder(embedding/x,lx)
    #Regularization if needed
    return encoding, encoding_len
```



In [None]:
# reference: https://github.com/salesforce/awd-lstm-lm/blob/dfd3cb0235d2caf2847a4d53e1cbd495b781b5d2/locked_dropout.py#L5
from torch.autograd import Variable
class LockedDropout(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, dropout = 0.3):
        if not self.training or not dropout:
            return x
        x_packed, x_length = pad_packed_sequence(x, batch_first=True)
        mas = x_packed.data.new(x_packed.size(0), 1, x_packed.size(2)).bernoulli_(1-dropout)
        mask = Variable(mas, requires_grad=False) / (1-dropout)
        mask = mask.expand_as(x_packed)
        x_packed = x_packed * mask
        x = pack_padded_sequence(x_packed, x_length, enforce_sorted=False, batch_first=True)
        return x

In [None]:
class PermuteBlock(torch.nn.Module):
    def forward(self, x):
      return x.transpose(1, 2)
      
class pBLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(pBLSTM, self).__init__()
        self.blstm = nn.Sequential(nn.LSTM(input_size, hidden_size, num_layers=1, bidirectional=True, batch_first=True))

    def forward(self, x_packed): 
        x_pads, x_lens = pad_packed_sequence(x_packed, batch_first=True)
        x_pad, x_len = self.trunc_reshape(x_pads, x_lens)
        x_padded = pack_padded_sequence(x_pad, x_len, batch_first=True, enforce_sorted=False)
        output, _ = self.blstm(x_padded)
        return output

    def trunc_reshape(self, x, x_lens): 
        if x.shape[1] % 2 != 0:
            x = x[:, :-1, :]
        batch, length, dimension = x.shape
        x = x.reshape(batch, length//2, dimension * 2)
        x_lens //= 2
        return x, x_lens

In [None]:
class Listener(torch.nn.Module):
  def __init__(self, hidden_size):
      super().__init__()
      input_size = 27
      complexity = 4
      self.feature_embedder = torch.nn.LSTM(input_size, hidden_size, 3, batch_first=True, bidirectional=True)
      self.pBLSTMs = nn.Sequential(pBLSTM(complexity * hidden_size, hidden_size), LockedDropout(), pBLSTM(complexity * hidden_size, hidden_size)
      )

  def forward(self,x, lx):
        x = pack_padded_sequence(x, lx, batch_first=True, enforce_sorted=False)
        x,_ = self.feature_embedder(x)
        x = self.pBLSTMs(x)
        encoding, encoding_len = pad_packed_sequence(x, batch_first=True)
        return encoding, encoding_len

## Attention

### Different ways to compute Attention

1. Dot-product attention
    * raw_weights = bmm(key, query) 
    * Optional: Scaled dot-product by normalizing with sqrt key dimension 
    * Check "Attention is All You Need" Section 3.2.1
    * 1st way is what most TAs are comfortable with, but if you want to explore, check out other methods below


2. Cosine attention
    * raw_weights = cosine(query, key) # almost the same as dot-product xD 

3. Bi-linear attention
    * W = Linear transformation (learnable parameter): d_k -> d_q
    * raw_weights = bmm(key @ W, query)

4. Multi-layer perceptron
    * Check "Neural Machine Translation and Sequence-to-sequence Models: A Tutorial" Section 8.4

5. Multi-Head Attention
    * Check "Attention is All You Need" Section 3.2.2
    * h = Number of heads
    * W_Q, W_K, W_V: Weight matrix for Q, K, V (h of them in total)
    * W_O: d_v -> d_v
    * Reshape K: (B, T, d_k) to (B, T, h, d_k // h) and transpose to (B, h, T, d_k // h)
    * Reshape V: (B, T, d_v) to (B, T, h, d_v // h) and transpose to (B, h, T, d_v // h)
    * Reshape Q: (B, d_q) to (B, h, d_q // h) `
    * raw_weights = Q @ K^T
    * masked_raw_weights = mask(raw_weights)
    * attention = softmax(masked_raw_weights)
    * multi_head = attention @ V
    * multi_head = multi_head reshaped to (B, d_v)
    * context = multi_head @ W_O

Pseudocode:

```python
class Attention:
    '''
    Attention is calculated using the key, value (from encoder embeddings) and query from decoder.

    After obtaining the raw weights, compute and return attention weights and context as follows.:

    attention_weights   = softmax(raw_weights)
    attention_context   = einsum("thinkwhatwouldbetheequationhere",attention, value) #take hint from raw_weights calculation

    At the end, you can pass context through a linear layer too.
    '''

    def init(listener_hidden_size,
              speller_hidden_size,
              projection_size):

        VW = Linear(listener_hidden_size,projection_size)
        KW = Linear(listener_hidden_size,projection_size)
        QW = Linear(speller_hidden_size,projection_size)

    def set_key_value(encoder_outputs):
        '''
        In this function we take the encoder embeddings and make key and values from it.
        key.shape   = (batch_size, timesteps, projection_size)
        value.shape = (batch_size, timesteps, projection_size)
        '''
        key = KW(encoder_outputs)
        value = VW(encoder_outputs)
      
    def compute_context(decoder_context):
        '''
        In this function from decoder context, we make the query, and then we
         multiply the queries with the keys to find the attention logits, 
         finally we take a softmax to calculate attention energy which gets 
         multiplied to the generted values and then gets summed.

        key.shape   = (batch_size, timesteps, projection_size)
        value.shape = (batch_size, timesteps, projection_size)
        query.shape = (batch_size, projection_size)

        You are also recomended to check out Abu's Lecture 19 to understand Attention better.
        '''
        query = QW(decoder_context) #(batch_size, projection_size)

        raw_weights = #using bmm or einsum. We need to perform batch matrix multiplication. It is important you do this step correctly.
        #What will be the shape of raw_weights?

        attention_weights = #What makes raw_weights -> attention_weights

        attention_context = #Multiply attention weights to values

        return attention_context, attention_weights 
```

In [None]:
class Attention(torch.nn.Module):
  def __init__(self, listener_hidden_size, speller_hidden_size, projection_size):
    super().__init__()
    new_lhs = 2 * listener_hidden_size
    self.KW = torch.nn.Linear(new_lhs, projection_size)
    self.VW = torch.nn.Linear(new_lhs, projection_size)

  def set_key_value(self, encoder_outputs):
    self.key = self.KW(encoder_outputs)
    self.value = self.VW(encoder_outputs)
  
  def get_shape(self):
    return self.key.shape

  def compute_context(self, decoder_context, mask = None):
    self.query = decoder_context
    raw_weights = torch.bmm(self.key, self.query.unsqueeze(-1)).squeeze(-1)
    dt = raw_weights.dtype
    if mask != None:
      mask = mask.to(device)
      if dt == torch.float32:
        val = -1e+30
      else:
        val = -1e+4
      raw_weights.masked_fill_(mask, val)
    attention_weights = torch.softmax(raw_weights, dim=1)
    attention_context = torch.bmm(attention_weights.unsqueeze(1), self.value).squeeze(1)
    return attention_context, attention_weights

## The Speller

Similar to the language model that you coded up for HW4P1, you have to code a language model for HW4P2 as well. This time, we will also call the attention context step, within the decoder to get the attended-encoder-embeddings.


What you have coded till now:

<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/EncoderAttention.png" alt="data flow" height="400">
</center>

For the Speller, what we have to code:


<center>
<img src="https://github.com/varunjain3/11785_s23_h4p2/raw/main/Decoder.png" alt="data flow" height="400">
</center>

In [None]:
class Speller(torch.nn.Module):

  # Refer to your HW4P1 implementation for help with setting up the language model.
  # The only thing you need to implement on top of your HW4P1 model is the attention module and teacher forcing.

  def __init__(self, attender:Attention, vocab_size, decoder_dim, embed_dim=256, proj_size=128):
    super(). __init__()
    self.attend = attender # Attention object in speller
    self.max_timesteps = 550 # Max timesteps
    self.embedding = torch.nn.Embedding(vocab_size, embed_dim)  # Embedding layer to convert token to latent space
    self.embedding_dropout = torch.nn.Sequential(torch.nn.Dropout(0.25),torch.nn.Dropout(0.25))     
    self.lstm_cells = torch.nn.Sequential(
        torch.nn.LSTMCell(proj_size + embed_dim, decoder_dim), torch.nn.LSTMCell(decoder_dim, proj_size)
    )

    # For CDN (Feel free to change)
    self.output_to_char = torch.nn.Linear(2 * proj_size, embed_dim)# Linear module to convert outputs to correct hidden size (Optional: TO make dimensions match)
    self.activation = torch.nn.Tanh()# Check which activation is suggested
    self.char_prob = torch.nn.Linear(embed_dim, vocab_size)# Linear layer to convert hidden space back to logits for token classification
    self.char_prob.weight = self.embedding.weight# Weight tying (From embedding layer)


  def lstm_step(self, input_word, hidden_state):

    for i in range(len(self.lstm_cells)):
        hidden_state[i] = self.lstm_cells[i](input_word, hidden_state[i])
        input_word = hidden_state[i][0]
    return input_word, hidden_state
    
  def CDN(self, input):
    # Make the CDN here, you can add the output-to-char
    output = self.output_to_char(input)
    output = self.activation(output)
    output = self.char_prob(output)
    return output
    
  def forward (self, encoder_len, y = None, teacher_forcing_ratio=1, isGumble=True):
    batch, length, value = self.attend.get_shape()
    mask = torch.arange(length).unsqueeze(0) >= encoder_len.unsqueeze(1)
    attn_context = torch.zeros(batch, value).to(device)# initial context tensor for time t = 0
    output_symbol = torch.ones(batch,).long().to(device)# Set it to SOS for time t = 0
    raw_outputs = []  
    attention_plot = []
    prediction = torch.zeros(batch, ).to(device)
      
    if y is None:
      timesteps = self.max_timesteps
      teacher_forcing_ratio = 0 #Why does it become zero?
      isGumble=False

    else:
      timesteps = y.shape[1] # How many timesteps are we predicting for?

    hidden_states_list = [None, None] # Initialize your hidden_states list here similar to HW4P1

    for t in range(timesteps):
      p = np.random.random() # generate a probability p between 0 and 1

      if p < teacher_forcing_ratio and t > 0: # Why do we consider cases only when t > 0? What is considered when t == 0? Think.
        output_symbol = y[:, t - 1].to(device) # Take from y, else draw from probability distribution


      char_embed = self.embedding(output_symbol) # Embed the character symbol

      # Concatenate the character embedding and context from attention, as shown in the diagram
      lstm_input = torch.cat([char_embed, attn_context], dim = 1)

      lstm_output, hidden_states_list = self.lstm_step(lstm_input, hidden_states_list) # Feed the input through LSTM Cells and attention.
      # What should we retrieve from forward_step to prepare for the next timestep?

      attn_context, attn_weights = self.attend.compute_context(lstm_output, mask) # Feed the resulting hidden state into attention

      cdn_input = torch.cat((lstm_output, attn_context), dim=1) # TODO: You need to concatenate the context from the attention module with the LSTM output hidden state, as shown in the diagram

      raw_pred = self.CDN(cdn_input) # call CDN with cdn_input

      # Generate a prediction for this timestep and collect it in output_symbols
      output_symbol = torch.argmax(raw_pred, dim = -1) # Draw correctly from raw_pred

      raw_outputs.append(raw_pred) # for loss calculation
      attention_plot.append(attn_weights) # for plotting attention plot

    
    attention_plot = torch.stack(attention_plot, dim=1)
    raw_outputs = torch.stack(raw_outputs, dim=1)

    return raw_outputs, attention_plot

## LAS

Here we finally build the LAS model, comibining the listener, attender and speller together, we have given a template, but you are free to read the paper and implement it yourself.

In [None]:
class LAS(torch.nn.Module):
  def __init__(self, vocab_size, encoder_size, decoder_size, projection_size): # add parameters
    super().__init__()

    # Pass the right parameters here
    self.augmentations = torch.nn.Sequential(
      PermuteBlock(),
      tat.FrequencyMasking(freq_mask_param = 5),
      tat.TimeMasking(time_mask_param = 250),
      PermuteBlock(),
    )
    self.listener = Listener(encoder_size)
    self.attend = Attention(encoder_size, decoder_size, projection_size)
    self.speller = Speller(self.attend, vocab_size, decoder_size)

  def forward(self, x, lx, y=None, teacher_forcing_ratio=1):
    # Encode speech features
    encoder_outputs, encoder_len = self.listener(x,lx)

    # We want to compute keys and values ahead of the decoding step, as they are constant for all timesteps
    # Set keys and values using the encoder outputs
    self.attend.set_key_value(encoder_outputs)

    # Decode text with the speller using context from the attention
    raw_outputs, attention_plots = self.speller(encoder_len = encoder_len,y=y,teacher_forcing_ratio=teacher_forcing_ratio)

    return raw_outputs, attention_plots

# Model Setup 

In [None]:
# Baseline LAS has the following configuration:
# Encoder bLSTM/pbLSTM Hidden Dimension of 512 (256 per direction)
# Decoder Embedding Layer Dimension of 256
# Decoder Hidden Dimension of 512 
# Attention Projection Size of 128
# Feel Free to Experiment with this 

model = LAS(
    # Initialize your model 
    # Read the paper and think about what dimensions should be used
    # You can experiment on these as well, but they are not requried for the early submission
    # Remember that if you are using weight tying, some sizes need to be the same
    vocab_size = len(VOCAB),
    encoder_size = 512, 
    decoder_size = 512,
    projection_size = 128
)

model = model.to(device)
print(model)
summary(model, x.float().to(device), x_len, y)

# Loss Function, Optimizers, Scheduler

In [None]:
optimizer   = torch.optim.Adam(model.parameters(), lr= config['lr']) # Feel free to experiment if needed
criterion   = torch.nn.CrossEntropyLoss(reduction="mean", ignore_index=PAD_TOKEN) #check how would you fill these values : https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
scaler      = torch.cuda.amp.GradScaler()
scheduler   = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor = 0.5, patience=4,verbose=True)

# Optional (but Recommended): Create a custom class for a Teacher Force Schedule

# Levenshtein Distance

In [None]:
# We have given you this utility function which takes a sequence of indices and converts them to a list of characters
def indices_to_chars(indices, vocab):
    tokens = []
    for i in indices: # This loops through all the indices
        if int(i) == SOS_TOKEN: # If SOS is encountered, dont add it to the final list
            continue
        elif int(i) == EOS_TOKEN: # If EOS is encountered, stop the decoding process
            break
        else:
            tokens.append(vocab[int(i)])
    return tokens

# To make your life more easier, we have given the Levenshtein distantce / Edit distance calculation code
def calc_edit_distance(predictions, y, ly, vocab, print_example= False):

    dist                = 0
    batch_size, seq_len = predictions.shape

    for batch_idx in range(batch_size): 

        y_sliced    = indices_to_chars(y[batch_idx,0:ly[batch_idx]], vocab)
        pred_sliced = indices_to_chars(predictions[batch_idx], vocab)

        # Strings - When you are using characters from the AudioDataset
        y_string    = ''.join(y_sliced)
        pred_string = ''.join(pred_sliced)
        
        dist        += Levenshtein.distance(pred_string, y_string)
        # Comment the above and uncomment below for toy dataset, as the toy dataset has a list of phonemes to compare
        # dist      += Levenshtein.distance(y_sliced, pred_sliced)

    if print_example: 
        # Print y_sliced and pred_sliced if you are using the toy dataset
        print("Ground Truth : ", y_string)
        print("Prediction   : ", pred_string)
        
    dist/=batch_size
    return dist

# Train and Validation functions 


In [None]:
def train(model, dataloader, criterion, optimizer, teacher_forcing_rate):

    model.train()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    running_loss        = 0.0
    running_perplexity  = 0.0
    
    for i, (x, y, lx, ly) in enumerate(dataloader):

        optimizer.zero_grad()

        x, y, lx, ly = x.to(device), y.to(device), lx, ly

        with torch.cuda.amp.autocast():

            raw_predictions, attention_plot = model(x, lx, y= y, teacher_forcing_ratio= teacher_forcing_rate)

            # Predictions are of Shape (batch_size, timesteps, vocab_size). 
            # Transcripts are of shape (batch_size, timesteps) Which means that you have batch_size amount of batches with timestep number of tokens.
            # So in total, you have batch_size*timesteps amount of characters.
            # Similarly, in predictions, you have batch_size*timesteps amount of probability distributions.
            # How do you need to modify transcipts and predictions so that you can calculate the CrossEntropyLoss? Hint: Use Reshape/View and read the docs
            # Also we recommend you plot the attention weights, you should get convergence in around 10 epochs, if not, there could be something wrong with 
            # your implementation
            loss        =  criterion(raw_predictions.view(-1, raw_predictions.size(2)), y.flatten() )# TODO: Cross Entropy Loss
            perplexity  = torch.exp(loss) # Perplexity is defined the exponential of the loss

            running_loss        += loss.item()
            running_perplexity  += perplexity.item()
        
        # Backward on the masked loss
        scaler.scale(loss).backward()

        # Optional: Use torch.nn.utils.clip_grad_norm to clip gradients to prevent them from exploding, if necessary
        # If using with mixed precision, unscale the Optimizer First before doing gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        
        scaler.step(optimizer)
        scaler.update()
        

        batch_bar.set_postfix(
            loss="{:.04f}".format(running_loss/(i+1)),
            perplexity="{:.04f}".format(running_perplexity/(i+1)),
            lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])),
            tf_rate='{:.02f}'.format(teacher_forcing_rate))
        batch_bar.update()

        del x, y, lx, ly
        torch.cuda.empty_cache()

    running_loss /= len(dataloader)
    running_perplexity /= len(dataloader)
    batch_bar.close()

    return running_loss, running_perplexity, attention_plot

In [None]:
def validate(model, dataloader):

    model.eval()

    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc="Val")

    running_lev_dist = 0.0

    for i, (x, y, lx, ly) in enumerate(dataloader):

        x, y, lx, ly = x.to(device), y.to(device), lx, ly

        with torch.inference_mode():
            raw_predictions, attentions = model(x, lx, y = None)

        # Greedy Decoding
        greedy_predictions   =  torch.argmax(raw_predictions,dim=-1)  # TODO: How do you get the most likely character from each distribution in the batch?
        # Calculate Levenshtein Distance
        running_lev_dist    += calc_edit_distance(greedy_predictions, y, ly, VOCAB, print_example = False) # You can use print_example = True for one specific index i in your batches if you want

        batch_bar.set_postfix(
            dist="{:.04f}".format(running_lev_dist/(i+1)))
        batch_bar.update()

        del x, y, lx, ly
        torch.cuda.empty_cache()

    batch_bar.close()
    running_lev_dist /= len(dataloader)

    return running_lev_dist

# Experiment

In [None]:
# Login to Wandb
# Initialize your Wandb Run Here
# Save your model architecture in a txt file, and save the file to Wandb
import wandb
wandb.login(key="06c8f81427188d28de31ac7bfe0ddeadd005abc3") #API Key is in your wandb account, under settings (wandb.ai/settings)

In [None]:
# Create your wandb run
run = wandb.init(
    name = "submission", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "hw4p2-ablations", ### Project should be created in your wandb account 
    config = config ### Wandb Config for your run
)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

def plot_attention(attention): 
    plt.clf()
    sns.heatmap(attention, cmap='GnBu')
    plt.show()

In [None]:
def save_model(model, optimizer, scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         metric[0]                  : metric[1], 
         'epoch'                    : epoch}, 
         path
    )

def load_model(path, model, metric= 'valid_acc', optimizer= None, scheduler= None):

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'], strict = False)

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
    epoch   = checkpoint['epoch']
    metric  = checkpoint[metric]

    return [model, optimizer, scheduler, epoch, metric]

In [None]:
last_epoch_completed = 0
start = last_epoch_completed
end = config["epochs"]
best_lev_dist =float("inf")
epoch_model_path = '/content/drive/MyDrive/hw4_epoch_model.pth'
best_model_path = '/content/drive/MyDrive/hw4model_1.pth'

In [None]:
torch.cuda.empty_cache()
gc.collect()



1802

In [None]:
tf_rate = 1.0

for epoch in range(0, config['epochs']):
    
    print("\nEpoch: {}/{}".format(epoch, config['epochs']))

    curr_lr = float(optimizer.param_groups[0]['lr'])
    train_loss, running_perplexity, attention_plot = train(model, train_loader, criterion, optimizer, tf_rate) #TODO
    valid_dist = validate(model, dev_loader)
    if 10 > valid_dist:
      scheduler.step(valid_dist)
    if 80 > valid_dist and epoch % 6 == 0:
      tf_rate = tf_rate - 0.05
      tf_rate = max(0.4, tf_rate)
    print("\tTrain Loss {:.04f}\t Learning Rate {:.07f}".format(train_loss, curr_lr))
    print("\tVal Dist {:.04f}%\t teaching rate {:.02f}".format(valid_dist, tf_rate))

    plot_attention(attention_plot[0].cpu().detach().numpy())

    wandb.log({
        'train_loss': train_loss,  
        'valid_dist': valid_dist,
        'lr'        : curr_lr
    })
    
    save_model(model, optimizer, scheduler, ['valid_dist', valid_dist], epoch, epoch_model_path)
    wandb.save(epoch_model_path)

    if valid_dist <= best_lev_dist:
        best_lev_dist = valid_dist
        save_model(model, optimizer, scheduler, ['valid_dist', valid_dist], epoch, best_model_path)
        wandb.save(best_model_path)
        print("Saved best model")

In [None]:
plot_attention(attention_plot[0].cpu().detach().numpy())

# Testing

In [None]:
model.eval()
final_predict = []
print("Testing")
for data in tqdm(test_loader):

    x, lx   = data
    x       = x.to(device)

    with torch.no_grad():
        raw_predictions, attentions = model(x, lx)

    greedy_predictions   =  torch.argmax(raw_predictions,dim=2)
    for i in range(len(raw_predictions)):
        string = indices_to_chars(greedy_predictions[i], VOCAB)
        pred_str = "".join(string)
        final_predict.append(pred_str)
    
    del x, lx, raw_predictions, attentions
    torch.cuda.empty_cache()

Testing



  0%|          | 0/11 [00:00<?, ?it/s][A
  9%|▉         | 1/11 [00:03<00:31,  3.19s/it][A
 18%|█▊        | 2/11 [00:05<00:24,  2.75s/it][A
 27%|██▋       | 3/11 [00:07<00:20,  2.54s/it][A
 36%|███▋      | 4/11 [00:10<00:17,  2.48s/it][A
 45%|████▌     | 5/11 [00:12<00:14,  2.46s/it][A
 55%|█████▍    | 6/11 [00:14<00:11,  2.38s/it][A
 64%|██████▎   | 7/11 [00:17<00:09,  2.32s/it][A
 73%|███████▎  | 8/11 [00:19<00:06,  2.25s/it][A
 82%|████████▏ | 9/11 [00:21<00:04,  2.31s/it][A
 91%|█████████ | 10/11 [00:24<00:02,  2.31s/it][A
100%|██████████| 11/11 [00:25<00:00,  2.30s/it]


In [None]:
with open("submission.csv", 'w') as f:
    f.write("index,label\n")
    for i, pred in enumerate(final_predict):
        f.write(f"{i},{pred}\n")

In [None]:
!kaggle competitions submit -c attention-based-speech-recognition-slack -f submission.csv -m "Message"

100% 292k/292k [00:01<00:00, 193kB/s]
Successfully submitted to Attention-Based Speech Recognition (Slack)