# Imports & Logging Init

In [1]:
import os
import codecs

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager

import torch
import sys
from datetime import datetime
import wandb
from trainer.logging.wandb_logger import WandbLogger

In [2]:
# Torch info
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
print(torch.version.cuda)           
print(torch.cuda.is_available())    
print(torch.cuda.get_device_name()) 

cuda:0
12.6
True
NVIDIA GeForce RTX 4070 Ti


# Downloads

In [3]:
# Get XTTS files
CHECKPOINT_PATH = './XTTS-files/'
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# DVAE files
DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"

# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(DVAE_LINK))
MEL_NORM_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(MEL_NORM_LINK))

# DVAE download if not exists
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
    print(" > Downloading DVAE files!")
    ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINT_PATH, progress_bar=True)

# XTTS v2.0 checkpoint
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"

# Transfer learning parameters
TOKENIZER_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(TOKENIZER_FILE_LINK))  # vocab.json
XTTS_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK))  # model.pth

# XTTS v2.0 download if not exists
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
    print(" > Downloading XTTS v2.0 files!")
    ModelManager._download_model_files(
        [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINT_PATH, progress_bar=True
    )

# Data Loading

Make sure your metadata follows LJSpeech format: \<file\>|\<transcription\>

In [4]:
# 1% to evaluate by listening to tests

# Set lang
LANGUAGE ='en'

# Set to folder name that contains metadata.csv and wavs dir (with the .wav examples)
DATASET= "Sherlock Holmes Stories  Read by Benedict Cumberbatch"
training_dir = f'./datasets/{DATASET}' # change to folder w/ training examples

dataset_config = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv", # metadata file w/ transcriptions
    language=LANGUAGE,
    path=training_dir
)

train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_size=0.02, # Might change
)

In [5]:
# Audio config
audio_config = XttsAudioConfig(sample_rate=16000, dvae_sample_rate=16000, output_sample_rate=24000)
# audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) # Docs

# Speaker Reference: Match these to the test sentences
SPEAKER_TEXT = [
    "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
    "This cake is great. It's so delicious and moist."
]

SPEAKER_REFERENCE = f"datasets/{DATASET}/wavs/chunk_0220.wav"


In [6]:
# Modify model arguments
model_args = GPTArgs(
    max_conditioning_length=143677, # Audio used for conditioning latents should be less than this
    # max_conditioning_length=132300 # Docs has 6 secs
    min_conditioning_length=66150, # and more than this
    debug_loading_failures=True,
    max_wav_length=255995, # Set >= longest audio in dataset
    # max_wav_length=255995, # Docs has ~11.6 seconds
    max_text_length=66150,
    mel_norm_file=MEL_NORM_FILE,
    dvae_checkpoint=DVAE_CHECKPOINT,
    xtts_checkpoint=XTTS_CHECKPOINT,  
    tokenizer_file=TOKENIZER_FILE,
    gpt_num_audio_tokens=1026, 
    gpt_start_audio_token=1024,
    gpt_stop_audio_token=1025,
    gpt_use_masking_gt_prompt_approach=True,
    gpt_use_perceiver_resampler=True,
)

# Training Config

In [7]:
OUT_PATH = './run/training/'
if not os.path.exists(OUT_PATH):
    os.makedirs(OUT_PATH)

RUN_NAME = 'Sherlock-Holmes-4-epochs'
# RUN_NAME = f"xttsv2_finetune_{datetime.now().strftime('%Y%m%d_%H%M')}"
PROJECT_NAME = 'XTTS-v2 Finetune'
DASHBOARD_LOGGER = 'wandb'
LOGGER_URI = None

OPTIMIZER_WD_ONLY_ON_WEIGHTS = True  
BATCH_SIZE = 3 # 4 is common
GRAD_ACUMM_STEPS = 84 # 252
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. 
# You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
START_WITH_EVAL = True  

