In [1]:
import sagemaker
import boto3
import os
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
import json
from datetime import datetime

# Initialize SageMaker session for Singapore region
sagemaker_session = sagemaker.Session(boto3.Session(region_name='ap-southeast-1'))
role = get_execution_role()
region = 'ap-southeast-1'  # Singapore region

# Use your existing S3 bucket with uploaded dataset
bucket = 'pan-sea-khmer-speech-dataset-sg'  # Your Singapore bucket
prefix = 'khmer-whisper-training'

print(f"🌟 SageMaker Session Initialized!")
print(f"📦 S3 Bucket: {bucket}")
print(f"🌍 Region: {region}")
print(f"🔑 Role: {role}")
print(f"📁 Prefix: {prefix}")

# Create S3 paths - Updated with your actual uploaded dataset path
dataset_s3_path = 's3://pan-sea-khmer-speech-dataset-sg/khmer-whisper-dataset/data'  # Your uploaded dataset
model_artifacts_path = f's3://{bucket}/{prefix}/model-artifacts'
logs_path = f's3://{bucket}/{prefix}/logs'

print(f"\n📊 S3 Paths:")
print(f"   Dataset: {dataset_s3_path}")
print(f"   Models: {model_artifacts_path}")
print(f"   Logs: {logs_path}")

# Verify dataset exists
s3_client = boto3.client('s3', region_name=region)
try:
    response = s3_client.list_objects_v2(
        Bucket='pan-sea-khmer-speech-dataset-sg',
        Prefix='khmer-whisper-dataset/data/',
        MaxKeys=5
    )
    if 'Contents' in response:
        print(f"✅ Dataset verified in S3: {len(response.get('Contents', []))} files found")
        print(f"📂 Sample files:")
        for obj in response['Contents'][:3]:
            print(f"   - {obj['Key']}")
    else:
        print("❌ No files found in dataset path")
