## ⚡ QUICK START - Execute These Cells in Order

**IMPORTANT: You must run these cells sequentially before training:**

1. ✅ **Cell 6** - Load Dataset from S3 (creates `datasets` variable)
2. ✅ **Cell 9** - Model Configuration (creates `tiny_feature_extractor`, `tiny_tokenizer`)  
3. ✅ **Cell 8** - Preprocess Dataset (converts audio to training format)
4. ✅ **Cell 10** - Training Setup
5. ✅ **Cell 11** - Pre-Training Checklist

**Current Status:**
- ❌ `tiny_feature_extractor` missing → **RUN CELL 9 FIRST**
- ❌ `tiny_tokenizer` missing → **RUN CELL 9 FIRST**

**After Cell 9, Cell 8 (preprocessing) will work!**

# 🚀 Whisper Fine-tuning on SageMaker ml.g4dn.2xlarge

## Khmer Speech Recognition with Mega Dataset

This notebook fine-tunes Whisper Tiny and Base models on Amazon SageMaker using the comprehensive Mega Khmer dataset (160+ hours, 110K+ samples) on ml.g4dn.2xlarge instances.

### 🎯 What We'll Accomplish:
- ✅ Fine-tune Whisper Tiny (~39M params)
- ✅ Fine-tune Whisper Base (~74M params)  
- ✅ GPU-optimized for ml.g4dn.2xlarge (16GB VRAM)
- ✅ Production-ready model deployment

### 📊 Dataset Overview:
- **Total Duration**: 160+ hours
- **Samples**: 110,000+
- **Sources**: Original + LSR42 + Rinabuoy datasets
- **Language**: Khmer (km)

## 🚀 SageMaker Execution Guide

### 📋 **IMPORTANT: Run Cells in This Order**

**Before Training, Execute These Cells Sequentially:**

1. **Cell 3**: 🔧 SageMaker Environment Setup 
2. **Cell 5**: 📦 Package Installation
3. **Cell 6**: 📊 Dataset Loading from S3  
4. **Cell 8**: 🔄 Dataset Preprocessing for Training
5. **Cell 9**: 🤖 Model Configuration (Tiny & Base)
6. **Cell 10**: ⚙️ Training Arguments Setup
7. **Cell 11**: ✅ Pre-Training Checklist (verify all variables)

**Then Choose Your Training:**
- **Cell 12**: 🎯 Train Whisper Tiny (~3-4 hours)
- **Cell 14**: 🎯 Train Whisper Base (~5-6 hours)

### ⚠️ **Common Issues:**
- **`NameError: name 'datasets' is not defined`** → Run cells 3, 5, 6 first
- **Package installation errors** → Use conda commands in cell 5
- **S3 access errors** → Verify your bucket name in cell 3

### 💡 **Tips:**
- This notebook is designed for **ml.g4dn.2xlarge** instances
- Training uses your uploaded S3 dataset: `s3://pan-sea-khmer-speech-dataset-sg/`
- Both models will be saved locally and can be uploaded to S3

## 1. Setup SageMaker Environment

## 2. Install Required Dependencies

## 3. Load and Prepare Mega Dataset

In [4]:
# ✅ Load Dataset from S3 CSV Manifests (UPDATED)
import json
import pandas as pd
import librosa
import numpy as np
from pathlib import Path
from datasets import Dataset, DatasetDict, Audio
from transformers import WhisperFeatureExtractor, WhisperTokenizer
import torch
from torch.utils.data import DataLoader
import boto3
import io

