# Piper TTS Fine-Tuning: Indian English Voice

This notebook provides a complete pipeline for fine-tuning the Piper TTS `en_US/hfc_female/medium` model using the IISc SPICOR English dataset to create an Indian English voice.

**Compatible with:** Google Colab and AWS SageMaker

## Overview
1. Environment Setup
2. Configuration
3. Dataset ETL from S3
4. Data Preprocessing
5. Model Fine-Tuning with Checkpointing
6. Resume Training from Checkpoint
7. ONNX Export for Web Deployment

## 1. Environment Setup

In [None]:
# Detect environment (Colab vs SageMaker)
import os
import sys

def detect_environment():
    """Detect if running on Colab or SageMaker."""
    if 'google.colab' in sys.modules:
        return 'colab'
    elif os.path.exists('/opt/ml'):
        return 'sagemaker'
    else:
        return 'local'

ENV = detect_environment()
print(f"Running on: {ENV}")

In [None]:
# Install system dependencies
if ENV == 'colab':
    !apt-get update -qq
    !apt-get install -y -qq build-essential cmake ninja-build espeak-ng
elif ENV == 'sagemaker':
    !sudo yum install -y espeak-ng cmake ninja-build gcc gcc-c++ make
else:
    print("Please ensure build-essential, cmake, ninja-build, and espeak-ng are installed.")

In [None]:
# Clone Piper repository and install
!git clone https://github.com/OHF-voice/piper1-gpl.git /tmp/piper1-gpl
%cd /tmp/piper1-gpl
!pip install -e .[train] --quiet
!./build_monotonic_align.sh
!python setup.py build_ext --inplace

In [None]:
# Install additional dependencies
!pip install boto3 huggingface_hub tqdm soundfile --quiet

## 2. Configuration

Configure all paths and parameters. Replace placeholder values with your actual settings.

In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

@dataclass
class Config:
    """Configuration for Piper fine-tuning pipeline."""
    
    # ==================== PLACEHOLDERS - UPDATE THESE ====================
    
    # S3 Configuration
    s3_bucket: str = "YOUR_S3_BUCKET_NAME"  # e.g., "my-tts-training-bucket"
    s3_dataset_prefix: str = "datasets/spicor/"  # Path to SPICOR dataset in S3
    s3_checkpoint_prefix: str = "checkpoints/piper-indian-english/"  # Where to save checkpoints
    aws_region: str = "us-east-1"  # Your AWS region
    
    # HuggingFace Configuration
    hf_checkpoint_repo: str = "rhasspy/piper-checkpoints"
    hf_checkpoint_path: str = "en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt"
    hf_model_repo: str = "rhasspy/piper-voices"  # Alternative: "the-vedantic-coder/text-to-speech-en-US-web"
    
    # ==================== LOCAL PATHS ====================
    
    # Base directories
    base_dir: Path = field(default_factory=lambda: Path("/tmp/piper-training"))
    
    @property
    def data_dir(self) -> Path:
        return self.base_dir / "data"
    
    @property
    def audio_dir(self) -> Path:
        return self.data_dir / "audio"
    
    @property
    def cache_dir(self) -> Path:
        return self.base_dir / "cache"
    
    @property
    def checkpoint_dir(self) -> Path:
        return self.base_dir / "checkpoints"
    
    @property
    def output_dir(self) -> Path:
        return self.base_dir / "output"
    
    @property
    def csv_path(self) -> Path:
        return self.data_dir / "metadata.csv"
    
    @property
    def config_path(self) -> Path:
        return self.output_dir / "config.json"
    
    @property
    def pretrained_checkpoint_path(self) -> Path:
        return self.checkpoint_dir / "pretrained.ckpt"
    
    # ==================== TRAINING PARAMETERS ====================
    
    # Voice settings
    voice_name: str = "en_IN-spicor-medium"
    espeak_voice: str = "en-us"  # Use en-us for Indian English phonemization
    sample_rate: int = 22050
    
    # Training hyperparameters
    batch_size: int = 16  # Adjust based on GPU memory (8-32 typical)
    learning_rate: float = 1e-4  # Lower LR for fine-tuning
    max_epochs: int = 500
    
    # Checkpointing
    checkpoint_every_n_steps: int = 500  # Save checkpoint every N steps
    checkpoint_every_n_epochs: int = 10  # Also save every N epochs
    keep_last_n_checkpoints: int = 5  # Number of recent checkpoints to keep
    
    # Resume training
    resume_from_checkpoint: Optional[str] = None  # Path to checkpoint to resume from
    
    # Hardware
    num_workers: int = 4
    accelerator: str = "auto"  # "gpu", "cpu", or "auto"
    devices: int = 1  # Number of GPUs
    precision: str = "16-mixed"  # Use mixed precision for faster training
    
    def create_directories(self):
        """Create all necessary directories."""
        for path in [self.data_dir, self.audio_dir, self.cache_dir, 
                     self.checkpoint_dir, self.output_dir]:
            path.mkdir(parents=True, exist_ok=True)

