# Using a pre-trained model for ASR

Wave2Vec 2.0 is a powerful model developed by Facebook for converting speech to text. Here, we'll demonstrate how fine-tuning this model can improve its accuracy, even with minimal additional training. We'll use the example of an audio clip of the word "Visual", which is initially misrecognized by the model.

### Without Fine-Tuning
Let's see how the pre-trained model performs without any fine-tuning in the below steps:
- Load the pre-trained Wave2Vec 2.0 model.
- Read the audio file and converts it into the format the model expects.
- Use the model to predict text from audio without any fine-tuning.


In [28]:
import torch
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

audio_file = 'data/visual.wav'
audio_input, sample_rate = librosa.load(audio_file, sr=16000)

input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values

with torch.no_grad():
    logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]

print("Transcription:", transcription)


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Transcription: FISIAL


### With Fine-Tuning
Now, let's fine-tune the model on the same audio data to see if we can improve the accuracy of the transcription.

#### Setup and Loading Data
We start by setting up the environment and loading our training data (the misrecognized audio file).

In [29]:
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from torch.optim import AdamW
from torch.nn import CTCLoss

# Load processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.train()

# Load your custom audio file and its corresponding transcript
audio_file_path = 'data/visual.wav'
transcript = "visual"

# Load audio
speech_array, sampling_rate = torchaudio.load(audio_file_path)

# Ensure the sampling rate is correct
if sampling_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
    speech_array = resampler(speech_array)


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Preprocessing Audio and Label Data
Convert the raw audio into the model's input format and encode the correct transcription into label format for training.

In [30]:
# Preprocess the audio
input_values = processor(speech_array.squeeze(), sampling_rate=16000, return_tensors="pt").input_values
# Process transcript using batch encoding
with processor.as_target_processor():
    labels = processor([transcript], return_tensors="pt").input_ids

#### Training Configuration
Set up the device for computations, the optimizer for adjusting model weights, and the loss function to evaluate model performance.

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_values = input_values.to(device)
labels = labels.to(device)

# Set up the optimizer and the loss function
optimizer = AdamW(model.parameters(), lr=1e-5)
loss_func = CTCLoss(blank=processor.tokenizer.pad_token_id, zero_infinity=True)

#### Training & Evaluation
Conduct 5 epochs of training. This involves a forward pass (computing the logits), calculating the loss, and updating the model using backpropagation. After training, we use the model to predict the transcription of the audio file again.

In [34]:
num_epochs = 5  # Adjust based on experiment observations

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    logits = model(input_values).logits

    # Compute the input lengths for the CTC loss
    input_lengths = torch.full(size=(logits.size(0),), fill_value=logits.size(1), dtype=torch.long)
    label_lengths = torch.full(size=(labels.size(0),), fill_value=labels.size(1), dtype=torch.long)

    # Calculate loss
    loss = loss_func(logits.transpose(0, 1), labels, input_lengths, label_lengths)

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

    # Evaluation in each epoch
    model.eval()
    with torch.no_grad():
        logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_text = processor.batch_decode(predicted_ids)[0]
    print(predicted_text)

Epoch 1, Loss: -31.09031867980957
FISUAL
Epoch 2, Loss: -30.60826301574707
FISUAL
Epoch 3, Loss: -29.554391860961914
VISUAL
Epoch 4, Loss: -25.355257034301758
VUSUAL
Epoch 5, Loss: -15.027382850646973
VISUAL


This script provides a very basic example of fine-tuning a Wav2Vec2 model on a single data point. We can see that with fine-tuning the model over 5 epochs, it has improved and learnt the correct transct of the audio files. However, fine-tuning on such a specific and small piece of data is for illustrative purposes. 
In a real-world scenario, one would need to manage larger datasets and more sophisticated training routines involving, more epochs, hyper-parameter tuning, validation and possibly early stopping based on performance metrics.
