In [None]:
LOCAL=True

In [None]:
if not LOCAL:
    !git clone https://github.com/ZaraGiraffe/MangoASR.git
    %cd MangoASR
    !pip install --upgrade transformers datasets evaluate huggingface_hub jiwer accelerate
else:
    %load_ext autoreload
    %autoreload 2

In [1]:
from datasets import load_dataset, Audio
import huggingface_hub as hub
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from evaluate import load

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

from utils.loaders import get_common_voice
from utils.collators import WhisperTrainCollator
from utils.trainers import MangoTrainer, TrainerConfig
from utils.wrappers import WhisperAsrWrapperModel, WhisperAsrWrapperConfig
from utils.metrics import ComputeStringSimilarityMetricsFunction

Get and process the dataset

In [3]:
write_hf_token = "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"

In [4]:
access_token = hub.login(write_hf_token, add_to_git_credential=True)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (manager).
Your token has been saved to C:\Users\znaum\.cache\huggingface\token
Login successful


In [5]:
common_voice_uk = get_common_voice('uk')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [6]:
common_voice_uk = common_voice_uk.cast_column("audio", Audio(sampling_rate=16000))

Get the whisper model and processor  
Also we need to wrap the model for the trainer

In [7]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
processor.tokenizer.set_prefix_tokens(language="uk", task="transcribe")

In [8]:
wraped_model_config = WhisperAsrWrapperConfig()
wrapped_model = WhisperAsrWrapperModel(model)

Initialize the loaders

In [9]:
collator = WhisperTrainCollator(
    processor=processor,
    device="cuda",
)

In [10]:
train_loader = DataLoader(common_voice_uk["train"].shard(num_shards=200, index=0), batch_size=4, collate_fn=collator)
eval_loader = DataLoader(common_voice_uk["test"].shard(num_shards=300, index=0), batch_size=4, collate_fn=collator)

Initialise optimizers

In [11]:
optim = AdamW(model.parameters(), lr=0.0001)
scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 0.95 ** epoch)

Train the model

In [12]:
trainer_config = TrainerConfig(
    model_name="whisper_asr_1.1",
    save_strategy="epoch",
)
trainer = MangoTrainer(
    model=wrapped_model,
    train_loader=train_loader,
    eval_loader=eval_loader,
    config=trainer_config,
    optimizer=optim,
    scheduler=scheduler,
)

In [13]:
wer = load("wer")
cer = load("cer")
compute_metrics = ComputeStringSimilarityMetricsFunction(
    processor=processor,
    wer=wer,
    cer=cer,
)

In [14]:
trainer.train(1, compute_metrics=compute_metrics)

train:   0%|          | 0/16 [00:00<?, ?it/s]

eval:   0%|          | 0/6 [00:00<?, ?it/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

events.out.tfevents.1707287482.zarawindows.21460.0:   0%|          | 0.00/444 [00:00<?, ?B/s]

events.out.tfevents.1707287482.zarawindows.21460.1:   0%|          | 0.00/476 [00:00<?, ?B/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

Save the model