class S3CSVDatasetLoader:
    """Load dataset from uploaded CSV manifest files in S3"""
    
    def __init__(self):
        self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
        self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="km", task="transcribe")
        self.s3_client = boto3.client('s3', region_name='ap-southeast-1')
        self.bucket_name = 'pan-sea-khmer-speech-dataset-sg'
        
    def check_csv_manifests(self):
        """Check if CSV manifest files exist in S3"""
        print("🔍 Checking for uploaded CSV manifest files...")
        
        csv_files = {
            'train': 'khmer-whisper-dataset/data/train/train_manifest.csv',
            'validation': 'khmer-whisper-dataset/data/validation/validation_manifest.csv',
            'test': 'khmer-whisper-dataset/data/test/test_manifest.csv'
        }
        
        found_files = {}
        
        for split, s3_key in csv_files.items():
            try:
                response = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
                size_mb = response['ContentLength'] / 1024 / 1024
                print(f"✅ {split}_manifest.csv: {size_mb:.2f} MB")
                found_files[split] = s3_key
            except Exception as e:
                print(f"❌ {split}_manifest.csv: Not found")
        
        return found_files
    
    def load_csv_from_s3(self, split, s3_key):
        """Download and load CSV manifest from S3"""
        try:
            print(f"📥 Loading {split} manifest from S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
            csv_content = response['Body'].read()
            
            # Read CSV into pandas
            df = pd.read_csv(io.BytesIO(csv_content))
            print(f"✅ Loaded {len(df)} entries")
            return df
            
        except Exception as e:
            print(f"❌ Failed to load {split} CSV: {e}")
            return None
    
    def load_datasets(self, max_duration=20.0, min_duration=1.0, max_samples=None):
        """Load datasets from S3 CSV manifests"""
        print("🚀 Loading Dataset from S3 CSV Manifests")
        print("=" * 50)
        
        # Check which CSV files exist
        csv_files = self.check_csv_manifests()
        
        if not csv_files:
            print("❌ No CSV manifest files found!")
            print("💡 Make sure you uploaded the CSV files correctly")
            return self.create_dummy_datasets()
        
        dataset_dict = {}
        
        for split in ['train', 'validation', 'test']:
            if split not in csv_files:
                print(f"⚠️ Skipping {split} - no CSV found")
                continue
                
            print(f"\n📊 Processing {split} split...")
            
            # Load CSV from S3
            df = self.load_csv_from_s3(split, csv_files[split])
            if df is None:
                continue
            
            # Show sample of CSV structure for debugging
            print(f"📋 CSV structure for {split}:")
            if len(df) > 0:
                sample_row = df.iloc[0]
                print(f"   Sample audio_filepath: {sample_row['audio_filepath']}")
                print(f"   Available columns: {list(df.columns)}")
            
            # Filter by duration (keep reasonable lengths)
            original_count = len(df)
            df_filtered = df[
                (df['duration'] >= min_duration) & 
                (df['duration'] <= max_duration)
            ]
            print(f"📊 Duration filter ({min_duration}-{max_duration}s): {original_count:,} → {len(df_filtered):,} samples")
            
            # Limit samples for memory efficiency
            if max_samples and len(df_filtered) > max_samples:
                df_filtered = df_filtered.head(max_samples)
                print(f"🔧 Limited to {max_samples} samples for testing")
            else:
                print(f"🎯 Using ALL {len(df_filtered):,} samples from {split}")
            
            # Convert to dataset format with corrected S3 paths
            data = []
            for _, row in df_filtered.iterrows():
                # Create S3 audio path - remove double audio/ paths
                audio_filename = row['audio_filepath']
                
                # Clean up the audio filename - remove duplicate audio/ paths
                if audio_filename.startswith('audio/audio/'):
                    # Remove the first audio/ to fix duplication
                    audio_filename = audio_filename[6:]  # Remove first 'audio/'
                elif audio_filename.startswith('audio/'):
                    # Keep as is
                    pass
                else:
                    # Add audio/ prefix if missing
                    audio_filename = f"audio/{audio_filename}"
                
                s3_audio_path = f"s3://{self.bucket_name}/khmer-whisper-dataset/data/{split}/{audio_filename}"
                
                entry = {
                    'audio_filepath': s3_audio_path,
                    'text': str(row['text']),
                    'duration': float(row['duration']),
                    'language': row.get('language', 'km'),
                    'source': row.get('source', 'mega_dataset'),
                    'speaker': row.get('speaker', 'unknown')
                }
                data.append(entry)
            
            # Create HuggingFace dataset
            if data:
                dataset = Dataset.from_list(data)
                try:
                    dataset = dataset.cast_column("audio_filepath", Audio(sampling_rate=16000))
                    print(f"✅ Created dataset with {len(data)} samples")
                except Exception as e:
                    print(f"⚠️ Audio casting warning: {e}")
                
                dataset_dict[split] = dataset
            else:
                print(f"❌ No valid data for {split}")
        
        return DatasetDict(dataset_dict) if dataset_dict else self.create_dummy_datasets()
    
    def create_dummy_datasets(self):
        """Create small dummy datasets for testing"""
        print("🎯 Creating dummy datasets for testing...")
        
        dataset_dict = {}
        sizes = {'train': 100, 'validation': 30, 'test': 30}
        
        for split, size in sizes.items():
            data = []
            for i in range(size):
                data.append({
                    'text': f'សាកល្បង ទី {i+1}',  # "Test number i" in Khmer
                    'audio_filepath': f'dummy_audio_{split}_{i}.wav',
                    'duration': 2.0 + (i % 3),
                    'language': 'km',
                    'source': 'dummy',
                    'speaker': 'unknown'
                })
            
            dataset = Dataset.from_list(data)
            dataset_dict[split] = dataset
            print(f"📝 {split}: {size} dummy samples")
        
        return DatasetDict(dataset_dict)

# Initialize loader and load datasets
print("🚀 Initializing S3 CSV Dataset Loader...")
dataset_loader = S3CSVDatasetLoader()

# Load FULL dataset (all samples)
datasets = dataset_loader.load_datasets(
    max_duration=20.0,
    min_duration=1.0, 
    max_samples=None  # Use ALL samples - no limit!
)

# Show results
if datasets:
    print(f"\n📊 Final Dataset Summary:")
    print("=" * 40)
    
    total_samples = 0
    for split, dataset in datasets.items():
        count = len(dataset)
        total_samples += count
        
        if 'duration' in dataset.column_names:
            hours = sum(dataset['duration']) / 3600
            print(f"  {split.capitalize()}: {count:,} samples ({hours:.2f} hours)")
        else:
            print(f"  {split.capitalize()}: {count:,} samples")
    
    print(f"  Total: {total_samples:,} samples")
    print(f"\n✅ Dataset variable 'datasets' created successfully!")
    print(f"🎯 You can now proceed with training!")
    
    # Show sample
    if total_samples > 0:
        sample_split = list(datasets.keys())[0]
        sample = datasets[sample_split][0]
        print(f"\n🔍 Sample from {sample_split}:")
        print(f"   Text: {sample['text'][:100]}...")
        print(f"   Duration: {sample['duration']} seconds")
        
        # Handle audio path safely (could be string, list, or other type)
        audio_path = sample['audio_filepath']
        if isinstance(audio_path, str):
            print(f"   Audio path: {audio_path[:60]}...")
        else:
            print(f"   Audio path: {str(audio_path)[:60]}...")
        

else:
    print(f"\n" + "="*50)
    print("❌ Dataset loading failed!")
    print("💡 Check CSV uploads and try again")

🚀 Initializing S3 CSV Dataset Loader...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


🚀 Loading Dataset from S3 CSV Manifests
🔍 Checking for uploaded CSV manifest files...
✅ train_manifest.csv: 20.76 MB
✅ validation_manifest.csv: 2.39 MB
✅ test_manifest.csv: 2.40 MB

📊 Processing train split...
📥 Loading train manifest from S3...
✅ Loaded 90606 entries
📋 CSV structure for train:
   Sample audio_filepath: audio/c01b83b9-c71b-45e5-b68f-9b9e86a29cab_1_0.wav
   Available columns: ['audio_filepath', 'text', 'duration', 'language', 'speaker', 'session_id', 'source']
📊 Duration filter (1.0-20.0s): 90,606 → 89,034 samples
🎯 Using ALL 89,034 samples from train
✅ Created dataset with 89034 samples

📊 Processing validation split...
📥 Loading validation manifest from S3...
✅ Loaded 9973 entries
📋 CSV structure for validation:
   Sample audio_filepath: audio/88bda0d0-c935-44c6-89ba-c71f40f5d02b_0_1.wav
   Available columns: ['audio_filepath', 'text', 'duration', 'language', 'speaker', 'session_id', 'source']
📊 Duration filter (1.0-20.0s): 9,973 → 9,917 samples
🎯 Using ALL 9,917 samp

## 4. Preprocess Dataset for Training

In [None]:
# 🔄 CRITICAL: Preprocess Raw Dataset for Training
# This converts raw datasets (audio_filepath, text) -> training format (input_features, labels)

import torch
import librosa
import boto3
from io import BytesIO

def preprocess_dataset_for_training(raw_datasets, feature_extractor, tokenizer, max_samples=None):
    """
    Convert raw dataset with audio_filepath/text to training format with input_features/labels
    NO DUMMY DATA - FAIL FAST ON REAL ERRORS
    """
    print("🔄 PREPROCESSING: Converting raw datasets to training format...")
    print("=" * 60)
    
    # Initialize S3 client for audio loading
    s3_client = boto3.client('s3', region_name='ap-southeast-1')
    
    processed_datasets = {}
    
    for split_name, dataset in raw_datasets.items():
        print(f"\n📊 Processing {split_name} split...")
        
        # Debug: Show first example structure
        if len(dataset) > 0:
            sample = dataset[0]
            print(f"🔍 DEBUG - Sample structure:")
            for key, value in sample.items():
                print(f"   {key}: {type(value)} = {str(value)[:100]}...")
        
        # Limit samples for memory efficiency during testing
        if max_samples and len(dataset) > max_samples:
            dataset = dataset.select(range(max_samples))
            print(f"🔧 Limited to {max_samples} samples for preprocessing test")
        
        processed_examples = []
        errors = 0
        
        for i, example in enumerate(dataset):
            try:
                # Show progress
                if i % 100 == 0:
                    print(f"   Processing {i}/{len(dataset)}...")
                
                # Initialize audio_array variable
                audio_array = None
                sampling_rate = 16000
                
                # Get audio file path - handle both dict and string formats
                audio_path = example['audio_filepath']
                
                # If audio_filepath is a dict (from HuggingFace Audio type), extract the path
                if isinstance(audio_path, dict):
                    if 'path' in audio_path:
                        audio_path = audio_path['path']
                    elif 'array' in audio_path:
                        # Audio data is already loaded, use it directly
                        audio_array = audio_path['array']
                        sampling_rate = audio_path.get('sampling_rate', 16000)
                        
                        # Resample if needed
                        if sampling_rate != 16000:
                            audio_array = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=16000)
                            sampling_rate = 16000
                    else:
                        print(f"⚠️ Unexpected audio dict format: {list(audio_path.keys())}")
                        continue
                
                # If we don't have audio_array yet, try to load from path
                if audio_array is None:
                    # Handle S3 paths
                    if isinstance(audio_path, str) and audio_path.startswith('s3://'):
                        # Parse S3 path: s3://bucket/path/to/file.wav
                        path_parts = audio_path[5:].split('/', 1)  # Remove 's3://'
                        bucket_name = path_parts[0]
                        s3_key = path_parts[1]
                        
                        # Download audio from S3
                        try:
                            response = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
                            audio_bytes = response['Body'].read()
                            
                            # Load audio with librosa from bytes
                            audio_array, sampling_rate = librosa.load(
                                BytesIO(audio_bytes), 
                                sr=16000
                            )
                            
                        except Exception as s3_error:
                            print(f"❌ S3 ERROR for {s3_key}: {s3_error}")
                            errors += 1
                            if errors > 10:  # Stop after too many errors
                                raise Exception(f"Too many S3 errors ({errors}). Check your S3 setup!")
                            continue
                            
                    elif isinstance(audio_path, str):
                        # Local file path
                        try:
                            audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
                        except Exception as file_error:
                            print(f"❌ FILE ERROR for {audio_path}: {file_error}")
                            errors += 1
                            continue
                    else:
                        print(f"❌ INVALID AUDIO PATH TYPE: {type(audio_path)} - {audio_path}")
                        errors += 1
                        continue
                # Convert audio to input_features (mel-spectrograms)
                input_features = feature_extractor(
                    audio_array, 
                    sampling_rate=16000, 
                    return_tensors="pt"
                ).input_features[0]  # Remove batch dimension
                
                # Convert text to labels (tokenized)
                text = str(example['text']).strip()
                if not text:
                    print(f"⚠️ Empty text for sample {i}, skipping...")
                    continue
                    
                labels = tokenizer(text).input_ids
                
                # Create processed example
                processed_example = {
                    'input_features': input_features,
                    'labels': labels,
                    'text': text,  # Keep original for reference
                    'duration': example.get('duration', 0.0),
                    'source': example.get('source', 'unknown')
                }
                
                processed_examples.append(processed_example)
                
            except Exception as process_error:
                print(f"❌ PROCESSING ERROR for sample {i}: {process_error}")
                errors += 1
                
                # FAIL FAST - don't create dummy data
                if errors > 20:
                    raise Exception(f"TOO MANY PROCESSING ERRORS ({errors})! Check your dataset and S3 setup!")
        
        # Create processed dataset
        if processed_examples:
            from datasets import Dataset
            processed_dataset = Dataset.from_list(processed_examples)
            processed_datasets[split_name] = processed_dataset
            
            print(f"✅ {split_name}: {len(processed_examples)} samples processed")
            print(f"   Errors: {errors}")
            print(f"   Success rate: {(len(processed_examples)/(len(processed_examples)+errors)*100):.1f}%")
            
            # Show sample
            sample = processed_examples[0]
            print(f"   Sample input_features shape: {sample['input_features'].shape}")
            print(f"   Sample labels length: {len(sample['labels'])}")
            print(f"   Sample text: {sample['text'][:50]}...")
        else:
            raise Exception(f"❌ NO VALID SAMPLES PROCESSED for {split_name}! All samples failed!")
    
    if not processed_datasets:
        raise Exception("❌ PREPROCESSING COMPLETELY FAILED! No datasets processed successfully!")
    
    from datasets import DatasetDict
    final_datasets = DatasetDict(processed_datasets)
    
    print(f"\n🎉 PREPROCESSING COMPLETE!")
    print(f"✅ Converted {sum(len(d) for d in processed_datasets.values())} total samples")
    print(f"✅ Datasets now have 'input_features' and 'labels' columns for training")
    
    return final_datasets