except Exception as e:
    print(f"⚠️ Could not verify dataset: {e}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
🌟 SageMaker Session Initialized!
📦 S3 Bucket: pan-sea-khmer-speech-dataset-sg
🌍 Region: ap-southeast-1
🔑 Role: arn:aws:iam::248619912656:role/AmazonSageMaker-ExecutionRole
📁 Prefix: khmer-whisper-training

📊 S3 Paths:
   Dataset: s3://pan-sea-khmer-speech-dataset-sg/khmer-whisper-dataset/data
   Models: s3://pan-sea-khmer-speech-dataset-sg/khmer-whisper-training/model-artifacts
   Logs: s3://pan-sea-khmer-speech-dataset-sg/khmer-whisper-training/logs
✅ Dataset verified in S3: 5 files found
📂 Sample files:
   - khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_0.wav
   - khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_1.wav
   - khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_2.wav


In [None]:
# Final SageMaker Installation - Fix Version Compatibility
import subprocess
import sys
import warnings
warnings.filterwarnings('ignore')  # Suppress dependency warnings

print("🚀 Installing Whisper Dependencies for SageMaker (Fixed Versions)")
print("=" * 60)

# Essential packages with compatible versions
def install_essential():
    """Install only what's needed for Whisper training with compatible versions"""
    
    # Core packages with version compatibility
    essential_packages = [
        "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118",
        "transformers==4.35.2",
        "datasets==2.14.7", 
        "accelerate==0.25.0",  # Compatible version with transformers 4.35.2
        "soundfile",
        "evaluate==0.4.1",
        "jiwer==3.0.3", 
        "boto3",
        "tqdm",
        "tensorboard"
    ]
    
    # Try conda for librosa first (most reliable)
    print("🎵 Installing librosa via conda...")
    try:
        subprocess.run("conda install -c conda-forge librosa -y", shell=True, check=True, capture_output=True)
        print("✅ Librosa installed via conda")
    except:
        print("⚠️ Conda librosa failed, trying pip...")
        subprocess.run("pip install librosa==0.10.1", shell=True)
    
    # Install other packages with version constraints
    print("\n📦 Installing packages with compatible versions...")
    for package in essential_packages:
        try:
            cmd = f"pip install {package}"
            print(f"Installing: {package.split()[0]}")
            result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
            print(f"✅ {package.split()[0]}")
        except subprocess.CalledProcessError as e:
            print(f"⚠️ Warning: {package.split()[0]} installation issue:")
            print(f"   {e.stderr[:100] if e.stderr else 'Unknown error'}")

# Run installation
install_essential()

# Test imports (the important ones)
print("\n🔍 Testing critical imports...")

def test_package(name, import_name=None):
    if import_name is None:
        import_name = name
    try:
        __import__(import_name)
        print(f"✅ {name} - Working")
        return True
    except ImportError as e:
        print(f"❌ {name} - Failed: {str(e)[:50]}...")
        return False

# Test critical packages for Whisper
critical_tests = [
    ("PyTorch", "torch"),
    ("Transformers", "transformers"),
    ("Datasets", "datasets"), 
    ("Accelerate", "accelerate"),
    ("Librosa", "librosa"),
    ("SoundFile", "soundfile"),
    ("Evaluate", "evaluate"),
    ("Boto3", "boto3")
]

working_count = 0
for name, import_name in critical_tests:
    if test_package(name, import_name):
        working_count += 1

print(f"\n📊 Status: {working_count}/{len(critical_tests)} critical packages working")

# Show versions of working packages
try:
    import torch
    import transformers
    import accelerate
    print(f"\n🎯 Key Versions:")
    print(f"   Python: {sys.version.split()[0]}")
    print(f"   PyTorch: {torch.__version__}")
    print(f"   Transformers: {transformers.__version__}")
    print(f"   Accelerate: {accelerate.__version__}")
    print(f"   CUDA Available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        
    # Check for compatibility
    from packaging import version
    trans_version = version.parse(transformers.__version__)
    accel_version = version.parse(accelerate.__version__)
    
    print(f"\n🔧 Compatibility Check:")
    if trans_version >= version.parse("4.35.0") and accel_version >= version.parse("0.25.0"):
        print("✅ Transformers and Accelerate versions are compatible")
    else:
        print("⚠️ Version mismatch detected - may cause training issues")
        
except Exception as e:
    print(f"⚠️ Could not get version info: {e}")

# Final status
if working_count >= 6:  # At least 6/8 critical packages
    print("\n🎉 SUCCESS: Ready for Whisper training!")
    print("💡 Version compatibility issues fixed")
    print("🚀 You can proceed with the training cells")
else:
    print("\n⚠️ Some packages need manual fixes:")
    print("   Try running this cell again if issues persist")

print("\n" + "="*60)

🚀 Installing Whisper Dependencies for SageMaker
🎵 Installing librosa via conda...
✅ Librosa installed via conda

📦 Installing other packages...
Installing: torch torchvision torchaudio
✅ torch torchvision torchaudio
Installing: transformers==4.35.2
✅ transformers==4.35.2
Installing: datasets==2.14.7
✅ datasets==2.14.7
Installing: soundfile
✅ soundfile
Installing: evaluate
✅ evaluate
Installing: jiwer
✅ jiwer
Installing: accelerate
✅ accelerate
Installing: boto3
✅ boto3
Installing: tqdm
✅ tqdm

🔍 Testing critical imports...
✅ PyTorch - Working
✅ Transformers - Working
✅ Datasets - Working
✅ Librosa - Working
✅ SoundFile - Working
✅ Evaluate - Working
✅ Boto3 - Working
✅ Accelerate - Working

📊 Status: 8/8 critical packages working

🎯 Key Versions:
   Python: 3.10.18
   PyTorch: 2.6.0+cu124
   Transformers: 4.35.2
   CUDA Available: True
   GPU: Tesla T4
   GPU Memory: 14.6 GB

🎉 SUCCESS: Ready for Whisper training!
🚀 You can proceed with the training cells



In [3]:
# 🔍 S3 Bucket Explorer - Run this to see your exact S3 structure
import boto3

def explore_s3_bucket():
    """Explore the S3 bucket structure to understand the dataset layout"""
    print("🔍 S3 Bucket Explorer")
    print("=" * 50)
    
    s3_client = boto3.client('s3', region_name='ap-southeast-1')
    bucket_name = 'pan-sea-khmer-speech-dataset-sg'
    prefix = 'khmer-whisper-dataset/data'
    
    try:
        # Get all objects with prefix
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
        
        all_objects = []
        for page in pages:
            all_objects.extend(page.get('Contents', []))
        
        print(f"📊 Total objects found: {len(all_objects)}")
        
        # Analyze structure
        folders = set()
        file_types = {}
        manifest_files = []
        
        for obj in all_objects:
            key = obj['Key']
            
            # Track folders
            parts = key.split('/')
            if len(parts) > 1:
                folder_path = '/'.join(parts[:-1])
                folders.add(folder_path)
            
            # Track file types
            if '.' in key:
                ext = key.split('.')[-1].lower()
                file_types[ext] = file_types.get(ext, 0) + 1
            
            # Track manifest files
            if 'manifest' in key.lower():
                manifest_files.append(key)
        
        print(f"\n📁 Folder Structure:")
        for folder in sorted(folders)[:20]:  # Show first 20 folders
            print(f"   {folder}")
        
        print(f"\n📄 File Types:")
        for ext, count in sorted(file_types.items()):
            print(f"   .{ext}: {count:,} files")
        
        print(f"\n📋 Manifest Files Found:")
        if manifest_files:
            for mf in manifest_files:
                print(f"   {mf}")
        else:
            print("   ❌ No manifest files found!")
        
        # Show sample files
        print(f"\n📝 Sample Files (first 10):")
        for i, obj in enumerate(all_objects[:10]):
            size_mb = obj['Size'] / 1024 / 1024
            print(f"   {i+1:2d}. {obj['Key']} ({size_mb:.2f} MB)")
        
        # Check for specific splits
        splits = ['train', 'validation', 'test']
        print(f"\n🎯 Split Analysis:")
        for split in splits:
            split_files = [obj for obj in all_objects if f'/{split}/' in obj['Key'] or f'_{split}_' in obj['Key']]
            audio_files = [obj for obj in split_files if obj['Key'].endswith('.wav')]
            manifest_files_split = [obj for obj in split_files if 'manifest' in obj['Key']]
            
            print(f"   {split.capitalize()}:")
            print(f"     Audio files: {len(audio_files):,}")
            print(f"     Manifest files: {len(manifest_files_split)}")
            
            if manifest_files_split:
                print(f"     Manifest path: {manifest_files_split[0]['Key']}")
    
    except Exception as e:
        print(f"❌ Error exploring S3: {e}")
        return None
    
    return {
        'total_objects': len(all_objects),
        'folders': folders,
        'file_types': file_types,
        'manifest_files': manifest_files
    }

# Run the exploration
bucket_info = explore_s3_bucket()

print("\n💡 Next Steps:")
print("1. Check if manifest files exist at the expected paths")
print("2. If no manifests found, we'll need to create them or use a different loading approach")
print("3. The dataset loader will fallback to dummy data if needed for testing")

🔍 S3 Bucket Explorer
📊 Total objects found: 110558

📁 Folder Structure:
   khmer-whisper-dataset/data/test
   khmer-whisper-dataset/data/test/audio
   khmer-whisper-dataset/data/train
   khmer-whisper-dataset/data/train/audio
   khmer-whisper-dataset/data/validation
   khmer-whisper-dataset/data/validation/audio

📄 File Types:
   .csv: 3 files
   .wav: 110,555 files

📋 Manifest Files Found:
   khmer-whisper-dataset/data/test/test_manifest.csv
   khmer-whisper-dataset/data/train/train_manifest.csv
   khmer-whisper-dataset/data/validation/validation_manifest.csv

📝 Sample Files (first 10):
    1. khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_0.wav (0.15 MB)
    2. khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_1.wav (0.15 MB)
    3. khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_0_2.wav (0.13 MB)
    4. khmer-whisper-dataset/data/test/audio/01a47357-e136-4e15-8842-e84a1f6707c4_1.wav (0.13 MB)
    5. 

In [6]:
# ✅ Load Dataset from S3 CSV Manifests (UPDATED)
import json
import pandas as pd
import librosa
import numpy as np
from pathlib import Path
from datasets import Dataset, DatasetDict, Audio
from transformers import WhisperFeatureExtractor, WhisperTokenizer
import torch
from torch.utils.data import DataLoader
import boto3
import io

class S3CSVDatasetLoader:
    """Load dataset from uploaded CSV manifest files in S3"""
    
    def __init__(self):
        self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
        self.tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="km", task="transcribe")
        self.s3_client = boto3.client('s3', region_name='ap-southeast-1')
        self.bucket_name = 'pan-sea-khmer-speech-dataset-sg'
        
    def check_csv_manifests(self):
        """Check if CSV manifest files exist in S3"""
        print("🔍 Checking for uploaded CSV manifest files...")
        
        csv_files = {
            'train': 'khmer-whisper-dataset/data/train/train_manifest.csv',
            'validation': 'khmer-whisper-dataset/data/validation/validation_manifest.csv',
            'test': 'khmer-whisper-dataset/data/test/test_manifest.csv'
        }
        
        found_files = {}
        
        for split, s3_key in csv_files.items():
            try:
                response = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
                size_mb = response['ContentLength'] / 1024 / 1024
                print(f"✅ {split}_manifest.csv: {size_mb:.2f} MB")
                found_files[split] = s3_key
            except Exception as e:
                print(f"❌ {split}_manifest.csv: Not found")
        
        return found_files
    
    def load_csv_from_s3(self, split, s3_key):
        """Download and load CSV manifest from S3"""
        try:
            print(f"📥 Loading {split} manifest from S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
            csv_content = response['Body'].read()
            
            # Read CSV into pandas
            df = pd.read_csv(io.BytesIO(csv_content))
            print(f"✅ Loaded {len(df)} entries")
            return df
            
        except Exception as e:
            print(f"❌ Failed to load {split} CSV: {e}")
            return None
    
    def load_datasets(self, max_duration=20.0, min_duration=1.0, max_samples=1000):
        """Load datasets from S3 CSV manifests"""
        print("🚀 Loading Dataset from S3 CSV Manifests")
        print("=" * 50)
        
        # Check which CSV files exist
        csv_files = self.check_csv_manifests()
        
        if not csv_files:
            print("❌ No CSV manifest files found!")
            print("💡 Make sure you uploaded the CSV files correctly")
            return self.create_dummy_datasets()
        
        dataset_dict = {}
        
        for split in ['train', 'validation', 'test']:
            if split not in csv_files:
                print(f"⚠️ Skipping {split} - no CSV found")
                continue
                
            print(f"\n📊 Processing {split} split...")
            
            # Load CSV from S3
            df = self.load_csv_from_s3(split, csv_files[split])
            if df is None:
                continue
            
            # Filter by duration
            original_count = len(df)
            df_filtered = df[
                (df['duration'] <= max_duration)
            ]
            print(f"📊 Duration filter: {original_count} → {len(df_filtered)} samples")
            
            # Limit samples for memory efficiency
            if len(df_filtered) > max_samples:
                df_filtered = df_filtered.head(max_samples)
                print(f"🔧 Limited to {max_samples} samples for SageMaker efficiency")
            
            # Convert to dataset format
            data = []
            for _, row in df_filtered.iterrows():
                # Create S3 audio path
                audio_filename = row['audio_filepath']
                s3_audio_path = f"s3://{self.bucket_name}/khmer-whisper-dataset/data/{split}/audio/{audio_filename}"
                
                entry = {
                    'audio_filepath': s3_audio_path,
                    'text': str(row['text']),
                    'duration': float(row['duration']),
                    'language': row.get('language', 'km'),
                    'source': row.get('source', 'mega_dataset'),
                    'speaker': row.get('speaker', 'unknown')
                }
                data.append(entry)
            
            # Create HuggingFace dataset
            if data:
                dataset = Dataset.from_list(data)
                try:
                    dataset = dataset.cast_column(sample['audio_filepath'], Audio(sampling_rate=16000))
                    print(f"✅ Created dataset with {len(data)} samples")
                except Exception as e:
                    print(f"⚠️ Audio casting warning: {e}")
                
                dataset_dict[split] = dataset
            else:
                print(f"❌ No valid data for {split}")
        
        return DatasetDict(dataset_dict) if dataset_dict else self.create_dummy_datasets()
    
    def create_dummy_datasets(self):
        """Create small dummy datasets for testing"""
        print("🎯 Creating dummy datasets for testing...")
        
        dataset_dict = {}
        sizes = {'train': 100, 'validation': 30, 'test': 30}
        
        for split, size in sizes.items():
            data = []
            for i in range(size):
                data.append({
                    'text': f'សាកល្បង ទី {i+1}',  # "Test number i" in Khmer
                    'audio_filepath': f'dummy_audio_{split}_{i}.wav',
                    'duration': 2.0 + (i % 3),
                    'language': 'km',
                    'source': 'dummy',
                    'speaker': 'unknown'
                })
            
            dataset = Dataset.from_list(data)
            dataset_dict[split] = dataset
            print(f"📝 {split}: {size} dummy samples")
        
        return DatasetDict(dataset_dict)

# Initialize loader and load datasets
print("🚀 Initializing S3 CSV Dataset Loader...")
dataset_loader = S3CSVDatasetLoader()

# Load with SageMaker-friendly settings
datasets = dataset_loader.load_datasets(
    max_duration=20.0,
    min_duration=1.0, 
    max_samples=1000  # Start with 1K samples per split for testing
)

# Show results
if datasets:
    print(f"\n📊 Final Dataset Summary:")
    print("=" * 40)
    
    total_samples = 0
    for split, dataset in datasets.items():
        count = len(dataset)
        total_samples += count
        
        if 'duration' in dataset.column_names:
            hours = sum(dataset['duration']) / 3600
            print(f"  {split.capitalize()}: {count:,} samples ({hours:.2f} hours)")
        else:
            print(f"  {split.capitalize()}: {count:,} samples")
    
    print(f"  Total: {total_samples:,} samples")
    print(f"\n✅ Dataset variable 'datasets' created successfully!")
    print(f"🎯 You can now proceed with training!")
    
    # Show sample
    if total_samples > 0:
        sample_split = list(datasets.keys())[0]
        sample = datasets[sample_split][0]
        print(f"\n🔍 Sample from {sample_split}:")
        print(f"   Text: {sample['text'][:100]}...")
        print(f"   Duration: {sample['duration']} seconds")
        print(f"   Audio path: {sample['audio_filepath'][:60]}...")
        
else:
    print(f"\n❌ Dataset loading failed!")
    print(f"💡 Check CSV uploads and try again")

print(f"\n" + "="*50)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


🚀 Initializing S3 CSV Dataset Loader...
🚀 Loading Dataset from S3 CSV Manifests
🔍 Checking for uploaded CSV manifest files...
✅ train_manifest.csv: 20.76 MB
✅ validation_manifest.csv: 2.39 MB
✅ test_manifest.csv: 2.40 MB

📊 Processing train split...
📥 Loading train manifest from S3...
✅ Loaded 90606 entries
📊 Duration filter: 90606 → 90040 samples
🔧 Limited to 1000 samples for SageMaker efficiency

📊 Processing validation split...
📥 Loading validation manifest from S3...
✅ Loaded 9973 entries
📊 Duration filter: 9973 → 9917 samples
🔧 Limited to 1000 samples for SageMaker efficiency

📊 Processing test split...
📥 Loading test manifest from S3...
✅ Loaded 9976 entries
📊 Duration filter: 9976 → 9916 samples
🔧 Limited to 1000 samples for SageMaker efficiency

📊 Final Dataset Summary:
  Train: 1,000 samples (1.37 hours)
  Validation: 1,000 samples (1.36 hours)
  Test: 1,000 samples (1.19 hours)
  Total: 3,000 samples

✅ Dataset variable 'datasets' created successfully!
🎯 You can now proceed w

In [7]:
from transformers import (
    WhisperForConditionalGeneration, 
    WhisperTokenizer, 
    WhisperFeatureExtractor,
    WhisperProcessor
)
from dataclasses import dataclass
import torch.nn as nn

@dataclass
class WhisperModelConfig:
    """Configuration for Whisper models optimized for ml.g4dn.2xlarge"""
    model_name: str
    max_length: int = 448
    language: str = "km"
    task: str = "transcribe"
    batch_size_tiny: int = 16
    batch_size_base: int = 8
    gradient_accumulation_steps: int = 2
    learning_rate: float = 1e-5
    warmup_steps: int = 500
    max_steps: int = 5000
    eval_steps: int = 500
    save_steps: int = 1000

class WhisperModelManager:
    """Manage Whisper model configurations and initialization"""
    
    def __init__(self):
        self.models = {
            'tiny': WhisperModelConfig("openai/whisper-tiny"),
            'base': WhisperModelConfig("openai/whisper-base")
        }
    
    def load_model_components(self, model_size):
        """Load tokenizer, feature extractor, and model"""
        config = self.models[model_size]
        
        print(f"🔧 Loading Whisper {model_size.upper()} components...")
        
        # Load tokenizer
        tokenizer = WhisperTokenizer.from_pretrained(
            config.model_name, 
            language=config.language, 
            task=config.task
        )
        
        # Load feature extractor
        feature_extractor = WhisperFeatureExtractor.from_pretrained(config.model_name)
        
        # Load processor (combines tokenizer and feature extractor)
        processor = WhisperProcessor.from_pretrained(config.model_name)
        processor.tokenizer = tokenizer
        
        # Load model
        model = WhisperForConditionalGeneration.from_pretrained(config.model_name)
        
        # Configure model for Khmer
        model.generation_config.language = config.language
        model.generation_config.task = config.task
        model.generation_config.forced_decoder_ids = None
        
        print(f"✅ Whisper {model_size.upper()} loaded successfully!")
        print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"   Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        
        return tokenizer, feature_extractor, processor, model, config

# Initialize model manager
model_manager = WhisperModelManager()

# Load Whisper Tiny components
print("🚀 Setting up Whisper models for training...")
tiny_tokenizer, tiny_feature_extractor, tiny_processor, tiny_model, tiny_config = model_manager.load_model_components('tiny')

print(f"\n📊 Model Configurations:")
print(f"  Whisper Tiny - Batch Size: {tiny_config.batch_size_tiny}")
print(f"  Whisper Base - Batch Size: {tiny_config.batch_size_base}")
print(f"  Max Length: {tiny_config.max_length}")
print(f"  Learning Rate: {tiny_config.learning_rate}")
print(f"  Gradient Accumulation: {tiny_config.gradient_accumulation_steps}")

# Check GPU memory
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"\n💾 GPU Memory: {gpu_memory:.1f} GB")
    print("✅ Configurations optimized for ml.g4dn.2xlarge (16GB VRAM)")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


🚀 Setting up Whisper models for training...
🔧 Loading Whisper TINY components...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


✅ Whisper TINY loaded successfully!
   Parameters: 37,760,640
   Trainable: 37,184,640

📊 Model Configurations:
  Whisper Tiny - Batch Size: 16
  Whisper Base - Batch Size: 8
  Max Length: 448
  Learning Rate: 1e-05
  Gradient Accumulation: 2

💾 GPU Memory: 14.6 GB
✅ Configurations optimized for ml.g4dn.2xlarge (16GB VRAM)


In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import EarlyStoppingCallback
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """Data collator for speech-to-text training"""
    
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods
        model_input_name = self.processor.model_input_names[0]
        input_features = [{model_input_name: feature[model_input_name]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def create_training_arguments(model_size, config, output_dir):
    """Create optimized training arguments for ml.g4dn.2xlarge - Compatible Version"""
    
    # Adjust batch size based on model size
    if model_size == 'tiny':
        per_device_train_batch_size = config.batch_size_tiny
        per_device_eval_batch_size = config.batch_size_tiny
    else:  # base
        per_device_train_batch_size = config.batch_size_base
        per_device_eval_batch_size = config.batch_size_base
    
    # Compatible training arguments (removed newer parameters)
    return Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps,
        max_steps=config.max_steps,
        gradient_checkpointing=True,  # Save memory
        fp16=True,  # Mixed precision for faster training
        evaluation_strategy="steps",
        eval_steps=config.eval_steps,
        save_steps=config.save_steps,
        logging_steps=100,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        save_total_limit=3,
        predict_with_generate=True,
        generation_max_length=config.max_length,
        # Removed dispatch_batches and other newer parameters
        remove_unused_columns=False,
        label_names=["labels"]
    )

def compute_metrics(eval_preds, tokenizer):
    """Compute WER and CER metrics"""
    wer_metric = evaluate.load("wer")
    cer_metric = evaluate.load("cer")
    
    pred_ids, label_ids = eval_preds
    
    # Replace -100 with pad token id
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Compute metrics
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer, "cer": cer}

def setup_trainer(model, tokenizer, processor, config, datasets, model_size):
    """Setup Seq2SeqTrainer with all components - Compatible Version"""
    
    # Create data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.generation_config.decoder_start_token_id,
    )
    
    # Create training arguments
    output_dir = f"./whisper-{model_size}-khmer"
    training_args = create_training_arguments(model_size, config, output_dir)
    
    # Create trainer with compatible parameters
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=datasets["train"],
        eval_dataset=datasets["validation"],
        data_collator=data_collator,
        compute_metrics=lambda eval_preds: compute_metrics(eval_preds, tokenizer),
        tokenizer=processor.feature_extractor,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    return trainer, training_args

