# Notebook Initialization

### Import necessary libraries for inference:
- XTTS model configs and classes for TTS generation
- PyTorch and torchaudio for tensor operations and audio I/O
- OS and regex utilities for file handling and processing

In [1]:
import os

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager

import torch
import sys
from datetime import datetime
import wandb
from trainer.logging.wandb_logger import WandbLogger

### Tests GPU availability
This ensures that the model will use a GPU (if available) to accelerate inference.

In [2]:
# Torch info
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
print(torch.version.cuda)           
print(torch.cuda.is_available())    
print(torch.cuda.get_device_name()) 

cuda:0
12.6
True
NVIDIA GeForce RTX 4070 Ti


# Downloads

This section ensures that all required XTTS model components are available locally. It checks for the presence of pretrained model files (for the DVAE, mel normalization, tokenizer, and the main XTTS checkpoint), and downloads them from Coqui's model hub if they are not already present. These files are used for model initialization and inference later. Model checkpoint files can also be found on [HuggingFace](https://huggingface.co/coqui/XTTS-v2).

In [9]:
# Set directory where model files will be stored
CHECKPOINT_PATH = '<PATH>'
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# DVAE files
DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"

# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(DVAE_LINK))
MEL_NORM_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(MEL_NORM_LINK))

# DVAE download if not exists
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
    print("Downloading DVAE files")
    ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINT_PATH, progress_bar=True)

# URLs for XTTS v2.0 tokenizer and checkpoint
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"

# Transfer learning parameters - Sets base model to use
TOKENIZER_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(TOKENIZER_FILE_LINK))  # vocab.json
XTTS_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK))  # model.pth

# XTTS v2.0 download if not exists
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
    print("Downloading XTTS v2.0 files")
    ModelManager._download_model_files(
        [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINT_PATH, progress_bar=True
    )

# Data Loading

This section prepares the dataset and configuration required for training and inference. It includes selecting a language, loading metadata and audio samples, specifying the reference speaker clip for voice cloning, and initializing model-specific arguments such as input length limits and file paths.

LJSpeech is a commonly used dataset format for text-to-speech systems, consisting of a metadata.csv file where each row links a transcript to an audio clip. The dataset follows the structure: audio_filename|transcript|normalized_transcript.

In [10]:
# Set language for dataset
LANGUAGE ='en'

# Folder containing 'metadata.csv' and 'wavs/' directory with audio clips
DATASET= "<DATASET NAME>"
training_dir = f"<DATASET PATH>/{DATASET}"

# Configure dataset format and paths
dataset_config = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv", # File containing transcription data
    language=LANGUAGE,
    path=training_dir
)

# Load training and evaluation samples (2% reserved for evaluation)
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_size=0.02,
)

This sets up the audio configuration and speaker reference data. The speaker reference is used to guide the model in cloning or adapting to the target voice. It is very important that this matches the intended speaker for the generated outputs.

In [12]:
# Audio config for model input/output and DVAE encoder
audio_config = XttsAudioConfig(
    sample_rate=22050, # Sample rate for internal processing
    dvae_sample_rate=22050, # Sample rate for DVAE encoder
    output_sample_rate=24000 # Final output audio sample rate
)

# Reference text samples used during evaluation
SPEAKER_TEXT = [
    "Hello, I am not a real person but I have a real voice.",
    "I love my new voice it sounds so good."
]

# Set reference audio clip path for speaker identity
SPEAKER_REFERENCE = "<PATH>"

This defines model-specific arguments such as audio/text length limits, pretrained checkpoints, and architectural settings like token usage and audio token encoding. These values should reflect the XTTS-v2 model constraints and performance recommendations.

In [13]:
# Modify model arguments
model_args = GPTArgs(
    max_conditioning_length=132300, # Maximum speaker reference length (~6 secs)
    min_conditioning_length=66150, # Minimum speaker reference length (~3 secs)
    debug_loading_failures=True, # Verbose debugging for audio/text loading fail
    max_wav_length=255995, # Maximum sample audio duration (~11.6 seconds)
    max_text_length=66150, # Maximum character length for text input
    mel_norm_file=MEL_NORM_FILE,
    dvae_checkpoint=DVAE_CHECKPOINT,
    xtts_checkpoint=XTTS_CHECKPOINT,  
    tokenizer_file=TOKENIZER_FILE,
    gpt_num_audio_tokens=1026, # Max audio tokens allowed
    gpt_start_audio_token=1024, # [START] token for audio in GPT
    gpt_stop_audio_token=1025, # [STOP] token for audio in GPT
    gpt_use_masking_gt_prompt_approach=True, # Enables ground-truth masking strategy
    gpt_use_perceiver_resampler=True, # Use Perceiver Resampler for conditioning
)

