In [6]:
import os
import random

ROOT = os.getcwd()

samples = open(os.path.join(ROOT, "dataset", "metadata.txt")).read().splitlines()
randomized = random.sample(samples, len(samples))

train_count = int(len(randomized) * 0.85)

train_dataset = randomized[:train_count]
val_dataset = randomized[train_count:]

train_metadata_outpath = os.path.join(ROOT, "dataset", "train.txt")
val_metadata_outpath = os.path.join(ROOT, "dataset", "val.txt")

f_train = open(train_metadata_outpath, "w")
f_train.write("\n".join(train_dataset))
f_train.close()

f_val = open(val_metadata_outpath, "w")
f_val.write("\n".join(val_dataset))
f_val.close()

In [None]:
!wget https://huggingface.co/Gatozu35/tortoise-tts/resolve/main/dvae.pth -O experiments/dvae.pth
!wget https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth -O experiments/autoregressive.pth

In [7]:
import os
from pathlib import Path

ROOT = os.getcwd()

DEFAULT_TRAIN_BS = 64
DEFAULT_VAL_BS = 32
Dataset_Training_Path = os.path.join(ROOT, "dataset/train.txt")
ValidationDataset_Training_Path = os.path.join(ROOT, "dataset/val.txt")

if Dataset_Training_Path == ValidationDataset_Training_Path:
    print("WARNING: training dataset path == validation dataset path!!!")
    print(
        "\tThis is technically okay but will make all of the validation metrics useless. "
    )
    print(
        "it will also SUBSTANTIALLY slow down the rate of training, because validation datasets are supposed to be much smaller than training ones."
    )


def txt_file_lines(p: str) -> int:
    return len(Path(p).read_text().strip().split("\n"))


training_samples = txt_file_lines(Dataset_Training_Path)
val_samples = txt_file_lines(ValidationDataset_Training_Path)

if training_samples < 128:
    print(
        "WARNING: very small dataset! the smallest dataset tested thus far had ~200 samples."
    )
if val_samples < 20:
    print(
        "WARNING: very small validation dataset! val batch size will be scaled down to account"
    )


def div_spillover(n: int, bs: int) -> int:  # returns new batch size
    epoch_steps, remain = divmod(n, bs)
    if epoch_steps * 2 > bs:
        return bs  # don't bother optimising this stuff if epoch_steps are high
    if not remain:
        return bs  # unlikely but still

    if remain * 2 < bs:  # "easier" to get rid of remainder -- should increase bs
        target_bs = n // epoch_steps
    else:  # easier to increase epoch_steps by 1 -- decrease bs
        target_bs = n // (epoch_steps + 1)
    assert n % target_bs < epoch_steps + 2  # should be very few extra
    return target_bs


if training_samples < DEFAULT_TRAIN_BS:
    print(
        "WARNING: dataset is smaller than a single batch. This will almost certainly perform poorly. Trying anyway"
    )
    train_bs = training_samples
else:
    train_bs = div_spillover(training_samples, DEFAULT_TRAIN_BS)
if val_samples < DEFAULT_VAL_BS:
    val_bs = val_samples
else:
    val_bs = div_spillover(val_samples, DEFAULT_VAL_BS)

steps_per_epoch = training_samples // train_bs
lr_decay_epochs = [20, 40, 56, 72]
lr_decay_steps = [steps_per_epoch * e for e in lr_decay_epochs]
print_freq = min(100, max(20, steps_per_epoch))
val_freq = save_checkpoint_freq = print_freq * 3

print("===CALCULATED SETTINGS===")
print(f"{train_bs=} {val_bs=}")
print(f"{val_freq=} {lr_decay_steps=}")
print(f"{print_freq=} {save_checkpoint_freq=}")

===CALCULATED SETTINGS===
train_bs=64 val_bs=32
val_freq=300 lr_decay_steps=[224720, 449440, 629216, 808992]
print_freq=100 save_checkpoint_freq=300


In [8]:
Experiment_Name = "Test1"  # @param {type:"string"}
Dataset_Training_Name = "TestDataset"  # @param {type:"string"}
ValidationDataset_Name = (
    "TestValidation"  # this seems to be useless??? @param {type:"string"}
)
SaveTrainingStates = False  # @param {type:"boolean"}
Keep_Last_N_Checkpoints = 0  # @param {type:"slider", min:0, max:10, step:1}