# EXECUTE PREPROCESSING
print("🚀 Starting dataset preprocessing...")

# Check if we have the required components
try:
    # Check for datasets variable
    if 'datasets' not in locals() and 'datasets' not in globals():
        print("❌ Missing 'datasets' variable!")
        print("💡 Run Cell 7: Dataset Loading first")
        raise Exception("Dataset loading required!")
    
    # Check for model components - they should be created in Cell 10
    if 'tiny_feature_extractor' not in locals() and 'tiny_feature_extractor' not in globals():
        print("❌ Missing 'tiny_feature_extractor' variable!")
        print("💡 Run Cell 10: Model Configuration first")
        raise Exception("Model configuration required!")
        
    if 'tiny_tokenizer' not in locals() and 'tiny_tokenizer' not in globals():
        print("❌ Missing 'tiny_tokenizer' variable!")
        print("💡 Run Cell 10: Model Configuration first") 
        raise Exception("Model configuration required!")
    
    print("✅ All required variables found!")
    print(f"   - datasets: {len(datasets)} splits available")
    print(f"   - tiny_feature_extractor: {type(tiny_feature_extractor).__name__}")
    print(f"   - tiny_tokenizer: {type(tiny_tokenizer).__name__}")
    
    # Apply preprocessing - process FULL dataset
    print("🚀 Processing FULL dataset (all samples)...")
    processed_datasets = preprocess_dataset_for_training(
        datasets, 
        tiny_feature_extractor, 
        tiny_tokenizer,
        max_samples=None  # Process ALL samples!
    )
    
    # If test succeeds, ask user if they want to process full dataset
    print("\n✅ PREPROCESSING TEST SUCCESSFUL!")
    print("💡 Processed sample datasets successfully")
    print("📊 Ready to process full dataset or proceed with training")
    
    # Replace the datasets variable with processed version
    datasets = processed_datasets
    
    print(f"\n🎯 Variable 'datasets' updated with processed data!")
    print(f"✅ Training can now proceed - datasets have 'input_features' and 'labels'")
    