# Training Config

This section sets up all necessary training parameters for fine-tuning XTTS-v2. It specifies output paths, batch sizes, evaluation settings, logging preferences, optimizer configuration, and training hyperparameters. The model is trained using the GPT-based XTTS trainer and supports integration with Weights & Biases (wandb) for experiment tracking.

The training configuration is designed to handle smaller batch sizes using gradient accumulation (BATCH_SIZE * GRAD_ACCUM_STEPS = 252) to match Coqui’s recommendations. Additionally, test_sentences are provided to synthesize and log audio samples each epoch during training.

In [7]:
# Define output directory for saving checkpoints, logs, and training artifacts
OUT_PATH = '<PATH>'
if not os.path.exists(OUT_PATH):
    os.makedirs(OUT_PATH)

# Name of the run and project (used in logs and dashboard tracking)
RUN_NAME = '<RUN NAME>'
PROJECT_NAME = '<PROJECT NAME>'
DASHBOARD_LOGGER = 'wandb' # Use Weights & Biases for logging
LOGGER_URI = None

# Batch size and gradient accumulation to meet recommendations
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True
BATCH_SIZE = 3
GRAD_ACUMM_STEPS = 84 
START_WITH_EVAL = True # Begin training with an initial evaluation pass

This defines the full training configuration using GPTTrainerConfig. It includes all required runtime options such as the number of epochs, evaluation strategy, optimizer and learning rate schedule, checkpoint saving intervals, and test sentence setup. Adjustments can be made here to control training behavior or debug model behavior.

In [14]:
config = GPTTrainerConfig(
    run_eval=True, # Whether to start with evaluation step
    epochs = 40, # Total training epochs
    output_path=OUT_PATH,
    model_args=model_args,
    run_name=RUN_NAME,
    project_name=PROJECT_NAME,
    run_description="""
        GPT XTTS training
        """,
    dashboard_logger=DASHBOARD_LOGGER,
    wandb_entity=None,
    logger_uri=LOGGER_URI,
    audio=audio_config,
    batch_size=BATCH_SIZE,
    batch_group_size=48, # Number of batches grouped internally for optimization
    eval_batch_size=BATCH_SIZE,
    num_loader_workers=0,
    eval_split_max_size=256, # Maximum number of samples taken from validation dataset for evaluation
    print_step=50, # Number of steps between printing training statistics
    plot_step=100, # Number of steps between plotting loss and metric graphs
    log_model_step=1000, # Number of steps between logging model checkpoints for external tracking
    save_step=1000, # Number of steps between saving model checkpoints locally
    save_n_checkpoints=1, # Number of past checkpoints to keep (older deleted)
    save_checkpoints=True, # Whether to save model checkpoints at regular intervals
    print_eval=True, # Whether to print evaluation results during training
    optimizer="AdamW",
    optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
    optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
    lr=5e-06,  
    lr_scheduler="MultiStepLR",
    lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
    test_sentences=[ 
        {
            "text": SPEAKER_TEXT[0],
            "speaker_wav": SPEAKER_REFERENCE, 
            "language": LANGUAGE,
        },
        {
            "text": SPEAKER_TEXT[1],
            "speaker_wav": SPEAKER_REFERENCE,
            "language": LANGUAGE,
        },
    ],
) 

# Training

This section kicks off the training process. It begins by initializing the model using the configuration defined earlier, then wraps it in a Trainer class that handles batching, evaluation, and checkpointing. Training is started with .fit(), and it can manually interrupted (with Ctrl+C) to safely save progress.

In [None]:
# Initialize XTTS model
model = GPTTrainer.init_from_config(config)

# Set up trainer
trainer = Trainer(
    TrainerArgs(
        restore_path=None, # Change to model path if resuming
        skip_train_epoch=False, # Whether to skip training (eval/debug only)
        start_with_eval=START_WITH_EVAL,
        grad_accum_steps=GRAD_ACUMM_STEPS,
    ),
    config,
    output_path=OUT_PATH,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)

# Start model training loop; safe to interrupt manually to trigger checkpoint save
trainer.fit()