# Piper TTS Fine-Tuning: Indian English Voice

This notebook provides a complete pipeline for fine-tuning the Piper TTS model.

**Supports two modes:**
- **COLAB mode (`COLAB = True`)**: Uses Google Drive for data storage (Proof of Concept)
- **Production mode (`COLAB = False`)**: Uses AWS S3 for data storage (Production-ready)

**Compatible with:** Google Colab, AWS SageMaker

## Overview
1. Environment Configuration (COLAB/AWS mode selection)
2. Environment Setup & GPU Check
3. Install Software Dependencies
4. Data ETL (Extract, Transform, Load)
5. Training Configuration & Execution
6. Model Export (ONNX)
7. Save & Test Model
8. Troubleshooting (Optional)

---

**Repository:** https://github.com/Vinit-source/piper1-gpl

# üîß **1. Environment Configuration** üîß

Set the `COLAB` constant to select between:
- `COLAB = True`: Google Drive mode (Proof of Concept)
- `COLAB = False`: AWS S3 mode (Production)

In [None]:
# =============================================================================
# ENVIRONMENT MODE SELECTION
# =============================================================================
# Set COLAB = True for Google Drive mode (Proof of Concept)
# Set COLAB = False for AWS S3 mode (Production)

COLAB: bool = True  # Toggle between Colab (Google Drive) and AWS (S3) mode

<IPython.core.display.Javascript object>

In [None]:
# =============================================================================
# CONFIGURATION DATACLASS
# =============================================================================
# Centralized configuration for all pipeline parameters.
# All required fields must be set - no fallback values are used.

from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import os
import sys

@dataclass
class PiperConfig:
    """
    Configuration for Piper TTS fine-tuning pipeline.
    
    Attributes:
        colab_mode: Whether running in Colab (Google Drive) or AWS (S3) mode
        
    Google Drive Configuration (COLAB = True):
        gdrive_dataset_path: Path to dataset folder in Google Drive
        gdrive_output_path: Path to save outputs in Google Drive
        
    AWS S3 Configuration (COLAB = False):
        s3_bucket: S3 bucket name
        s3_dataset_prefix: S3 prefix for dataset
        s3_checkpoint_prefix: S3 prefix for checkpoints
        aws_region: AWS region
        aws_access_key_id: AWS access key (optional, uses IAM role if not set)
        aws_secret_access_key: AWS secret key (optional, uses IAM role if not set)
        
    Model Configuration:
        model_name: Name for the output model
        espeak_voice: eSpeak voice for phonemization
        sample_rate: Audio sample rate in Hz
        
    Training Configuration:
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        validation_split: Fraction of data for validation
        num_test_examples: Number of test examples for audio generation
        learning_rate: Learning rate for fine-tuning
        precision: Training precision (e.g., "16-mixed", "32")
        checkpoint_epochs: Save checkpoint every N epochs
        device: Training device ("cpu" or "gpu")
        use_pretrained: Whether to use pretrained checkpoint
        resume_training: Whether to resume from existing checkpoint
    """
    
    # Environment mode
    colab_mode: bool = True
    
    # ==================== GOOGLE DRIVE CONFIGURATION ====================
    # Required when COLAB = True
    gdrive_dataset_path: str = "<GDRIVE_DATASET_PATH>"  # e.g., "/content/drive/MyDrive/Piper-POC-Training/"
    gdrive_output_path: str = "<GDRIVE_OUTPUT_PATH>"    # e.g., "/content/drive/MyDrive/Piper-POC-Training/output"
    
    # ==================== AWS S3 CONFIGURATION ====================
    # Required when COLAB = False
    s3_bucket: str = "<YOUR_S3_BUCKET_NAME>"           # e.g., "my-tts-training-bucket"
    s3_dataset_prefix: str = "<S3_DATASET_PATH>"       # e.g., "datasets/spicor"
    s3_checkpoint_prefix: str = "<S3_CHECKPOINT_PATH>" # e.g., "checkpoints/piper"
    aws_region: str = "<AWS_REGION>"                   # e.g., "us-east-1"
    aws_access_key_id: Optional[str] = None            # Leave None to use IAM role
    aws_secret_access_key: Optional[str] = None        # Leave None to use IAM role
    
    # ==================== MODEL CONFIGURATION ====================
    model_name: str = "<MODEL_NAME>"                   # e.g., "en_IN-spicor-medium"
    espeak_voice: str = "en-us"                        # eSpeak voice for phonemization
    sample_rate: int = 22050                           # Audio sample rate
    num_speakers: int = 1                              # Number of speakers (1 for single speaker)
    
    # ==================== TRAINING CONFIGURATION ====================
    batch_size: int = 8                                # Reduce if out of memory
    max_epochs: int = 4000                             # Maximum training epochs
    validation_split: float = 0.0                      # Validation split (0.0 to disable)
    num_test_examples: int = 0                         # Test examples for audio generation
    learning_rate: float = 1e-4                        # Learning rate
    precision: str = "16-mixed"                        # Training precision
    checkpoint_epochs: int = 200                       # Save checkpoint every N epochs
    device: str = "gpu"                                # "cpu" or "gpu"
    use_pretrained: bool = True                        # Use pretrained checkpoint
    resume_training: bool = False                      # Resume from existing checkpoint
    
    # ==================== LOCAL PATHS (Auto-configured) ====================
    base_dir: str = field(default="")
    local_dataset_dir: str = field(default="")
    local_wavs_dir: str = field(default="")
    local_cache_dir: str = field(default="")
    local_output_dir: str = field(default="")
    piper_dir: str = field(default="")
    
    # ==================== HUGGING FACE CHECKPOINT ====================
    hf_checkpoint_repo: str = "rhasspy/piper-checkpoints"
    hf_checkpoint_path: str = "en/en_US/ljspeech/high/ljspeech-2000.ckpt"
    hf_config_path: str = "en/en_US/ljspeech/high/config.json"
    
    def __post_init__(self):
        """Initialize derived paths based on environment mode."""
        if self.colab_mode:
            self.base_dir = "/content"
            self.piper_dir = "/content/piper1-gpl"
        else:
            self.base_dir = "./piper_training"
            self.piper_dir = "./piper1-gpl"
        
        self.local_dataset_dir = f"{self.base_dir}/dataset"
        self.local_wavs_dir = f"{self.local_dataset_dir}/wavs"
        self.local_cache_dir = f"{self.base_dir}/audio_cache"
        self.local_output_dir = f"{self.base_dir}/output/{self.model_name}"
    
    def validate(self) -> None:
        """
        Validate configuration and raise errors for missing required fields.
        No fallback values - all placeholders must be replaced.
        """
        errors = []
        
        # Validate model name
        if self.model_name == "<MODEL_NAME>" or not self.model_name:
            errors.append("model_name: Must be set (e.g., 'en_IN-spicor-medium')")
        
        if self.colab_mode:
            # Validate Google Drive configuration
            if self.gdrive_dataset_path == "<GDRIVE_DATASET_PATH>" or not self.gdrive_dataset_path:
                errors.append("gdrive_dataset_path: Must be set for COLAB mode")
            if self.gdrive_output_path == "<GDRIVE_OUTPUT_PATH>" or not self.gdrive_output_path:
                errors.append("gdrive_output_path: Must be set for COLAB mode")
        else:
            # Validate AWS S3 configuration
            if self.s3_bucket == "<YOUR_S3_BUCKET_NAME>" or not self.s3_bucket:
                errors.append("s3_bucket: Must be set for AWS mode")
            if self.s3_dataset_prefix == "<S3_DATASET_PATH>" or not self.s3_dataset_prefix:
                errors.append("s3_dataset_prefix: Must be set for AWS mode")
            if self.s3_checkpoint_prefix == "<S3_CHECKPOINT_PATH>" or not self.s3_checkpoint_prefix:
                errors.append("s3_checkpoint_prefix: Must be set for AWS mode")
            if self.aws_region == "<AWS_REGION>" or not self.aws_region:
                errors.append("aws_region: Must be set for AWS mode")
        
        # Validate training parameters
        if self.batch_size <= 0:
            errors.append("batch_size: Must be positive")
        if self.max_epochs <= 0:
            errors.append("max_epochs: Must be positive")
        if self.sample_rate <= 0:
            errors.append("sample_rate: Must be positive")
        if self.device not in ("cpu", "gpu"):
            errors.append("device: Must be 'cpu' or 'gpu'")
        
        if errors:
            error_msg = "Configuration validation failed:\n" + "\n".join(f"  - {e}" for e in errors)
            raise ValueError(error_msg)
    
    def create_directories(self) -> None:
        """Create all required local directories."""
        dirs = [
            self.local_dataset_dir,
            self.local_wavs_dir,
            self.local_cache_dir,
            self.local_output_dir,
        ]
        for d in dirs:
            Path(d).mkdir(parents=True, exist_ok=True)
            print(f"Created directory: {d}")
        
        # Create Google Drive output directory if in COLAB mode
        if self.colab_mode:
            Path(self.gdrive_output_path).mkdir(parents=True, exist_ok=True)
            print(f"Created Google Drive output directory: {self.gdrive_output_path}")