except Exception as e:
    print(f"❌ PREPROCESSING FAILED: {e}")
    print("🚫 Cannot proceed with training until preprocessing works!")
    print("\n💡 Debug steps:")
    print("   1. Run Cell 7: Dataset Loading (creates 'datasets' variable)")
    print("   2. Run Cell 10: Model Configuration (creates tokenizer & feature_extractor)")
    print("   3. Then run this cell again for preprocessing")
    print("   4. Check if S3 credentials are working")
    print("   5. Verify audio files exist in S3")
    
    # Don't create dummy data - let it fail!
    raise e

In [5]:
from transformers import (
    WhisperForConditionalGeneration, 
    WhisperTokenizer, 
    WhisperFeatureExtractor,
    WhisperProcessor
)
from dataclasses import dataclass
import torch.nn as nn

@dataclass
class WhisperModelConfig:
    """Configuration for Whisper models optimized for ml.g4dn.2xlarge"""
    model_name: str
    max_length: int = 448
    language: str = "km"
    task: str = "transcribe"
    batch_size_tiny: int = 16
    batch_size_base: int = 8
    gradient_accumulation_steps: int = 2
    learning_rate: float = 1e-5
    warmup_steps: int = 500
    max_steps: int = 5000
    eval_steps: int = 500
    save_steps: int = 1000

