In [1]:
!pip install jiwer openai-whisper torchcodec
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from transformers import AutoFeatureExtractor, WhisperModel
from transformers import LogitsProcessorList, EpsilonLogitsWarper

from transformers import LlamaTokenizer
from datasets import load_dataset
import torch, torchaudio
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from jiwer import wer
import pickle
from datasets import Dataset, Audio, Value
import os, random
from typing import Optional
from whisper.normalizers import EnglishTextNormalizer
import math
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
from pathlib import Path
import whisper
import copy, heapq
import pandas as pd
import random
from tqdm import tqdm

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting openai-whisper
  Downloading openai_whisper-20250625.tar.gz (803 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/803.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.2/803.2 kB[0m [31m51.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torchcodec
  Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from google.colab import drive
from google.colab import userdata
userdata.get('HF_TOKEN')
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-tiny', language='en', task='transcribe')
tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-tiny', language='en', task='transcribe')
processor = WhisperProcessor.from_pretrained('openai/whisper-tiny', language='en', task='transcribe')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
normalizer = EnglishTextNormalizer()


preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

In [4]:
def create_val_dataset(dataset):
  random.seed(42)

  val_ratio = 0.1
  n = len(dataset)
  n_val = int(n * val_ratio)

  indices = list(range(n))
  random.shuffle(indices)

  val_idx = set(indices[:n_val])
  val_list = [dataset[i] for i in range(n) if i in val_idx]
  train_list2 = [dataset[i] for i in range(n) if i not in val_idx]

  print(len(train_list2), len(val_list))
  return train_list2, val_list


In [15]:
def train(model, train_dataset,val_dataset, device, num_epochs=4, batch_size=1, accum_steps=16):
    # --- CONFIGURATION ---
    # Gradient Checkpointing is REMOVED to avoid graph errors.
    # We still disable caching because we don't need past key/values during training.
    model.config.use_cache = False
    model.config.output_attentions = False

    # ---------------------

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')

    print(f" Phase 2: Training on {len(train_dataset)} samples")
    print(f" Config: BS={batch_size}, Accum={accum_steps}, Checkpointing=OFF")

    for epoch in range(num_epochs):
        model.train()
        random.shuffle(train_dataset)
        epoch_loss = 0
        optimizer.zero_grad()
        step_count = 0

        # Iterate through data
        for i in tqdm(range(0, len(train_dataset), batch_size)):
            batch = train_dataset[i:i + batch_size]

            # 1. Prepare Inputs
            # Stack features (BS, 80, 3000)
            mels = torch.cat([item['mel'].unsqueeze(0) for item in batch]).to(device)

            # Pad labels
            labels = pad_sequence([item['labels'] for item in batch],
                                  batch_first=True, padding_value=-100).to(device)

            # 2. Forward Pass (Mixed Precision)
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(input_features=mels, labels=labels)
                loss = outputs.loss
                loss = loss / accum_steps
                loss.backward()

            epoch_loss += loss.item() / accum_steps
            step_count += 1
            if step_count % accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

            # Optional: Clear cache occasionally if VRAM is tight
            # if i % 50 == 0: torch.cuda.empty_cache()

        avg_loss = epoch_loss / len(train_dataset)
        wer = evaluate(model, val_dataset, processor, feature_extractor, device)
        print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}")
        print(f"Validation WER: {wer * 100:.2f}%")

    return model
def calculate_wer(reference, hypothesis):
    return wer(reference, hypothesis)

def evaluate(model, dataset, processor, feature_extractor, device):
    model.eval()
    preds, targets = [], []
    print(" Phase 3: Evaluating ")

    for item in tqdm(dataset):
        if 'mel' in item:
            # Use pre-extracted features for pseudo-labeled data
            mel = item['mel'].unsqueeze(0).to(device) # Add batch dimension
            text_label = item['text']
        elif 'audio' in item:
            # Extract features from raw audio for real-labeled data
            audio_array = item['audio']['array']
            sr = item['audio']['sampling_rate']
            inputs = feature_extractor(audio_array, sampling_rate=sr, return_tensors="pt")
            mel = inputs.input_features.to(device)
            text_label = item['text']
        else:
            raise ValueError("Dataset item must contain either 'audio' or 'input_features'.")

        with torch.no_grad():
            gen_ids = model.generate(mel, max_new_tokens=225)
            pred_text = processor.batch_decode(gen_ids, skip_special_tokens=True)[0]

            preds.append(pred_text.lower())
            targets.append(text_label.lower())

    return calculate_wer(targets, preds)


### Zero-shot Model

In [None]:
test_files = [f for f in os.listdir(f'/content/drive/MyDrive/data/') if f.startswith("test_") and f.endswith(".pt")]
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
model.to(device)
for t_file in test_files:
    test_path = os.path.join(f'/content/drive/MyDrive/data/', t_file)
    test_dataset = torch.load(test_path, weights_only=False)

    accent_name = t_file.replace("test_", "").replace(".pt", "")
    print(f"Evaluating {accent_name}...")

    score = evaluate(model, test_dataset, processor, feature_extractor, device)
    print(f"WER for {accent_name}: {score * 100:.2f}%")


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

Evaluating in...
 Phase 3: Evaluating 


  0%|          | 0/100 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 100/100 [00:13<00:00,  7.48it/s]


WER for in: 41.06%
Evaluating au...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:09<00:00, 10.48it/s]


WER for au: 20.58%
Evaluating sa...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.69it/s]


