# 🎯 Sinhala ASR Model Tester

This notebook is designed to load and test your trained Sinhala ASR model with any audio file.

## 📋 Requirements:
- Trained model saved in `./sinhala-whisper-asr-final/` directory
- Audio files to test (supports .wav, .flac, .mp3, etc.)
- Required libraries: transformers, librosa, torch

## 🚀 Features:
- Load your trained Sinhala ASR model
- Test single audio files
- Test multiple audio files
- Batch testing from dataset
- Performance analysis and metrics

In [None]:
# ================================
# INSTALL REQUIRED PACKAGES
# ================================

# Uncomment if packages are not installed
# %pip install transformers datasets librosa torch
# %pip install soundfile numpy pandas
# %pip install jiwer  # For evaluation metrics

print("📦 All required packages should be installed!")

In [None]:
# ================================
# IMPORTS
# ================================

import os
import torch
import librosa
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperFeatureExtractor,
    WhisperTokenizer
)

# Optional: For evaluation metrics
try:
    from jiwer import wer, cer
    METRICS_AVAILABLE = True
    print("📊 Evaluation metrics available (WER, CER)")
except ImportError:
    METRICS_AVAILABLE = False
    print("⚠️ jiwer not available - metrics will be limited")

print("✅ All imports successful!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"💻 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ================================
# MODEL CONFIGURATION
# ================================

# Model paths - Update these if your model is saved elsewhere
MODEL_PATH = "./sinhala-whisper-asr-final"  # Your trained model
FALLBACK_MODEL = "openai/whisper-base"     # Fallback to base model if trained model not found

# Audio settings
SAMPLE_RATE = 16000
MAX_AUDIO_LENGTH = 30  # seconds

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"🎯 Configuration:")
print(f"   📁 Model Path: {MODEL_PATH}")
print(f"   🔄 Fallback: {FALLBACK_MODEL}")
print(f"   📈 Sample Rate: {SAMPLE_RATE} Hz")
print(f"   ⏱️ Max Length: {MAX_AUDIO_LENGTH} seconds")
print(f"   💻 Device: {DEVICE}")

In [None]:
# ================================
# MODEL LOADING CLASS
# ================================