class WhisperModelManager:
    """Manage Whisper model configurations and initialization"""
    
    def __init__(self):
        self.models = {
            'tiny': WhisperModelConfig("openai/whisper-tiny"),
            'base': WhisperModelConfig("openai/whisper-base")
        }
    
    def load_model_components(self, model_size):
        """Load tokenizer, feature extractor, and model"""
        config = self.models[model_size]
        
        print(f"🔧 Loading Whisper {model_size.upper()} components...")
        
        # Load tokenizer
        tokenizer = WhisperTokenizer.from_pretrained(
            config.model_name, 
            language=config.language, 
            task=config.task
        )
        
        # Load feature extractor
        feature_extractor = WhisperFeatureExtractor.from_pretrained(config.model_name)
        
        # Load processor (combines tokenizer and feature extractor)
        processor = WhisperProcessor.from_pretrained(config.model_name)
        processor.tokenizer = tokenizer
        
        # Load model
        model = WhisperForConditionalGeneration.from_pretrained(config.model_name)
        
        # Configure model for Khmer
        model.generation_config.language = config.language
        model.generation_config.task = config.task
        model.generation_config.forced_decoder_ids = None
        
        print(f"✅ Whisper {model_size.upper()} loaded successfully!")
        print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"   Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        
        return tokenizer, feature_extractor, processor, model, config

# Initialize model manager
model_manager = WhisperModelManager()

# Load Whisper Tiny components
print("🚀 Setting up Whisper models for training...")
tiny_tokenizer, tiny_feature_extractor, tiny_processor, tiny_model, tiny_config = model_manager.load_model_components('tiny')

print(f"\n📊 Model Configurations:")
print(f"  Whisper Tiny - Batch Size: {tiny_config.batch_size_tiny}")
print(f"  Whisper Base - Batch Size: {tiny_config.batch_size_base}")
print(f"  Max Length: {tiny_config.max_length}")
print(f"  Learning Rate: {tiny_config.learning_rate}")
print(f"  Gradient Accumulation: {tiny_config.gradient_accumulation_steps}")

# Check GPU memory
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"\n💾 GPU Memory: {gpu_memory:.1f} GB")
    print("✅ Configurations optimized for ml.g4dn.2xlarge (16GB VRAM)")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


🚀 Setting up Whisper models for training...
🔧 Loading Whisper TINY components...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


✅ Whisper TINY loaded successfully!
   Parameters: 37,760,640
   Trainable: 37,184,640

📊 Model Configurations:
  Whisper Tiny - Batch Size: 16
  Whisper Base - Batch Size: 8
  Max Length: 448
  Learning Rate: 1e-05
  Gradient Accumulation: 2

💾 GPU Memory: 14.6 GB
✅ Configurations optimized for ml.g4dn.2xlarge (16GB VRAM)


In [6]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import EarlyStoppingCallback
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """Data collator for speech-to-text training - Fixed for input_features"""
    
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Handle both 'input_features' and model_input_names
        model_input_name = self.processor.model_input_names[0] if hasattr(self.processor, 'model_input_names') else 'input_features'
        
        # Extract input features - handle both naming conventions
        input_features = []
        for feature in features:
            if 'input_features' in feature:
                input_features.append({'input_features': feature['input_features']})
            elif model_input_name in feature:
                input_features.append({model_input_name: feature[model_input_name]})
            else:
                raise KeyError(f"Expected 'input_features' or '{model_input_name}' in feature dict, got keys: {list(feature.keys())}")
        
        # Extract labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        # Pad labels
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def create_training_arguments(model_size, config, output_dir):
    """Create optimized training arguments for ml.g4dn.2xlarge"""
    
    # Adjust batch size based on model size
    if model_size == 'tiny':
        per_device_train_batch_size = config.batch_size_tiny
        per_device_eval_batch_size = config.batch_size_tiny
    else:  # base
        per_device_train_batch_size = config.batch_size_base
        per_device_eval_batch_size = config.batch_size_base
    
    return Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps,
        max_steps=config.max_steps,
        gradient_checkpointing=True,  # Save memory
        fp16=True,  # Mixed precision for faster training
        evaluation_strategy="steps",
        eval_steps=config.eval_steps,
        save_steps=config.save_steps,
        logging_steps=100,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",  # Use eval_loss instead of wer
        greater_is_better=False,
        save_total_limit=3,
        predict_with_generate=True,
        generation_max_length=config.max_length,
        dataloader_num_workers=4,
        remove_unused_columns=False,
        label_names=["labels"]
    )

