<a href="https://colab.research.google.com/github/1pawn0/persian-speech-to-text-via-whisper/blob/main/whisper_finetune_persian.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
torch_ver = torch.__version__.split('+')[0].split('.')
if torch_ver[1] == '8':
    print('Installing torchcodec 0.7')
    %pip install torchcodec==0.7
    import torchcodec
elif torch_ver[1] == '9':
    print('Installing torchcodec 0.8')
    %pip install torchcodec==0.8
    import torchcodec

In [None]:
import os, gc
import httpx
import torch, torchcodec
import polars as pl
from torch.utils.data import Dataset, TensorDataset, DataLoader
from pathlib import Path
from tqdm.notebook import tqdm
from google.colab import userdata
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
)
model_name = "aictsharif/whisper-base-fa"

Model Docs: https://huggingface.co/docs/transformers/main/en/model_doc/whisper

https://huggingface.co/blog/fine-tune-whisper#fine-tuning-whisper-in-a-google-colab

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(model_name)
fe = WhisperFeatureExtractor.from_pretrained(model_name)
tok = WhisperTokenizer.from_pretrained(model_name, language="persian", task="transcribe")


In [None]:
# @title download and extract the dataset

base_url = "https://datacollective.mozillafoundation.org/api"
api_key = userdata.get("MOZILLA_API_KEY")
client_id = userdata.get("MOZILLA_CLIENT_ID")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
res: dict = httpx.post(f"{base_url}/datasets/cmflnuzw5j6l58skwzpc4ze0q/download", headers=headers).json()
filename = res["filename"]
download_url = res["downloadUrl"]
filesize = res["sizeBytes"]
content_type = res["contentType"]
expires_at = res["expiresAt"]
one_MB = int(2**20)
!wget --header="Authorization: Bearer {api_key}" -O "{filename}" "{download_url}"
print("Be Patient!")
!tar -xzf "{filename}"
print("The dataset has extracted.")
ds_path = Path('./cv-corpus-23.0-2025-09-05/fa')
for fpath in ds_path.iterdir():
    print(fpath)

In [None]:
# @title Run it if you already have the tarfile in your google drive
!tar -xzf "/content/drive/MyDrive/tmp/mcv-scripted-fa-v23.0.tar.gz" -C "/content/"

In [None]:
# @title read tsv files using polars

train = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/train.tsv",
    separator="\t",
    use_pyarrow=True,
).with_row_index()
validated = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/validated.tsv",
    separator="\t",
    use_pyarrow=True,
)
validated_sentences = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/validated.tsv",
    separator="\t",
    use_pyarrow=True,
)

other = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/other.tsv",
    separator="\t",
    use_pyarrow=True,
)
dev = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/dev.tsv",
    separator="\t",
    use_pyarrow=True,
)
clip_durations = pl.read_csv(
    source="cv-corpus-23.0-2025-09-05/fa/clip_durations.tsv",
    separator="\t",
    use_pyarrow=True,
)


In [None]:
audio_files_path = Path("./cv-corpus-23.0-2025-09-05/fa/clips")
waveforms = []
batch_features = []
sentences = []
sample_rate_target = 16000
max_duration = 30

for i in tqdm(range(len(train))):
    if i > 100:
        break
    fpath = audio_files_path / train["path"][i]
    sentence = train["sentence"][i]
    input_ids = tok.__call__(
        sentence,
        return_tensors="pt",
        return_attention_mask=False,
    ).input_ids
    waveform, pts, duration, sample_rate = torchcodec.decoders.AudioDecoder(
        source=fpath,
        sample_rate=16000,
        num_channels=1,
    ).get_all_samples()

    waveform_30sec = waveform.squeeze(0)[: sample_rate_target * max_duration]

    features: BatchFeature = fe.__call__(
        raw_speech=waveform_30sec,
        return_tensors="pt",
        sampling_rate=16000,
        do_normalize=True,
        return_attention_mask=False,
    )
    batch_features.append(features.input_features)
    waveforms.append(waveform_30sec)
    sentences.append(input_ids)


In [None]:
from torch.nn.utils.rnn import pad_sequence

all_features = torch.cat(batch_features, dim=0)

all_input_ids = pad_sequence([ids.squeeze(0) for ids in sentences], batch_first=True, padding_value=tok.pad_token_id)

dataset = TensorDataset(all_features, all_input_ids)


In [None]:
gc.collect()