WER for sa: 34.98%
Evaluating sg...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.34it/s]

WER for sg: 40.48%





### Real-Labels dataset

In [8]:
def train_real_labels(model,train_dataset,val_dataset,tokenizer,feature_extractor,device,num_epochs=4,batch_size=1,learning_rate=1e-5,gradient_accumulation_steps=16):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    best_wer = float("inf")
    best_state_dict = copy.deepcopy(model.state_dict())
    for epoch in range(num_epochs):
        model.train()
        random.shuffle(train_dataset)

        optimizer.zero_grad()
        for step, i in enumerate(range(0, len(train_dataset), batch_size)):
            batch = train_dataset[i:i + batch_size]

            input_features_batch = []
            for item in batch:
                audio_array = item["audio"]["array"]
                sampling_rate = item["audio"]["sampling_rate"]
                input_features_batch.append(feature_extractor(audio_array, sampling_rate=sampling_rate).input_features[0])

            padded_input_features_dicts = [{"input_features": feat_tensor} for feat_tensor in input_features_batch]
            feats_padded = feature_extractor.pad(padded_input_features_dicts, return_tensors="pt")["input_features"]

            labels_batch = []
            for item in batch:
                tokenized_labels = tokenizer(item["text"], add_special_tokens=True).input_ids
                labels_batch.append(torch.tensor(tokenized_labels, dtype=torch.long))

            labels_padded = pad_sequence(
                labels_batch,
                batch_first=True,
                padding_value=tokenizer.pad_token_id
            )
            labels_padded[labels_padded == tokenizer.pad_token_id] = -100

            feats_padded = feats_padded.to(device)
            labels_padded = labels_padded.to(device)

            outputs = model(input_features=feats_padded, labels=labels_padded)
            loss = outputs.loss
            loss = loss / gradient_accumulation_steps
            loss.backward()

            if step % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
        val_wer = evaluate_real(model, val_dataset, tokenizer, device)
        print(f"Epoch {epoch + 1}/{num_epochs}, Validation WER: {val_wer:.4f}")

        if val_wer < best_wer:
            best_wer = val_wer
            best_state_dict = copy.deepcopy(model.state_dict())
            print(f"  -> New best model (WER={best_wer:.4f}), saving in memory")

    model.load_state_dict(best_state_dict)
    print(f"Training done. Best WER: {best_wer:.4f}")
    return model

def evaluate_real(model,dataset,tokenizer,device,batch_size=16):
    model.eval()
    wer_scores = []

    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i + batch_size]

        input_features_batch = []
        for sample in batch:
            audio_array = sample["audio"]["array"]
            sampling_rate = sample["audio"]["sampling_rate"]
            input_features_batch.append(feature_extractor(audio_array, sampling_rate=sampling_rate).input_features[0])

        padded_input_features_dicts = [{"input_features": feat_tensor} for feat_tensor in input_features_batch]
        feats_padded = feature_extractor.pad(padded_input_features_dicts, return_tensors="pt")["input_features"].to(device)

        with torch.no_grad():
            generated_ids = model.generate(feats_padded)

        transcriptions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        for transcription, sample in zip(transcriptions, batch):
            transcription = transcription.lower().strip()
            reference = sample["text"].lower().strip()
            wer_score = calculate_wer(reference, transcription)
            wer_scores.append(wer_score)

    return np.mean(wer_scores)

