In [2]:
# ================================
# SINHALA ASR TRAINING WITH WHISPER - KAGGLE VERSION
# ================================

# Install required packages
%pip install transformers datasets evaluate jiwer
%pip install librosa scikit-learn pandas
%pip install soundfile
%pip install tensorboard
%pip install accelerate

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━

In [3]:
# ================================
# IMPORTS
# ================================

import pandas as pd
import numpy as np
import torch
import librosa
import os
import glob
from datasets import Dataset, Audio
from transformers import (
    WhisperFeatureExtractor, 
    WhisperTokenizer, 
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from jiwer import wer, cer, mer

2025-07-24 03:55:42.756490: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753329343.099337      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753329343.236106      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
# ================================
# 1. DATA PREPARATION WITH KAGGLE PATHS
# ================================

# Define Kaggle paths
AUDIO_small_PATH = "/kaggle/input/large-sinhala-asr-training-dataset"
CSV_small_PATH = "/kaggle/input/dataset"

print("🚀 Loading Sinhala ASR data for Kaggle training...")
print(f"📁 Audio small path: {AUDIO_small_PATH}")
print(f"📁 CSV small path: {CSV_small_PATH}")

# Verify paths exist
if not os.path.exists(AUDIO_small_PATH):
    print(f"❌ Audio path not found: {AUDIO_small_PATH}")
    raise FileNotFoundError("Audio dataset not found")
    
if not os.path.exists(CSV_small_PATH):
    print(f"❌ CSV path not found: {CSV_small_PATH}")
    raise FileNotFoundError("CSV files not found")

print("✅ All paths verified successfully!")

# Load CSV files
train_csv = os.path.join(CSV_small_PATH, "10-train.csv")
test_csv = os.path.join(CSV_small_PATH, "10-test.csv")

train_df = pd.read_csv(train_csv)
val_df = pd.read_csv(test_csv)

print(f"\n📊 Dataset Information:")
print(f"   🏋️ Training samples: {len(train_df):,}")
print(f"   🧪 Validation samples: {len(val_df):,}")
print(f"   📈 Total samples: {len(train_df) + len(val_df):,}")

# Check data structure
print(f"\n🔍 Data Structure:")
print(f"   📋 Train columns: {list(train_df.columns)}")
print(f"   📋 Val columns: {list(val_df.columns)}")

# Show sample data
print(f"\n📝 Sample Training Data:")
print(train_df.head(3))

🚀 Loading Sinhala ASR data for Kaggle training...
📁 Audio small path: /kaggle/input/large-sinhala-asr-training-dataset
📁 CSV small path: /kaggle/input/dataset
✅ All paths verified successfully!

📊 Dataset Information:
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000
   📈 Total samples: 10,000

🔍 Data Structure:
   📋 Train columns: ['file', 'sentence_cleaned']
   📋 Val columns: ['file', 'sentence_cleaned']

📝 Sample Training Data:
                                  file              sentence_cleaned
0  asr_sinhala/data/aa/aaaee62687.flac  අක්කයි මායි දෙන්නත් ළඟ නැතුව
1  asr_sinhala/data/07/07031079ca.flac      ශ්‍රී ලංකාව බැහැර කොට ඇත
2  asr_sinhala/data/31/3128fc4733.flac  ඔන්න ඔය විදිහට යෙදෙන නැකතෙන්


In [5]:
# ================================
# 2. PATH CONVERSION AND VALIDATION
# ================================

def convert_audio_path(relative_path, small_path=AUDIO_small_PATH):
    """Convert relative audio path to absolute Kaggle path"""
    if pd.isna(relative_path) or relative_path == "":
        return None
    
    # Handle already absolute paths
    if os.path.isabs(relative_path):
        return relative_path
    
    # Join with small path
    absolute_path = os.path.join(small_path, relative_path)
    return os.path.normpath(absolute_path)

def verify_audio_file(audio_path):
    """Verify if audio file exists and is readable"""
    try:
        if not os.path.exists(audio_path):
            return False
        audio, sr = librosa.load(audio_path, sr=16000)
        return len(audio) > 0
    except Exception as e:
        return False

# Ensure consistent column naming
audio_col = train_df.columns[0]
text_col = train_df.columns[1]

print(f"🔄 Using columns: '{audio_col}' as audio, '{text_col}' as sentence")

# Rename columns for consistency
train_df = train_df[[audio_col, text_col]].copy()
val_df = val_df[[audio_col, text_col]].copy()
train_df.columns = ["audio", "sentence"]
val_df.columns = ["audio", "sentence"]

# Convert relative paths to absolute paths
print(f"\n🔗 Converting audio paths...")
train_df['audio'] = train_df['audio'].apply(convert_audio_path)
val_df['audio'] = val_df['audio'].apply(convert_audio_path)

# Remove rows with missing data
initial_train_size = len(train_df)
initial_val_size = len(val_df)

train_df = train_df.dropna()
val_df = val_df.dropna()
train_df = train_df[train_df["sentence"].str.strip() != ""]
val_df = val_df[val_df["sentence"].str.strip() != ""]

print(f"\n🔍 Validating audio files (this may take a moment)...")

# Validate audio files (sample check for speed)
sample_size = min(1000, len(train_df))
train_sample = train_df.head(sample_size)
valid_count = sum(verify_audio_file(path) for path in train_sample['audio'])

print(f"📊 Audio validation results:")
print(f"   ✅ Valid files in sample: {valid_count}/{sample_size}")
print(f"   📈 Estimated validity rate: {valid_count/sample_size*100:.1f}%")

if valid_count < sample_size * 0.1:  # Less than 10% valid
    print("⚠️ Warning: Low audio file validity rate detected")
    print("💡 Check if audio paths are correctly mapped")

print(f"\n🧹 Data Cleaning Results:")
print(f"   🏋️ Training: {initial_train_size} → {len(train_df)} samples")
print(f"   🧪 Validation: {initial_val_size} → {len(val_df)} samples")

# Display final sample data
print(f"\n📝 Final Sample Data:")
print(train_df.head(3))

🔄 Using columns: 'file' as audio, 'sentence_cleaned' as sentence

🔗 Converting audio paths...

🔍 Validating audio files (this may take a moment)...
📊 Audio validation results:
   ✅ Valid files in sample: 1000/1000
   📈 Estimated validity rate: 100.0%

🧹 Data Cleaning Results:
   🏋️ Training: 8000 → 8000 samples
   🧪 Validation: 2000 → 2000 samples

📝 Final Sample Data:
                                               audio  \
0  /kaggle/input/large-sinhala-asr-training-datas...   
1  /kaggle/input/large-sinhala-asr-training-datas...   
2  /kaggle/input/large-sinhala-asr-training-datas...   

                       sentence  
0  අක්කයි මායි දෙන්නත් ළඟ නැතුව  
1      ශ්‍රී ලංකාව බැහැර කොට ඇත  
2  ඔන්න ඔය විදිහට යෙදෙන නැකතෙන්  


In [6]:
# ================================
# 3. WHISPER PROCESSOR SETUP
# ================================

print("🤖 Setting up Whisper processor for Sinhala...")

# Initialize Whisper components for Sinhala
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="si", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="si", task="transcribe")

