# Bengali ASR Model Analysis

This notebook demonstrates:
1. Loading a trained model
2. Running inference on sample audio
3. Visualizing errors (WER/CER)
4. Analyzing error patterns

**Author:** BRAC Data Science Team  
**Date:** October 2025

In [None]:
# Install required packages (if needed)
# !pip install transformers datasets librosa soundfile jiwer matplotlib seaborn

In [None]:
import os
import torch
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Audio, display

from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    WhisperProcessor,
    WhisperForConditionalGeneration
)
from jiwer import wer, cer

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load Model

Load a trained Wav2Vec2 or Whisper model checkpoint.

In [None]:
# Configuration
MODEL_PATH = "../models/wav2vec2_bengali/checkpoint-best"
MODEL_TYPE = "wav2vec2"  # or "whisper"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {MODEL_TYPE} model from {MODEL_PATH}")
print(f"Device: {DEVICE}")

# Load model and processor
if MODEL_TYPE == "wav2vec2":
    processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).to(DEVICE)
elif MODEL_TYPE == "whisper":
    processor = WhisperProcessor.from_pretrained(MODEL_PATH)
    model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(DEVICE)
else:
    raise ValueError(f"Unknown model type: {MODEL_TYPE}")

model.eval()
print("✓ Model loaded successfully!")

## 2. Inference on Sample Audio

Transcribe a sample Bengali audio file.

In [None]:
# Load sample audio
# TODO: Replace with actual audio file path
SAMPLE_AUDIO_PATH = "../data/samples/sample_bengali.wav"

# Check if file exists
if not os.path.exists(SAMPLE_AUDIO_PATH):
    print(f"⚠️ Sample audio not found at {SAMPLE_AUDIO_PATH}")
    print("Please provide a sample audio file or download from OpenSLR.")
else:
    # Load and display audio
    audio, sr = librosa.load(SAMPLE_AUDIO_PATH, sr=16000)
    print(f"Audio duration: {len(audio) / sr:.2f} seconds")
    print(f"Sample rate: {sr} Hz")
    
    # Display audio player
    display(Audio(audio, rate=sr))

In [None]:
def transcribe_audio(audio_path, model, processor, device, model_type="wav2vec2"):
    """Transcribe audio file."""
    # Load audio
    audio, sr = librosa.load(audio_path, sr=16000, mono=True)
    
    if model_type == "wav2vec2":
        # Process with Wav2Vec2
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        
        with torch.no_grad():
            logits = model(inputs.input_values.to(device)).logits
        
        # Decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        
        # Calculate confidence
        probs = torch.nn.functional.softmax(logits, dim=-1)
        max_probs = torch.max(probs, dim=-1).values
        confidence = float(max_probs.mean().cpu())
        
    else:  # whisper
        # Process with Whisper
        input_features = processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(device)
        
        with torch.no_grad():
            predicted_ids = model.generate(input_features)
        
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        confidence = 0.9  # Whisper doesn't provide confidence
    
    return transcription, confidence


# Transcribe sample
if os.path.exists(SAMPLE_AUDIO_PATH):
    transcript, confidence = transcribe_audio(
        SAMPLE_AUDIO_PATH, model, processor, DEVICE, MODEL_TYPE
    )
    
    print("\n" + "="*60)
    print("TRANSCRIPTION RESULT")
    print("="*60)
    print(f"Transcript: {transcript}")
    print(f"Confidence: {confidence:.2%}")
    print("="*60)

## 3. Evaluate on Test Set

Calculate WER and CER on a test set.

In [None]:
# Load test data
TEST_TSV = "../data/processed/test.tsv"