In [8]:
config = GPTTrainerConfig(
    run_eval=True,
    epochs = 4, # assuming you want to end training manually w/ keyboard interrupt
    output_path=OUT_PATH,
    model_args=model_args,
    run_name=RUN_NAME,
    project_name=PROJECT_NAME,
    run_description="""
        GPT XTTS training
        """,
    dashboard_logger=DASHBOARD_LOGGER,
    wandb_entity=None,
    logger_uri=LOGGER_URI,
    audio=audio_config,
    batch_size=BATCH_SIZE,
    batch_group_size=48,
    eval_batch_size=BATCH_SIZE,
    num_loader_workers=0, # On Windows, num_loader_workers > 0 can break multiprocessing in PyTorch
    eval_split_max_size=256, 
    print_step=50, 
    plot_step=100, 
    log_model_step=1000, 
    save_step=1000,
    save_n_checkpoints=1,
    save_checkpoints=True,
    print_eval=True,
    optimizer="AdamW",
    optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
    optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
    lr=5e-06,  
    lr_scheduler="MultiStepLR",
    lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
    test_sentences=[ 
        {
            "text": SPEAKER_TEXT[0],
            "speaker_wav": SPEAKER_REFERENCE, 
            "language": LANGUAGE,
        },
        {
            "text": SPEAKER_TEXT[1],
            "speaker_wav": SPEAKER_REFERENCE,
            "language": LANGUAGE,
        },
    ],
) 

# Training

In [9]:
# Model init
model = GPTTrainer.init_from_config(config)

# Model training
trainer = Trainer(
    TrainerArgs(
        restore_path=None, # Change to model path if resuming
        skip_train_epoch=False,
        start_with_eval=START_WITH_EVAL,
        grad_accum_steps=GRAD_ACUMM_STEPS,
    ),
    config,
    output_path=OUT_PATH,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)

# Manual interupts will set model to save at given checkpoint
trainer.fit()

 > Training Environment:
 | > Backend: Torch
 | > Mixed precision: False
 | > Precision: float32
 | > Current device: 0
 | > Num. of GPUs: 1
 | > Num. of CPUs: 20
 | > Num. of Torch Threads: 1
 | > Torch seed: 1
 | > Torch CUDNN: True
 | > Torch CUDNN deterministic: False
 | > Torch CUDNN benchmark: False
 | > Torch TF32 MatMul: False
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: robcaamano (robcaamano-new-jersey-institute-of-technology) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin



 > Model has 518442047 parameters

[4m[1m > EPOCH: 0/4[0m
 --> run\training\Sherlock-Holmes-4-epochs-April-25-2025_03+32PM-0000000

[1m > EVALUATION [0m

[1m   --> STEP: 0[0m
     | > loss_text_ce: 0.02065003104507923  (0.02065003104507923)
     | > loss_mel_ce: 4.8800153732299805  (4.8800153732299805)
     | > loss: 4.900665283203125  (4.900665283203125)

[1m   --> STEP: 1[0m
     | > loss_text_ce: 0.02369523048400879  (0.02369523048400879)
     | > loss_mel_ce: 4.812253952026367  (4.812253952026367)
     | > loss: 4.835948944091797  (4.835948944091797)

[1m   --> STEP: 2[0m
     | > loss_text_ce: 0.021101634949445724  (0.022398432716727257)
     | > loss_mel_ce: 4.779322624206543  (4.795788288116455)
     | > loss: 4.800424098968506  (4.818186521530151)

[1m   --> STEP: 3[0m
     | > loss_text_ce: 0.021626941859722137  (0.02214126909772555)
     | > loss_mel_ce: 4.4567179679870605  (4.682764848073323)
     | > loss: 4.478344917297363  (4.704905986785889)

[1m   --> STE

0,1
EvalStats/avg_loader_time,█▂▁
EvalStats/avg_loss,█▃▁
EvalStats/avg_loss_mel_ce,█▃▁
EvalStats/avg_loss_text_ce,█▁▁
TrainEpochStats/avg_grad_norm,▁▁
TrainEpochStats/avg_loader_time,█▁
TrainEpochStats/avg_loss,█▁
TrainEpochStats/avg_loss_mel_ce,█▁
TrainEpochStats/avg_loss_text_ce,█▁
TrainEpochStats/avg_step_time,▁█

0,1
EvalStats/avg_loader_time,0.05187
EvalStats/avg_loss,4.1927
EvalStats/avg_loss_mel_ce,4.17117
EvalStats/avg_loss_text_ce,0.02153
TrainEpochStats/avg_grad_norm,0.0
TrainEpochStats/avg_loader_time,0.05743
TrainEpochStats/avg_loss,0.05243
TrainEpochStats/avg_loss_mel_ce,4.38085
TrainEpochStats/avg_loss_text_ce,0.02336
TrainEpochStats/avg_step_time,1.3192
