# Using a pre-trained model for ASR

Wave2Vec 2.0 is a powerful model developed by Facebook for converting speech to text. Here, we'll demonstrate fine-tuning this modelwith minimal additional training

In [16]:
import jsonlines
import torchaudio
from datasets import Dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments
from pathlib import Path
import torch
import librosa
import IPython.display as ipd

### Using Wav2Vec2 to transcribe an audio file

In [17]:
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

audio_file = 'data/audio_1.wav'
audio_input, sample_rate = librosa.load(audio_file, sr=16000)

input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values

with torch.no_grad():
    logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]

print("Transcription:", transcription)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Transcription: HEADING HIS TWO STICK FERO TARGATIVE BLACK WHITE AND YELLOW COMMERCIAL AIR CRAFT TOOLED TO DEPOY IN CIRCUS AIR MISSILF


In [18]:
# Play the loaded audio file
audio_data, sampling_rate = librosa.load(audio_file, sr=None)
waveform, sample_rate = torchaudio.load(audio_file)
ipd.Audio(waveform, rate=sampling_rate)

## Fine-tuning

#### Setup and Loading Data
We start by setting up the environment and loading our training data.

In [19]:
# Define the path to the directory
data_dir = Path("data")

# Read data from a jsonl file and reformat it
data = {'key': [], 'audio': [], 'transcript': []}
with jsonlines.open(data_dir / "asr.jsonl") as reader:
    for obj in reader:
        if len(data['key']) < 3:  # Only keep the first 3 entries
            for key, value in obj.items():
                data[key].append(value)

# Convert to a Hugging Face dataset
dataset = Dataset.from_dict(data)
train_dataset = dataset  # Use all entries for training

In [20]:
data

{'key': [0, 1, 2],
 'audio': ['audio_0.wav', 'audio_1.wav', 'audio_2.wav'],
 'transcript': ['Heading is one five zero, target is green commercial aircraft, tool to deploy is electromagnetic pulse.',
  'Heading is two six zero, target is black, white, and yellow commercial aircraft, tool to deploy is surface-to-air missiles.',
  'Heading is one zero five, target is silver, green, and yellow light aircraft, tool to deploy is anti-air artillery.']}

#### Load Pretrained Model
Load the pretrained Wav2Vec2 model and its associated processor from the Hugging Face model hub.

In [21]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Preprocessing Audio and Label Data
Below, we define a preprocessing function preprocess_data that takes an input dictionary of examples (containing audio paths and transcripts) and preprocesses them for training. It loads audio files, processes them using the Wav2Vec2 processor, creates attention masks, and pads the labels to match the input length.

In [22]:
# Function to load and preprocess audio
def preprocess_data(examples):
    input_values = []
    attention_masks = []
    labels = []

    for audio_path, transcript in zip(examples['audio'], examples['transcript']):
        speech_array, sampling_rate = torchaudio.load(data_dir / audio_path)
        processed = processor(speech_array.squeeze(0), sampling_rate=sampling_rate, return_tensors="pt", padding=True)

        # Process labels with the same processor settings
        with processor.as_target_processor():
            label = processor(transcript, return_tensors="pt", padding=True)

        input_values.append(processed.input_values.squeeze(0))
        # Create attention masks based on the input values
        attention_mask = torch.ones_like(processed.input_values)
        attention_mask[processed.input_values == processor.tokenizer.pad_token_id] = 0  # Set padding tokens to 0
        attention_masks.append(attention_mask.squeeze(0))
        
        # Ensure labels are padded to the same length as inputs if needed
        padded_label = torch.full(processed.input_values.shape[1:], -100, dtype=torch.long)
        actual_length = label.input_ids.shape[1]
        padded_label[:actual_length] = label.input_ids.squeeze(0)
        labels.append(padded_label)

    # Concatenate all batches
    examples['input_values'] = torch.stack(input_values)
    examples['attention_mask'] = torch.stack(attention_masks)
    examples['labels'] = torch.stack(labels)

    return examples


#### Training Configuration
Define the training arguments for the Trainer, including the output directory, evaluation strategy, learning rate, batch size, number of epochs, and other training settings

In [23]:
# Apply preprocessing
train_dataset = train_dataset.map(preprocess_data, batched=True, batch_size=1, remove_columns=train_dataset.column_names)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    learning_rate=1e-4,
    per_device_train_batch_size=1,  # Reduce to one for simplicity
    num_train_epochs=3,
    weight_decay=0.005,
    save_steps=500,
    eval_steps=500,
    logging_steps=10,
    load_best_model_at_end=True
)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]



#### Training & Evaluation
Conduct 5 epochs of training

In [25]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=processor.feature_extractor
)

trainer.train()



  0%|          | 0/9 [00:00<?, ?it/s]

{'train_runtime': 16.8182, 'train_samples_per_second': 0.535, 'train_steps_per_second': 0.535, 'train_loss': 1743.306857638889, 'epoch': 3.0}


TrainOutput(global_step=9, training_loss=1743.306857638889, metrics={'train_runtime': 16.8182, 'train_samples_per_second': 0.535, 'train_steps_per_second': 0.535, 'train_loss': 1743.306857638889, 'epoch': 3.0})

This script provides a very basic example of fine-tuning a Wav2Vec2 model on a few data points. 
When fine-tuning a complex model like Wav2Vec 2.0 on an extremely limited dataset the model's performance is likely to be highly unpredictable and generally poor. Given the complexity and the depth of models like Wav2Vec 2.0, they require substantial data to adapt their pre-trained knowledge to new tasks or domains effectively. In a real-world scenario, one would need to manage larger datasets and more sophisticated training routines involving, more epochs, consider freezing some layers, hyper-parameter tuning, validation and possibly early stopping based on performance metrics.