print("⚙️ Training configuration setup complete! (Compatible Version)")
print("🎯 Optimized for ml.g4dn.2xlarge with 16GB VRAM")
print("✅ Mixed precision (FP16) enabled")
print("✅ Gradient checkpointing enabled for memory efficiency")
print("✅ Early stopping configured")
print("🔧 Compatible with older accelerate versions")

⚙️ Training configuration setup complete!
🎯 Optimized for ml.g4dn.2xlarge with 16GB VRAM
✅ Mixed precision (FP16) enabled
✅ Gradient checkpointing enabled for memory efficiency
✅ Early stopping configured


In [9]:
# ✅ Pre-Training Checklist - Run This Before Training
print("🔍 Checking Prerequisites for Whisper Training...")
print("=" * 50)

# Check if all required variables are defined
required_vars = {
    'datasets': 'Dataset loading (Cell 7)',
    'tiny_model': 'Whisper Tiny model (Cell 9)', 
    'tiny_tokenizer': 'Whisper Tiny tokenizer (Cell 9)',
    'tiny_processor': 'Whisper Tiny processor (Cell 9)',
    'tiny_config': 'Whisper Tiny config (Cell 9)',
    'setup_trainer': 'Training setup function (Cell 11)'
}

missing_vars = []
available_vars = []

