In [1]:
# ================================
# SINHALA ASR TRAINING WITH WHISPER - KAGGLE VERSION
# ================================

# Install required packages
%pip install transformers datasets evaluate jiwer
%pip install librosa scikit-learn pandas
%pip install soundfile
%pip install tensorboard
%pip install accelerate

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# ================================
# IMPORTS
# ================================

import pandas as pd
import numpy as np
import torch
import librosa
import os
import glob
from datasets import Dataset, Audio
from transformers import (
    WhisperFeatureExtractor, 
    WhisperTokenizer, 
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from jiwer import wer, cer, mer

2025-07-21 18:13:34.191596: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753121614.547955      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753121614.648770      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# ================================
# 1. DATA PREPARATION WITH KAGGLE PATHS
# ================================

# Define Kaggle paths
AUDIO_BASE_PATH = "/kaggle/input/large-sinhala-asr-training-dataset"
CSV_BASE_PATH = "/kaggle/input/new-dataset-10"

print("🚀 Loading Sinhala ASR data for Kaggle training...")
print(f"📁 Audio base path: {AUDIO_BASE_PATH}")
print(f"📁 CSV base path: {CSV_BASE_PATH}")

# Verify paths exist
if not os.path.exists(AUDIO_BASE_PATH):
    print(f"❌ Audio path not found: {AUDIO_BASE_PATH}")
    raise FileNotFoundError("Audio dataset not found")
    
if not os.path.exists(CSV_BASE_PATH):
    print(f"❌ CSV path not found: {CSV_BASE_PATH}")
    raise FileNotFoundError("CSV files not found")

print("✅ All paths verified successfully!")

# Load CSV files
train_csv = os.path.join(CSV_BASE_PATH, "10-train.csv")
test_csv = os.path.join(CSV_BASE_PATH, "10-test.csv")

train_df = pd.read_csv(train_csv)
val_df = pd.read_csv(test_csv)

print(f"\n📊 Dataset Information:")
print(f"   🏋️ Training samples: {len(train_df):,}")
print(f"   🧪 Validation samples: {len(val_df):,}")
print(f"   📈 Total samples: {len(train_df) + len(val_df):,}")

# Check data structure
print(f"\n🔍 Data Structure:")
print(f"   📋 Train columns: {list(train_df.columns)}")
print(f"   📋 Val columns: {list(val_df.columns)}")

# Show sample data
print(f"\n📝 Sample Training Data:")
print(train_df.head(3))

🚀 Loading Sinhala ASR data for Kaggle training...
📁 Audio base path: /kaggle/input/large-sinhala-asr-training-dataset
📁 CSV base path: /kaggle/input/new-dataset-10
✅ All paths verified successfully!

📊 Dataset Information:
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000
   📈 Total samples: 10,000

🔍 Data Structure:
   📋 Train columns: ['file', 'sentence_cleaned']
   📋 Val columns: ['file', 'sentence_cleaned']

📝 Sample Training Data:
                                  file              sentence_cleaned
0  asr_sinhala/data/aa/aaaee62687.flac  අක්කයි මායි දෙන්නත් ළඟ නැතුව
1  asr_sinhala/data/07/07031079ca.flac      ශ්‍රී ලංකාව බැහැර කොට ඇත
2  asr_sinhala/data/31/3128fc4733.flac  ඔන්න ඔය විදිහට යෙදෙන නැකතෙන්


In [4]:
# ================================
# 2. PATH CONVERSION AND VALIDATION
# ================================

def convert_audio_path(relative_path, base_path=AUDIO_BASE_PATH):
    """Convert relative audio path to absolute Kaggle path"""
    if pd.isna(relative_path) or relative_path == "":
        return None
    
    # Handle already absolute paths
    if os.path.isabs(relative_path):
        return relative_path
    
    # Join with base path
    absolute_path = os.path.join(base_path, relative_path)
    return os.path.normpath(absolute_path)

def verify_audio_file(audio_path):
    """Verify if audio file exists and is readable"""
    try:
        if not os.path.exists(audio_path):
            return False
        audio, sr = librosa.load(audio_path, sr=16000)
        return len(audio) > 0
    except Exception as e:
        return False

# Ensure consistent column naming
audio_col = train_df.columns[0]
text_col = train_df.columns[1]

print(f"🔄 Using columns: '{audio_col}' as audio, '{text_col}' as sentence")

# Rename columns for consistency
train_df = train_df[[audio_col, text_col]].copy()
val_df = val_df[[audio_col, text_col]].copy()
train_df.columns = ["audio", "sentence"]
val_df.columns = ["audio", "sentence"]

# Convert relative paths to absolute paths
print(f"\n🔗 Converting audio paths...")
train_df['audio'] = train_df['audio'].apply(convert_audio_path)
val_df['audio'] = val_df['audio'].apply(convert_audio_path)

# Remove rows with missing data
initial_train_size = len(train_df)
initial_val_size = len(val_df)

train_df = train_df.dropna()
val_df = val_df.dropna()
train_df = train_df[train_df["sentence"].str.strip() != ""]
val_df = val_df[val_df["sentence"].str.strip() != ""]

print(f"\n🔍 Validating audio files (this may take a moment)...")

# Validate audio files (sample check for speed)
sample_size = min(1000, len(train_df))
train_sample = train_df.head(sample_size)
valid_count = sum(verify_audio_file(path) for path in train_sample['audio'])

print(f"📊 Audio validation results:")
print(f"   ✅ Valid files in sample: {valid_count}/{sample_size}")
print(f"   📈 Estimated validity rate: {valid_count/sample_size*100:.1f}%")

if valid_count < sample_size * 0.1:  # Less than 10% valid
    print("⚠️ Warning: Low audio file validity rate detected")
    print("💡 Check if audio paths are correctly mapped")

print(f"\n🧹 Data Cleaning Results:")
print(f"   🏋️ Training: {initial_train_size} → {len(train_df)} samples")
print(f"   🧪 Validation: {initial_val_size} → {len(val_df)} samples")

# Display final sample data
print(f"\n📝 Final Sample Data:")
print(train_df.head(3))

🔄 Using columns: 'file' as audio, 'sentence_cleaned' as sentence

🔗 Converting audio paths...

🔍 Validating audio files (this may take a moment)...
📊 Audio validation results:
   ✅ Valid files in sample: 1000/1000
   📈 Estimated validity rate: 100.0%

🧹 Data Cleaning Results:
   🏋️ Training: 8000 → 8000 samples
   🧪 Validation: 2000 → 2000 samples

📝 Final Sample Data:
                                               audio  \
0  /kaggle/input/large-sinhala-asr-training-datas...   
1  /kaggle/input/large-sinhala-asr-training-datas...   
2  /kaggle/input/large-sinhala-asr-training-datas...   

                       sentence  
0  අක්කයි මායි දෙන්නත් ළඟ නැතුව  
1      ශ්‍රී ලංකාව බැහැර කොට ඇත  
2  ඔන්න ඔය විදිහට යෙදෙන නැකතෙන්  


In [5]:
# ================================
# 3. WHISPER PROCESSOR SETUP
# ================================

print("🤖 Setting up Whisper processor for Sinhala...")

# Initialize Whisper components for Sinhala
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base", language="si", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-base", language="si", task="transcribe")