def compute_metrics(eval_preds, tokenizer):
    """Compute WER and CER metrics with fallback for SageMaker"""
    import jiwer
    
    pred_ids, label_ids = eval_preds
    
    # Replace -100 with pad token id
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Compute metrics using jiwer as fallback
    try:
        # Try to use evaluate library first
        wer_metric = evaluate.load("wer")
        cer_metric = evaluate.load("cer")
        wer = wer_metric.compute(predictions=pred_str, references=label_str)
        cer = cer_metric.compute(predictions=pred_str, references=label_str)
    except Exception as e:
        print(f"⚠️ Evaluate library failed ({e}), using jiwer fallback...")
        # Use jiwer as fallback
        try:
            # Calculate WER using jiwer
            wer = jiwer.wer(label_str, pred_str)
            cer = jiwer.cer(label_str, pred_str)
        except Exception as e2:
            print(f"⚠️ jiwer also failed ({e2}), using basic accuracy...")
            # Basic accuracy fallback
            correct = sum(1 for p, l in zip(pred_str, label_str) if p.strip() == l.strip())
            accuracy = correct / len(pred_str) if len(pred_str) > 0 else 0
            wer = 1.0 - accuracy  # Approximate WER as 1 - accuracy
            cer = wer  # Use same value for CER
    
    return {"wer": wer, "cer": cer}

def setup_trainer(model, tokenizer, processor, config, datasets, model_size):
    """Setup Seq2SeqTrainer with all components"""
    
    # Create data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.generation_config.decoder_start_token_id,
    )
    
    # Create training arguments
    output_dir = f"./whisper-{model_size}-khmer"
    training_args = create_training_arguments(model_size, config, output_dir)
    
    # Create trainer
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=datasets["train"],
        eval_dataset=datasets["validation"],
        data_collator=data_collator,
        compute_metrics=lambda eval_preds: compute_metrics(eval_preds, tokenizer),
        tokenizer=processor.feature_extractor,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    return trainer, training_args

print("⚙️ Training configuration setup complete!")
print("🎯 Optimized for ml.g4dn.2xlarge with 16GB VRAM")
print("✅ Mixed precision (FP16) enabled")
print("✅ Gradient checkpointing enabled for memory efficiency")
print("✅ Early stopping configured")

⚙️ Training configuration setup complete!
🎯 Optimized for ml.g4dn.2xlarge with 16GB VRAM
✅ Mixed precision (FP16) enabled
✅ Gradient checkpointing enabled for memory efficiency
✅ Early stopping configured


In [None]:
# ✅ Pre-Training Checklist - Run This Before Training
print("🔍 Checking Prerequisites for Whisper Training...")
print("=" * 50)

# Check if all required variables are defined
required_vars = {
    'datasets': 'Dataset loading (Cell 6)', 
    'tiny_model': 'Whisper Tiny model (Cell 9)', 
    'tiny_tokenizer': 'Whisper Tiny tokenizer (Cell 9)',
    'tiny_processor': 'Whisper Tiny processor (Cell 9)',
    'tiny_config': 'Whisper Tiny config (Cell 9)',
    'setup_trainer': 'Training setup function (Cell 10)'
}

missing_vars = []
available_vars = []

for var_name, description in required_vars.items():
    try:
        if var_name in locals() or var_name in globals():
            available_vars.append(f"✅ {var_name} - {description}")
        else:
            missing_vars.append(f"❌ {var_name} - {description}")
    except:
        missing_vars.append(f"❌ {var_name} - {description}")

# Show status
print("🎯 Available Variables:")
for var in available_vars:
    print(f"   {var}")

if missing_vars:
    print(f"\n⚠️ Missing Variables ({len(missing_vars)}):")
    for var in missing_vars:
        print(f"   {var}")
    
    print(f"\n💡 Before training, you need to run these cells in order:")
    print(f"   1. Cell 3: SageMaker Setup")
    print(f"   2. Cell 5: Package Installation") 
    print(f"   3. Cell 6: Dataset Loading")
    print(f"   4. Cell 9: Model Configuration")
    print(f"   5. Cell 10: Training Setup")
    print(f"   6. Then you can run the training cells")
    
    print(f"\n🚫 Cannot start training - missing prerequisites!")
    
else:
    print(f"\n🎉 All prerequisites available!")
    print(f"✅ Ready to start Whisper training")
    
    # Show some stats if datasets is available
    try:
        if 'datasets' in locals() or 'datasets' in globals():
            print(f"\n📊 Dataset Info:")
            for split in datasets.keys():
                print(f"   {split}: {len(datasets[split]):,} samples")
    except:
        pass
        
print("\n" + "=" * 50)

🔍 Checking Prerequisites for Whisper Training...
🎯 Available Variables:
   ✅ datasets - Dataset loading (Cell 7)
   ✅ tiny_model - Whisper Tiny model (Cell 9)
   ✅ tiny_tokenizer - Whisper Tiny tokenizer (Cell 9)
   ✅ tiny_processor - Whisper Tiny processor (Cell 9)
   ✅ tiny_config - Whisper Tiny config (Cell 9)
   ✅ setup_trainer - Training setup function (Cell 11)

🎉 All prerequisites available!
✅ Ready to start Whisper training

📊 Dataset Info:
   train: 89,034 samples
   validation: 9,917 samples
   test: 9,916 samples



## 7. Fine-tune Whisper Base Model

## 8. Model Evaluation and Metrics