class SinhalaASRTester:
    def __init__(self, model_path=MODEL_PATH, fallback_model=FALLBACK_MODEL):
        self.model_path = model_path
        self.fallback_model = fallback_model
        self.device = DEVICE
        self.model = None
        self.processor = None
        self.model_type = None
        
    def load_model(self):
        """Load the trained model or fallback to base model"""
        
        print(f"🔄 Loading Sinhala ASR model...")
        
        # Try to load trained model first
        if os.path.exists(self.model_path):
            try:
                print(f"📁 Loading trained model from: {self.model_path}")
                self.model = WhisperForConditionalGeneration.from_pretrained(self.model_path)
                self.processor = WhisperProcessor.from_pretrained(self.model_path)
                self.model_type = "trained"
                print(f"✅ Trained model loaded successfully!")
                
            except Exception as e:
                print(f"❌ Error loading trained model: {e}")
                print(f"🔄 Falling back to base model...")
                self._load_base_model()
        else:
            print(f"📁 Trained model not found at: {self.model_path}")
            print(f"🔄 Loading base model...")
            self._load_base_model()
        
        # Configure model
        self.model.config.forced_decoder_ids = None
        self.model.config.suppress_tokens = []
        
        # Move to device
        self.model.to(self.device)
        self.model.eval()
        
        print(f"🎯 Model Configuration:")
        print(f"   🤖 Model Type: {self.model_type}")
        print(f"   🏗️ Architecture: {self.model.config.model_type}")
        print(f"   🌐 Language: Sinhala (si)")
        print(f"   💻 Device: {self.device}")
        
        return True
    
    def _load_base_model(self):
        """Load the base Whisper model"""
        try:
            self.model = WhisperForConditionalGeneration.from_pretrained(self.fallback_model)
            self.processor = WhisperProcessor.from_pretrained(
                self.fallback_model, 
                language="si", 
                task="transcribe"
            )
            self.model_type = "base"
            print(f"✅ Base model loaded successfully!")
        except Exception as e:
            print(f"❌ Error loading base model: {e}")
            raise
    
    def preprocess_audio(self, audio_path):
        """Load and preprocess audio file"""
        try:
            # Load audio
            audio_array, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
            
            # Get duration
            duration = len(audio_array) / SAMPLE_RATE
            
            # Trim if too long
            max_samples = MAX_AUDIO_LENGTH * SAMPLE_RATE
            if len(audio_array) > max_samples:
                audio_array = audio_array[:max_samples]
                print(f"⚠️ Audio trimmed to {MAX_AUDIO_LENGTH} seconds")
            
            # Normalize
            max_val = np.max(np.abs(audio_array))
            if max_val > 0:
                audio_array = audio_array / max_val
            
            return audio_array, duration
            
        except Exception as e:
            print(f"❌ Error preprocessing audio: {e}")
            return None, None
    
    def transcribe(self, audio_path, show_details=True):
        """Transcribe a single audio file"""
        
        if self.model is None:
            print(f"❌ Model not loaded. Call load_model() first.")
            return None
        
        if not os.path.exists(audio_path):
            print(f"❌ Audio file not found: {audio_path}")
            return None
        
        try:
            if show_details:
                print(f"\n🎵 Processing: {os.path.basename(audio_path)}")
            
            # Preprocess audio
            audio_array, duration = self.preprocess_audio(audio_path)
            if audio_array is None:
                return None
            
            if show_details:
                print(f"📊 Audio info: {duration:.2f}s, {len(audio_array):,} samples")
            
            # Process with model
            inputs = self.processor(audio_array, sampling_rate=SAMPLE_RATE, return_tensors="pt")
            input_features = inputs.input_features.to(self.device)
            
            # Generate transcription
            with torch.no_grad():
                predicted_ids = self.model.generate(
                    input_features,
                    language="si",
                    task="transcribe",
                    max_length=448,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0
                )
            
            # Decode
            transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            
            result = {
                'file': os.path.basename(audio_path),
                'path': audio_path,
                'duration': duration,
                'transcription': transcription,
                'word_count': len(transcription.split()) if transcription.strip() else 0,
                'char_count': len(transcription),
                'model_type': self.model_type
            }
            
            if show_details:
                self._display_result(result)
            
            return result
            
        except Exception as e:
            print(f"❌ Error during transcription: {e}")
            return None
    
    def _display_result(self, result):
        """Display transcription result"""
        print(f"\n" + "=" * 60)
        print(f"🎯 TRANSCRIPTION RESULT")
        print(f"=" * 60)
        print(f"🎵 File: {result['file']}")
        print(f"⏱️ Duration: {result['duration']:.2f} seconds")
        print(f"🤖 Model: {result['model_type']} Whisper")
        print(f"📝 Transcription: '{result['transcription']}'")
        
        if result['transcription'].strip():
            speaking_rate = result['word_count'] / (result['duration'] / 60) if result['duration'] > 0 else 0
            print(f"📊 Analysis:")
            print(f"   🔤 Words: {result['word_count']}")
            print(f"   📝 Characters: {result['char_count']}")
            print(f"   ⏱️ Speaking rate: ~{speaking_rate:.1f} words/min")
        else:
            print(f"⚠️ Empty transcription")
        
        print(f"=" * 60)

# Initialize the tester
print("🎯 Sinhala ASR Tester initialized!")

In [None]:
# ================================
# LOAD THE MODEL
# ================================

# Create tester instance
asr_tester = SinhalaASRTester()

# Load the model
success = asr_tester.load_model()

if success:
    print(f"\n🎉 Ready to test audio files!")
    print(f"💡 Use asr_tester.transcribe('your_audio_file.wav') to test")
else:
    print(f"\n❌ Failed to load model")

In [None]:
# ================================
# TEST SINGLE AUDIO FILE
# ================================

# Test your specific audio file
audio_file = "test_audio.wav"  # Change this to your audio file path

print(f"🎯 Testing audio file: {audio_file}")
print(f"=" * 50)