for var_name, description in required_vars.items():
    try:
        if var_name in locals() or var_name in globals():
            available_vars.append(f"✅ {var_name} - {description}")
        else:
            missing_vars.append(f"❌ {var_name} - {description}")
    except:
        missing_vars.append(f"❌ {var_name} - {description}")

# Show status
print("🎯 Available Variables:")
for var in available_vars:
    print(f"   {var}")

if missing_vars:
    print(f"\n⚠️ Missing Variables ({len(missing_vars)}):")
    for var in missing_vars:
        print(f"   {var}")
    
    print(f"\n💡 Before training, you need to run these cells in order:")
    print(f"   1. Cell 3: SageMaker Setup")
    print(f"   2. Cell 5: Package Installation") 
    print(f"   3. Cell 7: Dataset Loading")
    print(f"   4. Cell 9: Model Configuration")
    print(f"   5. Cell 11: Training Setup")
    print(f"   6. Then you can run the training cells")
    
    print(f"\n🚫 Cannot start training - missing prerequisites!")
    
else:
    print(f"\n🎉 All prerequisites available!")
    print(f"✅ Ready to start Whisper training")
    
    # Show some stats if datasets is available
    try:
        if 'datasets' in locals() or 'datasets' in globals():
            print(f"\n📊 Dataset Info:")
            for split in datasets.keys():
                print(f"   {split}: {len(datasets[split]):,} samples")
    except:
        pass
        
