# Piper TTS Fine-Tuning: Indian English Voice

This notebook provides a complete pipeline for fine-tuning the Piper TTS `en_US/hfc_female/medium` model
using the IISc SPICOR English dataset to create an Indian English voice.

**Compatible with:** Google Colab, AWS SageMaker

## Overview
1. Environment Setup
2. Configuration
3. Data ETL (Extract, Transform, Load from S3)
4. Model Checkpoint Download
5. Data Preprocessing
6. Fine-Tuning with Checkpointing
7. Resume Training from Checkpoint
8. Export to ONNX

## 1. Environment Setup

In [None]:
# Detect runtime environment
import os
import sys

def detect_environment():
    """Detect if running on Colab, SageMaker, or local."""
    if 'google.colab' in sys.modules:
        return 'colab'
    elif os.environ.get('SM_CURRENT_HOST'):
        return 'sagemaker'
    return 'local'

RUNTIME_ENV = detect_environment()
print(f"Detected environment: {RUNTIME_ENV}")

In [None]:
# Install system dependencies
if RUNTIME_ENV in ['colab', 'sagemaker']:
    !apt-get update -qq
    !apt-get install -y -qq build-essential cmake ninja-build espeak-ng
    print("System dependencies installed.")

In [None]:
# Install Python dependencies
!pip install -q torch>=2.0 lightning>=2.0 tensorboard tensorboardX
!pip install -q jsonargparse[signatures]>=4.27.7 pathvalidate>=3 onnx>=1
!pip install -q pysilero-vad>=2.1 cython>=3 librosa boto3 huggingface_hub
!pip install -q soundfile tqdm pandas
print("Python dependencies installed.")

In [None]:
# Clone and install Piper
PIPER_REPO_URL = "https://github.com/OHF-voice/piper1-gpl.git"
PIPER_DIR = "/content/piper1-gpl" if RUNTIME_ENV == 'colab' else "./piper1-gpl"

if not os.path.exists(PIPER_DIR):
    !git clone {PIPER_REPO_URL} {PIPER_DIR}

os.chdir(PIPER_DIR)
!pip install -e .[train] -q
!./build_monotonic_align.sh
!python setup.py build_ext --inplace
print("Piper installed successfully.")

## 2. Configuration

Configure all paths and parameters. Replace placeholders with your actual values.

In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

@dataclass
class Config:
    """Configuration for Piper fine-tuning pipeline."""
    
    # ==================== PLACEHOLDERS - UPDATE THESE ====================
    # S3 Configuration
    s3_bucket: str = "<YOUR_S3_BUCKET_NAME>"  # e.g., "my-tts-training-bucket"
    s3_dataset_prefix: str = "<S3_DATASET_PATH>"  # e.g., "datasets/spicor"
    s3_checkpoint_prefix: str = "<S3_CHECKPOINT_PATH>"  # e.g., "checkpoints/piper"
    aws_region: str = "<AWS_REGION>"  # e.g., "us-east-1"
    
    # AWS Credentials (leave empty to use IAM role or env vars)
    aws_access_key_id: Optional[str] = None
    aws_secret_access_key: Optional[str] = None
    # =====================================================================
    
    # Local paths
    base_dir: str = "./piper_training"
    data_dir: str = field(default="")
    audio_dir: str = field(default="")
    cache_dir: str = field(default="")
    checkpoint_dir: str = field(default="")
    output_dir: str = field(default="")
    
    # Model configuration
    voice_name: str = "en_IN-spicor-medium"
    espeak_voice: str = "en-us"  # espeak voice for phonemization
    sample_rate: int = 22050
    
    # Training configuration
    batch_size: int = 16  # Reduce if OOM
    num_workers: int = 4
    max_epochs: int = 1000
    val_check_interval: int = 1000  # Validate every N steps
    save_every_n_steps: int = 500  # Checkpoint frequency
    
    # Hugging Face checkpoint
    hf_checkpoint_repo: str = "datasets/rhasspy/piper-checkpoints"
    hf_checkpoint_path: str = "en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt"
    
    def __post_init__(self):
        self.data_dir = f"{self.base_dir}/data"
        self.audio_dir = f"{self.base_dir}/data/audio"
        self.cache_dir = f"{self.base_dir}/cache"
        self.checkpoint_dir = f"{self.base_dir}/checkpoints"
        self.output_dir = f"{self.base_dir}/output"

config = Config()
print(f"Configuration initialized for voice: {config.voice_name}")