def detect_runtime_environment() -> str:
    """
    Detect the current runtime environment.
    
    Returns:
        str: 'colab', 'sagemaker', or 'local'
    """
    if 'google.colab' in sys.modules:
        return 'colab'
    elif os.environ.get('SM_CURRENT_HOST'):
        return 'sagemaker'
    return 'local'


# Display detected environment
RUNTIME_ENV = detect_runtime_environment()
print(f"Detected runtime environment: {RUNTIME_ENV}")
print(f"Mode: {'Google Colab (Google Drive)' if COLAB else 'Production (AWS S3)'}")

In [None]:
# =============================================================================
# INITIALIZE CONFIGURATION
# =============================================================================
# Update the configuration values below before running the pipeline.
# All placeholder values (e.g., "<MODEL_NAME>") must be replaced.

config = PiperConfig(
    colab_mode=COLAB,
    
    # ----- Google Drive Configuration (for COLAB = True) -----
    gdrive_dataset_path="/content/drive/MyDrive/Piper-POC-Training/",
    gdrive_output_path="/content/drive/MyDrive/Piper-POC-Training/output",
    
    # ----- AWS S3 Configuration (for COLAB = False) -----
    s3_bucket="<YOUR_S3_BUCKET_NAME>",
    s3_dataset_prefix="<S3_DATASET_PATH>",
    s3_checkpoint_prefix="<S3_CHECKPOINT_PATH>",
    aws_region="<AWS_REGION>",
    aws_access_key_id=None,  # Set to None to use IAM role
    aws_secret_access_key=None,
    
    # ----- Model Configuration -----
    model_name="en_IN-spicor-medium",
    espeak_voice="en-us",
    sample_rate=22050,
    num_speakers=1,  # 1 for single speaker
    
    # ----- Training Configuration -----
    batch_size=8,
    max_epochs=4000,
    validation_split=0.0,
    num_test_examples=0,
    learning_rate=1e-4,
    precision="16-mixed",
    checkpoint_epochs=200,
    device="gpu",
    use_pretrained=True,
    resume_training=False,
)

# Validate configuration - raises error if any required field is missing
config.validate()

# Create local directories
config.create_directories()

# Display configuration summary
print("\n" + "=" * 60)
print("CONFIGURATION SUMMARY")
print("=" * 60)
print(f"Mode: {'COLAB (Google Drive)' if config.colab_mode else 'AWS (S3)'}")
print(f"Model name: {config.model_name}")
print(f"Sample rate: {config.sample_rate}")
print(f"Batch size: {config.batch_size}")
print(f"Max epochs: {config.max_epochs}")
print(f"Device: {config.device}")
print(f"Use pretrained: {config.use_pretrained}")
print(f"Local output: {config.local_output_dir}")
print("=" * 60)

In [None]:
# üñ•Ô∏è **2. Environment Setup** üñ•Ô∏è

Set up the runtime environment including GPU check and Google Drive mount (if in COLAB mode).

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
# =============================================================================
# GPU CHECK
# =============================================================================
# Check available GPU. A higher capable GPU leads to faster training speeds.
# Default Colab GPU is Tesla T4.

print("Checking GPU availability...")
!nvidia-smi

import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"\n‚úÖ GPU available: {gpu_name} ({gpu_memory:.1f} GB)")
else:
    print("\n‚ö†Ô∏è No GPU detected. Training will be slow on CPU.")
    if config.device == "gpu":
        raise RuntimeError("Configuration specifies 'gpu' but no GPU is available. Set config.device='cpu' or use a GPU-enabled runtime.")

Mounted at /content/drive


In [None]:
# =============================================================================
# MOUNT GOOGLE DRIVE (COLAB MODE ONLY)
# =============================================================================
# Mount Google Drive to access dataset and save outputs.
# This cell only runs in COLAB mode.

if COLAB:
    print("Mounting Google Drive...")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    print("‚úÖ Google Drive mounted successfully.")
    
    # Verify dataset path exists
    if not os.path.exists(config.gdrive_dataset_path):
        raise FileNotFoundError(
            f"Dataset path not found in Google Drive: {config.gdrive_dataset_path}\n"
            "Please verify the path exists and contains your dataset."
        )
    print(f"‚úÖ Dataset path verified: {config.gdrive_dataset_path}")
else:
    print("Skipping Google Drive mount (AWS mode).")

In [None]:
# =============================================================================
# COLAB ANTI-DISCONNECT (COLAB MODE ONLY)
# =============================================================================
# Prevents automatic disconnection in Google Colab.
# Note: Colab will still disconnect after 6-12 hours regardless.

if COLAB:
    import IPython
    js_code = '''
    function ClickConnect(){
        console.log("Anti-disconnect: clicking connect button");
        document.querySelector("colab-toolbar-button#connect").click()
    }
    setInterval(ClickConnect, 60000)
    '''
    display(IPython.display.Javascript(js_code))
    print("‚úÖ Anti-disconnect script activated.")
else:
    print("Skipping anti-disconnect (not running in Colab).")

# üì¶ **3. Install Software Dependencies** üì¶

Install Piper TTS and all required dependencies for training.

In [None]:
# =============================================================================
# INSTALL SYSTEM DEPENDENCIES
# =============================================================================
# Install required system packages for audio processing and building native extensions.

print("Installing system dependencies...")
!apt-get update -qq
!apt-get install -y -qq espeak-ng build-essential cmake ninja-build libespeak-ng1 libespeak-ng-dev
print("‚úÖ System dependencies installed.")

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Selecting previously unselected package libpcaudio0:amd64.
(Reading database ... 121713 files and directories currently installed.)
Preparing to unpack .../0-libpcaudio0_1.1-6build2_amd64.deb ...
Unpacking libpcaudio0:amd64 (1.1-6build2) ...
Selecting previously unselected package libsonic0:amd64.
Preparing to unpack .../1-libsonic0_0.2.0-11build1_amd64.deb ...
Unpacking libsonic0:amd64 (0.2.0-11build1) ...
Selecting previously unselected package espeak-ng-data:amd64.
Preparing to unpack .../2-espeak-ng-data_1.50+dfsg-10ubuntu0.1_amd64.deb ...
Unpacking espeak-ng-data:amd64 (1.50+dfsg-10ubuntu0.1) ...
Selecting previously unselected package libespeak-ng1:amd64.
Preparing to unpack .../3-libespeak-ng1_1.50+dfsg-10ubuntu0.1_amd64.deb ...
Unpacking libespeak-ng1:amd64 (1.50+dfsg-10ubuntu0.1) ...
Selecti

In [None]:
# =============================================================================
# CLONE AND INSTALL PIPER TTS
# =============================================================================
# Clone the piper1-gpl repository and install with training dependencies.

import os

PIPER_REPO_URL = "https://github.com/Vinit-source/piper1-gpl.git"
PIPER_BRANCH = "main"
PIPER_COMMIT = "fee9b9cefae4ebf9e196cfe994dea418f051506c"  # Stable release commit

# Clone repository if not exists
if not os.path.exists(config.piper_dir):
    print(f"Cloning Piper repository from {PIPER_REPO_URL}...")
    !git clone -b {PIPER_BRANCH} {PIPER_REPO_URL} {config.piper_dir}
else:
    print(f"Repository already exists at {config.piper_dir}")

# Change to piper directory
%cd {config.piper_dir}

# Checkout specific commit for reproducibility
print(f"Checking out commit {PIPER_COMMIT}...")
!git checkout -b release0.3.1 {PIPER_COMMIT} 2>/dev/null || git checkout release0.3.1

# Uninstall previous installation to ensure clean rebuild
print("Removing any previous Piper TTS installation...")
!pip uninstall -y piper-tts 2>/dev/null || true

# Install Piper with training dependencies
print("Installing Piper TTS with training dependencies...")
!pip install -e .[train]

# Build monotonic alignment module (required for VITS training)
print("Building monotonic alignment module...")
!bash build_monotonic_align.sh

print("\n‚úÖ Piper TTS installed successfully!")

/content/piper1-gpl
Collecting scikit-build
  Using cached scikit_build-0.18.1-py3-none-any.whl.metadata (18 kB)
Using cached scikit_build-0.18.1-py3-none-any.whl (85 kB)
Installing collected packages: scikit-build
Successfully installed scikit-build-0.18.1


--------------------------------------------------------------------------------
-- Trying 'Ninja' generator
--------------------------------
---------------------------
----------------------
-----------------
------------
-------
--
Not searching for unused variables given on the command line.
  Compatibility with CMake < 3.10 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value.  Or, use the <min>...<max> syntax
  to tell CMake that the project requires at least <min> but has been updated
  to work with policies introduced by <max> or earlier.

