# **Fine-Tune Whisper For Luganda ASR with 🤗 Transformers**

Adapted from [Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)

In [None]:
# Check for GPU
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print("Not connected to a GPU")
else:
    print(gpu_info)

#### Install Packages

In [2]:
%%capture
!pip install datasets==2.17.0
!pip install transformers==4.48.3
!pip install torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install jiwer # jiwer is used for evaluation using WER and CER
!pip install accelerate -U # Restart runtime after running this cell
!pip install wandb
!pip install librosa
!pip install soundfile
!pip install evaluate
!pip install matplotlib
!pip install soundfile==0.12.0 

#### IMPORT

In [2]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')

#### Huggingface login

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Install Git-LFS to support uploading model weights to huggingface

In [4]:
%%capture
!apt install git-lfs

### Prepare Data, Tokenizer, Feature Extractor
MMS uses the Wav2Vec2CTCTokenizer and the Wav2Vec2FeatureExtractor to process the inputs of the model

#### Create Wav2Vec2CTCTokenizer

In [5]:
from datasets import load_dataset, Audio

In [None]:
# You need to log in on HuggingFace and accept the terms and conditions of the Mozilla Foundation common voice dataset
lg_cv_train = load_dataset("mozilla-foundation/common_voice_7_0", "lg", split="train", trust_remote_code=True)
lg_cv_valid = load_dataset("mozilla-foundation/common_voice_7_0", "lg", split="validation", trust_remote_code=True)
lg_cv_test = load_dataset("mozilla-foundation/common_voice_7_0", "lg", split="test", trust_remote_code=True)

In [None]:
print(lg_cv_train)
print(lg_cv_valid)

In [24]:
lg_cv_train = lg_cv_train.select_columns(['audio', 'Transcriptions', 'duration'])
lg_cv_valid = lg_cv_valid.select_columns(['audio', 'Transcriptions', 'duration'])

In [None]:
lg_cv_valid[0]

In [None]:
cumsum = np.cumsum(lg_cv_train['duration'])
cumsum[cumsum<36_000].shape

In [None]:
sum(lg_cv_train['duration'])/3600

In [None]:
min(lg_cv_train['duration']), max(lg_cv_train['duration'])

In [None]:
sum(lg_cv_train['duration'])/3600

In [30]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
pd.Series(lg_cv_train['duration']).hist(bins =20, color = 'blue', grid =False)
plt.show()

#### Display

In [32]:
# Show samples from the dataset
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML


def show_random_elements(dataset, num_examples= 10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = (random.randint(0, len(dataset)-1))
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(lg_cv_train.remove_columns(['audio']), 5)

In [None]:
# normalize the transcripts. We are not training an orthographic model
def normalize(batch):
    batch['transcription'] = batch['Transcriptions'].lower()
    return batch

lg_cv_train  = lg_cv_train.map(normalize)
lg_cv_valid  = lg_cv_valid.map(normalize)

In [38]:
# Obtain the vocabulary from the dataset
def extract_all_chars(batch):
  all_text = " ".join(batch["transcription"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [None]:
luganda_cv_train_vocab   = lg_cv_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=lg_cv_train.column_names)
luganda_cv_valid_vocab   = lg_cv_valid.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=lg_cv_valid.column_names)

In [40]:
vocab_list = list(set(luganda_cv_train_vocab["vocab"][0])| set(luganda_cv_valid_vocab["vocab"][0]))

In [None]:
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

### Build Model

In [42]:
# Replace the spaces with |
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [None]:
# Add a padding token that corresponds to CTC blank token
vocab_dict['[UNK]'] = len(vocab_dict)
vocab_dict['[PAD]'] = len(vocab_dict)
len(vocab_dict)

In [44]:
# save the vocabulary as a json file
import json
with open('lg_cv_mms.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [45]:
# Load the vocabulary into an instance of the Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./lg_cv_mms.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [None]:
tokenizer.push_to_hub('usernamerepo', private=True)

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('username/repo')

#### Create Wav2Vec2FeatureExtractor

In [48]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [49]:
# Wrap the feature extractor and tokenizer in the Wav2Vec2Processor
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
processor.push_to_hub('username/repo', private=True)

In [None]:
# Play an audio sample from the dataset
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(lg_cv_train)-1)
print(lg_cv_train[rand_int]["transcription"])
ipd.Audio(data=np.asarray(lg_cv_train[rand_int]["audio"]["array"]), autoplay=False, rate=16000)

In [52]:
# Plot a graph of the audio
import matplotlib.pyplot as plt
from itertools import cycle

In [53]:
color_pal = plt.rcParams["axes.prop_cycle"].by_key()["color"]
color_cycle = cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])

In [54]:
y = lg_cv_train[rand_int]["audio"]["array"]

In [None]:
pd.Series(y).plot(figsize=(10, 5),
                  lw=1,
                  title='Raw Audio Example',
                 color=color_pal[0])
plt.show()