print("✅ Whisper processor setup completed!")
print(f"   🌐 Language: Sinhala (si)")
print(f"   🎯 Task: Transcribe")
print(f"   📏 Model: whisper-base")

🤖 Setting up Whisper processor for Sinhala...


preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

✅ Whisper processor setup completed!
   🌐 Language: Sinhala (si)
   🎯 Task: Transcribe
   📏 Model: whisper-base


In [6]:
# ================================
# 4. AUDIO PREPROCESSING FUNCTIONS
# ================================

def normalize_audio(audio_array):
    """Normalize audio to prevent clipping"""
    max_val = np.max(np.abs(audio_array))
    if max_val > 0:
        return audio_array / max_val
    return audio_array

def add_noise_augmentation(audio_array, noise_factor=0.005):
    """Add Gaussian noise for robustness"""
    noise = np.random.normal(0, noise_factor, audio_array.shape)
    return audio_array + noise

# Filter valid audio files for training
print("🔍 Filtering valid audio files for training...")

valid_train = []
valid_val = []

# Process training data
print("Processing training data...")
for idx, row in train_df.iterrows():
    if verify_audio_file(row['audio']):
        valid_train.append(row)
    else:
        if idx < 10:  # Only print first 10 invalid files
            print(f"   ⚠️ Invalid audio: {os.path.basename(row['audio'])}")

# Process validation data
print("Processing validation data...")
for idx, row in val_df.iterrows():
    if verify_audio_file(row['audio']):
        valid_val.append(row)
    else:
        if idx < 10:  # Only print first 10 invalid files
            print(f"   ⚠️ Invalid audio: {os.path.basename(row['audio'])}")