# Initialize configuration
config = Config()
config.create_directories()
print(f"Configuration initialized. Base directory: {config.base_dir}")

## 3. Dataset ETL from S3

Download and process the SPICOR English dataset from S3.

In [None]:
import boto3
from botocore.exceptions import ClientError
from tqdm import tqdm
import json

class S3DatasetLoader:
    """Handles dataset loading from S3."""
    
    def __init__(self, config: Config):
        self.config = config
        self.s3_client = boto3.client('s3', region_name=config.aws_region)
    
    def list_audio_files(self) -> list:
        """List all audio files in the S3 dataset prefix."""
        audio_files = []
        paginator = self.s3_client.get_paginator('list_objects_v2')
        
        for page in paginator.paginate(Bucket=self.config.s3_bucket, 
                                        Prefix=self.config.s3_dataset_prefix):
            for obj in page.get('Contents', []):
                key = obj['Key']
                if key.endswith(('.wav', '.mp3', '.flac')):
                    audio_files.append(key)
        
        return audio_files
    
    def download_file(self, s3_key: str, local_path: Path) -> bool:
        """Download a single file from S3."""
        try:
            local_path.parent.mkdir(parents=True, exist_ok=True)
            self.s3_client.download_file(
                self.config.s3_bucket, s3_key, str(local_path)
            )
            return True
        except ClientError as e:
            print(f"Error downloading {s3_key}: {e}")
            return False
    
    def download_dataset(self, max_files: Optional[int] = None) -> int:
        """Download the entire dataset from S3."""
        audio_files = self.list_audio_files()
        
        if max_files:
            audio_files = audio_files[:max_files]
        
        downloaded = 0
        for s3_key in tqdm(audio_files, desc="Downloading audio files"):
            filename = Path(s3_key).name
            local_path = self.config.audio_dir / filename
            
            if local_path.exists():
                downloaded += 1
                continue
            
            if self.download_file(s3_key, local_path):
                downloaded += 1
        
        return downloaded
    
    def download_metadata(self, metadata_key: str) -> dict:
        """Download and parse metadata JSON from S3."""
        local_path = self.config.data_dir / "metadata.json"
        self.download_file(metadata_key, local_path)
        
        with open(local_path, 'r') as f:
            return json.load(f)
    
    def upload_checkpoint(self, local_path: Path, s3_key: str) -> bool:
        """Upload a checkpoint to S3 for backup."""
        try:
            self.s3_client.upload_file(
                str(local_path), self.config.s3_bucket, s3_key
            )
            return True
        except ClientError as e:
            print(f"Error uploading checkpoint: {e}")
            return False

In [None]:
# Initialize S3 loader and download dataset
# Uncomment and run when ready to download from S3

# s3_loader = S3DatasetLoader(config)
# num_downloaded = s3_loader.download_dataset()
# print(f"Downloaded {num_downloaded} audio files")

## 4. Data Preprocessing

Process the SPICOR dataset into the format required by Piper.

In [None]:
import csv
import soundfile as sf
import librosa
from pathlib import Path
from typing import List, Tuple
import re