print("✅ Whisper processor setup completed!")
print(f"   🌐 Language: Sinhala (si)")
print(f"   🎯 Task: Transcribe")
print(f"   📏 Model: whisper-small")

🤖 Setting up Whisper processor for Sinhala...


preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

✅ Whisper processor setup completed!
   🌐 Language: Sinhala (si)
   🎯 Task: Transcribe
   📏 Model: whisper-small


In [7]:
# ================================
# 4. AUDIO PREPROCESSING FUNCTIONS
# ================================

def normalize_audio(audio_array):
    """Normalize audio to prevent clipping"""
    max_val = np.max(np.abs(audio_array))
    if max_val > 0:
        return audio_array / max_val
    return audio_array

def add_noise_augmentation(audio_array, noise_factor=0.005):
    """Add Gaussian noise for robustness"""
    noise = np.random.normal(0, noise_factor, audio_array.shape)
    return audio_array + noise

# Filter valid audio files for training
print("🔍 Filtering valid audio files for training...")

valid_train = []
valid_val = []

# Process training data
print("Processing training data...")
for idx, row in train_df.iterrows():
    if verify_audio_file(row['audio']):
        valid_train.append(row)
    else:
        if idx < 10:  # Only print first 10 invalid files
            print(f"   ⚠️ Invalid audio: {os.path.smallname(row['audio'])}")

# Process validation data
print("Processing validation data...")
for idx, row in val_df.iterrows():
    if verify_audio_file(row['audio']):
        valid_val.append(row)
    else:
        if idx < 10:  # Only print first 10 invalid files
            print(f"   ⚠️ Invalid audio: {os.path.smallname(row['audio'])}")