# Update dataframes
train_df = pd.DataFrame(valid_train)
val_df = pd.DataFrame(valid_val)

print(f"\n📊 Final Valid Dataset:")
print(f"   🏋️ Training samples: {len(train_df):,}")
print(f"   🧪 Validation samples: {len(val_df):,}")
print(f"   📈 Total valid samples: {len(train_df) + len(val_df):,}")

if len(train_df) == 0:
    raise ValueError("No valid training samples found! Check audio paths.")

🔍 Filtering valid audio files for training...
Processing training data...
Processing validation data...

📊 Final Valid Dataset:
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000
   📈 Total valid samples: 10,000


In [7]:
# ================================
# 5. DATASET CREATION
# ================================

print("📦 Creating HuggingFace datasets...")

# Create datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

# Cast audio columns with target sampling rate
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
val_dataset = val_dataset.cast_column("audio", Audio(sampling_rate=16000))

def prepare_dataset(examples):
    """Prepare dataset with robust preprocessing"""
    # Load and process audio
    audio = examples["audio"]
    audio_array = audio["array"]
    
    # Apply normalization
    audio_array = normalize_audio(audio_array)
    
    # Optional: Add noise augmentation (uncomment if needed)
    # audio_array = add_noise_augmentation(audio_array, noise_factor=0.001)
    
    # Ensure audio length constraints
    min_length = 1000  # ~0.06 seconds at 16kHz
    max_length = 480000  # ~30 seconds at 16kHz
    
    if len(audio_array) < min_length:
        # Pad short audio
        audio_array = np.pad(audio_array, (0, min_length - len(audio_array)), 'constant')
    elif len(audio_array) > max_length:
        # Truncate long audio
        audio_array = audio_array[:max_length]
    
    # Compute log-Mel input features
    examples["input_features"] = feature_extractor(
        audio_array, sampling_rate=16000
    ).input_features[0]
    
    # Clean up audio data
    del examples["audio"]
    
    # Process text
    sentences = examples["sentence"]
    
    # Clean and normalize text
    if isinstance(sentences, str):
        sentences = sentences.strip()
    
    # Encode target text to label ids
    examples["labels"] = tokenizer(sentences).input_ids
    del examples["sentence"]
    
    return examples

# Apply preprocessing
print("🔄 Preprocessing training dataset...")
train_dataset = train_dataset.map(prepare_dataset, num_proc=1)

print("🔄 Preprocessing validation dataset...")
val_dataset = val_dataset.map(prepare_dataset, num_proc=1)

print("✅ Dataset preprocessing completed!")

📦 Creating HuggingFace datasets...
🔄 Preprocessing training dataset...


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

🔄 Preprocessing validation dataset...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

✅ Dataset preprocessing completed!


In [8]:
# ================================
# 6. DATA COLLATOR
# ================================

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Handle input features
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        # Handle labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        
        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
print("✅ Data collator configured!")

✅ Data collator configured!


In [9]:
# ================================
# 7. EVALUATION METRICS
# ================================

def compute_metrics(pred):
    """Compute comprehensive ASR evaluation metrics"""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode token IDs to strings
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Calculate metrics
    wer_score = wer(label_str, pred_str) * 100
    cer_score = cer(label_str, pred_str) * 100
    
    # Sentence Error Rate
    ser_score = (
        sum(ref.strip() != pred.strip() for ref, pred in zip(label_str, pred_str))
        / len(label_str)
    ) * 100

    return {
        "wer": wer_score,
        "cer": cer_score,
        "ser": ser_score,
    }

print("📊 Evaluation metrics configured!")

📊 Evaluation metrics configured!


In [10]:
# ================================
# 8. MODEL SETUP
# ================================

print("🤖 Loading pre-trained Whisper model...")

try:
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise

# Configure model for Sinhala
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"💻 Using device: {device}")

print(f"📋 Model Information:")
print(f"   🏗️ Architecture: {model.config.model_type}")
print(f"   📏 Model size: whisper-base")
print(f"   🌐 Target language: Sinhala")

🤖 Loading pre-trained Whisper model...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

✅ Model loaded successfully!
💻 Using device: cuda
📋 Model Information:
   🏗️ Architecture: whisper
   📏 Model size: whisper-base
   🌐 Target language: Sinhala


In [11]:
# ================================
# 9. TRAINING ARGUMENTS
# ================================