class DataPreprocessor:
    """Preprocesses audio data for Piper training."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def parse_spicor_metadata(self, metadata_path: Path) -> List[Tuple[str, str]]:
        """
        Parse SPICOR dataset metadata.
        
        SPICOR format varies - adjust this method based on actual format.
        Expected: Returns list of (audio_filename, transcript) tuples.
        """
        entries = []
        
        # Handle different possible metadata formats
        if metadata_path.suffix == '.json':
            import json
            with open(metadata_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                for item in data:
                    audio_file = item.get('audio', item.get('file', item.get('path')))
                    text = item.get('text', item.get('transcript', item.get('sentence')))
                    if audio_file and text:
                        entries.append((audio_file, text))
        
        elif metadata_path.suffix == '.csv':
            with open(metadata_path, 'r', encoding='utf-8') as f:
                reader = csv.reader(f, delimiter='|')
                for row in reader:
                    if len(row) >= 2:
                        entries.append((row[0], row[-1]))
        
        elif metadata_path.suffix == '.txt':
            with open(metadata_path, 'r', encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('|')
                    if len(parts) >= 2:
                        entries.append((parts[0], parts[-1]))
        
        return entries
    
    def normalize_text(self, text: str) -> str:
        """Normalize text for TTS training."""
        # Remove extra whitespace
        text = ' '.join(text.split())
        # Basic normalization
        text = text.strip()
        return text
    
    def validate_audio(self, audio_path: Path, min_duration: float = 0.5, 
                       max_duration: float = 15.0) -> bool:
        """Validate audio file for training."""
        try:
            info = sf.info(str(audio_path))
            duration = info.duration
            return min_duration <= duration <= max_duration
        except Exception as e:
            print(f"Error validating {audio_path}: {e}")
            return False
    
    def resample_audio(self, input_path: Path, output_path: Path, 
                       target_sr: int = 22050) -> bool:
        """Resample audio to target sample rate."""
        try:
            audio, sr = librosa.load(str(input_path), sr=target_sr, mono=True)
            sf.write(str(output_path), audio, target_sr)
            return True
        except Exception as e:
            print(f"Error resampling {input_path}: {e}")
            return False
    
    def create_metadata_csv(self, entries: List[Tuple[str, str]]) -> int:
        """
        Create Piper-compatible metadata CSV.
        
        Format: audio_filename|text
        """
        valid_entries = 0
        
        with open(self.config.csv_path, 'w', encoding='utf-8', newline='') as f:
            writer = csv.writer(f, delimiter='|')
            
            for audio_file, text in tqdm(entries, desc="Processing entries"):
                audio_path = self.config.audio_dir / audio_file
                
                # Check if audio file exists
                if not audio_path.exists():
                    # Try with .wav extension
                    audio_path = self.config.audio_dir / f"{audio_file}.wav"
                    if not audio_path.exists():
                        continue
                
                # Validate audio
                if not self.validate_audio(audio_path):
                    continue
                
                # Normalize text
                normalized_text = self.normalize_text(text)
                if not normalized_text:
                    continue
                
                # Write entry
                writer.writerow([audio_path.name, normalized_text])
                valid_entries += 1
        
        return valid_entries
    
    def process_dataset(self, metadata_path: Path) -> int:
        """Full preprocessing pipeline."""
        print("Parsing metadata...")
        entries = self.parse_spicor_metadata(metadata_path)
        print(f"Found {len(entries)} entries in metadata")
        
        print("Creating metadata CSV...")
        valid_entries = self.create_metadata_csv(entries)
        print(f"Created metadata CSV with {valid_entries} valid entries")
        
        return valid_entries

In [None]:
# Create sample metadata for testing (replace with actual SPICOR metadata path)
# preprocessor = DataPreprocessor(config)
# num_entries = preprocessor.process_dataset(config.data_dir / "spicor_metadata.json")
# print(f"Processed {num_entries} training samples")

## 5. Download Pre-trained Checkpoint

Download the base checkpoint from HuggingFace for fine-tuning.

In [None]:
from huggingface_hub import hf_hub_download, snapshot_download

class CheckpointManager:
    """Manages model checkpoints."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def download_pretrained_checkpoint(self) -> Path:
        """Download pre-trained checkpoint from HuggingFace."""
        print(f"Downloading checkpoint from {self.config.hf_checkpoint_repo}...")
        
        checkpoint_path = hf_hub_download(
            repo_id=self.config.hf_checkpoint_repo,
            filename=self.config.hf_checkpoint_path,
            repo_type="dataset",
            local_dir=self.config.checkpoint_dir,
            local_dir_use_symlinks=False
        )
        
        # Copy to standard location
        import shutil
        shutil.copy(checkpoint_path, self.config.pretrained_checkpoint_path)
        
        print(f"Checkpoint saved to: {self.config.pretrained_checkpoint_path}")
        return self.config.pretrained_checkpoint_path
    
    def get_latest_checkpoint(self) -> Optional[Path]:
        """Find the latest training checkpoint."""
        checkpoints = list(self.config.checkpoint_dir.glob("epoch=*-step=*.ckpt"))
        
        if not checkpoints:
            return None
        
        # Sort by modification time
        checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
        return checkpoints[0]
    
    def cleanup_old_checkpoints(self, keep_n: int = 5):
        """Remove old checkpoints, keeping the N most recent."""
        checkpoints = list(self.config.checkpoint_dir.glob("epoch=*-step=*.ckpt"))
        checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
        
        for ckpt in checkpoints[keep_n:]:
            print(f"Removing old checkpoint: {ckpt.name}")
            ckpt.unlink()