result = asr_tester.transcribe(audio_file)

if result:
    print(f"\n✅ Transcription completed!")
else:
    print(f"\n❌ Transcription failed")
    print(f"💡 Make sure the audio file exists and is in a supported format")

In [None]:
# ================================
# TEST MULTIPLE AUDIO FILES
# ================================

def test_multiple_files(file_list, show_individual=True):
    """Test multiple audio files"""
    
    print(f"🎯 Testing {len(file_list)} audio files...")
    print(f"=" * 50)
    
    results = []
    successful = 0
    
    for i, audio_file in enumerate(file_list, 1):
        print(f"\n📁 [{i}/{len(file_list)}] Processing: {os.path.basename(audio_file)}")
        
        result = asr_tester.transcribe(audio_file, show_details=show_individual)
        
        if result:
            results.append(result)
            successful += 1
            if not show_individual:
                print(f"   ✅ Success: '{result['transcription'][:50]}...'")
        else:
            print(f"   ❌ Failed")
    
    # Summary
    print(f"\n" + "=" * 60)
    print(f"📊 BATCH TESTING SUMMARY")
    print(f"=" * 60)
    print(f"📁 Total files: {len(file_list)}")
    print(f"✅ Successful: {successful}")
    print(f"❌ Failed: {len(file_list) - successful}")
    print(f"📈 Success rate: {successful/len(file_list)*100:.1f}%")
    
    if results:
        total_duration = sum(r['duration'] for r in results)
        total_words = sum(r['word_count'] for r in results)
        avg_speaking_rate = total_words / (total_duration / 60) if total_duration > 0 else 0
        
        print(f"⏱️ Total audio: {total_duration:.1f} seconds")
        print(f"🔤 Total words: {total_words}")
        print(f"⏱️ Avg speaking rate: {avg_speaking_rate:.1f} words/min")
    
    return results

# Example: Test multiple files from dataset
# Uncomment and modify paths as needed

# Sample FLAC files from your dataset
sample_files = [
    "asr_sinhala/data/00/0000f47c22.flac",
    "asr_sinhala/data/00/000101700f.flac",
    "asr_sinhala/data/00/000107b539.flac"
]

# Filter existing files
existing_files = [f for f in sample_files if os.path.exists(f)]

if existing_files:
    print(f"Found {len(existing_files)} sample files to test")
    batch_results = test_multiple_files(existing_files, show_individual=False)
else:
    print(f"No sample audio files found. Update the paths above to test multiple files.")

In [None]:
# ================================
# TEST FROM CSV DATASET
# ================================

def test_from_csv(csv_path, audio_base_path="asr_sinhala/data", max_samples=5):
    """Test audio files from CSV dataset with ground truth comparison"""
    
    if not os.path.exists(csv_path):
        print(f"❌ CSV file not found: {csv_path}")
        return
    
    # Load CSV
    try:
        df = pd.read_csv(csv_path)
        print(f"📊 Loaded CSV with {len(df)} samples")
        
        # Take sample
        sample_df = df.head(max_samples)
        
        print(f"🎯 Testing {len(sample_df)} samples from dataset...")
        print(f"=" * 60)
        
        results = []
        
        for idx, row in sample_df.iterrows():
            audio_path = os.path.join(audio_base_path, row.iloc[0])  # First column is audio path
            ground_truth = row.iloc[1] if len(row) > 1 else "N/A"    # Second column is text
            
            print(f"\n📁 [{idx+1}/{len(sample_df)}] {os.path.basename(audio_path)}")
            print(f"📝 Ground truth: '{ground_truth}'")
            
            if os.path.exists(audio_path):
                result = asr_tester.transcribe(audio_path, show_details=False)
                
                if result:
                    result['ground_truth'] = ground_truth
                    results.append(result)
                    
                    print(f"🤖 Prediction: '{result['transcription']}'")
                    
                    # Calculate metrics if available
                    if METRICS_AVAILABLE and ground_truth != "N/A":
                        try:
                            wer_score = wer(ground_truth, result['transcription']) * 100
                            cer_score = cer(ground_truth, result['transcription']) * 100
                            print(f"📊 WER: {wer_score:.1f}%, CER: {cer_score:.1f}%")
                            result['wer'] = wer_score
                            result['cer'] = cer_score
                        except:
                            print(f"📊 Metrics calculation failed")
                    
                    print(f"✅ Success")
                else:
                    print(f"❌ Transcription failed")
            else:
                print(f"❌ Audio file not found: {audio_path}")
        
        # Summary
        if results:
            print(f"\n" + "=" * 60)
            print(f"📊 DATASET TESTING SUMMARY")
            print(f"=" * 60)
            print(f"✅ Successful transcriptions: {len(results)}/{len(sample_df)}")
            
            if METRICS_AVAILABLE and any('wer' in r for r in results):
                avg_wer = np.mean([r['wer'] for r in results if 'wer' in r])
                avg_cer = np.mean([r['cer'] for r in results if 'cer' in r])
                print(f"📈 Average WER: {avg_wer:.1f}%")
                print(f"📈 Average CER: {avg_cer:.1f}%")
        
        return results
        
    except Exception as e:
        print(f"❌ Error processing CSV: {e}")
        return None

