# Imports & Logging Init

In [1]:
# Imports
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager

import sys
import os
from datetime import datetime
import wandb
from trainer.logging.wandb_logger import WandbLogger

!pip list

Package                    Version
-------------------------- ----------------------
absl-py                    2.1.0
accelerate                 1.0.1
aiofiles                   23.2.1
aiohttp                    3.9.3
aiosignal                  1.3.1
altair                     5.5.0
aniso8601                  9.0.1
annotated-types            0.6.0
ansicon                    1.89.0
antlr4-python3-runtime     4.9.3
anyascii                   0.3.2
anyio                      4.4.0
argon2-cffi                23.1.0
argon2-cffi-bindings       21.2.0
arrow                      1.3.0
asttokens                  2.4.1
async-lru                  2.0.4
attrs                      23.2.0
audioread                  3.0.1
awscli                     1.37.16
Babel                      2.14.0
bangla                     0.0.2
beautifulsoup4             4.12.3
bleach                     6.1.0
blessed                    1.20.0
blinker                    1.7.0
blis                       0.7.11
bnnumerizer  

# Downloads

In [17]:
# Get XTTS files
CHECKPOINT_PATH = './XTTS-files/'
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# DVAE files
DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"

# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(DVAE_LINK))
MEL_NORM_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(MEL_NORM_LINK))

# DVAE download if not exists
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
    print(" > Downloading DVAE files!")
    ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINT_PATH, progress_bar=True)

# XTTS v2.0 checkpoint
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"

# Transfer learning parameters
TOKENIZER_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(TOKENIZER_FILE_LINK))  # vocab.json
XTTS_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK))  # model.pth

# XTTS v2.0 download if not exists
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
    print(" > Downloading XTTS v2.0 files!")
    ModelManager._download_model_files(
        [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINT_PATH, progress_bar=True
    )

# Data Prep

Does this lose info?

# Data Loading

Make sure your metadata follows LJSpeech format: \<file\>|\<transcription\>

In [None]:
# 1% to evaluate by listening to tests

# DATASET = 
training_dir = f'./datasets/{DATASET}/' # change to folder w/ training examples

dataset_config = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv", # metadata file w/ transcriptions
    language=LANGUAGE,
    path=training_dir
)

train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_size=0.02 # Might change
)

# MODIFY

## Model args
max_wav_length: maximum audio length from your dataset

Ensure the length of the reference audio is within the min-max range

## Sample rates

Defaults:

input: 22050

output: 24000

sample_rate and dvae_sample_rate should match data (probably 22050)

## Speaker reference

Choosing the right speaker reference is VERY important for XTTS-v2. Select ~10 clips saying a full sentence with an intonation/rythm/speed/style that sounds pretty good & then pick best.

In [19]:
# Modify model arguments
model_args = GPTArgs(
    max_conditioning_length=143677, # Audio used for conditioning latents should be less than this 
    min_conditioning_length=66150, # and more than this
    debug_loading_failures=True,
    max_wav_length=223997, # Set >= longest audio in dataset  
    max_text_length=200, 
    mel_norm_file=MEL_NORM_FILE,
    dvae_checkpoint=DVAE_CHECKPOINT,
    xtts_checkpoint=XTTS_CHECKPOINT,  
    tokenizer_file=TOKENIZER_FILE,
    gpt_num_audio_tokens=1026, 
    gpt_start_audio_token=1024,
    gpt_stop_audio_token=1025,
    gpt_use_masking_gt_prompt_approach=True,
    gpt_use_perceiver_resampler=True,
)

In [None]:
# Audio config
audio_config = XttsAudioConfig(
    sample_rate=16000,
    dvae_sample_rate=16000,
    output_sample_rate=24000
)

SPEAKER_REFERENCE = './audio/FILE.wav' # To be changed

# Training Config

In [25]:
OUT_PATH = './training_outputs/'

RUN_NAME = 'xtts-v2_test'
# RUN_NAME = f"xttsv2_finetune_{datetime.now().strftime('%Y%m%d_%H%M')}"
PROJECT_NAME = 'XTTS-v2 Finetune'
DASHBOARD_LOGGER = 'wandb'
LOGGER_URI = None

OPTIMIZER_WD_ONLY_ON_WEIGHTS = True  
BATCH_SIZE = 4 # 4 is common
LANGUAGE = 'en'

In [None]:
config = GPTTrainerConfig(
    run_eval=True,
    epochs = 1000, # assuming you want to end training manually w/ keyboard interrupt
    output_path=OUT_PATH,
    model_args=model_args,
    run_name=RUN_NAME,
    project_name=PROJECT_NAME,
    run_description="""
        GPT XTTS training
        """,
    dashboard_logger=DASHBOARD_LOGGER,
    wandb_entity=None,
    logger_uri=LOGGER_URI,
    audio=audio_config,
    batch_size=BATCH_SIZE,
    batch_group_size=48,
    eval_batch_size=BATCH_SIZE,
    num_loader_workers=0, # On Windows, num_loader_workers > 0 can break multiprocessing in PyTorch
    eval_split_max_size=256, 
    print_step=50, 
    plot_step=100, 
    log_model_step=1000, 
    save_step=None,
    save_n_checkpoints=3, # Rotate last 3 checkpoints
    save_checkpoints=True,
    print_eval=True,
    optimizer="AdamW",
    optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
    optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
    lr=5e-06,  
    lr_scheduler="MultiStepLR",
    lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
    test_sentences=[ 
        {
            "text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
            "speaker_wav": SPEAKER_REFERENCE, 
            "language": LANGUAGE,
        },
        {
            "text": "This cake is great. It's so delicious and moist.",
            "speaker_wav": SPEAKER_REFERENCE,
            "language": LANGUAGE,
        },
        {
            "text": "Hello my name is Rob and this is my presentation.",
            "speaker_wav": SPEAKER_REFERENCE,
            "language": LANGUAGE,
        }
        
    ],
) 

model = GPTTrainer.init_from_config(config)

# Training

In [None]:
GRAD_ACUMM_STEPS = 252
START_WITH_EVAL = True  

trainer = Trainer(
    TrainerArgs(
        restore_path=None, # Change to model path if resuming
        skip_train_epoch=False,
        start_with_eval=START_WITH_EVAL,
        grad_accum_steps=GRAD_ACUMM_STEPS,
    ),
    config,
    output_path=OUT_PATH,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)
trainer.fit()