print("\n" + "=" * 50)

🔍 Checking Prerequisites for Whisper Training...
🎯 Available Variables:
   ✅ datasets - Dataset loading (Cell 7)
   ✅ tiny_model - Whisper Tiny model (Cell 9)
   ✅ tiny_tokenizer - Whisper Tiny tokenizer (Cell 9)
   ✅ tiny_processor - Whisper Tiny processor (Cell 9)
   ✅ tiny_config - Whisper Tiny config (Cell 9)
   ✅ setup_trainer - Training setup function (Cell 11)

🎉 All prerequisites available!
✅ Ready to start Whisper training

📊 Dataset Info:
   train: 1,000 samples
   validation: 1,000 samples
   test: 1,000 samples



In [10]:
import time
from datetime import datetime
import os

print("🚀 Starting Whisper Tiny Fine-tuning...")
print("=" * 60)

# Setup trainer for Tiny model
tiny_trainer, tiny_training_args = setup_trainer(
    model=tiny_model,
    tokenizer=tiny_tokenizer, 
    processor=tiny_processor,
    config=tiny_config,
    datasets=datasets,
    model_size="tiny"
)

# Print training info
print(f"📊 Whisper Tiny Training Configuration:")
print(f"   Model Parameters: {sum(p.numel() for p in tiny_model.parameters()):,}")
print(f"   Batch Size: {tiny_training_args.per_device_train_batch_size}")
print(f"   Gradient Accumulation: {tiny_training_args.gradient_accumulation_steps}")
print(f"   Effective Batch Size: {tiny_training_args.per_device_train_batch_size * tiny_training_args.gradient_accumulation_steps}")
print(f"   Max Steps: {tiny_training_args.max_steps}")
print(f"   Learning Rate: {tiny_training_args.learning_rate}")
print(f"   Output Directory: {tiny_training_args.output_dir}")