In [None]:
# Download pre-trained checkpoint
checkpoint_manager = CheckpointManager(config)
pretrained_ckpt = checkpoint_manager.download_pretrained_checkpoint()

## 6. Model Fine-Tuning with Checkpointing

Fine-tune the Piper model with frequent checkpointing for fault tolerance.

In [None]:
import torch
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from pathlib import Path
import sys

# Add piper to path
sys.path.insert(0, '/tmp/piper1-gpl/src')

from piper.train.vits.lightning import VitsModel
from piper.train.vits.dataset import VitsDataModule

class PiperTrainer:
    """Handles Piper model training with checkpointing."""
    
    def __init__(self, config: Config):
        self.config = config
        self.trainer = None
        self.model = None
        self.data_module = None
    
    def setup_data_module(self) -> VitsDataModule:
        """Initialize the data module."""
        self.data_module = VitsDataModule(
            csv_path=self.config.csv_path,
            cache_dir=self.config.cache_dir,
            espeak_voice=self.config.espeak_voice,
            config_path=self.config.config_path,
            voice_name=self.config.voice_name,
            sample_rate=self.config.sample_rate,
            audio_dir=self.config.audio_dir,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
        )
        return self.data_module
    
    def setup_callbacks(self) -> list:
        """Setup training callbacks for checkpointing."""
        callbacks = [
            # Save checkpoint every N steps
            ModelCheckpoint(
                dirpath=self.config.checkpoint_dir,
                filename="{epoch}-{step}-{val_loss:.4f}",
                save_top_k=self.config.keep_last_n_checkpoints,
                monitor="val_loss",
                mode="min",
                every_n_train_steps=self.config.checkpoint_every_n_steps,
                save_last=True,
            ),
            # Save checkpoint every N epochs
            ModelCheckpoint(
                dirpath=self.config.checkpoint_dir,
                filename="epoch-{epoch}",
                save_top_k=-1,  # Keep all epoch checkpoints
                every_n_epochs=self.config.checkpoint_every_n_epochs,
            ),
            # Monitor learning rate
            LearningRateMonitor(logging_interval="step"),
            # Early stopping (optional)
            EarlyStopping(
                monitor="val_loss",
                patience=50,
                mode="min",
                verbose=True,
            ),
        ]
        return callbacks
    
    def setup_trainer(self, resume_checkpoint: Optional[Path] = None) -> L.Trainer:
        """Setup PyTorch Lightning trainer."""
        logger = TensorBoardLogger(
            save_dir=self.config.output_dir,
            name="piper_finetune",
            version="indian_english"
        )
        
        self.trainer = L.Trainer(
            accelerator=self.config.accelerator,
            devices=self.config.devices,
            precision=self.config.precision,
            max_epochs=self.config.max_epochs,
            callbacks=self.setup_callbacks(),
            logger=logger,
            log_every_n_steps=10,
            val_check_interval=0.25,  # Validate 4 times per epoch
            gradient_clip_val=1.0,
            enable_progress_bar=True,
        )
        
        return self.trainer
    
    def load_model(self, checkpoint_path: Optional[Path] = None) -> VitsModel:
        """Load model from checkpoint or create new."""
        if checkpoint_path and checkpoint_path.exists():
            print(f"Loading model from checkpoint: {checkpoint_path}")
            self.model = VitsModel.load_from_checkpoint(
                checkpoint_path,
                map_location="cpu",
                learning_rate=self.config.learning_rate,
            )
        else:
            print("Creating new model...")
            self.model = VitsModel(
                batch_size=self.config.batch_size,
                sample_rate=self.config.sample_rate,
                learning_rate=self.config.learning_rate,
            )
        
        return self.model
    
    def train(self, resume_from: Optional[Path] = None):
        """
        Run training with optional resume from checkpoint.
        
        Args:
            resume_from: Path to checkpoint to resume from.
                        If None, starts fresh from pretrained checkpoint.
        """
        # Setup data
        self.setup_data_module()
        
        # Setup trainer
        self.setup_trainer()
        
        # Determine checkpoint to use
        ckpt_path = None
        if resume_from and resume_from.exists():
            print(f"Resuming training from: {resume_from}")
            ckpt_path = str(resume_from)
            self.model = self.load_model(resume_from)
        elif self.config.pretrained_checkpoint_path.exists():
            print(f"Fine-tuning from pretrained: {self.config.pretrained_checkpoint_path}")
            ckpt_path = str(self.config.pretrained_checkpoint_path)
            self.model = self.load_model(self.config.pretrained_checkpoint_path)
        else:
            print("Starting training from scratch...")
            self.model = self.load_model()
        
        # Run training
        print("Starting training...")
        self.trainer.fit(
            model=self.model,
            datamodule=self.data_module,
            ckpt_path=ckpt_path,
        )
        
        print("Training complete!")
        return self.trainer.checkpoint_callback.best_model_path