Fp16 = False  # @param {type:"boolean"}
Use8bit = True  # @param {type:"boolean"}
TrainingRate = "1e-5"  # @param {type:"string"}
TortoiseCompat = True  # @param {type:"boolean"}

TrainBS = ""  # @param {type:"string"}
ValBS = ""  # @param {type:"string"}
ValFreq = ""  # @param {type:"string"}
LRDecaySteps = ""  # @param {type:"string"}
PrintFreq = ""  # @param {type:"string"}
SaveCheckpointFreq = ""  # @param {type:"string"}


def take(orig, override):
    if override == "":
        return orig
    return type(orig)(override)


train_bs = take(train_bs, TrainBS)
val_bs = take(val_bs, ValBS)
val_freq = take(val_freq, ValFreq)
lr_decay_steps = eval(LRDecaySteps) if LRDecaySteps else lr_decay_steps
print_freq = take(print_freq, PrintFreq)
save_checkpoint_freq = take(save_checkpoint_freq, SaveCheckpointFreq)
assert len(lr_decay_steps) == 4
gen_lr_steps = ", ".join(str(v) for v in lr_decay_steps)

In [9]:
import urllib.request

urllib.request.urlretrieve(
    "https://raw.githubusercontent.com/152334H/DL-Art-School/master/experiments/EXAMPLE_gpt.yml",
    "experiments/EXAMPLE_gpt.yml",
)

('experiments/EXAMPLE_gpt.yml', <http.client.HTTPMessage at 0x762e2af434a0>)

In [10]:

import os

!sed -i 's/batch_size: 128/batch_size: '"$train_bs"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/batch_size: 64/batch_size: '"$val_bs"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/val_freq: 500/val_freq: '"$val_freq"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/500, 1000, 1400, 1800/'"$gen_lr_steps"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/print_freq: 100/print_freq: '"$print_freq"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/save_checkpoint_freq: 500/save_checkpoint_freq: '"$save_checkpoint_freq"'/g' ./experiments/EXAMPLE_gpt.yml

!sed -i 's+CHANGEME_validation_dataset_name+'"$ValidationDataset_Name"'+g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's+CHANGEME_path_to_validation_dataset+'"$ValidationDataset_Training_Path"'+g' ./experiments/EXAMPLE_gpt.yml
if(Fp16==True):
  os.system("sed -i 's+fp16: false+fp16: true+g' ./experiments/EXAMPLE_gpt.yml")
!sed -i 's/use_8bit: true/use_8bit: '"$Use8bit"'/g' ./experiments/EXAMPLE_gpt.yml

!sed -i 's/disable_state_saving: true/disable_state_saving: '"$SaveTrainingStates"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/tortoise_compat: True/tortoise_compat: '"$TortoiseCompat"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/number_of_checkpoints_to_save: 0/number_of_checkpoints_to_save: '"$Keep_Last_N_Checkpoints"'/g' ./experiments/EXAMPLE_gpt.yml


!sed -i 's/CHANGEME_training_dataset_name/'"$Dataset_Training_Name"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's/CHANGEME_your_experiment_name/'"$Experiment_Name"'/g' ./experiments/EXAMPLE_gpt.yml
!sed -i 's+CHANGEME_path_to_training_dataset+'"$Dataset_Training_Path"'+g' ./experiments/EXAMPLE_gpt.yml


if (not TrainingRate=="1e-5"):
  os.system("sed -i 's+!!float 1e-5 # CHANGEME:+!!float '" + TrainingRate + "' #+g' ./experiments/EXAMPLE_gpt.yml")



In [None]:
%cd /content/DL-Art-School/codes

!python3 train.py -opt ../experiments/EXAMPLE_gpt.yml

In [None]:
# load the tortoise-tts-fast fork
# use the new --ar-checkpoint option with
# /path/to/DL-Art-School/experiments/<INSERT EXPERIMENT NAME HERE>/models/<MOST RECENT STEPS>_gpt.pth