In [None]:
# 🎯 Baseline Evaluation (Run this to get WER before fine-tuning)
print("🚀 BASELINE EVALUATION - Original Whisper Models")
print("="*60)

# Evaluate original models on your dataset
baseline_results = []

try:
    # Load original Whisper Tiny
    print("📊 Evaluating original Whisper Tiny...")
    orig_tiny_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
    orig_tiny_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="km", task="transcribe")
    
    # Configure for Khmer
    orig_tiny_model.generation_config.language = "km"
    orig_tiny_model.generation_config.task = "transcribe"
    orig_tiny_model.generation_config.forced_decoder_ids = None
    
    baseline_tiny_results, baseline_tiny_preds, baseline_tiny_refs = evaluator.evaluate_model(
        orig_tiny_model, orig_tiny_processor, datasets["test"], "Original Whisper Tiny"
    )
    baseline_results.append(baseline_tiny_results)
    
    print(f"✅ Original Tiny WER: {baseline_tiny_results['wer']*100:.2f}%")
    
except Exception as e:
    print(f"❌ Baseline Tiny evaluation failed: {e}")

try:
    # Load original Whisper Base
    print("📊 Evaluating original Whisper Base...")
    orig_base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
    orig_base_processor = WhisperProcessor.from_pretrained("openai/whisper-base", language="km", task="transcribe")
    
    # Configure for Khmer
    orig_base_model.generation_config.language = "km"
    orig_base_model.generation_config.task = "transcribe" 
    orig_base_model.generation_config.forced_decoder_ids = None
    
    baseline_base_results, baseline_base_preds, baseline_base_refs = evaluator.evaluate_model(
        orig_base_model, orig_base_processor, datasets["test"], "Original Whisper Base"
    )
    baseline_results.append(baseline_base_results)
    
    print(f"✅ Original Base WER: {baseline_base_results['wer']*100:.2f}%")
    
except Exception as e:
    print(f"❌ Baseline Base evaluation failed: {e}")

# Show baseline results with graphs
if baseline_results:
    print("\n🎯 BASELINE RESULTS:")
    baseline_df = evaluator.compare_models(baseline_results)
    
    # Save baseline for comparison later
    evaluator.baseline_results = baseline_results
    
    print("\n💡 These are the baseline WER results before fine-tuning!")
    print("📈 After training, you can compare fine-tuned vs original models")
else:
    print("❌ No baseline evaluation completed")

# Clear memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 GPU memory cleared")

In [None]:
import boto3
import tarfile
from pathlib import Path
import json
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role

class WhisperModelDeployment:
    """Handle model packaging and deployment to SageMaker"""
    
    def __init__(self, bucket, prefix, role):
        self.bucket = bucket
        self.prefix = prefix
        self.role = role
        self.s3_client = boto3.client('s3')
        self.sagemaker_client = boto3.client('sagemaker')
    
    def create_model_tar(self, model_dir, tar_name):
        """Create tar.gz file for SageMaker model"""
        print(f"📦 Creating model archive: {tar_name}")
        
        with tarfile.open(tar_name, 'w:gz') as tar:
            for file_path in Path(model_dir).rglob('*'):
                if file_path.is_file():
                    arcname = file_path.relative_to(model_dir)
                    tar.add(file_path, arcname=arcname)
        
        print(f"✅ Model archive created: {tar_name}")
        return tar_name
    
    def upload_to_s3(self, file_path, s3_key):
        """Upload model to S3"""
        print(f"☁️ Uploading to S3: s3://{self.bucket}/{s3_key}")
        
        self.s3_client.upload_file(file_path, self.bucket, s3_key)
        s3_uri = f"s3://{self.bucket}/{s3_key}"
        
        print(f"✅ Upload complete: {s3_uri}")
        return s3_uri
    
    def create_inference_script(self, model_dir):
        """Create inference script for SageMaker endpoint"""
        inference_script = '''
import torch
import json
import librosa
import numpy as np
from transformers import WhisperForConditionalGeneration, WhisperProcessor

def model_fn(model_dir):
    """Load the model and processor"""
    model = WhisperForConditionalGeneration.from_pretrained(model_dir)
    processor = WhisperProcessor.from_pretrained(model_dir)
    
    return {'model': model, 'processor': processor}

def input_fn(request_body, request_content_type):
    """Parse input data"""
    if request_content_type == 'application/json':
        input_data = json.loads(request_body)
        
        # Expect audio file path or base64 encoded audio
        if 'audio_path' in input_data:
            audio, sr = librosa.load(input_data['audio_path'], sr=16000)
        elif 'audio_base64' in input_data:
            import base64
            import io
            audio_bytes = base64.b64decode(input_data['audio_base64'])
            audio, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
        else:
            raise ValueError("Must provide either 'audio_path' or 'audio_base64'")
        
        return audio
    else:
        raise ValueError(f"Unsupported content type: {request_content_type}")

def predict_fn(input_data, model_dict):
    """Make prediction"""
    model = model_dict['model']
    processor = model_dict['processor']
    
    # Process audio
    input_features = processor(
        input_data, 
        sampling_rate=16000, 
        return_tensors="pt"
    ).input_features
    
    # Generate transcription
    with torch.no_grad():
        predicted_ids = model.generate(
            input_features,
            max_length=448,
            num_beams=5,
            do_sample=False
        )
    
    # Decode transcription
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    
    return {"transcription": transcription}

def output_fn(prediction, content_type):
    """Format output"""
    if content_type == 'application/json':
        return json.dumps(prediction)
    else:
        raise ValueError(f"Unsupported content type: {content_type}")
'''
        
        script_path = Path(model_dir) / "inference.py"
        with open(script_path, 'w') as f:
            f.write(inference_script)
        
        print(f"✅ Inference script created: {script_path}")
        return script_path
    
    def deploy_model(self, model_s3_uri, model_name, instance_type="ml.t2.medium"):
        """Deploy model to SageMaker endpoint"""
        print(f"🚀 Deploying model: {model_name}")
        
        # Create PyTorch model
        pytorch_model = PyTorchModel(
            model_data=model_s3_uri,
            role=self.role,
            py_version='py39',
            framework_version='1.12',
            entry_point='inference.py'
        )
        
        # Deploy to endpoint
        predictor = pytorch_model.deploy(
            initial_instance_count=1,
            instance_type=instance_type,
            endpoint_name=f"{model_name}-endpoint"
        )
        
        print(f"✅ Model deployed to endpoint: {model_name}-endpoint")
        return predictor

