# Whisper ASR Fine-Tuning Pipeline for Punjabi (Indic-ASR)

In [None]:
# 📦 1. Imports
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset, Audio
import torch
import torchaudio
import evaluate
import numpy as np
import os

In [None]:
# 🧠 2. Environment and GPU Check
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# 📁 3. Load Indic-ASR Dataset for Punjabi
try:
    dataset = load_dataset("ai4bharat/indic-asr", "pa", split="train[:1%]")  # use 1% for test run
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    print("✅ Dataset loaded successfully")
except Exception as e:
    print("❌ Failed to load dataset:", str(e))
    raise

In [None]:
# 🧩 4. Load Whisper Model and Processor
try:
    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
    model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="punjabi", task="transcribe")
    model.config.suppress_tokens = []
    print("✅ Whisper model and processor loaded")
except Exception as e:
    print("❌ Failed to load model:", str(e))
    raise

In [None]:
# 🧼 5. Preprocess Dataset
def prepare_dataset(batch):
    try:
        audio = batch["audio"]
        inputs = processor(audio["array"], sampling_rate=16000)
        batch["input_features"] = inputs.input_features[0]
        with processor.as_target_processor():
            batch["labels"] = processor(batch["text"]).input_ids
    except Exception as e:
        print("Preprocessing error:", str(e))
        batch["input_features"] = []
        batch["labels"] = []
    return batch

In [None]:
try:
    dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)
    print("✅ Preprocessing completed")
except Exception as e:
    print("❌ Preprocessing failed:", str(e))
    raise

In [None]:
# ⚙️ 6. Training Arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-punjabi-finetuned",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    warmup_steps=100,
    max_steps=200,
    save_steps=100,
    eval_steps=100,
    logging_steps=50,
    evaluation_strategy="steps",
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    push_to_hub=False,
)

In [None]:
# 📦 7. Data Collator
def data_collator(features):
    input_features = [{"input_features": f["input_features"]} for f in features if f["input_features"]]
    label_features = [f["labels"] for f in features if f["labels"]]
    batch = processor.feature_extractor.pad(input_features, return_tensors="pt")
    labels_batch = processor.tokenizer.pad({"input_ids": label_features}, return_tensors="pt")
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
    batch["labels"] = labels
    return batch

In [None]:
# 📏 8. Metric Calculation (WER)
metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    wer_score = metric.compute(predictions=pred_str, references=label_str)
    print(f"WER: {wer_score:.4f}")
    return {"wer": wer_score}

In [None]:
# 🏋️‍♂️ 9. Trainer Initialization
try:
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset,
        eval_dataset=dataset,
        tokenizer=processor.feature_extractor,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    print("✅ Trainer initialized")
except Exception as e:
    print("❌ Failed to initialize trainer:", str(e))
    raise

In [None]:
# 🚀 10. Fine-Tuning the Model
try:
    trainer.train()
    print("✅ Training completed")
except Exception as e:
    print("❌ Training failed:", str(e))
    raise

In [None]:
# 🧪 11. Evaluate Final WER on a Few Samples
results = []
for example in dataset.select(range(5)):
    try:
        input_data = processor(example["input_features"], return_tensors="pt").to(model.device)
        with torch.no_grad():
            generated_ids = model.generate(input_data["input_features"])
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        results.append((transcription, example["text"]))
        print("Ref:", example["text"])
        print("Hyp:", transcription)
        print("---")
    except Exception as e:
        print("❌ Inference failed:", str(e))

In [None]:
preds, refs = zip(*results)
final_wer = metric.compute(predictions=preds, references=refs)
print("🔍 Final WER on samples:", final_wer)