In [None]:
# Initialize trainer
piper_trainer = PiperTrainer(config)

# Option 1: Start fresh fine-tuning from pretrained checkpoint
# best_checkpoint = piper_trainer.train()

# Option 2: Resume from a specific checkpoint
# resume_ckpt = config.checkpoint_dir / "epoch=50-step=10000-val_loss=0.1234.ckpt"
# best_checkpoint = piper_trainer.train(resume_from=resume_ckpt)

# Option 3: Auto-resume from latest checkpoint
# latest_ckpt = checkpoint_manager.get_latest_checkpoint()
# best_checkpoint = piper_trainer.train(resume_from=latest_ckpt)

## 7. Resume Training from Checkpoint

Utility to easily resume training from an existing checkpoint.

In [None]:
def resume_training(config: Config, checkpoint_path: Optional[str] = None):
    """
    Resume training from a checkpoint.
    
    Args:
        config: Training configuration
        checkpoint_path: Explicit path to checkpoint, or None to auto-detect latest
    """
    checkpoint_manager = CheckpointManager(config)
    
    if checkpoint_path:
        resume_ckpt = Path(checkpoint_path)
    else:
        # Find latest checkpoint
        resume_ckpt = checkpoint_manager.get_latest_checkpoint()
    
    if resume_ckpt is None:
        print("No checkpoint found. Starting fresh training.")
        resume_ckpt = config.pretrained_checkpoint_path
    else:
        print(f"Resuming from: {resume_ckpt}")
    
    # Create trainer and resume
    trainer = PiperTrainer(config)
    best_checkpoint = trainer.train(resume_from=resume_ckpt)
    
    return best_checkpoint

# Example usage:
# best_ckpt = resume_training(config)
# Or with explicit checkpoint:
# best_ckpt = resume_training(config, "/path/to/checkpoint.ckpt")

## 8. ONNX Export for Web Deployment

Export the trained model to ONNX format for lightweight web deployment.

In [None]:
import torch
from typing import Optional

