In [None]:
!pip install soundfile
!pip install transformers[torch]
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
import transformers
import pandas as pd
import json
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from zipfile import ZipFile

In [None]:
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments

In [None]:
processor = WhisperProcessor.from_pretrained("openai/whisper-medium.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")

In [None]:
class ASRDataset(Dataset):
    def __init__(self, jsonl_path, data_path):
        self.root_dir = data_path
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
        with open(jsonl_path, "r") as f:
            self.df = pd.DataFrame([json.loads(l) for l in f.readlines()])
            
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, i):
        audio_path = os.path.join(self.root_dir, self.df.at[i, "audio"])
        waveform, sampling_rate = torchaudio.load(audio_path)
        input_features = self.processor(
            waveform.squeeze(),
            sampling_rate=sampling_rate,
            return_tensors="pt",
        ).input_features[0]
        transcript = self.df.at[i, "transcript"]
        tokens = self.processor.tokenizer(transcript).input_ids
        return {"input_features":input_features, "labels": tokens}
    
dataset = ASRDataset("advanced/asr/asr.jsonl", "advanced/asr")
        

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [None]:
dataset[0]["input_features"].shape

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-finetune",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=1,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [None]:
trainer.train()