In [None]:
# Create directory structure
def create_directories(config: Config):
    """Create all required directories."""
    dirs = [
        config.data_dir,
        config.audio_dir,
        config.cache_dir,
        config.checkpoint_dir,
        config.output_dir,
    ]
    for d in dirs:
        Path(d).mkdir(parents=True, exist_ok=True)
        print(f"Created: {d}")

create_directories(config)

## 3. Data ETL (Extract, Transform, Load)

Download and process the SPICOR dataset from S3.

In [None]:
import boto3
from botocore.exceptions import ClientError
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class S3DataLoader:
    """Handle S3 data operations for dataset and checkpoints."""
    
    def __init__(self, config: Config):
        self.config = config
        self.s3_client = self._create_s3_client()
    
    def _create_s3_client(self):
        """Create S3 client with optional credentials."""
        kwargs = {'region_name': self.config.aws_region}
        if self.config.aws_access_key_id and self.config.aws_secret_access_key:
            kwargs['aws_access_key_id'] = self.config.aws_access_key_id
            kwargs['aws_secret_access_key'] = self.config.aws_secret_access_key
        return boto3.client('s3', **kwargs)
    
    def list_objects(self, prefix: str) -> list:
        """List all objects under a prefix."""
        objects = []
        paginator = self.s3_client.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=self.config.s3_bucket, Prefix=prefix):
            if 'Contents' in page:
                objects.extend([obj['Key'] for obj in page['Contents']])
        return objects
    
    def download_file(self, s3_key: str, local_path: str) -> bool:
        """Download a single file from S3."""
        try:
            Path(local_path).parent.mkdir(parents=True, exist_ok=True)
            self.s3_client.download_file(self.config.s3_bucket, s3_key, local_path)
            return True
        except ClientError as e:
            logger.error(f"Failed to download {s3_key}: {e}")
            return False
    
    def download_dataset(self, local_dir: str) -> int:
        """Download entire dataset from S3."""
        prefix = self.config.s3_dataset_prefix
        objects = self.list_objects(prefix)
        
        downloaded = 0
        for s3_key in tqdm(objects, desc="Downloading dataset"):
            relative_path = s3_key[len(prefix):].lstrip('/')
            local_path = os.path.join(local_dir, relative_path)
            if self.download_file(s3_key, local_path):
                downloaded += 1
        
        logger.info(f"Downloaded {downloaded} files from S3")
        return downloaded
    
    def upload_checkpoint(self, local_path: str, s3_key: str) -> bool:
        """Upload checkpoint to S3 for backup."""
        try:
            self.s3_client.upload_file(local_path, self.config.s3_bucket, s3_key)
            logger.info(f"Uploaded checkpoint to s3://{self.config.s3_bucket}/{s3_key}")
            return True
        except ClientError as e:
            logger.error(f"Failed to upload {local_path}: {e}")
            return False

# Initialize S3 loader
s3_loader = S3DataLoader(config)
print("S3 Data Loader initialized.")

In [None]:
# Download dataset from S3
# Uncomment and run when ready to download
# downloaded_count = s3_loader.download_dataset(config.data_dir)
# print(f"Downloaded {downloaded_count} files")

## 4. Download Pre-trained Checkpoint from Hugging Face

In [None]:
from huggingface_hub import hf_hub_download, snapshot_download

def download_pretrained_checkpoint(config: Config) -> str:
    """Download pre-trained checkpoint from Hugging Face."""
    checkpoint_path = os.path.join(config.checkpoint_dir, "pretrained.ckpt")
    
    if os.path.exists(checkpoint_path):
        print(f"Checkpoint already exists: {checkpoint_path}")
        return checkpoint_path
    
    print("Downloading pre-trained checkpoint from Hugging Face...")
    downloaded_path = hf_hub_download(
        repo_id="rhasspy/piper-checkpoints",
        filename="en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt",
        repo_type="dataset",
        local_dir=config.checkpoint_dir,
    )
    
    # Also download config.json
    hf_hub_download(
        repo_id="rhasspy/piper-checkpoints",
        filename="en/en_US/hfc_female/medium/config.json",
        repo_type="dataset",
        local_dir=config.checkpoint_dir,
    )
    
    print(f"Checkpoint downloaded to: {downloaded_path}")
    return downloaded_path

pretrained_ckpt_path = download_pretrained_checkpoint(config)

## 5. Data Preprocessing

Process SPICOR dataset and create metadata CSV.

