# Whisper-Small Baseline Evaluation on FLEURS Hindi
**JoshTalks ASR Research | Q1 — Baseline Evaluation**

Evaluates pretrained `openai/whisper-small` on a subset of FLEURS Hindi test set.

**Outputs:**
- `predictions.csv` — Reference vs. prediction pairs
- `baseline_metrics.json` — WER/CER/substitution/insertion/deletion rates

## 1. Setup & Dependencies

In [None]:
!pip install -q torch transformers datasets evaluate jiwer librosa soundfile tqdm pandas numpy matplotlib

In [None]:
import torch
import json
import os
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")

In [None]:
# Configuration
MODEL_NAME = "openai/whisper-small"
LANGUAGE = "hi"
FLEURS_DATASET = "google/fleurs"
FLEURS_LANG = "hi_in"
NUM_EVAL_SAMPLES = 200
BATCH_SIZE = 4
OUTPUT_DIR = "outputs/q1"
SEED = 42

os.makedirs(OUTPUT_DIR, exist_ok=True)
torch.manual_seed(SEED)
np.random.seed(SEED)

## 2. Load FLEURS Hindi Test Set

In [None]:
from datasets import load_dataset

# Load FLEURS Hindi test split (streaming to save memory)
fleurs_test = load_dataset(FLEURS_DATASET, FLEURS_LANG, split="test", trust_remote_code=True)

print(f"Total test samples: {len(fleurs_test)}")
print(f"Using subset: {min(NUM_EVAL_SAMPLES, len(fleurs_test))} samples")
print(f"Sample features: {fleurs_test.features}")

# Select subset
eval_subset = fleurs_test.select(range(min(NUM_EVAL_SAMPLES, len(fleurs_test))))
print(f"\nSubset size: {len(eval_subset)}")
print(f"Sample transcription: {eval_subset[0]['transcription'][:100]}")

## 3. Load Whisper Model

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Load processor and model
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

# Move to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

# Set forced decoder IDs for Hindi
forced_decoder_ids = processor.get_decoder_prompt_ids(language=LANGUAGE, task="transcribe")

print(f"Model loaded on: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

## 4. Run Inference

In [None]:
import librosa

references = []
predictions = []
audio_ids = []

print(f"Running inference on {len(eval_subset)} samples...")

for i, sample in enumerate(tqdm(eval_subset, desc="Evaluating")):
    # Get audio and reference
    audio = sample["audio"]["array"]
    sr = sample["audio"]["sampling_rate"]
    reference = sample["transcription"]
    
    # Resample to 16kHz if needed
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    
    # Process audio
    input_features = processor(
        audio, sampling_rate=16000, return_tensors="pt"
    ).input_features.to(device)
    
    # Generate
    with torch.no_grad():
        predicted_ids = model.generate(
            input_features,
            forced_decoder_ids=forced_decoder_ids,
            max_new_tokens=225,
        )
    
    # Decode
    prediction = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    
    references.append(reference)
    predictions.append(prediction)
    audio_ids.append(sample.get("id", i))

print(f"\nInference complete. {len(predictions)} predictions generated.")

## 5. Compute Metrics

In [None]:
import jiwer

# Corpus-level WER
corpus_wer = jiwer.wer(references, predictions)
corpus_cer = jiwer.cer(references, predictions)

# Detailed breakdown
wer_output = jiwer.process_words(references, predictions)
total_ref = wer_output.substitutions + wer_output.deletions + wer_output.hits

metrics = {
    "label": "whisper-small-pretrained",
    "model": MODEL_NAME,
    "eval_dataset": f"{FLEURS_DATASET}/{FLEURS_LANG}",
    "num_samples": len(references),
    "corpus_wer": round(corpus_wer, 4),
    "corpus_cer": round(corpus_cer, 4),
    "total_substitutions": wer_output.substitutions,
    "total_insertions": wer_output.insertions,
    "total_deletions": wer_output.deletions,
    "total_hits": wer_output.hits,
    "total_ref_words": total_ref,
    "substitution_rate": round(wer_output.substitutions / max(total_ref, 1), 4),
    "insertion_rate": round(wer_output.insertions / max(total_ref, 1), 4),
    "deletion_rate": round(wer_output.deletions / max(total_ref, 1), 4),
    "timestamp": datetime.now().isoformat(),
    "device": device,
}

print("\n" + "="*50)
print("BASELINE EVALUATION RESULTS")
print("="*50)
print(f"Model: {MODEL_NAME}")
print(f"Samples: {metrics['num_samples']}")
print(f"WER: {metrics['corpus_wer']:.4f}")
print(f"CER: {metrics['corpus_cer']:.4f}")
print(f"Substitution Rate: {metrics['substitution_rate']:.4f}")
print(f"Insertion Rate: {metrics['insertion_rate']:.4f}")
print(f"Deletion Rate: {metrics['deletion_rate']:.4f}")
print("="*50)

## 6. Save Outputs

In [None]:
# Save predictions CSV
pred_df = pd.DataFrame({
    "id": audio_ids,
    "reference": references,
    "prediction": predictions,
})
pred_df.to_csv(os.path.join(OUTPUT_DIR, "predictions.csv"), index=False)
print(f"Predictions saved to {OUTPUT_DIR}/predictions.csv")

# Save metrics JSON
with open(os.path.join(OUTPUT_DIR, "baseline_metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2, ensure_ascii=False)
print(f"Metrics saved to {OUTPUT_DIR}/baseline_metrics.json")

# Display sample predictions
print("\nSample Predictions:")
print("-"*60)
for i in range(min(5, len(references))):
    print(f"REF: {references[i][:80]}")
    print(f"HYP: {predictions[i][:80]}")
    print("-"*60)

## 7. Quick Error Analysis

In [None]:
# Per-utterance WER distribution
import matplotlib.pyplot as plt

per_utt_wer = []
for ref, hyp in zip(references, predictions):
    if ref.strip():
        per_utt_wer.append(jiwer.wer(ref, hyp))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# WER histogram
axes[0].hist(per_utt_wer, bins=30, color='#4A90D9', edgecolor='white', alpha=0.85)
axes[0].axvline(np.mean(per_utt_wer), color='red', linestyle='--', label=f'Mean: {np.mean(per_utt_wer):.3f}')
axes[0].set_xlabel('WER per utterance')
axes[0].set_ylabel('Count')
axes[0].set_title('WER Distribution (Baseline)', fontweight='bold')
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# Error type breakdown
error_types = ['Substitutions', 'Insertions', 'Deletions']
error_counts = [metrics['total_substitutions'], metrics['total_insertions'], metrics['total_deletions']]
colors = ['#E8636F', '#F5A623', '#4A90D9']
axes[1].bar(error_types, error_counts, color=colors, edgecolor='white')
axes[1].set_ylabel('Count')
axes[1].set_title('Error Type Breakdown', fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)
for i, v in enumerate(error_counts):
    axes[1].text(i, v + 5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'baseline_analysis.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Analysis plot saved.")