class ONNXExporter:
    """Exports Piper models to ONNX format."""
    
    OPSET_VERSION = 15
    
    def __init__(self, config: Config):
        self.config = config
    
    def export(self, checkpoint_path: Path, output_path: Optional[Path] = None) -> Path:
        """
        Export model checkpoint to ONNX format.
        
        Args:
            checkpoint_path: Path to the trained .ckpt file
            output_path: Optional output path for .onnx file
        
        Returns:
            Path to the exported ONNX model
        """
        if output_path is None:
            output_path = self.config.output_dir / f"{self.config.voice_name}.onnx"
        
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        print(f"Loading checkpoint: {checkpoint_path}")
        model = VitsModel.load_from_checkpoint(checkpoint_path, map_location="cpu")
        model_g = model.model_g
        
        # Set to inference mode
        model_g.eval()
        
        # Remove weight normalization for inference
        with torch.no_grad():
            model_g.dec.remove_weight_norm()
        
        # Create inference forward function
        def infer_forward(text, text_lengths, scales, sid=None):
            noise_scale = scales[0]
            length_scale = scales[1]
            noise_scale_w = scales[2]
            audio = model_g.infer(
                text,
                text_lengths,
                noise_scale=noise_scale,
                length_scale=length_scale,
                noise_scale_w=noise_scale_w,
                sid=sid,
            )[0].unsqueeze(1)
            return audio
        
        model_g.forward = infer_forward
        
        # Prepare dummy inputs
        num_symbols = model_g.n_vocab
        num_speakers = model_g.n_speakers
        
        dummy_input_length = 50
        sequences = torch.randint(
            low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
        )
        sequence_lengths = torch.LongTensor([sequences.size(1)])
        
        sid = None
        if num_speakers > 1:
            sid = torch.LongTensor([0])
        
        scales = torch.FloatTensor([0.667, 1.0, 0.8])  # noise, length, noise_w
        dummy_input = (sequences, sequence_lengths, scales, sid)
        
        # Export to ONNX
        print(f"Exporting to ONNX: {output_path}")
        torch.onnx.export(
            model=model_g,
            args=dummy_input,
            f=str(output_path),
            verbose=False,
            opset_version=self.OPSET_VERSION,
            input_names=["input", "input_lengths", "scales", "sid"],
            output_names=["output"],
            dynamic_axes={
                "input": {0: "batch_size", 1: "phonemes"},
                "input_lengths": {0: "batch_size"},
                "output": {0: "batch_size", 2: "time"},
            },
        )
        
        print(f"Successfully exported model to: {output_path}")
        
        # Also copy the config file
        config_output = output_path.with_suffix(".onnx.json")
        if self.config.config_path.exists():
            import shutil
            shutil.copy(self.config.config_path, config_output)
            print(f"Copied config to: {config_output}")
        
        return output_path
    
    def verify_onnx(self, onnx_path: Path) -> bool:
        """Verify the exported ONNX model."""
        try:
            import onnx
            import onnxruntime as ort
            
            # Load and check model
            model = onnx.load(str(onnx_path))
            onnx.checker.check_model(model)
            
            # Test inference
            session = ort.InferenceSession(str(onnx_path))
            print(f"ONNX model verified successfully!")
            print(f"Input names: {[i.name for i in session.get_inputs()]}")
            print(f"Output names: {[o.name for o in session.get_outputs()]}")
            
            return True
        except Exception as e:
            print(f"ONNX verification failed: {e}")
            return False

In [None]:
# Export the best checkpoint to ONNX
# exporter = ONNXExporter(config)

# Use the best checkpoint from training
# best_ckpt = Path(best_checkpoint)  # From training step
# Or specify a checkpoint path
# best_ckpt = config.checkpoint_dir / "epoch=100-step=20000-val_loss=0.0500.ckpt"

# onnx_path = exporter.export(best_ckpt)
# exporter.verify_onnx(onnx_path)

## 9. Upload Results to S3 (Optional)

Upload the trained model and checkpoints to S3 for backup.

