In [2]:
import os

# Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer.
from trainer import Trainer, TrainerArgs

# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.fast_speech_config import FastSpeechConfig

# BaseDatasetConfig: defines name, formatter and path of the dataset.
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

In [3]:
output_path = os.getcwd() 

In [4]:
dataset_config = BaseDatasetConfig(
    formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "LJSpeech-1.1/")
)

In [5]:
# INITIALIZE THE TRAINING CONFIGURATION
# Configure the model. Every config class inherits the BaseTTSConfig.
# config = ForwardTTSArgs()
config = FastSpeechConfig(
    batch_size=8,
    eval_batch_size=2,
    num_loader_workers=4,
    num_eval_loader_workers=4,
    run_eval=True,
    test_delay_epochs=-1,
    epochs=100,
    text_cleaner="phoneme_cleaners",
    use_phonemes=True,
    phoneme_language="en-us",
    phoneme_cache_path="phoneme_cache",
    print_step=10,
    print_eval=True,
    mixed_precision=True,
    output_path="output",
    datasets=[dataset_config],
    grad_clip=1,
    log_model_step=20,
    plot_step=10,
    use_noise_augment=True,
    lr=0.00001,
    lr_scheduler_params={"warmup_steps": 300},
)

In [6]:
ap = AudioProcessor.init_from_config(config)

# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# If characters are not defined in the config, default characters are passed to the config
tokenizer, config = TTSTokenizer.init_from_config(config)

 > Setting up Audio Processor...
 | > sample_rate:22050
 | > resample:False
 | > num_mels:80
 | > log_func:np.log10
 | > min_level_db:-100
 | > frame_shift_ms:None
 | > frame_length_ms:None
 | > ref_level_db:20
 | > fft_size:1024
 | > power:1.5
 | > preemphasis:0.0
 | > griffin_lim_iters:60
 | > signal_norm:True
 | > symmetric_norm:True
 | > mel_fmin:0
 | > mel_fmax:None
 | > pitch_fmin:1.0
 | > pitch_fmax:640.0
 | > spec_gain:20.0
 | > stft_pad_mode:reflect
 | > max_norm:4.0
 | > clip_norm:True
 | > do_trim_silence:True
 | > trim_db:45
 | > do_sound_norm:False
 | > do_amp_to_db_linear:True
 | > do_amp_to_db_mel:True
 | > do_rms_norm:False
 | > db_level:None
 | > stats_path:None
 | > base:10
 | > hop_length:256
 | > win_length:1024


In [7]:
ap.resample = True
ap.do_sound_norm = True
# ap.do_rms_norm = True

In [8]:
config.batch_size = 2

In [9]:
# config.eval_split_size = 0.012195121951219513

In [10]:
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_max_size=config.eval_split_max_size,
    eval_split_size=config.eval_split_size,
)

 | > Found 619 files in /mnt/45b9faff-45f3-43f2-903f-9b92a9a6338c/major-project/notebooks/tensorflow-tts/LJSpeech-1.1


In [11]:
model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
# model = ForwardTTS(config)

In [12]:
trainer = Trainer(
    TrainerArgs(), config, "output", model=model, train_samples=train_samples, eval_samples=eval_samples
)

 > Training Environment:
 | > Backend: Torch
 | > Mixed precision: True
 | > Precision: fp16
 | > Current device: 0
 | > Num. of GPUs: 1
 | > Num. of CPUs: 8
 | > Num. of Torch Threads: 4
 | > Torch seed: 54321
 | > Torch CUDNN: True
 | > Torch CUDNN deterministic: False
 | > Torch CUDNN benchmark: False
 | > Torch TF32 MatMul: False
 > Start Tensorboard: tensorboard --logdir=output/run-January-01-2025_08+26PM-252392a
  self.scaler = torch.cuda.amp.GradScaler()

 > Model has 37022561 parameters


In [None]:
# AND... 3,2,1... 🚀
trainer.fit()


[4m[1m > EPOCH: 0/100[0m
 --> output/run-January-01-2025_08+26PM-252392a

[1m > TRAINING (2025-01-01 20:26:54) [0m




> DataLoader initialization
| > Tokenizer:
	| > add_blank: False
	| > use_eos_bos: False
	| > use_phonemes: True
	| > phonemizer:
		| > phoneme language: en-us
		| > phoneme backend: gruut
| > Number of instances : 613
 | > Preprocessing samples
 | > Max text length: 169
 | > Min text length: 21
 | > Avg text length: 79.71778140293638
 | 
 | > Max audio length: 1244201.0
 | > Min audio length: 95271.0
 | > Avg audio length: 387105.7243066884
 | > Num. instances discarded samples: 0
 | > Batch group size: 0.


  with autocast(enabled=False):

[1m   --> TIME: 2025-01-01 20:27:07 -- STEP: 0/307 -- GLOBAL_STEP: 0[0m
     | > loss_spec: 2.6727843284606934  (2.6727843284606934)
     | > loss_dur: 1.190316915512085  (1.190316915512085)
     | > loss_aligner: 17.09288787841797  (17.09288787841797)
     | > loss_binary_alignment: 3.511650323867798  (3.511650323867798)
     | > loss: 24.467639923095703  (24.467639923095703)
     | > duration_error: 5.6094279289245605  (5.6094279289245605)
     | > amp_scaler: 32768.0  (32768.0)
     | > grad_norm: 0  (0)
     | > current_lr: 3.333333333333334e-08 
     | > step_time: 8.5044  (8.504414796829224)
     | > loader_time: 4.2978  (4.2977516651153564)

  with autocast(enabled=False):

[1m   --> TIME: 2025-01-01 20:27:16 -- STEP: 10/307 -- GLOBAL_STEP: 10[0m
     | > loss_spec: 3.7792012691497803  (3.008633279800415)
     | > loss_dur: 1.5952441692352295  (1.5044141173362733)
     | > loss_aligner: 26.886241912841797  (21.139652442932128)
     | > loss_b

In [None]:
%tb

In [None]:
import librosa
import os


def check_audio_files(directory):
    corrupted_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith((".wav")):  # Adjust extensions as per your dataset
                file_path = os.path.join(root, file)
                try:
                    audio, sr = librosa.load(file_path, sr=None)
                    if len(audio) == 0:
                        raise ValueError("Empty audio file")
                except Exception as e:
                    print(f"Corrupted: {file_path}, Error: {e}")
                    corrupted_files.append(file_path)
    return corrupted_files


# Replace with the path to your dataset
dataset_directory = (
    "/mnt/45b9faff-45f3-43f2-903f-9b92a9a6338c/test/tensorflow-tts/LJSpeech-1.1/wavs"
)
corrupted_audio_files = check_audio_files(dataset_directory)
print(f"Found {len(corrupted_audio_files)} corrupted files.")

In [30]:
import torchaudio


def check_audio_duration(directory, min_duration=0.5):  # Minimum duration in seconds
    corrupted_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith((".wav", ".mp3")):
                file_path = os.path.join(root, file)
                try:
                    waveform, sr = torchaudio.load(file_path)
                    duration = waveform.size(1) / sr
                    if duration < min_duration:
                        print(f"Short audio file: {file_path}, Duration: {duration}s")
                        corrupted_files.append(file_path)
                except Exception as e:
                    print(f"Corrupted: {file_path}, Error: {e}")
                    corrupted_files.append(file_path)
    return corrupted_files


dataset_directory = (
    "/mnt/45b9faff-45f3-43f2-903f-9b92a9a6338c/test/tensorflow-tts/LJSpeech-1.1/wavs"
)
short_or_corrupted_files = check_audio_duration(dataset_directory)