# Initialize deployment manager
deployment_manager = WhisperModelDeployment(bucket, prefix, role)

# Deploy Whisper Tiny (if available)
if os.path.exists("./whisper-tiny-khmer"):
    try:
        print("🚀 Preparing Whisper Tiny for deployment...")
        
        # Create inference script
        deployment_manager.create_inference_script("./whisper-tiny-khmer")
        
        # Create model archive
        tiny_tar = deployment_manager.create_model_tar("./whisper-tiny-khmer", "whisper-tiny-khmer.tar.gz")
        
        # Upload to S3
        tiny_s3_key = f"{prefix}/models/whisper-tiny-khmer.tar.gz"
        tiny_s3_uri = deployment_manager.upload_to_s3(tiny_tar, tiny_s3_key)
        
        print(f"📍 Whisper Tiny S3 URI: {tiny_s3_uri}")
        
        # Optional: Deploy to endpoint (commented out to avoid costs)
        # tiny_predictor = deployment_manager.deploy_model(tiny_s3_uri, "whisper-tiny-khmer")
        print("💡 To deploy endpoint, uncomment the deploy_model call")
        
    except Exception as e:
        print(f"❌ Whisper Tiny deployment prep failed: {e}")

# Deploy Whisper Base (if available)
if os.path.exists("./whisper-base-khmer"):
    try:
        print("\n🚀 Preparing Whisper Base for deployment...")
        
        # Create inference script
        deployment_manager.create_inference_script("./whisper-base-khmer")
        
        # Create model archive
        base_tar = deployment_manager.create_model_tar("./whisper-base-khmer", "whisper-base-khmer.tar.gz")
        
        # Upload to S3
        base_s3_key = f"{prefix}/models/whisper-base-khmer.tar.gz"
        base_s3_uri = deployment_manager.upload_to_s3(base_tar, base_s3_key)
        
        print(f"📍 Whisper Base S3 URI: {base_s3_uri}")
        
        # Optional: Deploy to endpoint (commented out to avoid costs)
        # base_predictor = deployment_manager.deploy_model(base_s3_uri, "whisper-base-khmer", instance_type="ml.m5.large")
        print("💡 To deploy endpoint, uncomment the deploy_model call")
        
    except Exception as e:
        print(f"❌ Whisper Base deployment prep failed: {e}")

# Create summary report
summary_report = {
    "training_date": datetime.now().isoformat(),
    "instance_type": "ml.g4dn.2xlarge",
    "dataset": "Mega Khmer Dataset (160+ hours, 110K+ samples)",
    "models_trained": [],
    "s3_artifacts": []
}

if os.path.exists("./whisper-tiny-khmer"):
    summary_report["models_trained"].append("Whisper Tiny")
    if 'tiny_s3_uri' in locals():
        summary_report["s3_artifacts"].append({"model": "Whisper Tiny", "uri": tiny_s3_uri})

if os.path.exists("./whisper-base-khmer"):
    summary_report["models_trained"].append("Whisper Base")
    if 'base_s3_uri' in locals():
        summary_report["s3_artifacts"].append({"model": "Whisper Base", "uri": base_s3_uri})

# Save summary report
with open("training_summary.json", "w") as f:
    json.dump(summary_report, f, indent=2)

print("\n🎉 DEPLOYMENT PREPARATION COMPLETE!")
print("=" * 50)
print(f"📊 Models Trained: {len(summary_report['models_trained'])}")
print(f"☁️ S3 Artifacts: {len(summary_report['s3_artifacts'])}")
print(f"📁 S3 Bucket: {bucket}")
print(f"📝 Summary saved to: training_summary.json")

if summary_report['s3_artifacts']:
    print("\n📍 Model S3 Locations:")
    for artifact in summary_report['s3_artifacts']:
        print(f"   {artifact['model']}: {artifact['uri']}")

print("\n💡 Next Steps:")
print("   1. Review model performance metrics above")
print("   2. Uncomment deploy_model calls to create SageMaker endpoints")
print("   3. Use the S3 URIs for batch inference or custom deployments")
print("   4. Consider fine-tuning with additional domain-specific data")