# Test with your CSV files
csv_files = ["test.csv", "train.csv"]

for csv_file in csv_files:
    if os.path.exists(csv_file):
        print(f"\n🎯 Testing from {csv_file}...")
        csv_results = test_from_csv(csv_file, max_samples=3)
        break
else:
    print(f"No CSV files found for testing")

In [None]:
# ================================
# CUSTOM AUDIO TEST
# ================================

def test_custom_audio():
    """Interactive function to test any audio file"""
    
    print(f"🎯 Custom Audio Testing")
    print(f"=" * 40)
    print(f"Enter the path to your audio file, or 'quit' to exit")
    print(f"Supported formats: .wav, .flac, .mp3, .m4a, .ogg")
    
    while True:
        audio_path = input(f"\n🎵 Audio file path: ").strip()
        
        if audio_path.lower() in ['quit', 'exit', 'q']:
            print(f"👋 Goodbye!")
            break
        
        if not audio_path:
            continue
        
        # Remove quotes if present
        audio_path = audio_path.strip('"\'')
        
        result = asr_tester.transcribe(audio_path)
        
        if result:
            print(f"\n✅ Transcription successful!")
        else:
            print(f"\n❌ Transcription failed")

# Uncomment to run interactive testing
# test_custom_audio()

print(f"💡 Uncomment the line above to run interactive audio testing")

In [None]:
# ================================
# SAVE RESULTS TO FILE
# ================================

def save_results_to_csv(results, filename="transcription_results.csv"):
    """Save transcription results to CSV file"""
    
    if not results:
        print(f"❌ No results to save")
        return
    
    try:
        # Convert results to DataFrame
        df = pd.DataFrame(results)
        
        # Save to CSV
        df.to_csv(filename, index=False)
        
        print(f"💾 Results saved to: {filename}")
        print(f"📊 Saved {len(results)} transcription results")
        
    except Exception as e:
        print(f"❌ Error saving results: {e}")

# Example: Save results if you have any
# save_results_to_csv(batch_results, "my_transcription_results.csv")

print(f"💡 Use save_results_to_csv(your_results, 'filename.csv') to save results")

## 🎯 Quick Usage Guide

### Single File Testing:
```python
result = asr_tester.transcribe("your_audio.wav")
```

### Multiple Files Testing:
```python
files = ["audio1.wav", "audio2.wav", "audio3.wav"]
results = test_multiple_files(files)
```

### Dataset Testing:
```python
results = test_from_csv("your_dataset.csv", "path/to/audio/files")
```

### Save Results:
```python
save_results_to_csv(results, "transcription_results.csv")
```

## 📝 Notes:
- Make sure your trained model is in `./sinhala-whisper-asr-final/`
- If trained model is not found, it will use the base Whisper model
- Supports various audio formats: WAV, FLAC, MP3, M4A, OGG
- Audio longer than 30 seconds will be automatically trimmed
- Install `jiwer` package for WER/CER evaluation metrics

## 🚀 Happy Testing!