# Update dataframes
train_df = pd.DataFrame(valid_train)
val_df = pd.DataFrame(valid_val)

print(f"\n📊 Final Valid Dataset:")
print(f"   🏋️ Training samples: {len(train_df):,}")
print(f"   🧪 Validation samples: {len(val_df):,}")
print(f"   📈 Total valid samples: {len(train_df) + len(val_df):,}")

if len(train_df) == 0:
    raise ValueError("No valid training samples found! Check audio paths.")

🔍 Filtering valid audio files for training...
Processing training data...
Processing validation data...

📊 Final Valid Dataset:
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000
   📈 Total valid samples: 10,000


In [8]:
# ================================
# 5. DATASET CREATION
# ================================

print("📦 Creating HuggingFace datasets...")

# Create datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

# Cast audio columns with target sampling rate
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
val_dataset = val_dataset.cast_column("audio", Audio(sampling_rate=16000))

def prepare_dataset(examples):
    """Prepare dataset with robust preprocessing"""
    # Load and process audio
    audio = examples["audio"]
    audio_array = audio["array"]
    
    # Apply normalization
    audio_array = normalize_audio(audio_array)
    
    # Optional: Add noise augmentation (uncomment if needed)
    # audio_array = add_noise_augmentation(audio_array, noise_factor=0.001)
    
    # Ensure audio length constraints
    min_length = 1000  # ~0.06 seconds at 16kHz
    max_length = 480000  # ~30 seconds at 16kHz
    
    if len(audio_array) < min_length:
        # Pad short audio
        audio_array = np.pad(audio_array, (0, min_length - len(audio_array)), 'constant')
    elif len(audio_array) > max_length:
        # Truncate long audio
        audio_array = audio_array[:max_length]
    
    # Compute log-Mel input features
    examples["input_features"] = feature_extractor(
        audio_array, sampling_rate=16000
    ).input_features[0]
    
    # Clean up audio data
    del examples["audio"]
    
    # Process text
    sentences = examples["sentence"]
    
    # Clean and normalize text
    if isinstance(sentences, str):
        sentences = sentences.strip()
    
    # Encode target text to label ids
    examples["labels"] = tokenizer(sentences).input_ids
    del examples["sentence"]
    
    return examples

# Apply preprocessing
print("🔄 Preprocessing training dataset...")
train_dataset = train_dataset.map(prepare_dataset, num_proc=1)

print("🔄 Preprocessing validation dataset...")
val_dataset = val_dataset.map(prepare_dataset, num_proc=1)

print("✅ Dataset preprocessing completed!")

📦 Creating HuggingFace datasets...
🔄 Preprocessing training dataset...


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

🔄 Preprocessing validation dataset...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

✅ Dataset preprocessing completed!


In [9]:
# ================================
# 6. DATA COLLATOR
# ================================

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Handle input features
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        # Handle labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        
        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
print("✅ Data collator configured!")

✅ Data collator configured!


In [10]:
# ================================
# 7. EVALUATION METRICS
# ================================

def compute_metrics(pred):
    """Compute comprehensive ASR evaluation metrics"""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode token IDs to strings
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Calculate metrics
    wer_score = wer(label_str, pred_str) * 100
    cer_score = cer(label_str, pred_str) * 100
    
    # Sentence Error Rate
    ser_score = (
        sum(ref.strip() != pred.strip() for ref, pred in zip(label_str, pred_str))
        / len(label_str)
    ) * 100

    return {
        "wer": wer_score,
        "cer": cer_score,
        "ser": ser_score,
    }

print("📊 Evaluation metrics configured!")

📊 Evaluation metrics configured!


In [11]:
# ================================
# 8. MODEL SETUP
# ================================

print("🤖 Loading pre-trained Whisper model...")

try:
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise

# Configure model for Sinhala
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"💻 Using device: {device}")

print(f"📋 Model Information:")
print(f"   🏗️ Architecture: {model.config.model_type}")
print(f"   📏 Model size: whisper-small")
print(f"   🌐 Target language: Sinhala")

🤖 Loading pre-trained Whisper model...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

✅ Model loaded successfully!
💻 Using device: cuda
📋 Model Information:
   🏗️ Architecture: whisper
   📏 Model size: whisper-small
   🌐 Target language: Sinhala


In [12]:
# ================================
# 9. TRAINING ARGUMENTS
# ================================

# Output directory
output_dir = "./whisper-sinhala-asr-model"