In [None]:
import pandas as pd
import soundfile as sf
import librosa
from pathlib import Path
from typing import List, Tuple

class SPICORDataProcessor:
    """Process SPICOR dataset for Piper training."""
    
    def __init__(self, config: Config):
        self.config = config
        self.audio_extensions = ['.wav', '.flac', '.mp3', '.ogg']
    
    def find_audio_files(self, directory: str) -> List[Path]:
        """Find all audio files in directory."""
        audio_files = []
        for ext in self.audio_extensions:
            audio_files.extend(Path(directory).rglob(f"*{ext}"))
        return sorted(audio_files)
    
    def get_audio_duration(self, audio_path: Path) -> float:
        """Get audio duration in seconds."""
        try:
            info = sf.info(str(audio_path))
            return info.duration
        except Exception:
            return 0.0
    
    def resample_audio(self, audio_path: Path, output_path: Path) -> bool:
        """Resample audio to target sample rate."""
        try:
            audio, sr = librosa.load(str(audio_path), sr=self.config.sample_rate, mono=True)
            sf.write(str(output_path), audio, self.config.sample_rate)
            return True
        except Exception as e:
            logger.error(f"Failed to resample {audio_path}: {e}")
            return False
    
    def parse_transcript(self, transcript_path: Path) -> dict:
        """Parse transcript file (adapt based on SPICOR format)."""
        transcripts = {}
        try:
            with open(transcript_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if '|' in line:
                        parts = line.split('|')
                        if len(parts) >= 2:
                            transcripts[parts[0]] = parts[1]
                    elif '\t' in line:
                        parts = line.split('\t')
                        if len(parts) >= 2:
                            transcripts[parts[0]] = parts[1]
        except Exception as e:
            logger.error(f"Failed to parse transcript: {e}")
        return transcripts
    
    def process_dataset(
        self,
        source_dir: str,
        transcript_file: Optional[str] = None,
        min_duration: float = 0.5,
        max_duration: float = 15.0,
    ) -> pd.DataFrame:
        """Process dataset and create metadata DataFrame."""
        audio_files = self.find_audio_files(source_dir)
        logger.info(f"Found {len(audio_files)} audio files")
        
        # Load transcripts if provided
        transcripts = {}
        if transcript_file and os.path.exists(transcript_file):
            transcripts = self.parse_transcript(Path(transcript_file))
        
        metadata = []
        processed_dir = Path(self.config.audio_dir)
        
        for audio_path in tqdm(audio_files, desc="Processing audio files"):
            duration = self.get_audio_duration(audio_path)
            
            # Filter by duration
            if duration < min_duration or duration > max_duration:
                continue
            
            # Get transcript
            audio_id = audio_path.stem
            transcript = transcripts.get(audio_id, transcripts.get(audio_path.name, ""))
            
            if not transcript:
                # Try to find transcript in .txt file with same name
                txt_path = audio_path.with_suffix('.txt')
                if txt_path.exists():
                    transcript = txt_path.read_text(encoding='utf-8').strip()
            
            if not transcript:
                continue
            
            # Resample and copy audio
            output_path = processed_dir / f"{audio_id}.wav"
            if not output_path.exists():
                if not self.resample_audio(audio_path, output_path):
                    continue
            
            metadata.append({
                'audio_file': f"{audio_id}.wav",
                'text': transcript,
                'duration': duration,
            })
        
        df = pd.DataFrame(metadata)
        logger.info(f"Processed {len(df)} valid utterances")
        return df
    
    def create_metadata_csv(self, df: pd.DataFrame, output_path: str) -> str:
        """Create metadata CSV in Piper format."""
        # Piper format: audio_file|text
        csv_path = Path(output_path)
        with open(csv_path, 'w', encoding='utf-8') as f:
            for _, row in df.iterrows():
                f.write(f"{row['audio_file']}|{row['text']}\n")
        
        logger.info(f"Created metadata CSV: {csv_path}")
        return str(csv_path)

# Initialize processor
data_processor = SPICORDataProcessor(config)
print("Data processor initialized.")

In [None]:
# Process SPICOR dataset
# Uncomment and adjust paths when dataset is downloaded

# SPICOR_SOURCE_DIR = f"{config.data_dir}/spicor_raw"  # Adjust to your dataset structure
# TRANSCRIPT_FILE = f"{config.data_dir}/transcripts.txt"  # Adjust path

# metadata_df = data_processor.process_dataset(
#     source_dir=SPICOR_SOURCE_DIR,
#     transcript_file=TRANSCRIPT_FILE,
#     min_duration=0.5,
#     max_duration=15.0,
# )

# metadata_csv_path = data_processor.create_metadata_csv(
#     metadata_df,
#     f"{config.data_dir}/metadata.csv"
# )

# print(f"Total utterances: {len(metadata_df)}")
# print(f"Total audio duration: {metadata_df['duration'].sum() / 3600:.2f} hours")

## 6. Fine-Tuning with Checkpointing

In [None]:
import torch
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from datetime import datetime

class PiperTrainer:
    """Wrapper class for Piper model training with checkpointing."""
    
    def __init__(self, config: Config):
        self.config = config
        self._setup_torch()
    
    def _setup_torch(self):
        """Configure PyTorch settings."""
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.deterministic = False
        torch.manual_seed(42)
    
    def get_callbacks(self) -> list:
        """Create training callbacks."""
        callbacks = [
            # Save checkpoints frequently
            ModelCheckpoint(
                dirpath=self.config.checkpoint_dir,
                filename='piper-{epoch:04d}-{step:08d}-{val_loss:.4f}',
                save_top_k=3,
                monitor='val_loss',
                mode='min',
                every_n_train_steps=self.config.save_every_n_steps,
                save_last=True,
            ),
            # Additional checkpoint for step-based saving
            ModelCheckpoint(
                dirpath=self.config.checkpoint_dir,
                filename='piper-step-{step:08d}',
                every_n_train_steps=self.config.save_every_n_steps,
                save_top_k=-1,  # Keep all
            ),
            # Learning rate monitoring
            LearningRateMonitor(logging_interval='step'),
            # Early stopping (optional)
            EarlyStopping(
                monitor='val_loss',
                patience=50,
                mode='min',
                verbose=True,
            ),
        ]
        return callbacks
    
    def get_logger(self) -> TensorBoardLogger:
        """Create TensorBoard logger."""
        return TensorBoardLogger(
            save_dir=self.config.output_dir,
            name='piper_training',
            version=datetime.now().strftime('%Y%m%d_%H%M%S'),
        )
    
    def find_latest_checkpoint(self) -> Optional[str]:
        """Find the latest checkpoint for resuming training."""
        ckpt_dir = Path(self.config.checkpoint_dir)
        
        # Look for 'last.ckpt' first
        last_ckpt = ckpt_dir / 'last.ckpt'
        if last_ckpt.exists():
            return str(last_ckpt)
        
        # Find most recent checkpoint by modification time
        checkpoints = list(ckpt_dir.glob('*.ckpt'))
        if checkpoints:
            latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
            return str(latest)
        
        return None
    
    def train(
        self,
        metadata_csv: str,
        pretrained_ckpt: Optional[str] = None,
        resume_from_checkpoint: Optional[str] = None,
    ):
        """Run training with checkpointing."""
        from piper.train.vits.dataset import VitsDataModule
        from piper.train.vits.lightning import VitsModel
        
        # Determine checkpoint path
        ckpt_path = resume_from_checkpoint
        if ckpt_path is None and pretrained_ckpt:
            ckpt_path = pretrained_ckpt
        
        # Create data module
        data_module = VitsDataModule(
            csv_path=metadata_csv,
            cache_dir=self.config.cache_dir,
            espeak_voice=self.config.espeak_voice,
            config_path=f"{self.config.output_dir}/{self.config.voice_name}.json",
            voice_name=self.config.voice_name,
            sample_rate=self.config.sample_rate,
            audio_dir=self.config.audio_dir,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
        )
        
        # Create trainer
        trainer = L.Trainer(
            max_epochs=self.config.max_epochs,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=1,
            callbacks=self.get_callbacks(),
            logger=self.get_logger(),
            val_check_interval=self.config.val_check_interval,
            log_every_n_steps=50,
            precision='16-mixed' if torch.cuda.is_available() else 32,
            gradient_clip_val=1.0,
        )
        
        # Load or create model
        if ckpt_path and Path(ckpt_path).exists():
            print(f"Loading model from checkpoint: {ckpt_path}")
            model = VitsModel.load_from_checkpoint(ckpt_path)
        else:
            print("Creating new model")
            model = VitsModel(
                batch_size=self.config.batch_size,
                sample_rate=self.config.sample_rate,
            )
        
        # Start training
        print(f"Starting training for {self.config.max_epochs} epochs...")
        trainer.fit(
            model,
            data_module,
            ckpt_path=resume_from_checkpoint,  # Only for resuming optimizer state
        )
        
        return trainer, model

# Initialize trainer
piper_trainer = PiperTrainer(config)
print("Piper Trainer initialized.")

In [None]:
# Start training
# Uncomment when ready to train

# METADATA_CSV = f"{config.data_dir}/metadata.csv"

# trainer, model = piper_trainer.train(
#     metadata_csv=METADATA_CSV,
#     pretrained_ckpt=pretrained_ckpt_path,
#     resume_from_checkpoint=None,  # Set to checkpoint path to resume
# )

## 7. Resume Training from Checkpoint

In [None]:
def resume_training(config: Config, metadata_csv: str) -> tuple:
    """Resume training from the latest checkpoint."""
    trainer_instance = PiperTrainer(config)
    
    # Find latest checkpoint
    latest_ckpt = trainer_instance.find_latest_checkpoint()
    
    if latest_ckpt:
        print(f"Resuming from checkpoint: {latest_ckpt}")
        return trainer_instance.train(
            metadata_csv=metadata_csv,
            resume_from_checkpoint=latest_ckpt,
        )
    else:
        print("No checkpoint found. Starting fresh training.")
        return trainer_instance.train(
            metadata_csv=metadata_csv,
            pretrained_ckpt=download_pretrained_checkpoint(config),
        )

# Resume training
# Uncomment when ready

# METADATA_CSV = f"{config.data_dir}/metadata.csv"
# trainer, model = resume_training(config, METADATA_CSV)

## 8. Export to ONNX

Export the fine-tuned model to ONNX format for web deployment.

In [None]:
import torch
import onnx
from pathlib import Path

OPSET_VERSION = 15

class ONNXExporter:
    """Export Piper model to ONNX format."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def export(
        self,
        checkpoint_path: str,
        output_path: Optional[str] = None,
    ) -> str:
        """Export model checkpoint to ONNX."""
        from piper.train.vits.lightning import VitsModel
        
        torch.manual_seed(1234)
        
        if output_path is None:
            output_path = f"{self.config.output_dir}/{self.config.voice_name}.onnx"
        
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Load model
        print(f"Loading checkpoint: {checkpoint_path}")
        model = VitsModel.load_from_checkpoint(checkpoint_path, map_location='cpu')
        model_g = model.model_g
        
        # Set to eval mode
        model_g.eval()
        
        # Remove weight normalization
        with torch.no_grad():
            model_g.dec.remove_weight_norm()
        
        # Create inference forward function
        def infer_forward(text, text_lengths, scales, sid=None):
            noise_scale = scales[0]
            length_scale = scales[1]
            noise_scale_w = scales[2]
            audio = model_g.infer(
                text,
                text_lengths,
                noise_scale=noise_scale,
                length_scale=length_scale,
                noise_scale_w=noise_scale_w,
                sid=sid,
            )[0].unsqueeze(1)
            return audio
        
        model_g.forward = infer_forward
        
        # Prepare dummy input
        num_symbols = model_g.n_vocab
        num_speakers = model_g.n_speakers
        
        dummy_input_length = 50
        sequences = torch.randint(
            low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
        )
        sequence_lengths = torch.LongTensor([sequences.size(1)])
        
        sid = None
        if num_speakers > 1:
            sid = torch.LongTensor([0])
        
        scales = torch.FloatTensor([0.667, 1.0, 0.8])
        dummy_input = (sequences, sequence_lengths, scales, sid)
        
        # Export to ONNX
        print(f"Exporting to ONNX: {output_path}")
        torch.onnx.export(
            model=model_g,
            args=dummy_input,
            f=str(output_path),
            verbose=False,
            opset_version=OPSET_VERSION,
            input_names=['input', 'input_lengths', 'scales', 'sid'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size', 1: 'phonemes'},
                'input_lengths': {0: 'batch_size'},
                'output': {0: 'batch_size', 2: 'time'},
            },
        )
        
        # Verify the exported model
        print("Verifying ONNX model...")
        onnx_model = onnx.load(str(output_path))
        onnx.checker.check_model(onnx_model)
        
        print(f"Model exported successfully to: {output_path}")
        print(f"Model size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
        
        return str(output_path)
    
    def copy_config(self, source_config: str, output_path: Optional[str] = None):
        """Copy and rename config file to match ONNX model."""
        import shutil
        
        if output_path is None:
            output_path = f"{self.config.output_dir}/{self.config.voice_name}.onnx.json"
        
        shutil.copy(source_config, output_path)
        print(f"Config copied to: {output_path}")
        return output_path

# Initialize exporter
onnx_exporter = ONNXExporter(config)
print("ONNX Exporter initialized.")

In [None]:
# Export model to ONNX
# Uncomment when training is complete

# Find best checkpoint
# best_ckpt = piper_trainer.find_latest_checkpoint()
# Or specify manually:
# best_ckpt = f"{config.checkpoint_dir}/your-checkpoint.ckpt"

# Export to ONNX
# onnx_path = onnx_exporter.export(best_ckpt)

# Copy config file
# config_source = f"{config.output_dir}/{config.voice_name}.json"
# onnx_exporter.copy_config(config_source)

## 9. Upload Final Model to S3 (Optional)

In [None]:
def upload_final_model(config: Config, s3_loader: S3DataLoader):
    """Upload final ONNX model and config to S3."""
    onnx_path = f"{config.output_dir}/{config.voice_name}.onnx"
    config_path = f"{config.output_dir}/{config.voice_name}.onnx.json"
    
    s3_prefix = f"{config.s3_checkpoint_prefix}/final"
    
    # Upload ONNX model
    s3_loader.upload_checkpoint(
        onnx_path,
        f"{s3_prefix}/{config.voice_name}.onnx"
    )
    
    # Upload config
    s3_loader.upload_checkpoint(
        config_path,
        f"{s3_prefix}/{config.voice_name}.onnx.json"
    )
    
    print(f"Model uploaded to s3://{config.s3_bucket}/{s3_prefix}/")

# Upload final model
# Uncomment when ready
# upload_final_model(config, s3_loader)

## 10. Test the Exported Model

In [None]:
import onnxruntime as ort
import numpy as np
import json

def test_onnx_model(onnx_path: str, config_path: str, text: str = "Hello, this is a test."):
    """Test the exported ONNX model."""
    # Load config
    with open(config_path, 'r') as f:
        model_config = json.load(f)
    
    phoneme_id_map = model_config['phoneme_id_map']
    
    # Create ONNX session
    session = ort.InferenceSession(onnx_path)
    
    # For testing, use simple phoneme IDs
    # In production, use espeak-ng for proper phonemization
    test_phonemes = [phoneme_id_map.get(c, [0])[0] for c in text.lower() if c in phoneme_id_map]
    
    if not test_phonemes:
        test_phonemes = [1, 2, 3, 4, 5]  # Fallback
    
    # Prepare inputs
    input_array = np.array([test_phonemes], dtype=np.int64)
    input_lengths = np.array([len(test_phonemes)], dtype=np.int64)
    scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)
    
    # Run inference
    inputs = {
        'input': input_array,
        'input_lengths': input_lengths,
        'scales': scales,
    }
    
    output = session.run(None, inputs)
    audio = output[0]
    
    print(f"Generated audio shape: {audio.shape}")
    print(f"Audio duration: {audio.shape[-1] / model_config['audio']['sample_rate']:.2f} seconds")
    
    return audio

# Test the model
# Uncomment when ONNX export is complete

# onnx_path = f"{config.output_dir}/{config.voice_name}.onnx"
# config_path = f"{config.output_dir}/{config.voice_name}.onnx.json"
# audio = test_onnx_model(onnx_path, config_path)

## Summary

This notebook provides a complete pipeline for:

1. **Environment Setup**: Install all dependencies for Colab/SageMaker
2. **Configuration**: Centralized config with S3 and training parameters
3. **Data ETL**: Download SPICOR dataset from S3 and process it
4. **Checkpoint Download**: Get pre-trained model from Hugging Face
5. **Data Preprocessing**: Create metadata CSV in Piper format
6. **Fine-Tuning**: Train with frequent checkpointing (every 500 steps)
7. **Resume Training**: Automatically find and resume from latest checkpoint
8. **ONNX Export**: Export to lightweight ONNX format for web deployment
9. **S3 Upload**: Backup final model to S3
10. **Testing**: Verify the exported model works correctly

### Key Features:
- Modular design with separate classes for each component
- Automatic checkpoint saving every 500 training steps
- Support for resuming training from any checkpoint
- S3 integration for data and checkpoint storage
- ONNX export with model verification
- Compatible with both Google Colab and AWS SageMaker