<a href="https://colab.research.google.com/github/1pawn0/persian-speech-to-text-via-whisper/blob/main/whisper_finetune_persian.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
torch_ver = torch.__version__.split('+')[0].split('.')
if torch_ver[1] == '8':
    print('Installing torchcodec 0.7')
    %pip install torchcodec==0.7
    import torchcodec
elif torch_ver[1] == '9':
    print('Installing torchcodec 0.8')
    %pip install torchcodec==0.8
    import torchcodec

In [None]:
import os, gc
import httpx
import torch, torchcodec
import polars as pl
from torch.utils.data import Dataset, TensorDataset, DataLoader
from pathlib import Path
from tqdm.notebook import tqdm
from google.colab import userdata
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
)
model_name = "aictsharif/whisper-base-fa"

Model Docs: https://huggingface.co/docs/transformers/main/en/model_doc/whisper

https://huggingface.co/blog/fine-tune-whisper#fine-tuning-whisper-in-a-google-colab

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(model_name)
fe = WhisperFeatureExtractor.from_pretrained(model_name)
tok = WhisperTokenizer.from_pretrained(model_name, language="persian", task="transcribe")


In [None]:
# @title download and extract the dataset

base_url = "https://datacollective.mozillafoundation.org/api"
api_key = userdata.get("MOZILLA_API_KEY")
client_id = userdata.get("MOZILLA_CLIENT_ID")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
res: dict = httpx.post(f"{base_url}/datasets/cmflnuzw5j6l58skwzpc4ze0q/download", headers=headers).json()
filename = res["filename"]
download_url = res["downloadUrl"]
filesize = res["sizeBytes"]
content_type = res["contentType"]
expires_at = res["expiresAt"]
one_MB = int(2**20)
!wget --header="Authorization: Bearer {api_key}" -O "{filename}" "{download_url}"
print("Be Patient!")
!tar -xzf "{filename}"
print("The dataset has extracted.")
ds_path = Path('./cv-corpus-23.0-2025-09-05/fa')
for fpath in ds_path.iterdir():
    print(fpath)

In [None]:
# @title Run it if you already have the tarfile in your google drive
!tar -xzf "/content/drive/MyDrive/tmp/mcv-scripted-fa-v23.0.tar.gz" -C "/content/"

In [None]:
# @title read tsv files using polars

train = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/train.tsv",
    separator="\t",
    use_pyarrow=True,
)
validated = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/validated.tsv",
    separator="\t",
    use_pyarrow=True,
)
validated_sentences = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/validated.tsv",
    separator="\t",
    use_pyarrow=True,
)

other = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/other.tsv",
    separator="\t",
    use_pyarrow=True,
)
dev = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/dev.tsv",
    separator="\t",
    use_pyarrow=True,
)
clip_durations = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/clip_durations.tsv",
    separator="\t",
    use_pyarrow=True,
)

# max_chars = train["sentence"].str.len_chars().quantile(.99) # = 63
train = train.filter(pl.col("sentence").str.len_chars() < 64).with_row_index()

train = (
    train.join(clip_durations, left_on="path", right_on="clip", how="left")
    .with_columns(pl.col("duration[ms]").truediv(1000).alias("duration"))
    .drop("duration[ms]")
)


In [None]:
# @title find tokenized input_ids' max_length
# # Tokenize all of the input sentences to find the highest number of input_ids length
# target_labels = tok.__call__(
#     train["sentence"].to_list(),
#     return_tensors="pt",
#     return_attention_mask=False,
#     padding=True,
#     truncation=True,
# ).input_ids.squeeze(0)

# print(target_labels.shape) # torch.Size([29547, 52]), therefore the max_length is `52`


In [None]:
class AudioDataset(Dataset):
    def __init__(
        self,
        df: pl.DataFrame,
        feature_extractor: WhisperFeatureExtractor,
        tokenizer: WhisperTokenizer,
        audio_files_dir: Path,
        target_sample_rate: int = 16000,
        max_duration: int = 30,
    ):
        self.df = df
        self.fe = feature_extractor
        self.tok = tokenizer
        self.fdir = audio_files_dir
        self.sr = target_sample_rate
        self.duration = max_duration

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        waveform: torch.Tensor = (
            torchcodec.decoders.AudioDecoder(
                source=self.fdir / self.df["path"][idx],
                sample_rate=self.sr,
                num_channels=1,
            )
            .get_all_samples()
            .data.squeeze(0)[: self.sr * self.duration]
        )

        features: BatchFeature = self.fe.__call__(
            raw_speech=waveform,
            return_tensors="pt",
            sampling_rate=self.sr,
            do_normalize=True,
            return_attention_mask=False,
        ).input_features.squeeze(0)

        target_label: torch.Tensor = self.tok.__call__(
            self.df["sentence"][idx],
            return_tensors="pt",
            return_attention_mask=False,
            # precalculated_longest_input_ids_tensor_length = 52
            max_length=52,
            padding="max_length",
            truncation=True,
        ).input_ids.squeeze(0)

        return features, target_label


audio_files_path = Path("./cv-corpus-23.0-2025-09-05/fa/clips")
ds = AudioDataset(train, fe, tok, audio_files_path)
loader = DataLoader(
    ds,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
)


In [None]:
for Xb, yb in loader:
    print(f"Xb: {Xb.shape}")
    print(f"yb: {yb.shape}")
    break