if os.path.exists(TEST_TSV):
    test_df = pd.read_csv(TEST_TSV, sep='\t')
    print(f"Loaded {len(test_df)} test samples")
    
    # Sample first 10 for quick evaluation (change to evaluate all)
    test_df_sample = test_df.head(10)
    
    # Transcribe
    predictions = []
    references = []
    
    for idx, row in test_df_sample.iterrows():
        try:
            pred, _ = transcribe_audio(row['path'], model, processor, DEVICE, MODEL_TYPE)
            predictions.append(pred)
            references.append(row['transcript'])
        except Exception as e:
            print(f"Error on {row['path']}: {e}")
            continue
    
    # Calculate metrics
    wer_score = wer(references, predictions)
    cer_score = cer(references, predictions)
    
    print(f"\nWord Error Rate (WER): {wer_score*100:.2f}%")
    print(f"Character Error Rate (CER): {cer_score*100:.2f}%")
    
    # Create results dataframe
    results_df = pd.DataFrame({
        'reference': references,
        'prediction': predictions
    })
    
    display(results_df.head())
else:
    print(f"Test file not found: {TEST_TSV}")
    print("Please run preprocessing first: cd ../data && python preprocess.py")

## 4. Visualize Errors

Visualize error distribution and patterns.

In [None]:
# Calculate per-sample WER
if 'results_df' in locals():
    sample_wers = []
    for ref, pred in zip(references, predictions):
        sample_wers.append(wer([ref], [pred]))
    
    # Plot WER distribution
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram
    axes[0].hist(sample_wers, bins=20, edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Word Error Rate')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('WER Distribution')
    axes[0].axvline(np.mean(sample_wers), color='red', linestyle='--', label=f'Mean: {np.mean(sample_wers):.2%}')
    axes[0].legend()
    
    # Box plot
    axes[1].boxplot(sample_wers, vert=True)
    axes[1].set_ylabel('Word Error Rate')
    axes[1].set_title('WER Box Plot')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nError Statistics:")
    print(f"Mean WER: {np.mean(sample_wers):.2%}")
    print(f"Median WER: {np.median(sample_wers):.2%}")
    print(f"Std Dev: {np.std(sample_wers):.2%}")
    print(f"Min WER: {np.min(sample_wers):.2%}")
    print(f"Max WER: {np.max(sample_wers):.2%}")

## 5. Error Pattern Analysis

Identify common transcription errors.

In [None]:
from collections import Counter
import difflib

if 'results_df' in locals():
    substitutions = []
    deletions = []
    insertions = []
    
    for ref, pred in zip(references, predictions):
        ref_words = ref.split()
        pred_words = pred.split()
        
        matcher = difflib.SequenceMatcher(None, ref_words, pred_words)
        
        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag == 'replace':
                for i, j in zip(range(i1, i2), range(j1, j2)):
                    if i < len(ref_words) and j < len(pred_words):
                        substitutions.append((ref_words[i], pred_words[j]))
            elif tag == 'delete':
                for i in range(i1, i2):
                    if i < len(ref_words):
                        deletions.append(ref_words[i])
            elif tag == 'insert':
                for j in range(j1, j2):
                    if j < len(pred_words):
                        insertions.append(pred_words[j])
    
    print("\nMost Common Substitutions:")
    print("="*60)
    for (ref_word, pred_word), count in Counter(substitutions).most_common(10):
        print(f"{ref_word:20s} → {pred_word:20s} ({count} times)")
    
    print("\nMost Common Deletions:")
    print("="*60)
    for word, count in Counter(deletions).most_common(10):
        print(f"{word:20s} (deleted {count} times)")
    
    print("\nMost Common Insertions:")
    print("="*60)
    for word, count in Counter(insertions).most_common(10):
        print(f"{word:20s} (inserted {count} times)")

## Conclusion

This notebook demonstrated:
- Loading and using a trained Bengali ASR model
- Transcribing audio samples
- Calculating WER and CER metrics
- Analyzing error patterns

**Next Steps:**
1. Collect more training data in error-prone areas
2. Apply targeted data augmentation
3. Fine-tune on BRAC dialect data
4. Implement language model for better accuracy