In [56]:
# Prepare the data for training
def prepare_dataset(batch):
    audio = batch["audio"]

    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["length"]     = batch["duration"]

    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

In [57]:
lg_cv_train = lg_cv_train.cast_column('audio', Audio(16000))
lg_cv_valid = lg_cv_valid.cast_column('audio', Audio(16000))

In [None]:
# Apply the data preparation function to the data
lg_cv_train = lg_cv_train.map(prepare_dataset, num_proc=1, remove_columns=['audio'])
lg_cv_valid = lg_cv_valid.map(prepare_dataset, num_proc=1, remove_columns=['audio'])

In [None]:
sum(lg_cv_train['duration'])/3600

### Training

Create a special collate function that pads the input values to the maximum length in the batch because MMS has a very long context length

#### Set-up Trainer

In [60]:
# Apply separate padding to the input and the labels
# Set the padding value to -100 so that the tokens are not taken into account in the loss function

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split  inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature['input_values']} for feature in features]
        label_features = [{"input_ids": feature['labels']} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding = self.padding,
            return_tensors = "pt",
        )

        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding = self.padding,
                return_tensors = "pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [61]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [62]:
# Define the wer and cer metrics
import evaluate

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

In [63]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

In [None]:
# Load the MMS-1B
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/mms-1b-all",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.1,
    mask_time_prob=0.05,
    layerdrop=0.1,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True
)

In [70]:
model.config.ctc_zero_infinity = True

In [71]:
# Freeze the feature extractor as it has been sufficiently trained
model.freeze_feature_extractor()

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [73]:
model = model.to(device)

### Baseline Test

In [60]:
input_dict = lg_cv_valid[0]

logits = model(torch.tensor(input_dict["input_values"]).to("cuda").unsqueeze(0)).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

In [None]:
processor.decode(pred_ids)

In [None]:
processor.decode(input_dict["labels"]).lower()

In [None]:
# Evaluation is carried out with a batch size of 1
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)

  return batch

results = lg_cv_valid.map(map_to_result, remove_columns=lg_cv_valid.column_names)

In [None]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

In [None]:
print("Test CER: {:.3f}".format(cer_metric.compute(predictions=results["pred_str"], references=results["text"])))

### Wandb for Logging and Montoring

In [None]:
import wandb

wandb.login()

In [None]:
# Wandb arguments
%env WANDB_LOG_MODEL=end
%env WANDB_PROJECT=ASR Africa
%env WANDB_WATCH=all
%env WANDB_SILENT=true

In [76]:
# Define the training arguments
# use the group_by_length argument to make training more efficient

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./repo",
    group_by_length=True,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2, # increase the effective batch size to 4,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    save_strategy ="steps",
    num_train_epochs=50,
    bf16=True, # mixed precision training
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=3e-4,
    warmup_ratio=0.1,
    save_total_limit=2,
    push_to_hub=True,
    gradient_checkpointing=True,
    report_to="wandb",
    run_name="repo",
    load_best_model_at_end=True,
    metric_for_best_model = "wer",  
    greater_is_better=False,
    hub_private_repo = True,
    torch_compile = True,
    dataloader_num_workers=8,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,
    hub_model_id='username/repo',
    )

In [78]:
# Pass the model, the training arguments and the data collator to the Trainer

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=lg_cv_train,
    eval_dataset=lg_cv_valid,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train(resume_from_checkpoint=False)
wandb.finish()
trainer.push_to_hub('username/repo')

### Test Model

In [None]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    'username/repo',
    ctc_loss_reduction="mean",
    pad_token_id=tokenizer.pad_token_id,
    vocab_size=len(tokenizer),
)

In [87]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = model.to(device)

In [89]:
input_dict = lg_cv_test[0]

logits = model(torch.tensor(input_dict["input_values"]).to("cuda").unsqueeze(0)).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

In [None]:
processor.decode(pred_ids)

In [None]:
processor.decode(input_dict["labels"]).lower()

In [None]:
# Evaluation is carried out with a batch size of 1
def map_to_result(batch):
    model.eval()
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits
    
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
    # batch["length"] = batch["duration"]
    
    
    return batch

results = lg_cv_test.map(map_to_result, remove_columns=lg_cv_test.column_names)

In [None]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

In [None]:
print("Test CER: {:.3f}".format(cer_metric.compute(predictions=results["pred_str"], references=results["text"])))

In [97]:
def calculate_wer_cer(batch):
    batch["WER"] = wer_metric.compute(references=[batch["text"]], predictions=[batch["pred_str"]])
    batch["CER"] = cer_metric.compute(references=[batch["text"]], predictions=[batch["pred_str"]])
    return batch

In [None]:
# Calculate WER and CER for the test set
results = results.map(calculate_wer_cer, num_proc=8)

In [None]:
show_random_elements(results)

In [100]:
df = results.to_pandas()

In [102]:
df.to_csv('results.csv', index =False)