# Fine-tuning a Wav2vec Model for Automatic Speech Recognition

**Name**: 

The ASR system used in the first two labs relied on an acoustic model made up of two main components: a *feature* extractor (based on the [wav2vec](https://arxiv.org/pdf/2006.11477) architecture) and a *classifier* (a linear layer). The model then operates in two stages:

1. The speech waveform is passed to the feature extractor to transform it into a latent representation (= the *feature maps*).
2. The latent representation is passed to a classification layer to compute the probability of each character (= the *emission matrix*).

<center><img src="https://github.com/magronp/magronp.github.io/blob/master/images/wav2vec2asr.png?raw=true" width="800"></center>

To train such a model in practice, a two-stage proces is used. First, the wav2vec model alone is pre-trained in a self-supervised manner (using speech data only). Then, the classification layer is added and the whole model \{wav2vec+classifier\} is fine-tuned in a supervised manner (from speech+transcript data).

In the previous labs, we have used a whole model that was already pre-trained and fine-tuned for ASR. In this lab, we start from the pre-trained wav2vec model, and we reproduce the process of fine-tuning it for ASR. 

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchaudio
import IPython
import os
import fnmatch
import copy
import matplotlib.pyplot as plt
torch.random.manual_seed(0);

MAX_FILES = 100 # lower this number for processing a subset of the dataset

In [None]:
# Main dataset path - If needed, you can change it HERE but NOWHERE ELSE in the notebook!
data_dir = "../dataset/"

In [None]:
# Speech and transcripts sub-directories paths
data_speech_dir = os.path.join(data_dir, 'speech')
data_transc_dir = os.path.join(data_dir, 'transcription')

## Preparation

In [None]:
# Example file
audio_file = '61-70968-0001.wav'
audio_file_path = os.path.join(data_speech_dir, audio_file)
print(f"Audio file path: {audio_file_path}")

waveform, sr = torchaudio.load(audio_file_path, channels_first=True)
IPython.display.Audio(data=waveform, rate=sr)

In [None]:
# Function to load a transcript; the preprocessing here is a bit different than in previous labs (more convenient for training)
def get_true_transcript(transc_file_path):
    with open(transc_file_path, "r") as f:
        true_transcript = f.read()
    true_transcript = true_transcript.lower().replace(' ','|').replace('\n','')
    return true_transcript

# Load and display the true transcription
transc_file_path = os.path.join(data_transc_dir, audio_file.replace('wav', 'txt'))
true_transcript = get_true_transcript(transc_file_path)
print(true_transcript)

In [None]:
# We provide the list of labels (=characters) that can be found in the dataset
labels = ['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm',
          'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']
n_labels = len(labels)

We need to transform the true transcript into a list of integers in order to feed it to a training loss function. To that end, we define a dictionary `dico_labels` that maps each character in the list of possible labels to an integer (for instance, `dico_labels['e']=2` or `dico_labels['a']=4`).

In [None]:
# TODO: using the 'labels' list above, define this dictionary
dico_labels = {}
for index, element in enumerate(labels):
    dico_labels[element] = index

In [None]:
# Apply dico_labels to the true transcript, and build a tensor from it
target_indices = [dico_labels[c] for c in true_transcript]
target_indices = torch.tensor(target_indices, dtype=torch.long)
print(target_indices)

## Acoustic model

Torchaudio comprises many models whose pretrained weights can be loaded directly (the list can be found [here](https://pytorch.org/audio/stable/pipelines.html#id3)). Here we use `WAV2VEC2_BASE`, which is pre-trained on speech data, but not fine-tuned for ASR.

We can apply it to the waveform to compute the feature maps, which is a tensor of size `[1, time steps, feature dim]` (recall that `1` corresponds to the batch size).

In [None]:
# Load the acoustic model: here it's a Wav2vec that is not fine-tuned
model_name = 'WAV2VEC2_BASE'
bundle = getattr(torchaudio.pipelines, model_name)
wav2vec = bundle.get_model()

# Display model architecture
print(wav2vec)

In [None]:
# Compute the feature maps
with torch.inference_mode():
    features, _ = wav2vec(waveform)

print(features.shape)

<span style="color:red"> **Exercise 1**</span>. Complete the `Wav2vecASR` model class below. This model includes a wav2vec feature extractor (so it should be re-instaciated in the `__init__` method) and a linear classification layer that takes the feature maps and outputs log-probabilities per class (thus it has to use a [log softmax](https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html) activation after classification).

In [None]:
class Wav2vecASR(nn.Module):
    def __init__(self, output_size):
        super().__init__()

        self.wav2vec = getattr(torchaudio.pipelines, 'WAV2VEC2_BASE').get_model()
        self.feature_dim = self.wav2vec.encoder.feature_projection.projection.out_features
        #self.feature_dim = 768 #the easy way
        self.output_layer = nn.Linear(self.feature_dim, output_size)
    
    def forward(self, x):
        features, _ = self.wav2vec(x)
        emission = self.output_layer(features)
        emission = emission.log_softmax(dim=-1)
        return emission


In [None]:
# Instanciate the model
output_size = n_labels
model = Wav2vecASR(output_size)

In [None]:
# A function to count the number of trainable parameters
def count_tlearnable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Number of trainable parameters:', count_tlearnable_params(model))

The model is very large (90M+ parameters) mostly because of the wav2vec part. To speed training in this lab, we freeze the wav2vec part and we only train the classification layer. To that end, we need to set `requires_grad = False` for the wav2vec's parameters.

In [None]:
# TODO: Define a function to freeze parameters, apply it to "wav2vec" part of the model,
# and print the new number of trainable parameters
def freeze_params(m):
    for param in m.parameters():
        param.requires_grad = False
    return
    
model.wav2vec.apply(freeze_params)

print('Number of trainable parameters:', count_tlearnable_params(model))

In [None]:
# Initialization function for the network's parameters
def init_params(m, seed=0):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data, generator=torch.manual_seed(seed))
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    return model 

In [None]:
# TODO: Apply the initialization function to the classification layer of the model
model.output_layer.apply(init_params)

In [None]:
# TODO: apply the model to the waveform to compute the emission matrix, print its shape and plot it
with torch.inference_mode():
    emission = model(waveform)
print(emission.shape)

# Vizualize the emission matrix
plt.figure()
plt.imshow(emission[0].cpu().T)
plt.title("Emission matrix")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
plt.show()

In [None]:
# We provide a function to get a transcript from the emission matrix (similar to the greedy decoder from lab 1)
def transcript_from_emission(emission, labels):
    indices = torch.argmax(emission, dim=-1)  # take the most likely index at each time step
    indices = torch.unique_consecutive(indices, dim=-1) # remove duplicates
    indices = [i for i in indices if i != 0] # remove the blank token
    transcript = "".join([labels[i] for i in indices]) # convert integers back into characters
    transcript = transcript.lower()
    return transcript

# The transcript using a non-trained model should look bad
print(transcript_from_emission(emission[0], labels))

## Dataset

We now need to define the `Dataset` class, to efficiently load the speech data and true transcript (in the form of indices / integers).

<span style="color:red"> **Exercise 2**</span>. Complete the `ASRdataset` class below (`__init__`, `__len__`, and `__getitem__` methods). In particular, the `__getitem__` method should return three outputs:
- `waveform`: a tensor containing the speech waveform
- `true_transcript`: the true transcript as per using the provided loading function
- `target_indices`: a tensor containing the integers corresponding to the transcript

The `__init__` method should make use of the `MAX_FILES` variable to limit the size of the dataset (for speed).

In [None]:
# ASR Dataset class
class ASRdataset(Dataset):
    def __init__(self, data_speech_dir, data_transc_dir, dico_labels, MAX_FILES=None):
        self.data_speech_dir = data_speech_dir
        self.data_transc_dir = data_transc_dir
        self.audio_files = self._find_files(data_speech_dir)[:MAX_FILES]
        self.dico_labels = dico_labels

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, index):
        # load the waveform
        audio_file = self.audio_files[index]
        audio_file_path = os.path.join(self.data_speech_dir, audio_file)
        waveform, _ = torchaudio.load(audio_file_path, channels_first=True)
        # load the true transcript
        transc_file_path = os.path.join(self.data_transc_dir, audio_file.replace('wav', 'txt'))
        true_transcript = get_true_transcript(transc_file_path)
        # transform transcript into list of integers
        target_indices = [self.dico_labels[c] for c in true_transcript]
        target_indices = torch.tensor(target_indices, dtype=torch.long)
        return waveform, true_transcript, target_indices

    def _find_files(self, directory, pattern='*.wav'):
        """Recursively finds all files matching the pattern."""
        files = []
        for root, _, filenames in os.walk(directory):
            for filename in fnmatch.filter(filenames, pattern):
                files.append(filename)
        files = sorted(files)
        return files


In [None]:
# Instanciate the ASR dataset and print its length
asrdataset = ASRdataset(data_speech_dir, data_transc_dir, dico_labels, MAX_FILES=MAX_FILES)
print('Dataset length:', len(asrdataset))

# Get the first data sample, and print some information
waveform, true_transcript, target_indices = asrdataset[0]
print(waveform.shape)
print(true_transcript)
print(target_indices)

In real-life applications (and as we usually do in the "Neural Networks" labs), we assemble the data samples into *batches* for efficiency. However, our data points here have different lenghts in general: two speech waveforms / two transcripts are not guaranteed to have the same duration / number of characters. Therefore, in such a case we need to customize the dataloader such that it performs some [padding operation](https://www.codefull.org/2018/11/use-pytorchs-dataloader-with-variable-length-sequences-for-lstm-gru/) in order to yield data samples of same length. However, to keep things simple in this lab, we skip this padding operation, and we do not work with batches but rather iterate over the dataset directly.

## Training with the CTC loss

For a given data sample in the dataset, we have:
- the emission matrix (log-probability of each character over time frames)
- the true transcript (represented as a list of integers corresponding to each character)

However, to train our network, we need to obtain an estimated transcript so that we can compute a loss between the true and estimated transcripts. This requires some post-processing of the emission matrix, but the good news is that we don't have to do it, the [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) handles that for us!

In lab 2 we have use the CTC algorithm to perform inference, but it can also be used as a loss function to train an ASR network. Not only the CTC loss handles the post-processing from the emission matrix, but its great advantage is that it performs alignment from input/output pairs of different lengths, so we don't have to explicitly align each character in the transcript with a time frame in the emission matrix.

In Pytorch, [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) is fed with the the emission matrix and the tensor containing target indices (corresponding to the true transcript). We also need to give it the input and target lengths explicitly: indeed, even though here we don't manipulate batches / we don't do padding, in general this would be the case so we need to let the function know what is the actual input/target length (before padding).

In [None]:
# Compute the emission matrix (not in inference mode such that we keep track of the gradients)
emission = model(waveform)
emission = emission[0] #remove the "batch" dimension (since here batch_size=1)

# Define the input and target lengths as tensors
T = emission.shape[0]
L = len(target_indices)
input_length = torch.tensor(T, dtype=torch.long)
target_length = torch.tensor(L, dtype=torch.long)

# Instanciate a loss object
ctc_loss = nn.CTCLoss()

# Compute the loss
loss = ctc_loss(emission, target_indices, input_length, target_length)
print(loss.item())

# Compute the gradients
loss.backward()

<span style="color:red"> **Exercise 3**</span>. Write the training function `training_wav2vecASR` to train the model. It is similar to what we usually do in the Neural Networks labs, although here we do not build batches of data (thus we don't need a `Dataloader` object). Instead, we directly iterate over the `Dataset`. The training function uses an Adam optimizer, and no validation.

In [None]:
#Training function
def training_wav2vecASR(model, train_dataloader, num_epochs, loss_fn, learning_rate):

    model_tr = copy.deepcopy(model)
    model_tr.train()
    optimizer = torch.optim.Adam(model_tr.parameters(), lr=learning_rate)
    train_losses = []

    for epoch in range(num_epochs):
        tr_loss = 0
        for data_sample in asrdataset:
            
            # Get the data
            waveform, true_trans, target_indices = data_sample
            
            # Apply the model
            emission = model_tr(waveform)
            emission = emission.squeeze()
            
            #print(true_trans)
            #print(transcript_from_emission(emission, labels))

            # Get the input and target lengths
            input_length = torch.tensor(emission.shape[0], dtype=torch.long)
            target_length = torch.tensor(len(target_indices), dtype=torch.long)
    
            # Compute the CTC loss
            loss = loss_fn(emission, target_indices, input_length, target_length)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            tr_loss += loss.item()
    
        # Normalize and store loss
        tr_loss = tr_loss / len(asrdataset)
        train_losses.append(tr_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {tr_loss:.4f}")

    return model_tr, train_losses

In [None]:
# Training parameters (only one epoch for speed)
num_epochs = 1
learning_rate = 0.1
loss_fn = nn.CTCLoss()

model_tr, train_losses = training_wav2vecASR(model, asrdataset, num_epochs, loss_fn, learning_rate)

## Overfitting on one sample

We observed above that training is slow: this is because the wav2vec model is rather large so the forward pass takes some time (even though the gradients are not tracked), and also because we did not assemble data into batches.

Nonetheless, we can assess that our training pipeline works properly by conducting training on a unique sample, and checking the transcript on the same sample. This technique (= overfitting on one sample) is useful to "crash test" if everything runs and if our model / training pipeline has any chance to work on a larger dataset.

<span style="color:red"> **Exercise 4**</span>. Instanciate a model from scratch, freeze the wav2part, and initialize the classification layer. Build a dataset made up of 1 sample (use the `MAX_FILES` parameter), and conduct training for 50 epoch using this dataset. Once the model is trained, compute the emission matrix from this sample's waveform and the estimated transcript. Display the true and estimated transcripts. Also do the same on another sentence so that you can assess that the model is not able to generalize properly.

In [None]:
# Instanciate a model, freeze the wav2vec part, and initialize the classifier
model = Wav2vecASR(n_labels)
model.wav2vec.apply(freeze_params)
model.output_layer.apply(init_params)

# Build a very small dataset
asrdataset = ASRdataset(data_speech_dir, data_transc_dir, dico_labels, MAX_FILES=1)

# Training
num_epochs = 50
model_tr, train_losses = training_wav2vecASR(model, asrdataset, num_epochs, loss_fn, learning_rate)

# Save the model's parameters
torch.save(model_tr.state_dict(), 'model_wav2vecASR.pt')

In [None]:
# Load the data sample
waveform, true_transcript, _ = asrdataset[0]

# Instanciate the model and load the trained parameters
model = Wav2vecASR(n_labels)
model.load_state_dict(torch.load('model_wav2vecASR.pt', weights_only=True))

# Apply the model and get the estimated transcript
with torch.inference_mode():
    emission = model(waveform)
    est_transcript = transcript_from_emission(emission[0], labels)

# Display the true and estimated transcripts
print('True transcript: ', true_transcript)
print('Estimated transcript: ', est_transcript)

In [None]:
# Try on a different sentence
asrdataset = ASRdataset(data_speech_dir, data_transc_dir, dico_labels, MAX_FILES=2)
waveform, true_transcript, _ = asrdataset[-1]

with torch.inference_mode():
    emission = model(waveform)
    est_transcript = transcript_from_emission(emission[0], labels)

print('True transcript (other sentence): ', true_transcript)
print('Estimated transcript (other sentence): ', est_transcript)