# Notes

## Setup notes
- If installing TTS package on a venv, install propper cuda enabled torch otherwise default torch will be installed, preventing cuda from being used.
- Go to "TTS\tts\layers\tortoise\arch_utils.py" replace references of LogitWarper to LogitsProcessor
- Go to "TTS\tts\models\xtts.py then to function get_compatible_checkpoint_state_dict. On line 714: checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]. Add the argument 'weights_only = False": checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"), weights_only = False)["model"]


In [1]:
'''Imports'''
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
import torch
import sys
import os
from datetime import datetime
import wandb
from trainer.logging.wandb_logger import WandbLogger

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
'''Display device used'''
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
print(torch.version.cuda)           
print(torch.cuda.is_available())    
print(torch.cuda.get_device_name()) 

cuda:0
12.4
True
NVIDIA GeForce RTX 2070


In [3]:
'''DOWNLOADS'''
# Get XTTS files
CHECKPOINT_PATH = './XTTS-files/'
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# DVAE files
DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"

# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(DVAE_LINK))
MEL_NORM_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(MEL_NORM_LINK))

# DVAE download if not exists
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
    print(" > Downloading DVAE files!")
    ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINT_PATH, progress_bar=True)

# XTTS v2.0 checkpoint
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"

# Transfer learning parameters. NOTE: Sets base model to use
TOKENIZER_FILE = os.path.join(CHECKPOINT_PATH, os.path.basename(TOKENIZER_FILE_LINK))  # vocab.json
XTTS_CHECKPOINT = os.path.join(CHECKPOINT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK))  # model.pth

# XTTS v2.0 download if not exists
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
    print(" > Downloading XTTS v2.0 files!")
    ModelManager._download_model_files(
        [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINT_PATH, progress_bar=True
    )
print("Paths set.")

Paths set.


In [None]:
'''DATA LOADING'''
# Set lang
LANGUAGE ='en'
# Set to folder name that contains metadata.csv and wavs dir (with the .wav examples)
DATASET= "tom_hanks_dutch_house"
training_dir = f'./datasets/{DATASET}/' # change to folder w/ training examples

# Dataset uses ljspeech format
dataset_config = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv", # metadata file w/ transcriptions
    language=LANGUAGE,
    path=training_dir
)

# Turn off eval split. Will evaluate manually
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_size=0.02 # Might change
)


'''MODIFY'''
# Audio config
audio_config = XttsAudioConfig(sample_rate=16000, dvae_sample_rate=16000, output_sample_rate=24000) 

# Speaker Reference: Match theses to the test sentences
### Only need 1 speaker audio reference. Do not need to match voice to text
SPEAKER_TEXT = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
    "This cake is great. It's so delicious and moist."
]
SPEAKER_REFERENCE = f"datasets/{DATASET}/wavs/chunk_0009.wav"


In [None]:
'''Set Model arguments'''
model_args = GPTArgs(
    max_conditioning_length=143677, # Audio used for conditioning latents should be less than this 
    min_conditioning_length=66150, # 6
    debug_loading_failures=True,
    max_wav_length=255995, # Set >= longest audio in dataset  ~ 11.6.set to max 10
    max_text_length=66150, # min 3
    mel_norm_file=MEL_NORM_FILE,
    dvae_checkpoint=DVAE_CHECKPOINT,
    xtts_checkpoint=XTTS_CHECKPOINT,  
    tokenizer_file=TOKENIZER_FILE,
    gpt_num_audio_tokens=1026, 
    gpt_start_audio_token=1024,
    gpt_stop_audio_token=1025,
    gpt_use_masking_gt_prompt_approach=True,
    gpt_use_perceiver_resampler=True,
)

In [None]:
'''Set up configuration file'''
'''TRAINING CONFIG'''
OUT_PATH = './training_outputs/'

RUN_NAME = f"xttsv2_finetune_{datetime.now().strftime('%Y%m%d_%H%M')}"
PROJECT_NAME = "XTTS-v2 Finetune"
DASHBOARD_LOGGER = 'wandb'
LOGGER_URI = None

OPTIMIZER_WD_ONLY_ON_WEIGHTS = True  

BATCH_SIZE = 3 # 4 is common

config = GPTTrainerConfig(
    run_eval=True,
    epochs = 2, # assuming you want to end training manually w/ keyboard interrupt
    output_path=OUT_PATH,
    model_args=model_args,
    run_name=RUN_NAME,
    project_name=PROJECT_NAME,
    run_description="""
        GPT XTTS training
        """,
    dashboard_logger=DASHBOARD_LOGGER,
    wandb_entity=None,
    logger_uri=LOGGER_URI,
    audio=audio_config,
    batch_size=BATCH_SIZE,
    batch_group_size=48,
    eval_batch_size=BATCH_SIZE,
    num_loader_workers=0, # On Windows, num_loader_workers > 0 can break multiprocessing in PyTorch
    eval_split_max_size=256, 
    print_step=50, 
    plot_step=100, 
    log_model_step=1000, 
    save_step=1000, # Needs to be an int
    save_n_checkpoints=3, # Rotate last 3 checkpoints
    save_checkpoints=True,
    print_eval=True,
    optimizer="AdamW",
    optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
    optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
    lr=5e-06,  
    lr_scheduler="MultiStepLR",
    lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
    test_sentences=[ 
        {
            "text": SPEAKER_TEXT[0],
            "speaker_wav": SPEAKER_REFERENCE, 
            "language": LANGUAGE,
        },
        {
            "text": SPEAKER_TEXT[1],
            "speaker_wav": SPEAKER_REFERENCE,
            "language": LANGUAGE,
        }
    ],
) 

In [None]:
'''Set up Trainer'''
# Init model 
model = GPTTrainer.init_from_config(config)

# Init Trainer
GRAD_ACUMM_STEPS = 84 # Note: GRAD_ACUMM_STEPS * BATCH_SIZE = 252
START_WITH_EVAL = True  

trainer = Trainer(
    TrainerArgs(
        restore_path=None, # Change to model path if resuming
        skip_train_epoch=False,
        start_with_eval=START_WITH_EVAL,
        grad_accum_steps=GRAD_ACUMM_STEPS,
    ),
    config,
    output_path=OUT_PATH,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)

In [None]:
'''TRAINING: manual interupts will set model to output saves at given checkpoints'''
trainer.fit()