# training_args = Seq2SeqTrainingArguments(
#     output_dir=output_dir,
    
#     # Training schedule
#     num_train_epochs=12,
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=8,
#     gradient_accumulation_steps=2,
    
#     # Optimization
#     learning_rate=1.7e-05,
#     warmup_steps=500,
#     lr_scheduler_type="linear",
#     weight_decay=0.01,
    
#     # Evaluation and saving
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=3,
#     load_best_model_at_end=True,
#     metric_for_best_model="wer",
#     greater_is_better=False,
    
#     # Memory optimization
#     gradient_checkpointing=True,
#     fp16=torch.cuda.is_available(),  # Use FP16 if CUDA available
#     dataloader_pin_memory=False,
    
#     # Generation settings
#     predict_with_generate=True,
#     generation_max_length=225,
    
#     # Logging
#     logging_steps=25,
#     logging_strategy="steps",
#     report_to=["tensorboard"],
    
#     # Additional settings
#     remove_unused_columns=False,
#     label_names=["labels"],
# )

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    
    # Evaluation and saving
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    
    # Batch sizes - adjusted for noisy data
    per_device_train_batch_size=8,  # Reduced for stability
    per_device_eval_batch_size=8,
    # gradient_accumulation_steps=1,   # Compensate for smaller batch size
    
    # Learning rate - slightly lower for noisy data
    learning_rate=1.7e-05,
    warmup_steps=100,
    lr_scheduler_type="linear",
    
    # Memory optimization
    gradient_checkpointing=True,
    fp16=True,
    dataloader_pin_memory=False,
    
    # Training duration
    num_train_epochs=4,  # Reduced epochs for noisy data
    
    # Generation settings
    predict_with_generate=True,
    generation_max_length=225,
    
    # Logging
    logging_steps=50,
    report_to=["tensorboard"],
    
    # Additional stability settings
    max_grad_norm=1.0,
    weight_decay=0.01,
)

print("⚙️ Training arguments configured!")
print(f"   📁 Output directory: {output_dir}")
print(f"   🔄 Epochs: {training_args.num_train_epochs}")
print(f"   📦 Batch size: {training_args.per_device_train_batch_size}")
print(f"   📈 Learning rate: {training_args.learning_rate}")

⚙️ Training arguments configured!
   📁 Output directory: ./whisper-sinhala-asr-model
   🔄 Epochs: 4
   📦 Batch size: 8
   📈 Learning rate: 1.7e-05