# Check GPU memory before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    allocated_memory = torch.cuda.memory_allocated(0) / 1024**3
    cached_memory = torch.cuda.memory_reserved(0) / 1024**3
    print(f"\n💾 Pre-training GPU Memory:")
    print(f"   Allocated: {allocated_memory:.2f} GB")
    print(f"   Cached: {cached_memory:.2f} GB")

print(f"\n⏰ Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("🎯 Expected training time: ~3-4 hours")

# Start training
start_time = time.time()

try:
    # Train the model
    tiny_trainer.train()
    
    training_time = time.time() - start_time
    print(f"\n🎉 Whisper Tiny Training Complete!")
    print(f"⏱️ Total Training Time: {training_time/3600:.2f} hours")
    
    # Save the final model
    tiny_trainer.save_model()
    tiny_trainer.tokenizer.save_pretrained(tiny_training_args.output_dir)
    
    print(f"💾 Model saved to: {tiny_training_args.output_dir}")
    
    # Get final metrics
    final_logs = tiny_trainer.state.log_history[-1]
    if 'eval_wer' in final_logs:
        print(f"📊 Final WER: {final_logs['eval_wer']:.4f}")
        print(f"📊 Final CER: {final_logs.get('eval_cer', 'N/A'):.4f}")
    
    # Memory usage after training
    if torch.cuda.is_available():
        allocated_memory = torch.cuda.memory_allocated(0) / 1024**3
        max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
        print(f"\n💾 Post-training GPU Memory:")
        print(f"   Current: {allocated_memory:.2f} GB")
        print(f"   Peak: {max_memory:.2f} GB")
        
        # Clear cache for next model
        torch.cuda.empty_cache()
        print("🧹 GPU cache cleared for next model")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

print("=" * 60)

🚀 Starting Whisper Tiny Fine-tuning...