# Output directory
output_dir = "./whisper-sinhala-asr-model"

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    
    # Training schedule
    num_train_epochs=12,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    
    # Optimization
    learning_rate=1e-5,
    warmup_steps=500,
    lr_scheduler_type="linear",
    weight_decay=0.01,
    
    # Evaluation and saving
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    
    # Memory optimization
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),  # Use FP16 if CUDA available
    dataloader_pin_memory=False,
    
    # Generation settings
    predict_with_generate=True,
    generation_max_length=225,
    
    # Logging
    logging_steps=25,
    logging_strategy="steps",
    report_to=["tensorboard"],
    
    # Additional settings
    remove_unused_columns=False,
    label_names=["labels"],
)

print("⚙️ Training arguments configured!")
print(f"   📁 Output directory: {output_dir}")
print(f"   🔄 Epochs: {training_args.num_train_epochs}")
print(f"   📦 Batch size: {training_args.per_device_train_batch_size}")
print(f"   📈 Learning rate: {training_args.learning_rate}")

⚙️ Training arguments configured!
   📁 Output directory: ./whisper-sinhala-asr-model
   🔄 Epochs: 12
   📦 Batch size: 16
   📈 Learning rate: 1e-05


In [12]:
# ================================
# 10. TRAINER SETUP
# ================================

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("🏃‍♂️ Trainer configured successfully!")
print(f"   🏋️ Training samples: {len(train_dataset):,}")
print(f"   🧪 Validation samples: {len(val_dataset):,}")

  trainer = Seq2SeqTrainer(


🏃‍♂️ Trainer configured successfully!
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000


In [13]:
# ================================
# 11. TRAINING EXECUTION
# ================================

print("🚀 Starting Sinhala ASR training...")
print("=" * 50)
print(f"📊 Dataset: {len(train_dataset):,} training, {len(val_dataset):,} validation")
print(f"🤖 Model: Whisper-base fine-tuned for Sinhala")
print(f"💻 Device: {device}")
print(f"⏱️ Estimated time: ~{len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs // 100} minutes")
print("=" * 50)

# Start training
try:
    trainer.train()
    print("✅ Training completed successfully!")
except Exception as e:
    print(f"❌ Training failed: {e}")
    raise

🚀 Starting Sinhala ASR training...
📊 Dataset: 8,000 training, 2,000 validation
🤖 Model: Whisper-base fine-tuned for Sinhala
💻 Device: cuda
⏱️ Estimated time: ~30 minutes


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Epoch,Training Loss,Validation Loss,Wer,Cer,Ser
1,1.5288,1.493721,123.609408,154.63509,100.0
2,1.2758,1.215741,106.621335,111.225256,100.0
3,0.4232,0.409705,87.961208,36.633738,99.7
4,0.2789,0.299378,73.013042,22.987642,97.15
5,0.2111,0.261623,68.910935,21.646696,96.45
6,0.1729,0.251005,66.692676,20.788416,95.2
7,0.1427,0.250083,65.856649,20.651317,95.7
8,0.1193,0.255152,67.461821,21.250423,96.15
9,0.1011,0.258103,65.120945,20.506705,94.75
10,0.0822,0.265802,65.243563,20.305751,95.7


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


✅ Training completed successfully!


In [14]:
# ================================
# 12. MODEL SAVING
# ================================

import shutil

# Clean up any existing model directory
model_save_dir = "./sinhala-whisper-asr-final"
if os.path.exists(model_save_dir):
    try:
        shutil.rmtree(model_save_dir)
        print(f"🗑️ Cleaned up existing directory: {model_save_dir}")
    except:
        pass

# Create fresh directory
os.makedirs(model_save_dir, exist_ok=True)

try:
    print("💾 Saving trained Sinhala ASR model...")
    
    # Save model
    model.save_pretrained(model_save_dir, safe_serialization=False)
    print("✅ Model saved successfully!")
    
    # Save processor
    processor.save_pretrained(model_save_dir)
    print("✅ Processor saved successfully!")
    
    # Verify saved files
    saved_files = os.listdir(model_save_dir)
    print(f"📁 Saved files: {saved_files}")
    
    print(f"\n🎉 Sinhala ASR model training completed!")
    print(f"📁 Model saved to: {model_save_dir}")
    
except Exception as e:
    print(f"❌ Error saving model: {e}")
    print("💡 Trying alternative save method...")
    
    # Alternative save method
    torch.save(model.state_dict(), os.path.join(model_save_dir, "pytorch_model.bin"))
    processor.save_pretrained(model_save_dir)
    print("✅ Alternative save completed!")

💾 Saving trained Sinhala ASR model...
✅ Model saved successfully!
✅ Processor saved successfully!
📁 Saved files: ['config.json', 'pytorch_model.bin', 'preprocessor_config.json', 'tokenizer_config.json', 'generation_config.json', 'merges.txt', 'normalizer.json', 'added_tokens.json', 'special_tokens_map.json', 'vocab.json']

🎉 Sinhala ASR model training completed!
📁 Model saved to: ./sinhala-whisper-asr-final


In [15]:
model_save_dir = "./sinhala-whisper-asr-final"
model.save_pretrained(model_save_dir)
processor.save_pretrained(model_save_dir)


[]

In [16]:
import shutil

shutil.make_archive("sinhala-whisper-asr-final", 'zip', model_save_dir)

'/kaggle/working/sinhala-whisper-asr-final.zip'

In [17]:
from IPython.display import FileLink

FileLink("sinhala-whisper-asr-final.zip")

In [18]:
# ================================
# 13. MODEL TESTING
# ================================

print("🧪 Testing the trained Sinhala ASR model...")

def test_model_inference(audio_file_path=None):
    """Test model with audio file"""
    
    # Load the saved model
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = WhisperForConditionalGeneration.from_pretrained(model_save_dir)
        processor = WhisperProcessor.from_pretrained(model_save_dir)
        model.to(device)
        print(f"✅ Model loaded successfully on {device}")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None
    
    # If no specific audio file provided, use a sample from validation set
    if audio_file_path is None and len(val_df) > 0:
        sample_row = val_df.iloc[0]
        audio_file_path = sample_row['audio']
        expected_text = sample_row['sentence']
        print(f"🎵 Testing with sample: {os.path.basename(audio_file_path)}")
        print(f"📝 Expected: {expected_text}")
    
    if audio_file_path and os.path.exists(audio_file_path):
        try:
            # Load and process audio
            audio_array, sr = librosa.load(audio_file_path, sr=16000)
            print(f"📊 Audio duration: {len(audio_array)/16000:.2f} seconds")
            
            # Process with model
            inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
            input_features = inputs.input_features.to(device)
            
            # Generate prediction
            with torch.no_grad():
                predicted_ids = model.generate(
                    input_features,
                    max_length=448,
                    num_beams=1,
                    do_sample=False
                )
            
            # Decode prediction
            transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            
            print(f"\n🎯 RESULT:")
            print(f"   🎵 Audio: {os.path.basename(audio_file_path)}")
            print(f"   📝 Prediction: '{transcription}'")
            print(f"   🤖 Model: Fine-tuned Whisper for Sinhala")
            
            return transcription
            
        except Exception as e:
            print(f"❌ Error during inference: {e}")
            return None
    else:
        print(f"❌ Audio file not found: {audio_file_path}")
        return None

# Test the model
result = test_model_inference()
if result:
    print(f"\n✅ Model testing completed successfully!")
else:
    print(f"\n⚠️ Model testing encountered issues")

🧪 Testing the trained Sinhala ASR model...


`generation_config` default values have been modified to match model-specific defaults: {'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257], 'forced_decoder_ids': [[1, None], [2, 50359]]}. If this is not desired, please set these values explicitly.


✅ Model loaded successfully on cuda
🎵 Testing with sample: af9a714290.flac
📝 Expected: එම භූමිය තුළට පිටස්තරයින්ට පැමිණීම තහනම් උණා
📊 Audio duration: 4.70 seconds
❌ Error during inference: You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively.

⚠️ Model testing encountered issues


In [None]:
# ================================
# TEST SPECIFIC AUDIO FILE: test_audio.wav
# ================================

print("🎯 Testing specific audio file: test_audio.wav")
print("=" * 50)

def test_audio_wav_file():
    """Test the specific test_audio.wav file"""
    
    # Define the audio file path
    audio_file_path = "test_audio.wav"
    
    # Check if file exists
    if not os.path.exists(audio_file_path):
        print(f"❌ File not found: {audio_file_path}")
        print("💡 Make sure test_audio.wav is in the same directory as this notebook")
        return None
    
    # Load the saved model
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = WhisperForConditionalGeneration.from_pretrained(model_save_dir)
        processor = WhisperProcessor.from_pretrained(model_save_dir)
        
        # Fix configuration issues
        model.config.forced_decoder_ids = None
        model.config.suppress_tokens = []
        
        model.to(device)
        print(f"✅ Model loaded successfully on {device}")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None
    
    try:
        print(f"\n🎵 Processing audio file: {audio_file_path}")
        
        # Load and analyze audio
        audio_array, sr = librosa.load(audio_file_path, sr=16000)
        audio_duration = len(audio_array) / 16000
        
        print(f"📊 Audio Information:")
        print(f"   📁 File: {audio_file_path}")
        print(f"   ⏱️ Duration: {audio_duration:.2f} seconds")
        print(f"   📈 Sample Rate: {sr} Hz")
        print(f"   🔢 Samples: {len(audio_array):,}")
        
        # Check audio quality
        max_amplitude = np.max(np.abs(audio_array))
        print(f"   🔊 Max Amplitude: {max_amplitude:.4f}")
        
        if audio_duration < 0.5:
            print(f"   ⚠️ Warning: Very short audio ({audio_duration:.2f}s)")
        elif audio_duration > 30:
            print(f"   ⚠️ Warning: Long audio ({audio_duration:.2f}s), may be truncated")
        
        # Process with model
        print(f"\n🔄 Processing with Sinhala ASR model...")
        inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
        input_features = inputs.input_features.to(device)
        
        print(f"   📊 Input features shape: {input_features.shape}")
        
        # Generate transcription
        print(f"🤖 Generating transcription...")
        with torch.no_grad():
            # Clear any conflicting configuration
            model.config.forced_decoder_ids = None
            model.config.suppress_tokens = []
            
            predicted_ids = model.generate(
                input_features,
                max_length=448,
                num_beams=1,
                do_sample=False,
                language="si",
                task="transcribe"
            )
        
        # Decode prediction
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        
        # Display results
        print(f"\n" + "=" * 60)
        print(f"🎯 TRANSCRIPTION RESULT")
        print(f"=" * 60)
        print(f"🎵 Audio File: {audio_file_path}")
        print(f"⏱️ Duration: {audio_duration:.2f} seconds")
        print(f"🤖 Model: Fine-tuned Whisper for Sinhala ASR")
        print(f"📝 Transcription: '{transcription}'")
        print(f"=" * 60)
        
        # Additional analysis
        if transcription.strip() == "":
            print(f"⚠️ Warning: Empty transcription - audio might be silent or unclear")
        else:
            word_count = len(transcription.split())
            char_count = len(transcription)
            print(f"📊 Text Analysis:")
            print(f"   📝 Characters: {char_count}")
            print(f"   🔤 Words: {word_count}")
            print(f"   ⏱️ Speaking rate: ~{word_count/(audio_duration/60):.1f} words/minute")
        
        return {
            'file': audio_file_path,
            'duration': audio_duration,
            'transcription': transcription,
            'word_count': len(transcription.split()),
            'char_count': len(transcription)
        }
        
    except Exception as e:
        print(f"❌ Error during processing: {e}")
        import traceback
        traceback.print_exc()
        return None

# Run the test
print("🚀 Starting test_audio.wav analysis...")
result = test_audio_wav_file()

if result:
    print(f"\n✅ test_audio.wav processed successfully!")
    print(f"💡 You can now use this transcription result for further analysis")
else:
    print(f"\n❌ Failed to process test_audio.wav")
    print(f"💡 Check if the file exists and is a valid audio file")

print(f"\n🔚 Test completed!")

In [19]:
# ================================
# 14. TRAINING SUMMARY
# ================================

print("\n" + "="*60)
print("🎉 SINHALA ASR TRAINING COMPLETED!")
print("="*60)
print(f"📊 Training Data: {len(train_df):,} samples")
print(f"🧪 Validation Data: {len(val_df):,} samples")
print(f"🤖 Model: Whisper-base fine-tuned for Sinhala")
print(f"💾 Saved to: {model_save_dir}")
print(f"💻 Device: {device}")
print(f"⏱️ Epochs: {training_args.num_train_epochs}")
print("="*60)
print("🚀 Your Sinhala ASR model is ready for use!")
print("💡 You can now use this model for Sinhala speech recognition tasks.")
print("="*60)


🎉 SINHALA ASR TRAINING COMPLETED!
📊 Training Data: 8,000 samples
🧪 Validation Data: 2,000 samples
🤖 Model: Whisper-base fine-tuned for Sinhala
💾 Saved to: ./sinhala-whisper-asr-final
💻 Device: cuda
⏱️ Epochs: 12
🚀 Your Sinhala ASR model is ready for use!
💡 You can now use this model for Sinhala speech recognition tasks.
