In [1]:
LOCAL=True

In [2]:
if not LOCAL:
    !git clone https://github.com/ZaraGiraffe/MangoASR.git
    %cd MangoASR
    !pip install --upgrade transformers datasets evaluate huggingface_hub jiwer accelerate
else:
    %load_ext autoreload
    %autoreload 2

In [3]:
from datasets import load_dataset, Audio
import huggingface_hub as hub
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizerFast
from evaluate import load

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR
import numpy as np

from utils.loaders import get_common_voice
from utils.collators import WhisperTrainCollator
from utils.trainers import MangoTrainer, TrainerConfig
from utils.wrappers import WhisperAsrWrapperModel, WhisperAsrWrapperConfig
from utils.metrics import ComputeStringSimilarityMetricsFunction

Get and process the dataset

In [4]:
write_hf_token = "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"

In [5]:
access_token = hub.login(write_hf_token, add_to_git_credential=True)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (manager).
Your token has been saved to C:\Users\znaum\.cache\huggingface\token
Login successful


In [6]:
common_voice_uk = get_common_voice('uk')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [7]:
common_voice_uk = common_voice_uk.cast_column("audio", Audio(sampling_rate=16000))

Get the whisper model and processor  
Also we need to wrap the model for the trainer

In [8]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
processor.tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-base")
processor.tokenizer.set_prefix_tokens(language="uk", task="transcribe")

In [9]:
wrapped_model_config = WhisperAsrWrapperConfig(
    pad_token_id = -100,
)
wrapped_model = WhisperAsrWrapperModel(model, config=wrapped_model_config)

Initialize the loaders

In [10]:
collator = WhisperTrainCollator(
    processor=processor,
    device="cuda",
    pad_token_id=-100,
)

In [11]:
train_loader = DataLoader(
    common_voice_uk["train"].shard(num_shards=200, index=0), 
    batch_size=4, 
    collate_fn=collator,
)
eval_loader = DataLoader(
    common_voice_uk["test"].shard(num_shards=300, index=0), 
    batch_size=4, 
    collate_fn=collator
)

Initialise optimizers

In [12]:
optim = AdamW(model.parameters(), lr=0.0001)
scheduler = CyclicLR(
    optim, 
    base_lr=0.0001,
    max_lr=0.01,
    mode="exp_range",
)

Train the model

In [13]:
trainer_config = TrainerConfig(
    model_name="whisper_asr_1.1",
    save_strategy="epoch",
    hf_user="Zarakun",
    gradient_accumulation_steps=16,
)
trainer = MangoTrainer(
    model=wrapped_model,
    train_loader=train_loader,
    eval_loader=eval_loader,
    config=trainer_config,
    optimizer=optim,
    scheduler=scheduler,
)

In [14]:
wer = load("wer")
cer = load("cer")
compute_metrics = ComputeStringSimilarityMetricsFunction(
    processor=processor,
    wer=wer,
    cer=cer,
    pad_token_id=50257,
)

In [15]:
trainer.train(20, compute_metrics=compute_metrics)

train:   0%|          | 0/80 [00:00<?, ?it/s]

eval:   0%|          | 0/30 [00:00<?, ?it/s]

events.out.tfevents.1707301782.zarawindows.7264.1:   0%|          | 0.00/525 [00:00<?, ?B/s]

events.out.tfevents.1707301782.zarawindows.7264.0:   0%|          | 0.00/444 [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

events.out.tfevents.1707301782.zarawindows.7264.0:   0%|          | 0.00/816 [00:00<?, ?B/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1707301782.zarawindows.7264.0:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1707301782.zarawindows.7264.0:   0%|          | 0.00/1.56k [00:00<?, ?B/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.pt:   0%|          | 0.00/290M [00:00<?, ?B/s]

events.out.tfevents.1707301782.zarawindows.7264.0:   0%|          | 0.00/1.93k [00:00<?, ?B/s]