[0m
-- The C compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working 

In [None]:
# =============================================================================
# BUILD NATIVE EXTENSIONS
# =============================================================================
# Build the eSpeak bridge and other native extensions.

%cd {config.piper_dir}

print("Installing scikit-build...")
!pip install scikit-build

print("Building native extensions...")
!python3 setup.py build_ext --inplace

print("Installing additional dependencies...")
!pip install onnxscript

print("\n‚úÖ Native extensions built successfully!")

Collecting onnxscript
  Downloading onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.12-py3-none-any.whl.metadata (3.2 kB)
Downloading onnxscript-0.5.6-py3-none-any.whl (683 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m683.0/683.0 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx_ir-0.1.12-py3-none-any.whl (129 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m129.3/129.3 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx_ir, onnxscript
Successfully installed onnx_ir-0.1.12 onnxscript-0.5.6


# üì• **4. Data ETL (Extract, Transform, Load)** üì•

Download and process the dataset. Supports both Google Drive (COLAB mode) and AWS S3 (Production mode).

In [None]:
# =============================================================================
# DATA LOADER CLASSES
# =============================================================================
# Abstract data loading with support for both Google Drive and S3.

import os
import wave
import zipfile
import datetime
import shutil
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Tuple, List, Optional

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DataLoaderBase(ABC):
    """Abstract base class for data loading operations."""
    
    def __init__(self, config: PiperConfig):
        self.config = config
    
    @abstractmethod
    def download_dataset(self) -> int:
        """Download dataset to local directory. Returns number of files downloaded."""
        pass
    
    @abstractmethod
    def upload_checkpoint(self, local_path: str, remote_key: str) -> bool:
        """Upload checkpoint to remote storage. Returns success status."""
        pass


class GoogleDriveDataLoader(DataLoaderBase):
    """Data loader for Google Drive (COLAB mode)."""
    
    def __init__(self, config: PiperConfig):
        super().__init__(config)
        if not config.colab_mode:
            raise ValueError("GoogleDriveDataLoader requires COLAB mode (config.colab_mode=True)")
    
    def download_dataset(self) -> int:
        """
        Copy dataset from Google Drive to local directory.
        
        Expected folder structure in Google Drive:
            your_dataset_folder/
            ‚îú‚îÄ‚îÄ wavs/           # Audio files (WAV format, 22050Hz recommended)
            ‚îÇ   ‚îú‚îÄ‚îÄ 1.wav
            ‚îÇ   ‚îú‚îÄ‚îÄ 2.wav
            ‚îÇ   ‚îî‚îÄ‚îÄ ...
            ‚îî‚îÄ‚îÄ metadata.csv    # Transcript file (format: wavs/filename.wav|text)
        
        Returns:
            int: Number of audio files copied
        """
        gdrive_path = self.config.gdrive_dataset_path.strip()
        
        # Validate source path exists
        if not os.path.exists(gdrive_path):
            raise FileNotFoundError(f"Dataset folder not found in Google Drive: {gdrive_path}")
        
        # Check for wavs folder or zip file
        gdrive_wavs_path = os.path.join(gdrive_path, "wavs")
        gdrive_wavs_zip = os.path.join(gdrive_path, "wavs.zip")
        
        if os.path.exists(gdrive_wavs_zip):
            logger.info("Found wavs.zip, extracting...")
            with zipfile.ZipFile(gdrive_wavs_zip, 'r') as zip_ref:
                zip_ref.extractall(self.config.local_dataset_dir)
            
            # Handle nested wavs folder if present
            nested_wavs = f"{self.config.local_dataset_dir}/wavs/wavs"
            if os.path.exists(nested_wavs):
                for f in os.listdir(nested_wavs):
                    shutil.move(f"{nested_wavs}/{f}", f"{self.config.local_wavs_dir}/{f}")
                    
        elif os.path.exists(gdrive_wavs_path):
            logger.info("Found wavs folder, copying...")
            # Use shell copy for better performance with large datasets
            os.system(f'cp -r "{gdrive_wavs_path}"/* "{self.config.local_wavs_dir}/"')
        else:
            raise FileNotFoundError(
                f"No 'wavs' folder or 'wavs.zip' found in {gdrive_path}\n"
                "Expected structure:\n"
                "  your_dataset_folder/\n"
                "  ‚îú‚îÄ‚îÄ wavs/\n"
                "  ‚îÇ   ‚îú‚îÄ‚îÄ 1.wav\n"
                "  ‚îÇ   ‚îî‚îÄ‚îÄ ...\n"
                "  ‚îî‚îÄ‚îÄ metadata.csv"
            )
        
        # Clean up macOS ghost files
        self._cleanup_ghost_files(self.config.local_wavs_dir)
        
        # Count files
        audio_count = len([f for f in os.listdir(self.config.local_wavs_dir) if f.endswith('.wav')])
        logger.info(f"Copied {audio_count} audio files from Google Drive")
        
        return audio_count
    
    def upload_checkpoint(self, local_path: str, remote_key: str) -> bool:
        """
        Copy checkpoint to Google Drive.
        
        Args:
            local_path: Path to local checkpoint file
            remote_key: Relative path within gdrive_output_path
            
        Returns:
            bool: Success status
        """
        try:
            dest_path = os.path.join(self.config.gdrive_output_path, remote_key)
            Path(dest_path).parent.mkdir(parents=True, exist_ok=True)
            shutil.copy(local_path, dest_path)
            logger.info(f"Uploaded checkpoint to Google Drive: {dest_path}")
            return True
        except Exception as e:
            logger.error(f"Failed to upload checkpoint: {e}")
            return False
    
    def _cleanup_ghost_files(self, directory: str) -> None:
        """Delete macOS ghost files (._*) from directory."""
        count = 0
        for filename in os.listdir(directory):
            if filename.startswith("._"):
                file_path = os.path.join(directory, filename)
                os.remove(file_path)
                count += 1
        if count > 0:
            logger.info(f"Cleaned up {count} macOS artifact files.")


class S3DataLoader(DataLoaderBase):
    """Data loader for AWS S3 (Production mode)."""
    
    def __init__(self, config: PiperConfig):
        super().__init__(config)
        if config.colab_mode:
            raise ValueError("S3DataLoader requires AWS mode (config.colab_mode=False)")
        
        self.s3_client = self._create_s3_client()
    
    def _create_s3_client(self):
        """Create S3 client with optional credentials."""
        try:
            import boto3
            from botocore.exceptions import ClientError
        except ImportError:
            raise ImportError("boto3 is required for S3 operations. Install with: pip install boto3")
        
        kwargs = {'region_name': self.config.aws_region}
        if self.config.aws_access_key_id and self.config.aws_secret_access_key:
            kwargs['aws_access_key_id'] = self.config.aws_access_key_id
            kwargs['aws_secret_access_key'] = self.config.aws_secret_access_key
        
        return boto3.client('s3', **kwargs)
    
    def _list_objects(self, prefix: str) -> List[str]:
        """List all objects under a prefix."""
        objects = []
        paginator = self.s3_client.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=self.config.s3_bucket, Prefix=prefix):
            if 'Contents' in page:
                objects.extend([obj['Key'] for obj in page['Contents']])
        return objects
    
    def _download_file(self, s3_key: str, local_path: str) -> bool:
        """Download a single file from S3."""
        from botocore.exceptions import ClientError
        try:
            Path(local_path).parent.mkdir(parents=True, exist_ok=True)
            self.s3_client.download_file(self.config.s3_bucket, s3_key, local_path)
            return True
        except ClientError as e:
            logger.error(f"Failed to download {s3_key}: {e}")
            return False
    
    def download_dataset(self) -> int:
        """
        Download dataset from S3 to local directory.
        
        Returns:
            int: Number of files downloaded
        """
        from tqdm import tqdm
        
        prefix = self.config.s3_dataset_prefix
        objects = self._list_objects(prefix)
        
        if not objects:
            raise FileNotFoundError(
                f"No objects found in s3://{self.config.s3_bucket}/{prefix}\n"
                "Please verify the S3 bucket and prefix are correct."
            )
        
        downloaded = 0
        for s3_key in tqdm(objects, desc="Downloading dataset from S3"):
            relative_path = s3_key[len(prefix):].lstrip('/')
            local_path = os.path.join(self.config.local_dataset_dir, relative_path)
            if self._download_file(s3_key, local_path):
                downloaded += 1
        
        logger.info(f"Downloaded {downloaded} files from S3")
        return downloaded
    
    def upload_checkpoint(self, local_path: str, remote_key: str) -> bool:
        """
        Upload checkpoint to S3.
        
        Args:
            local_path: Path to local checkpoint file
            remote_key: S3 key (relative to s3_checkpoint_prefix)
            
        Returns:
            bool: Success status
        """
        from botocore.exceptions import ClientError
        try:
            s3_key = f"{self.config.s3_checkpoint_prefix}/{remote_key}"
            self.s3_client.upload_file(local_path, self.config.s3_bucket, s3_key)
            logger.info(f"Uploaded checkpoint to s3://{self.config.s3_bucket}/{s3_key}")
            return True
        except ClientError as e:
            logger.error(f"Failed to upload {local_path}: {e}")
            return False


def get_data_loader(config: PiperConfig) -> DataLoaderBase:
    """
    Factory function to get the appropriate data loader based on configuration.
    
    Args:
        config: PiperConfig instance
        
    Returns:
        DataLoaderBase: Appropriate data loader instance
    """
    if config.colab_mode:
        return GoogleDriveDataLoader(config)
    else:
        return S3DataLoader(config)


# Initialize data loader
data_loader = get_data_loader(config)
print(f"‚úÖ Data loader initialized: {type(data_loader).__name__}")

Dataset path: /content/drive/MyDrive/Piper-POC-Training/
Output path: /content/drive/MyDrive/Piper-POC-Training/output
Local dataset dir: /content/dataset
Local output dir: /content/output/en_IN-spicor-medium


In [None]:
# =============================================================================
# DOWNLOAD DATASET
# =============================================================================
# Download/copy dataset from remote storage to local directory.

print("Downloading dataset...")
audio_count = data_loader.download_dataset()
print(f"\n‚úÖ Dataset loaded: {audio_count} audio files")

In [None]:
# =============================================================================
# DATASET STATISTICS UTILITY
# =============================================================================
# Calculate and display dataset statistics.

import wave
import datetime

def get_dataset_duration(wav_path: str) -> Tuple[int, str]:
    """
    Calculate total duration of all WAV files in a directory.
    
    Args:
        wav_path: Path to directory containing WAV files
        
    Returns:
        Tuple[int, str]: (count of files, formatted duration string)
    """
    total_duration = 0.0
    total_count = 0
    
    if not os.path.exists(wav_path):
        raise FileNotFoundError(f"WAV directory not found: {wav_path}")
    
    for file_name in os.listdir(wav_path):
        if not file_name.endswith(".wav"):
            continue
        
        full_path = os.path.join(wav_path, file_name)
        try:
            with wave.open(full_path, "rb") as wave_file:
                frames = wave_file.getnframes()
                rate = wave_file.getframerate()
                duration = frames / float(rate)
                total_duration += duration
                total_count += 1
        except Exception as e:
            logger.warning(f"Skipping bad file {file_name}: {e}")
            continue
    
    duration_str = str(datetime.timedelta(seconds=round(total_duration, 0)))
    return total_count, duration_str


# Calculate and display dataset statistics
audio_count, dataset_duration = get_dataset_duration(config.local_wavs_dir)
print(f"\nüìä Dataset Statistics:")
print(f"   Audio files: {audio_count}")
print(f"   Total duration: {dataset_duration}")

Found wavs.zip, extracting...

‚úÖ Dataset loaded: 10 audio files, total duration: 0:01:24


In [None]:
# =============================================================================
# TEXT NORMALIZATION AND TRANSCRIPT PROCESSING
# =============================================================================
# Process transcript file with text normalization protocols.

import re
import pandas as pd


class TranscriptProcessor:
    """
    Process transcript files with text normalization.
    
    Applies normalization protocols:
    1. Orthographic expansion of non-standard words (currency, percentages, symbols)
    2. Acronym handling (spacing out consecutive capitals)
    3. Punctuation and prosodic boundaries (ensure terminal punctuation)
    """
    
    # Supported transcript file names
    TRANSCRIPT_FILES = ["metadata.csv", "transcripts.txt", "transcript.txt", "metadata.txt"]
    
    @staticmethod
    def clean_line(text: str) -> str:
        """
        Apply text normalization protocols to a single line of text.
        
        Args:
            text: Input text line
            
        Returns:
            str: Normalized text
        """
        # --- Protocol 1: Orthographic Expansion ---
        # Currency: $50 -> 50 dollars
        text = re.sub(r'\$(\d+(?:\.\d+)?)', r'\1 dollars', text)
        
        # Percentages: 50% -> 50 percent
        text = re.sub(r'(\d+)%', r'\1 percent', text)
        
        # Ampersand: & -> and
        text = text.replace('&', ' and ')
        
        # Plus sign: + -> plus
        text = text.replace('+', ' plus ')
        
        # --- Protocol 2: Acronym Handling ---
        # Space out consecutive capitals: "IAS" -> "I A S"
        def space_acronym(match):
            return " ".join(match.group(1))
        text = re.sub(r'\b([A-Z]{2,})\b', space_acronym, text)
        
        # --- Protocol 3: Punctuation and Prosodic Boundaries ---
        text = text.strip()
        if text and text[-1] not in ['.', '!', '?']:
            text += '.'
        
        return text
    
    @classmethod
    def process_file(cls, input_filename: str, output_filename: str) -> int:
        """
        Process transcript file: read, normalize each line, and write to output.
        
        Args:
            input_filename: Path to input transcript file
            output_filename: Path to output processed file
            
        Returns:
            int: Number of lines processed
            
        Raises:
            FileNotFoundError: If input file doesn't exist
            ValueError: If file contains no valid entries
        """
        if not os.path.exists(input_filename):
            raise FileNotFoundError(f"Transcript file not found: {input_filename}")
        
        lines_processed = 0
        
        with open(input_filename, 'r', encoding='utf-8') as infile, \
             open(output_filename, 'w', encoding='utf-8') as outfile:
            
            for line in infile:
                if '|' in line:
                    parts = line.strip().split('|', 1)
                    if len(parts) == 2:
                        file_id, original_text = parts
                        cleaned_text = cls.clean_line(original_text)
                        outfile.write(f"{file_id}|{cleaned_text}\n")
                        lines_processed += 1
                    else:
                        logger.warning(f"Skipping malformed line (missing text after '|'): {line.strip()}")
                else:
                    logger.warning(f"Skipping non-ID|Text line: {line.strip()}")
        
        if lines_processed == 0:
            raise ValueError(
                f"No valid transcript entries found in {input_filename}.\n"
                "Expected format: filename.wav|transcript text"
            )
        
        logger.info(f"Processed {lines_processed} lines for text normalization.")
        return lines_processed
    
    @classmethod
    def find_and_process_transcript(cls, wavs_dir: str, output_dir: str) -> str:
        """
        Find transcript file in wavs directory and process it.
        
        Args:
            wavs_dir: Directory containing wavs and transcript file
            output_dir: Directory to write processed metadata.csv
            
        Returns:
            str: Path to processed metadata.csv
            
        Raises:
            FileNotFoundError: If no transcript file is found
        """
        source_path = None
        
        for tf in cls.TRANSCRIPT_FILES:
            current_attempt = os.path.join(wavs_dir, tf)
            if os.path.exists(current_attempt):
                logger.info(f"Found transcript file: {tf}")
                source_path = current_attempt
                break
        
        if source_path is None:
            raise FileNotFoundError(
                f"No transcript file found in {wavs_dir}.\n"
                f"Expected one of: {cls.TRANSCRIPT_FILES}"
            )
        
        output_path = os.path.join(output_dir, "metadata.csv")
        logger.info(f"Processing transcript: {source_path} -> {output_path}")
        cls.process_file(source_path, output_path)
        
        return output_path


# Process transcript file
metadata_csv_path = TranscriptProcessor.find_and_process_transcript(
    config.local_wavs_dir,
    config.local_dataset_dir
)

# Display processed metadata
df_metadata = pd.read_csv(metadata_csv_path, sep='|', header=None, names=['filename', 'text'])
print(f"\n‚úÖ Processed transcript: {len(df_metadata)} entries")
print("\nüìã Sample entries:")
display(df_metadata.head())

Found transcript file: transcripts.txt in local wavs directory.
Applying text normalization to transcripts.txt and saving to /content/dataset/metadata.csv
Processed 10 lines for text normalization.


Unnamed: 0,filename,text
0,1,Vegetables are exported to all the areas like ...
1,2,They are also used in autism therapy which hel...
2,3,Agriculturist Machaiah from Pulikotu village s...
3,4,Yo Yo Honey Singh cheering up a singer in Delh...
4,5,"No F O Bs in Cotton Green railway station, pas..."


# ü§ñ **5. Training** ü§ñ

Configure and run the Piper TTS fine-tuning process.

In [None]:
# =============================================================================
# CHECKPOINT MANAGER
# =============================================================================
# Manage pretrained checkpoints from Hugging Face.

from huggingface_hub import hf_hub_download


class CheckpointManager:
    """Manage model checkpoints for training and fine-tuning."""
    
    def __init__(self, config: PiperConfig):
        self.config = config
        self.pretrained_ckpt_path = os.path.join(config.base_dir, "pretrained.ckpt")
    
    def download_pretrained_checkpoint(self) -> str:
        """
        Download pretrained checkpoint from Hugging Face.
        
        Returns:
            str: Path to downloaded checkpoint
            
        Raises:
            RuntimeError: If download fails
        """
        if os.path.exists(self.pretrained_ckpt_path):
            logger.info(f"Pretrained checkpoint already exists: {self.pretrained_ckpt_path}")
            return self.pretrained_ckpt_path
        
        logger.info("Downloading pretrained checkpoint from Hugging Face...")
        
        try:
            downloaded_path = hf_hub_download(
                repo_id=self.config.hf_checkpoint_repo,
                filename=self.config.hf_checkpoint_path,
                repo_type="dataset",
                local_dir=os.path.join(self.config.base_dir, "checkpoints"),
            )
            
            # Also download the config for reference
            hf_hub_download(
                repo_id=self.config.hf_checkpoint_repo,
                filename=self.config.hf_config_path,
                repo_type="dataset",
                local_dir=os.path.join(self.config.base_dir, "checkpoints"),
            )
            
            # Copy to expected location
            shutil.copy(downloaded_path, self.pretrained_ckpt_path)
            
            logger.info(f"Pretrained checkpoint downloaded to: {self.pretrained_ckpt_path}")
            return self.pretrained_ckpt_path
            
        except Exception as e:
            raise RuntimeError(f"Failed to download pretrained checkpoint: {e}")
    
    def find_latest_checkpoint(self, checkpoint_dir: str) -> Optional[str]:
        """
        Find the latest checkpoint in a directory for resuming training.
        
        Args:
            checkpoint_dir: Directory to search for checkpoints
            
        Returns:
            Optional[str]: Path to latest checkpoint, or None if not found
        """
        import glob
        import re
        
        # Look for 'last.ckpt' first
        last_ckpt = os.path.join(checkpoint_dir, "last.ckpt")
        if os.path.exists(last_ckpt):
            return last_ckpt
        
        # Find most recent checkpoint by modification time
        checkpoints = glob.glob(f"{checkpoint_dir}/**/*.ckpt", recursive=True)
        if checkpoints:
            latest = max(checkpoints, key=os.path.getmtime)
            return latest
        
        return None
    
    def upgrade_checkpoint_for_cpu(self, ckpt_path: str) -> None:
        """
        Upgrade checkpoint for CPU compatibility (PyTorch Lightning upgrade).
        
        Args:
            ckpt_path: Path to checkpoint file
        """
        import torch
        import pathlib
        from argparse import Namespace
        from lightning.pytorch.utilities.upgrade_checkpoint import _upgrade
        
        logger.info(f"Upgrading checkpoint for compatibility: {ckpt_path}")
        
        with torch.serialization.safe_globals([pathlib.PosixPath]):
            args = Namespace(path=str(ckpt_path), extension=".ckpt", map_to_cpu=True)
            _upgrade(args)
        
        logger.info("Checkpoint upgrade complete.")


# Initialize checkpoint manager
ckpt_manager = CheckpointManager(config)
print("‚úÖ Checkpoint manager initialized.")

Language: en-us
Sample rate: 22050
Batch size: 8
Max epochs: 4000
Validation split: 0.0
Use pretrained: True
Resume training: False


In [None]:
# =============================================================================
# DOWNLOAD PRETRAINED CHECKPOINT
# =============================================================================
# Download pretrained checkpoint from Hugging Face for fine-tuning.

if config.use_pretrained and not config.resume_training:
    pretrained_ckpt_path = ckpt_manager.download_pretrained_checkpoint()
    
    # Upgrade checkpoint for compatibility
    ckpt_manager.upgrade_checkpoint_for_cpu(pretrained_ckpt_path)
    
    print(f"\n‚úÖ Pretrained checkpoint ready: {pretrained_ckpt_path}")
    
elif config.resume_training:
    print("Resume training mode - will look for existing checkpoint in output folder.")
else:
    print("Training from scratch (no pretrained checkpoint).")

Downloading pretrained checkpoint from Hugging Face...


en/en_US/ljspeech/high/ljspeech-2000.ckp(‚Ä¶):   0%|          | 0.00/998M [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]


‚úÖ Pretrained checkpoint downloaded to: /content/pretrained.ckpt


In [None]:
# =============================================================================
# LAUNCH TENSORBOARD
# =============================================================================
# TensorBoard allows monitoring training progress in real-time.

%load_ext tensorboard
%tensorboard --logdir {config.local_output_dir}

print("‚úÖ TensorBoard launched. Monitor training progress above.")

INFO: Creating a backup of the existing checkpoint files before overwriting in the upgrade process.
INFO:lightning.pytorch.utilities.upgrade_checkpoint:Creating a backup of the existing checkpoint files before overwriting in the upgrade process.


Attempting to upgrade: /content/pretrained.ckpt


INFO: Upgrading checkpoints ...
INFO:lightning.pytorch.utilities.upgrade_checkpoint:Upgrading checkpoints ...
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:06<00:00,  6.77s/it]
INFO: Done.
INFO:lightning.pytorch.utilities.upgrade_checkpoint:Done.


Upgrade process complete.


In [None]:
# =============================================================================
# START TRAINING
# =============================================================================
# Run the Piper TTS training process.

import glob
import re

# Change to piper directory
%cd {config.piper_dir}

# Prepare paths
csv_path = metadata_csv_path
config_path = f"{config.local_output_dir}/{config.model_name}.json"

# Determine checkpoint path argument
if config.resume_training:
    # Look for existing checkpoint
    checkpoints = glob.glob(f"{config.local_output_dir}/lightning_logs/**/checkpoints/last.ckpt", recursive=True)
    if checkpoints:
        # Sort by version number to get the latest
        def get_version(path):
            match = re.findall(r'version_(\d+)', path)
            return int(match[0]) if match else 0
        latest_ckpt = sorted(checkpoints, key=get_version)[-1]
        print(f"Resuming from checkpoint: {latest_ckpt}")
        ckpt_path_arg = f'--ckpt_path "{latest_ckpt}"'
    else:
        raise FileNotFoundError(
            f"resume_training=True but no checkpoint found in {config.local_output_dir}/lightning_logs/\n"
            "Set resume_training=False to start fresh or ensure checkpoints exist."
        )
elif config.use_pretrained:
    ckpt_url = f"https://huggingface.co/datasets/{config.hf_checkpoint_repo}/blob/main/{config.hf_checkpoint_path}"
    print(f"Fine-tuning from pretrained checkpoint: {ckpt_url}")
    ckpt_path_arg = f'--ckpt_path "{ckpt_url}"'
else:
    ckpt_path_arg = ""
    print("Training from scratch (no checkpoint).")

# Build training command
train_cmd = f"""
python -m piper.train fit \\
    --data.csv_path "{csv_path}" \\
    --data.cache_dir "{config.local_cache_dir}" \\
    --data.audio_dir "{config.local_wavs_dir}" \\
    --data.espeak_voice "{config.espeak_voice}" \\
    --data.config_path "{config_path}" \\
    --data.voice_name "{config.model_name}" \\
    --data.batch_size {config.batch_size} \\
    --data.validation_split {config.validation_split} \\
    --data.num_test_examples {config.num_test_examples} \\
    --model.sample_rate {config.sample_rate} \\
    --model.num_speakers {config.num_speakers} \\
    --trainer.max_epochs {config.max_epochs} \\
    --trainer.accelerator {config.device} \\
    --trainer.devices 1 \\
    --trainer.precision 32 \\
    --trainer.default_root_dir "{config.local_output_dir}" \\
    --trainer.callbacks+=ModelCheckpoint \\
    --trainer.callbacks.dirpath "{config.local_output_dir}/checkpoints" \\
    --trainer.callbacks.filename "piper-{{epoch:04d}}-{{step:08d}}" \\
    --trainer.callbacks.save_top_k 3 \\
    --trainer.callbacks.monitor "val_loss" \\
    --trainer.callbacks.save_last true \\
    --trainer.callbacks.every_n_epochs {config.checkpoint_epochs} \\
    --model.learning_rate {config.learning_rate} \\
    {ckpt_path_arg}
"""

print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Training command:\n{train_cmd}")
print("=" * 60)

!{train_cmd}

# üíæ **6. Save Training Outputs** üíæ

Save trained model checkpoints and configuration to remote storage (Google Drive or S3).

In [None]:
# =============================================================================
# SAVE TRAINING OUTPUTS TO REMOTE STORAGE
# =============================================================================
# Copy trained model, checkpoints, and logs to Google Drive or S3.

import os
import shutil


def save_training_outputs(config: PiperConfig, data_loader: DataLoaderBase) -> None:
    """
    Save all training outputs to remote storage.
    
    Args:
        config: PiperConfig instance
        data_loader: DataLoader instance for upload operations
    """
    if config.colab_mode:
        # Create model-specific output directory in Google Drive
        gdrive_model_dir = os.path.join(config.gdrive_output_path, config.model_name)
        os.makedirs(gdrive_model_dir, exist_ok=True)
        
        files_to_copy = []
        
        # Copy last checkpoint
        last_ckpt = f"{config.local_output_dir}/checkpoints/last.ckpt"
        if os.path.exists(last_ckpt):
            files_to_copy.append((last_ckpt, f"{gdrive_model_dir}/last.ckpt"))
        else:
            logger.warning(f"Last checkpoint not found: {last_ckpt}")
        
        # Copy config file
        config_file = f"{config.local_output_dir}/{config.model_name}.json"
        if os.path.exists(config_file):
            files_to_copy.append((config_file, f"{gdrive_model_dir}/{config.model_name}.json"))
        else:
            logger.warning(f"Config file not found: {config_file}")
        
        # Copy lightning logs directory
        lightning_logs = f"{config.local_output_dir}/lightning_logs"
        if os.path.exists(lightning_logs):
            shutil.copytree(lightning_logs, f"{gdrive_model_dir}/lightning_logs", dirs_exist_ok=True)
            logger.info(f"Copied lightning_logs to {gdrive_model_dir}")
        
        # Copy individual files
        for src, dst in files_to_copy:
            if os.path.exists(src):
                logger.info(f"Copying {src} -> {dst}")
                shutil.copy(src, dst)
        
        # Display saved files
        print(f"\n‚úÖ Model saved to Google Drive: {gdrive_model_dir}")
        print("\nFiles saved:")
        for f in os.listdir(gdrive_model_dir):
            fpath = os.path.join(gdrive_model_dir, f)
            if os.path.isfile(fpath):
                size_mb = os.path.getsize(fpath) / (1024 * 1024)
                print(f"  - {f} ({size_mb:.2f} MB)")
            else:
                print(f"  - {f}/ (directory)")
    
    else:
        # Upload to S3
        last_ckpt = f"{config.local_output_dir}/checkpoints/last.ckpt"
        if os.path.exists(last_ckpt):
            data_loader.upload_checkpoint(last_ckpt, f"{config.model_name}/last.ckpt")
        
        config_file = f"{config.local_output_dir}/{config.model_name}.json"
        if os.path.exists(config_file):
            data_loader.upload_checkpoint(config_file, f"{config.model_name}/{config.model_name}.json")
        
        print(f"\n‚úÖ Model uploaded to S3: s3://{config.s3_bucket}/{config.s3_checkpoint_prefix}/{config.model_name}/")


# Save training outputs
save_training_outputs(config, data_loader)

Copied lightning_logs to /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium
Copying /content/output/en_IN-spicor-medium/checkpoints/last.ckpt -> /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium/last.ckpt
Copying /content/output/en_IN-spicor-medium/en_IN-spicor-medium.json -> /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium/en_IN-spicor-medium.json

‚úÖ Model saved to Google Drive: /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium

Files saved:
  - lightning_logs (0.00 MB)
  - last.ckpt (806.70 MB)
  - en_IN-spicor-medium.json (0.00 MB)


In [None]:
# üéß **7. Test Model from Checkpoint** üéß

Test the trained model by loading from checkpoint and generating speech.

Only 'last.ckpt' found. Using: /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium_en-US-epochs-3000/last.ckpt
Loading model from /content/pretrained.ckpt...
Phonemizing using 'en-us' (Matching training data)...
Generated Phonemes (First 10): ['√∞', '…™', 's', ' ', '…™', 'z', ' ', '…ê', ' ', 't']
ID Sequence (First 20): [1, 41, 1, 74, 1, 31, 1, 3, 1, 74, 1, 38, 1, 3, 1, 50, 1, 3, 1, 32]
Generating audio...


In [None]:
# =============================================================================
# TEST MODEL FROM CHECKPOINT
# =============================================================================
# Load model from checkpoint and generate test audio.

import json
import torch
import pathlib
import numpy as np
from scipy.io.wavfile import write
from IPython.display import Audio, display
import sys

# Add piper src to path
if f"{config.piper_dir}/src" not in sys.path:
    sys.path.append(f"{config.piper_dir}/src")

from piper.train.vits.lightning import VitsModel
from piper.phonemize_espeak import EspeakPhonemizer


class ModelTester:
    """Test trained Piper TTS model from checkpoint."""
    
    def __init__(self, config: PiperConfig):
        self.config = config
        self.model = None
        self.model_config = None
        self.phonemizer = None
    
    def load_model(self, checkpoint_path: str, config_path: str) -> None:
        """
        Load model from checkpoint.
        
        Args:
            checkpoint_path: Path to .ckpt file
            config_path: Path to model config JSON
        """
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")
        
        # Load config
        with open(config_path, 'r') as f:
            self.model_config = json.load(f)
        
        # Load model
        logger.info(f"Loading model from {checkpoint_path}...")
        with torch.serialization.safe_globals([pathlib.PosixPath]):
            self.model = VitsModel.load_from_checkpoint(checkpoint_path, map_location='cpu')
        
        self.model.eval()
        with torch.no_grad():
            self.model.model_g.dec.remove_weight_norm()
        
        # Initialize phonemizer
        self.phonemizer = EspeakPhonemizer()
        
        logger.info("Model loaded successfully.")
    
    def synthesize(
        self,
        text: str,
        noise_scale: float = 0.667,
        length_scale: float = 1.0,
        noise_scale_w: float = 0.8,
    ) -> np.ndarray:
        """
        Synthesize speech from text.
        
        Args:
            text: Text to synthesize
            noise_scale: Noise scale for synthesis
            length_scale: Length scale (1.0 = normal speed)
            noise_scale_w: Noise scale for duration predictor
            
        Returns:
            np.ndarray: Audio data
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
        
        # Phonemize text
        phoneme_lists = self.phonemizer.phonemize(self.config.espeak_voice, text)
        phonemes = []
        for sentence in phoneme_lists:
            phonemes.extend(sentence)
        
        # Map phonemes to IDs with interspersing
        id_map = self.model_config["phoneme_id_map"]
        pad_id = id_map.get("^", [0])[0]
        
        phoneme_ids = [pad_id]
        missing_phonemes = []
        
        for p in phonemes:
            if p in id_map:
                phoneme_ids.extend(id_map[p])
                phoneme_ids.append(pad_id)
            else:
                missing_phonemes.append(p)
        
        if missing_phonemes:
            logger.warning(f"Missing phonemes in config: {missing_phonemes}")
        
        # Convert to tensors
        sequence = torch.tensor(phoneme_ids, dtype=torch.long).unsqueeze(0)
        sequence_lengths = torch.tensor([len(phoneme_ids)], dtype=torch.long)
        
        # Handle speaker ID
        sid = None
        if self.model_config.get("num_speakers", 1) > 1:
            sid = torch.tensor([0], dtype=torch.long)
        
        # Generate audio
        with torch.no_grad():
            audio = self.model.model_g.infer(
                x=sequence,
                x_lengths=sequence_lengths,
                sid=sid,
                noise_scale=noise_scale,
                length_scale=length_scale,
                noise_scale_w=noise_scale_w
            )[0]
        
        # Process audio
        audio_data = audio.squeeze().cpu().numpy()
        audio_data = audio_data / np.max(np.abs(audio_data))
        
        return audio_data
    
    def synthesize_and_save(
        self,
        text: str,
        output_path: str,
        **kwargs
    ) -> str:
        """
        Synthesize speech and save to WAV file.
        
        Args:
            text: Text to synthesize
            output_path: Output WAV file path
            **kwargs: Additional arguments for synthesize()
            
        Returns:
            str: Path to saved WAV file
        """
        audio_data = self.synthesize(text, **kwargs)
        sample_rate = self.model_config["audio"]["sample_rate"]
        
        write(output_path, sample_rate, (audio_data * 32767).astype(np.int16))
        logger.info(f"Audio saved to {output_path}")
        
        return output_path


# Initialize tester
model_tester = ModelTester(config)
print("‚úÖ Model tester initialized.")

üîß Starting Final Library Patching Process...
  ‚úÖ [transforms.py] Assertion removed.
  ‚úÖ [modules.py] Inserted shape guard (size > 1).
  ‚úÖ [lightning.py] Added sample_bytes param.
  ‚úÖ [export_onnx.py] Added security globals.
üîß Patching Complete.

üìÇ Checkpoint: epoch=2868-step=1575188.ckpt
üìÑ Config copied.
üöÄ Exporting to /content/checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.onnx...
Applied 194 of general pattern rewrite rules.


‚úÖ‚úÖ‚úÖ EXPORT SUCCESSFUL! ‚úÖ‚úÖ‚úÖ
Saved to: /content/checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.onnx


In [None]:
# =============================================================================
# TEST SYNTHESIS FROM CHECKPOINT
# =============================================================================
# Generate test audio from trained model.

# Configuration for testing
TEST_TEXT = "Hello, this is a test of the fine-tuned Piper text to speech model. Is this sounding correct?"

# Determine checkpoint and config paths
if config.colab_mode:
    gdrive_model_dir = os.path.join(config.gdrive_output_path, config.model_name)
    test_checkpoint_path = f"{gdrive_model_dir}/last.ckpt"
    test_config_path = f"{gdrive_model_dir}/{config.model_name}.json"
else:
    test_checkpoint_path = f"{config.local_output_dir}/checkpoints/last.ckpt"
    test_config_path = f"{config.local_output_dir}/{config.model_name}.json"

# Load model and synthesize
try:
    model_tester.load_model(test_checkpoint_path, test_config_path)
    
    print(f"\nSynthesizing: \"{TEST_TEXT}\"")
    audio_data = model_tester.synthesize(TEST_TEXT)
    
    sample_rate = model_tester.model_config["audio"]["sample_rate"]
    print(f"Audio duration: {len(audio_data) / sample_rate:.2f} seconds")
    
    # Display audio player
    display(Audio(audio_data, rate=sample_rate))
    
    # Save to file
    output_wav = f"{config.local_output_dir}/test_output.wav"
    model_tester.synthesize_and_save(TEST_TEXT, output_wav)
    print(f"‚úÖ Test audio saved to: {output_wav}")
    
except FileNotFoundError as e:
    print(f"‚ö†Ô∏è Cannot test model: {e}")
    print("This is expected if training has not completed yet.")

In [None]:
# üì¶ **8. Export to ONNX** üì¶

Export the trained model to ONNX format for production inference.

Using specified pretrained checkpoint for ONNX export: /content/checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt
/content/piper1-gpl
Exporting ONNX from /content/checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt...


INFO: Lightning automatically upgraded your loaded checkpoint from v1.9.0 to v2.6.0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint ../checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt`
INFO:lightning.pytorch.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.0 to v2.6.0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint ../checkpoints/en/en_US/hfc_female/medium/epoch=2868-step=1575188.ckpt`
  WeightNorm.apply(module, name, dim)
  torch.onnx.export(
W1205 11:34:43.409000 165 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 15 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsucc


Error during export: Failed to export the model with torch.export. [96mThis is step 1/3[0m of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and submit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the [96m*torch.export*[0m component and attach the full error stack as well as reproduction scripts.

## Exception summary

<class 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode'>: Could not guard on data-dependent expression Eq(u2, 1) (unhinted: Eq(u2, 1)).  (Size-like symbols: none)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TOR

In [None]:
# =============================================================================
# ONNX EXPORTER
# =============================================================================
# Export trained model to ONNX format.

import subprocess
import sys


class ONNXExporter:
    """Export Piper model to ONNX format for production inference."""
    
    def __init__(self, config: PiperConfig):
        self.config = config
    
    def export(
        self,
        checkpoint_path: str,
        output_path: str,
    ) -> str:
        """
        Export model checkpoint to ONNX format.
        
        Args:
            checkpoint_path: Path to .ckpt checkpoint file
            output_path: Path for output .onnx file
            
        Returns:
            str: Path to exported ONNX file
            
        Raises:
            FileNotFoundError: If checkpoint doesn't exist
            RuntimeError: If export fails
        """
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        
        logger.info(f"Exporting ONNX from {checkpoint_path}...")
        logger.info(f"Output path: {output_path}")
        
        # Ensure output directory exists
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        
        # Build export command
        cmd = [
            sys.executable, "-m", "piper.train.export_onnx",
            "--checkpoint", checkpoint_path,
            "--output-file", output_path
        ]
        
        try:
            result = subprocess.run(
                cmd,
                cwd=self.config.piper_dir,
                check=True,
                capture_output=True,
                text=True
            )
            logger.info(result.stdout)
            logger.info(f"ONNX export successful: {output_path}")
            return output_path
            
        except subprocess.CalledProcessError as e:
            error_msg = f"ONNX export failed (exit code {e.returncode}):\n{e.stderr}"
            raise RuntimeError(error_msg)
    
    def copy_config(self, source_config: str, output_path: str) -> str:
        """
        Copy and rename config file to match ONNX model naming convention.
        
        Args:
            source_config: Path to source config JSON
            output_path: Path for output config (should end in .onnx.json)
            
        Returns:
            str: Path to copied config file
        """
        if not os.path.exists(source_config):
            raise FileNotFoundError(f"Source config not found: {source_config}")
        
        shutil.copy(source_config, output_path)
        logger.info(f"Config copied to: {output_path}")
        return output_path


# Initialize ONNX exporter
onnx_exporter = ONNXExporter(config)
print("‚úÖ ONNX exporter initialized.")

Copying /content/output/en_IN-spicor-medium/en_IN-spicor-medium.json -> /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium/en_IN-spicor-medium.onnx.json
Copying /content/output/en_IN-spicor-medium/checkpoints/last.ckpt -> /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium/last.ckpt

‚úÖ Model saved to Google Drive: /content/drive/MyDrive/Piper-POC-Training/output/en_IN-spicor-medium

Files saved:
  - lightning_logs (0.00 MB)
  - last.ckpt (806.70 MB)
  - en_IN-spicor-medium.json (0.00 MB)
  - en_IN-spicor-medium.onnx.json (0.00 MB)


# =============================================================================
# EXPORT MODEL TO ONNX
# =============================================================================
# Run the ONNX export process.

# Determine paths
if config.colab_mode:
    gdrive_model_dir = os.path.join(config.gdrive_output_path, config.model_name)
    export_checkpoint_path = f"{gdrive_model_dir}/last.ckpt"
    export_config_path = f"{gdrive_model_dir}/{config.model_name}.json"
    onnx_output_path = f"{gdrive_model_dir}/{config.model_name}.onnx"
    onnx_config_path = f"{gdrive_model_dir}/{config.model_name}.onnx.json"
else:
    export_checkpoint_path = f"{config.local_output_dir}/checkpoints/last.ckpt"
    export_config_path = f"{config.local_output_dir}/{config.model_name}.json"
    onnx_output_path = f"{config.local_output_dir}/{config.model_name}.onnx"
    onnx_config_path = f"{config.local_output_dir}/{config.model_name}.onnx.json"

try:
    # Export to ONNX
    onnx_exporter.export(export_checkpoint_path, onnx_output_path)
    
    # Copy config file
    onnx_exporter.copy_config(export_config_path, onnx_config_path)
    
    # Display results
    if os.path.exists(onnx_output_path):
        onnx_size_mb = os.path.getsize(onnx_output_path) / (1024 * 1024)
        print(f"\n‚úÖ ONNX export complete!")
        print(f"   Model: {onnx_output_path} ({onnx_size_mb:.2f} MB)")
        print(f"   Config: {onnx_config_path}")
    
except FileNotFoundError as e:
    print(f"‚ö†Ô∏è Cannot export ONNX: {e}")
    print("Ensure training has completed and checkpoint exists.")

In [None]:
# üéß **9. Test ONNX Model** üéß

Test the exported ONNX model by generating speech using ONNX Runtime.

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'node_index_put' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:640 onnxruntime::Broadcaster::Broadcaster(gsl::span<const long int>, gsl::span<const long int>) largest <= 1 was false. Can broadcast 0 by 0 or 1. 67 is invalid.


# =============================================================================
# TEST ONNX MODEL
# =============================================================================
# Test the exported ONNX model using ONNX Runtime.

import json
import numpy as np
from IPython.display import Audio, display

try:
    import onnxruntime as ort
except ImportError:
    print("Installing onnxruntime...")
    !pip install -q onnxruntime
    import onnxruntime as ort


class ONNXModelTester:
    """Test ONNX model inference."""
    
    def __init__(self, onnx_path: str, config_path: str):
        """
        Initialize ONNX model tester.
        
        Args:
            onnx_path: Path to .onnx model file
            config_path: Path to .onnx.json config file
        """
        if not os.path.exists(onnx_path):
            raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")
        
        # Load config
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        self.phoneme_id_map = self.config.get('phoneme_id_map', {})
        self.sample_rate = self.config.get('audio', {}).get('sample_rate', 22050)
        
        # Create ONNX session
        self.session = ort.InferenceSession(onnx_path)
        
        print(f"ONNX model loaded: {onnx_path}")
    
    def text_to_phoneme_ids(self, text: str) -> list:
        """
        Simple text to phoneme ID conversion for testing.
        
        Note: For production, use espeak-ng for proper phonemization.
        
        Args:
            text: Input text
            
        Returns:
            list: List of phoneme IDs
        """
        ids = []
        for char in text.lower():
            if char in self.phoneme_id_map:
                value = self.phoneme_id_map[char]
                if isinstance(value, list):
                    ids.extend(value)
                else:
                    ids.append(value)
            elif char == ' ' and ' ' in self.phoneme_id_map:
                value = self.phoneme_id_map[' ']
                if isinstance(value, list):
                    ids.extend(value)
                else:
                    ids.append(value)
        
        if not ids:
            raise ValueError("No valid phoneme IDs generated from text. Check phoneme_id_map.")
        
        return ids
    
    def synthesize(
        self,
        text: str,
        noise_scale: float = 0.667,
        length_scale: float = 1.0,
        noise_scale_w: float = 0.8,
    ) -> np.ndarray:
        """
        Synthesize speech from text using ONNX model.
        
        Args:
            text: Text to synthesize
            noise_scale: Noise scale
            length_scale: Length scale (1.0 = normal speed)
            noise_scale_w: Noise scale for duration predictor
            
        Returns:
            np.ndarray: Audio data
        """
        phoneme_ids = self.text_to_phoneme_ids(text)
        
        # Prepare inputs
        input_array = np.array([phoneme_ids], dtype=np.int64)
        input_lengths = np.array([len(phoneme_ids)], dtype=np.int64)
        scales = np.array([noise_scale, length_scale, noise_scale_w], dtype=np.float32)
        
        inputs = {
            'input': input_array,
            'input_lengths': input_lengths,
            'scales': scales,
        }
        
        # Run inference
        output = self.session.run(None, inputs)
        audio = output[0].squeeze()
        
        return audio


# Test ONNX model
TEST_TEXT_ONNX = "Hello, this is a test of the exported ONNX model."

try:
    onnx_tester = ONNXModelTester(onnx_output_path, onnx_config_path)
    
    print(f"\nSynthesizing with ONNX: \"{TEST_TEXT_ONNX}\"")
    audio_data = onnx_tester.synthesize(TEST_TEXT_ONNX)
    
    print(f"Generated audio shape: {audio_data.shape}")
    print(f"Audio duration: {len(audio_data) / onnx_tester.sample_rate:.2f} seconds")
    
    # Display audio player
    display(Audio(audio_data, rate=onnx_tester.sample_rate))
    
    print("\n‚úÖ ONNX model test successful!")
    
except FileNotFoundError as e:
    print(f"‚ö†Ô∏è Cannot test ONNX model: {e}")
    print("This is expected if ONNX export has not completed yet.")
except Exception as e:
    print(f"‚ö†Ô∏è ONNX test failed: {e}")

# üîß **10. Troubleshooting** üîß

This section contains utilities and code snippets for debugging common issues.

## 10.1 Checkpoint Upgrade Utility

Utility for upgrading checkpoints for CPU compatibility (PyTorch Lightning version changes).

In [None]:
# =============================================================================
# CHECKPOINT UPGRADE UTILITY (TROUBLESHOOTING)
# =============================================================================
# Use this if you encounter checkpoint compatibility issues.

# import torch
# import pathlib
# from pathlib import Path
# from argparse import Namespace
# from lightning.pytorch.utilities.upgrade_checkpoint import _upgrade

# def upgrade_checkpoint_cpu_safe(checkpoint_path: str) -> None:
#     """
#     Upgrade checkpoint for CPU compatibility.
#     
#     Args:
#         checkpoint_path: Path to checkpoint file
#     """
#     path = Path(checkpoint_path)
#     
#     if not path.exists():
#         raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
#     
#     # Load checkpoint with CPU mapping
#     checkpoint = torch.load(
#         str(path),
#         map_location=torch.device("cpu"),
#         weights_only=False  # Safe for Lightning checkpoints
#     )
#     
#     # Save upgraded CPU-compatible version
#     torch.save(checkpoint, str(path))
#     
#     # Run Lightning upgrade
#     args = Namespace(path=str(path), extension=".ckpt")
#     _upgrade(args)
#     
#     print(f"Upgraded and saved CPU-compatible checkpoint: {path}")

# # Usage:
# # upgrade_checkpoint_cpu_safe("/content/pretrained.ckpt")

## 10.2 Library Patching for ONNX Export

Patches for known issues when exporting to ONNX format.

In [None]:
# =============================================================================
# LIBRARY PATCHING FOR ONNX EXPORT (TROUBLESHOOTING)
# =============================================================================
# Use this if you encounter ONNX export issues.

# def patch_piper_library_for_onnx(base_dir: str) -> None:
#     """
#     Apply patches to Piper library files for ONNX export compatibility.
#     
#     Patches:
#     1. transforms.py: Remove assertion that fails during export
#     2. modules.py: Fix dynamic convolution shape issues
#     3. lightning.py: Add sample_bytes parameter for legacy checkpoints
#     4. export_onnx.py: Add PyTorch 2.6 security globals
#     
#     Args:
#         base_dir: Path to piper repository root
#     """
#     import os
#     
#     transforms_file = os.path.join(base_dir, "src/piper/train/vits/transforms.py")
#     modules_file = os.path.join(base_dir, "src/piper/train/vits/modules.py")
#     lightning_file = os.path.join(base_dir, "src/piper/train/vits/lightning.py")
#     export_file = os.path.join(base_dir, "src/piper/train/export_onnx.py")
#     
#     print("üîß Starting library patching for ONNX export...")
#     
#     # --- PATCH 1: transforms.py (Remove Assertion) ---
#     if os.path.exists(transforms_file):
#         with open(transforms_file, "r") as f:
#             lines = f.readlines()
#         
#         new_lines = []
#         patched = False
#         for line in lines:
#             if "assert (discriminant >= 0).all()" in line and not line.strip().startswith("#"):
#                 new_lines.append(f"# {line.strip()} # PATCHED\n")
#                 patched = True
#             else:
#                 new_lines.append(line)
#         
#         if patched:
#             with open(transforms_file, "w") as f:
#                 f.writelines(new_lines)
#             print("  ‚úÖ [transforms.py] Assertion disabled.")
#         else:
#             print("  ‚è≠Ô∏è [transforms.py] Already patched or not needed.")
#     
#     # --- PATCH 2: modules.py (Fix Dynamic Convolution Shape) ---
#     if os.path.exists(modules_file):
#         with open(modules_file, "r") as f:
#             lines = f.readlines()
#         
#         # Remove any old patches
#         clean_lines = [l for l in lines if "torch._check" not in l or "x0.shape" not in l]
#         
#         final_lines = []
#         patched = False
#         for line in clean_lines:
#             if "h = self.pre(x0) * x_mask" in line:
#                 indent = line.split("h = self.pre(x0)")[0]
#                 final_lines.append(f"{indent}torch._check(x0.size(2) > 1) # ONNX PATCH\n")
#                 final_lines.append(line)
#                 patched = True
#             else:
#                 final_lines.append(line)
#         
#         if patched:
#             with open(modules_file, "w") as f:
#                 f.writelines(final_lines)
#             print("  ‚úÖ [modules.py] Shape guard inserted.")
#     
#     # --- PATCH 3: lightning.py (Legacy Checkpoint) ---
#     if os.path.exists(lightning_file):
#         with open(lightning_file, "r") as f:
#             lines = f.readlines()
#         
#         if not any("sample_bytes" in line for line in lines):
#             new_lines = []
#             for line in lines:
#                 new_lines.append(line)
#                 if "sample_rate: int = 22050," in line:
#                     new_lines.append("        sample_bytes: int = 2, # PATCHED\n")
#             with open(lightning_file, "w") as f:
#                 f.writelines(new_lines)
#             print("  ‚úÖ [lightning.py] sample_bytes parameter added.")
#         else:
#             print("  ‚è≠Ô∏è [lightning.py] Already patched.")
#     
#     # --- PATCH 4: export_onnx.py (PyTorch 2.6 Security) ---
#     if os.path.exists(export_file):
#         with open(export_file, "r") as f:
#             content = f.read()
#         
#         patch_str = "torch.serialization.add_safe_globals([pathlib.PosixPath])"
#         if patch_str not in content:
#             new_content = content.replace(
#                 "import torch",
#                 "import torch\nimport pathlib\ntry:\n    torch.serialization.add_safe_globals([pathlib.PosixPath])\nexcept AttributeError:\n    pass"
#             )
#             with open(export_file, "w") as f:
#                 f.write(new_content)
#             print("  ‚úÖ [export_onnx.py] Security globals added.")
#         else:
#             print("  ‚è≠Ô∏è [export_onnx.py] Already patched.")
#     
#     print("üîß Patching complete.\n")

# # Usage:
# # patch_piper_library_for_onnx("/content/piper1-gpl")

## 10.3 Clear Dataset Directory

Utility to clear local dataset directory for fresh download.

In [None]:
# =============================================================================
# CLEAR DATASET DIRECTORY (TROUBLESHOOTING)
# =============================================================================
# Use this to clear local dataset directory for a fresh download.

# import shutil
# import os

# def clear_dataset_directory(config: PiperConfig) -> None:
#     """
#     Clear local dataset directory for fresh download.
#     
#     WARNING: This will delete all local copies of audio files!
#     
#     Args:
#         config: PiperConfig instance
#     """
#     if os.path.exists(config.local_wavs_dir):
#         print(f"Removing: {config.local_wavs_dir}")
#         shutil.rmtree(config.local_wavs_dir)
#         os.makedirs(config.local_wavs_dir, exist_ok=True)
#         print("‚úÖ Dataset directory cleared.")
#     else:
#         print("Dataset directory does not exist.")

# # Usage:
# # clear_dataset_directory(config)

# üìã **Summary**

This notebook provides a complete pipeline for fine-tuning Piper TTS with support for both 
Google Colab (POC) and AWS (Production) environments.

## Pipeline Overview

| Step | Description |
|------|-------------|
| 1. Configuration | Set `COLAB` mode and configure paths/parameters |
| 2. Environment Setup | GPU check, Google Drive mount (Colab) or S3 setup (AWS) |
| 3. Install Dependencies | System packages, Piper TTS, and native extensions |
| 4. Data ETL | Download/copy dataset, process transcripts |
| 5. Training | Fine-tune model with checkpoint saving |
| 6. Save Outputs | Copy checkpoints to remote storage |
| 7. Test Model | Test from checkpoint |
| 8. Export ONNX | Export to ONNX format |
| 9. Test ONNX | Test exported ONNX model |
| 10. Troubleshooting | Utilities for common issues |

## Output Files

After training, you'll find these files in your output folder:
- `{model_name}.onnx` - The trained model in ONNX format
- `{model_name}.onnx.json` - Model configuration file
- `last.ckpt` - Latest training checkpoint (for resuming training)
- `lightning_logs/` - Training logs and intermediate checkpoints

## Using the Trained Model

To use the trained model with Piper CLI:
```bash
echo "Hello world" | piper --model {model_name}.onnx --output_file output.wav
```

## Mode Comparison

| Feature | COLAB Mode | AWS Mode |
|---------|------------|----------|
| Data Storage | Google Drive | AWS S3 |
| Use Case | Proof of Concept | Production |
| Data Loader | GoogleDriveDataLoader | S3DataLoader |
| Credentials | Google Account | AWS IAM/Keys |

## Key Classes

- **PiperConfig**: Centralized configuration dataclass
- **DataLoaderBase**: Abstract base for data loading (GoogleDriveDataLoader, S3DataLoader)
- **TranscriptProcessor**: Text normalization and transcript processing
- **CheckpointManager**: Manage pretrained/training checkpoints
- **ModelTester**: Test model from checkpoint
- **ONNXExporter**: Export model to ONNX format
- **ONNXModelTester**: Test exported ONNX model