In [None]:
training = torch.load(f'/content/drive/MyDrive/data/real_labels/train.pt', weights_only=False)
train_list = training
validation = torch.load(f'/content/drive/MyDrive/data/real_labels/validation.pt', weights_only=False)
val_list = validation
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
model.to(device)
trained_model = train_real_labels(model, train_list,val_list,tokenizer,feature_extractor, device)
torch.save(trained_model.state_dict(), f'/content/drive/MyDrive/data/real_labels/real_trained_model.pth')

In [None]:
test_files = [f for f in os.listdir(f'/content/drive/MyDrive/data/') if f.startswith("test_") and f.endswith(".pt")]
for t_file in test_files:
    test_path = os.path.join(f'/content/drive/MyDrive/data/', t_file)
    test_dataset = torch.load(test_path, weights_only=False)

    accent_name = t_file.replace("test_", "").replace(".pt", "")
    print(f"Evaluating {accent_name}...")

    score = evaluate(trained_model, test_dataset,tokenizer,feature_extractor, device)
    print(f"WER for {accent_name}: {score * 100:.2f}%")


Evaluating in...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:11<00:00,  8.38it/s]


WER for in: 30.15%
Evaluating au...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.20it/s]


WER for au: 16.37%
Evaluating sa...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:12<00:00,  8.28it/s]


WER for sa: 26.03%
Evaluating sg...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:12<00:00,  7.87it/s]

WER for sg: 36.24%





### Pseudo-labels dataset

In [6]:
training, val_list = create_val_dataset(torch.load(f'/content/drive/MyDrive/data/pseudo/train_pseudo_only.pt', weights_only=False))
torch.save(training, f'/content/drive/MyDrive/data/pseudo/train_list.pt')
torch.save(val_list, f'/content/drive/MyDrive/data/pseudo/val_list.pt')

7200 800


In [16]:
training = torch.load(f'/content/drive/MyDrive/data/pseudo/train_list.pt')
val_list = torch.load(f'/content/drive/MyDrive/data/pseudo/val_list.pt')
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
model.to(device)
trained_model = train(model, training,val_list, device)
torch.save(trained_model.state_dict(), f'/content/drive/MyDrive/data/pseudo/pseudo_trained_model.pth')


 Phase 2: Training on 7200 samples
 Config: BS=1, Accum=16, Checkpointing=OFF


100%|██████████| 7200/7200 [05:46<00:00, 20.80it/s]


 Phase 3: Evaluating 


  0%|          | 0/800 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 800/800 [01:25<00:00,  9.34it/s]


Epoch 1/4 - Avg Loss: 0.0020
Validation WER: 25.61%


100%|██████████| 7200/7200 [05:27<00:00, 22.00it/s]


 Phase 3: Evaluating 


100%|██████████| 800/800 [01:26<00:00,  9.28it/s]


Epoch 2/4 - Avg Loss: 0.0012
Validation WER: 27.02%


100%|██████████| 7200/7200 [05:27<00:00, 21.96it/s]


 Phase 3: Evaluating 


100%|██████████| 800/800 [01:27<00:00,  9.16it/s]


Epoch 3/4 - Avg Loss: 0.0007
Validation WER: 28.61%


100%|██████████| 7200/7200 [05:27<00:00, 22.01it/s]


 Phase 3: Evaluating 


100%|██████████| 800/800 [01:25<00:00,  9.34it/s]


Epoch 4/4 - Avg Loss: 0.0005
Validation WER: 27.10%


In [20]:
test_files = [f for f in os.listdir(f'/content/drive/MyDrive/data/test') if f.startswith("test_") and f.endswith(".pt")]
for t_file in test_files:
    test_path = os.path.join(f'/content/drive/MyDrive/data/test/', t_file)
    test_dataset = torch.load(test_path, weights_only=False)

    accent_name = t_file.replace("test_", "").replace(".pt", "")
    print(f"Evaluating {accent_name}...")

    score = evaluate(trained_model, test_dataset, processor, feature_extractor, device)
    print(f"WER for {accent_name}: {score * 100:.2f}%")


Evaluating in...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.27it/s]


WER for in: 39.82%
Evaluating au...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:09<00:00, 10.32it/s]


WER for au: 23.35%
Evaluating sa...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.23it/s]


WER for sa: 41.52%
Evaluating sg...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:10<00:00,  9.80it/s]

WER for sg: 43.64%