In [13]:
# ================================
# 10. TRAINER SETUP
# ================================

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("🏃‍♂️ Trainer configured successfully!")
print(f"   🏋️ Training samples: {len(train_dataset):,}")
print(f"   🧪 Validation samples: {len(val_dataset):,}")

  trainer = Seq2SeqTrainer(


🏃‍♂️ Trainer configured successfully!
   🏋️ Training samples: 8,000
   🧪 Validation samples: 2,000


In [None]:
# ================================
# 11. TRAINING EXECUTION
# ================================

print("🚀 Starting Sinhala ASR training...")
print("=" * 50)
print(f"📊 Dataset: {len(train_dataset):,} training, {len(val_dataset):,} validation")
print(f"🤖 Model: Whisper-small fine-tuned for Sinhala")
print(f"💻 Device: {device}")
print(f"⏱️ Estimated time: ~{len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs // 100} minutes")
print("=" * 50)

# Start training
try:
    trainer.train()
    print("✅ Training completed successfully!")
except Exception as e:
    print(f"❌ Training failed: {e}")
    raise

🚀 Starting Sinhala ASR training...
📊 Dataset: 8,000 training, 2,000 validation
🤖 Model: Whisper-small fine-tuned for Sinhala
💻 Device: cuda
⏱️ Estimated time: ~40 minutes


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Epoch,Training Loss,Validation Loss,Wer,Cer,Ser
1,0.2125,0.211231,57.217701,15.330729,91.8
2,0.1346,0.17916,51.187159,13.501484,87.6
3,0.0675,0.184831,49.526251,12.823498,86.9


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [None]:
# ================================
# 12. MODEL SAVING
# ================================

import shutil

# Clean up any existing model directory
model_save_dir = "./sinhala-whisper-asr-final"
if os.path.exists(model_save_dir):
    try:
        shutil.rmtree(model_save_dir)
        print(f"🗑️ Cleaned up existing directory: {model_save_dir}")
    except:
        pass

# Create fresh directory
os.makedirs(model_save_dir, exist_ok=True)

try:
    print("💾 Saving trained Sinhala ASR model...")
    
    # Save model
    model.save_pretrained(model_save_dir, safe_serialization=False)
    print("✅ Model saved successfully!")
    
    # Save processor
    processor.save_pretrained(model_save_dir)
    print("✅ Processor saved successfully!")
    
    # Verify saved files
    saved_files = os.listdir(model_save_dir)
    print(f"📁 Saved files: {saved_files}")
    
    print(f"\n🎉 Sinhala ASR model training completed!")
    print(f"📁 Model saved to: {model_save_dir}")
    
except Exception as e:
    print(f"❌ Error saving model: {e}")
    print("💡 Trying alternative save method...")
    
    # Alternative save method
    torch.save(model.state_dict(), os.path.join(model_save_dir, "pytorch_model.bin"))
    processor.save_pretrained(model_save_dir)
    print("✅ Alternative save completed!")

In [None]:
model_save_dir = "./sinhala-whisper-asr-final"
model.save_pretrained(model_save_dir)
processor.save_pretrained(model_save_dir)


In [None]:
import shutil

shutil.make_archive("sinhala-whisper-asr-final", 'zip', model_save_dir)

In [None]:
from IPython.display import FileLink

FileLink("sinhala-whisper-asr-final.zip")

In [None]:
# ================================
# 12. CREATE DOWNLOADABLE ARCHIVES FOR KAGGLE
# ================================

import os
import shutil
import zipfile
from IPython.display import FileLink

def create_downloadable_archive(source_dir, archive_name):
    """Create a downloadable zip archive"""
    if os.path.exists(source_dir):
        # Create zip file
        shutil.make_archive(archive_name, 'zip', source_dir)
        zip_path = f"{archive_name}.zip"
        
        if os.path.exists(zip_path):
            file_size = os.path.getsize(zip_path) / (1024 * 1024)  # Size in MB
            print(f"✅ Created {zip_path} ({file_size:.2f} MB)")
            return zip_path
        else:
            print(f"❌ Failed to create {zip_path}")
            return None
    else:
        print(f"❌ Source directory {source_dir} does not exist")
        return None

# Create downloadable archives
print("\n🔄 Creating downloadable model archives...")

# 1. Final trained model
final_model_zip = create_downloadable_archive(
    "./whisper-sinhala-asr-model-final", 
    "whisper-sinhala-asr-model-final"
)

# 2. Last checkpoint from training
checkpoint_dir = "./whisper-sinhala-asr-model"
if os.path.exists(checkpoint_dir):
    # Find the last checkpoint
    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        # Sort by checkpoint number
        checkpoints.sort(key=lambda x: int(x.split("-")[1]))
        last_checkpoint = checkpoints[-1]
        last_checkpoint_path = os.path.join(checkpoint_dir, last_checkpoint)
        
        print(f"📂 Found last checkpoint: {last_checkpoint}")
        
        # Create archive for last checkpoint
        last_checkpoint_zip = create_downloadable_archive(
            last_checkpoint_path, 
            f"whisper-sinhala-asr-model-{last_checkpoint}"
        )
    else:
        print("❌ No checkpoints found in training directory")
        last_checkpoint_zip = None
else:
    print("❌ Training directory does not exist")
    last_checkpoint_zip = None

# 3. Training logs and metrics
if os.path.exists("./whisper-sinhala-asr-model/runs"):
    tensorboard_logs_zip = create_downloadable_archive(
        "./whisper-sinhala-asr-model/runs", 
        "whisper-sinhala-asr-model-tensorboard-logs"
    )
else:
    tensorboard_logs_zip = None

# Create a comprehensive package with all files
print("\n📦 Creating comprehensive model package...")
comprehensive_package = "whisper-sinhala-asr-model-complete"
os.makedirs(comprehensive_package, exist_ok=True)

# Copy final model
if os.path.exists("./whisper-sinhala-asr-model-final"):
    shutil.copytree("./whisper-sinhala-asr-model-final", 
                    f"{comprehensive_package}/final_model", 
                    dirs_exist_ok=True)

# Copy last checkpoint
if os.path.exists(last_checkpoint_path):
    shutil.copytree(last_checkpoint_path, 
                    f"{comprehensive_package}/last_checkpoint", 
                    dirs_exist_ok=True)


# Create comprehensive zip
comprehensive_zip = create_downloadable_archive(
    comprehensive_package, 
    "whisper-sinhala-asr-model-complete"
)