In [None]:
def upload_results_to_s3(config: Config, s3_loader: S3DatasetLoader):
    """Upload trained model and checkpoints to S3."""
    
    # Upload ONNX model
    onnx_path = config.output_dir / f"{config.voice_name}.onnx"
    if onnx_path.exists():
        s3_key = f"{config.s3_checkpoint_prefix}models/{onnx_path.name}"
        s3_loader.upload_checkpoint(onnx_path, s3_key)
        print(f"Uploaded ONNX model to s3://{config.s3_bucket}/{s3_key}")
    
    # Upload config
    config_path = config.output_dir / f"{config.voice_name}.onnx.json"
    if config_path.exists():
        s3_key = f"{config.s3_checkpoint_prefix}models/{config_path.name}"
        s3_loader.upload_checkpoint(config_path, s3_key)
        print(f"Uploaded config to s3://{config.s3_bucket}/{s3_key}")
    
    # Upload latest checkpoint
    checkpoint_manager = CheckpointManager(config)
    latest_ckpt = checkpoint_manager.get_latest_checkpoint()
    if latest_ckpt:
        s3_key = f"{config.s3_checkpoint_prefix}checkpoints/{latest_ckpt.name}"
        s3_loader.upload_checkpoint(latest_ckpt, s3_key)
        print(f"Uploaded checkpoint to s3://{config.s3_bucket}/{s3_key}")

# Example usage:
# upload_results_to_s3(config, s3_loader)

## 10. Test Inference

Test the exported ONNX model with sample text.

In [None]:
import numpy as np
from IPython.display import Audio

def test_onnx_inference(onnx_path: Path, text: str, config: Config):
    """Test inference with the exported ONNX model."""
    import onnxruntime as ort
    import json
    
    # Add piper to path for phonemization
    sys.path.insert(0, '/tmp/piper1-gpl/src')
    from piper.phonemize_espeak import EspeakPhonemizer
    from piper.phoneme_ids import phonemes_to_ids
    
    # Load config
    config_path = onnx_path.with_suffix(".onnx.json")
    with open(config_path, 'r') as f:
        model_config = json.load(f)
    
    # Phonemize text
    phonemizer = EspeakPhonemizer()
    phonemes = phonemizer.phonemize(config.espeak_voice, text)
    
    # Convert to IDs
    phoneme_ids = []
    for sentence_phonemes in phonemes:
        phoneme_ids.extend(phonemes_to_ids(sentence_phonemes))
    
    # Prepare inputs
    input_ids = np.array([phoneme_ids], dtype=np.int64)
    input_lengths = np.array([len(phoneme_ids)], dtype=np.int64)
    scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)  # noise, length, noise_w
    
    # Run inference
    session = ort.InferenceSession(str(onnx_path))
    
    inputs = {
        "input": input_ids,
        "input_lengths": input_lengths,
        "scales": scales,
    }
    
    # Add speaker ID if multi-speaker
    if model_config.get("num_speakers", 1) > 1:
        inputs["sid"] = np.array([0], dtype=np.int64)
    
    output = session.run(None, inputs)[0]
    
    # Convert to audio
    audio = output.squeeze()
    
    return audio, config.sample_rate

# Example usage:
# onnx_path = config.output_dir / f"{config.voice_name}.onnx"
# audio, sr = test_onnx_inference(onnx_path, "Hello, this is a test of Indian English voice.", config)
# Audio(audio, rate=sr)

## Summary

This notebook provides a complete pipeline for:

1. **Environment Setup**: Compatible with both Google Colab and AWS SageMaker
2. **Configuration**: Modular configuration with placeholders for S3 bucket, paths, etc.
3. **Dataset ETL**: Download SPICOR dataset from S3
4. **Data Preprocessing**: Convert to Piper-compatible format
5. **Checkpoint Management**: Download pretrained checkpoint from HuggingFace
6. **Fine-Tuning**: Train with frequent checkpointing (every N steps/epochs)
7. **Resume Training**: Easily resume from any checkpoint
8. **ONNX Export**: Export to lightweight format for web deployment
9. **S3 Backup**: Upload results to S3 for persistence
10. **Inference Testing**: Verify the exported model

### Next Steps

1. Update the configuration placeholders with your actual S3 bucket and paths
2. Upload your SPICOR dataset to S3
3. Run the cells in order to train and export your model
4. Use the exported `.onnx` and `.onnx.json` files for web deployment