# VITS Training and Inference Notebook

This notebook provides a complete workflow for training a VITS (Variational Inference with adversarial learning for end-to-end Text-to-Speech) model and using it for inference. It is designed to be run on Google Colab with a free GPU.

## 1. Setup and Installation

First, we install the necessary libraries. We'll also clone the original VITS repository to use some of their code for data processing and the model architecture.

In [None]:
# Install dependencies
!pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install numpy scipy librosa unidecode tensorboard
!pip install phonemizer

# Clone VITS repository for utilities
!git clone https://github.com/jaywalnut310/vits.git
%cd vits
!pip install -e .

## 2. Imports and Configuration

In [None]:
import os
import json
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np
import commons
import utils
from data_utils import (
  TextAudioLoader,
  TextAudioCollate,
  DistributedBucketSampler
)
from models import (
  SynthesizerTrn,
  MultiPeriodDiscriminator,
)
from text.symbols import symbols
from text import text_to_sequence
from torch.utils.tensorboard import SummaryWriter
import librosa
import IPython.display as ipd

# Check for GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

### Create a configuration file
We'll create a JSON config file for our model. This is based on the `base_vits.json` you have and common parameters.

In [None]:
config_json = {
    "train": {
        "log_interval": 200,
        "eval_interval": 1000,
        "seed": 1234,
        "epochs": 10000,
        "learning_rate": 2e-4,
        "betas": [0.8, 0.99],
        "eps": 1e-9,
        "batch_size": 16,
        "fp16_run": true,
        "lr_decay": 0.999875,
        "segment_size": 8192,
        "init_lr_ratio": 1,
        "warmup_epochs": 0,
        "c_mel": 45,
        "c_kl": 1.0
    },
    "data": {
        "training_files": "filelists/ljs_audio_text_train_filelist.txt.cleaned",
        "validation_files": "filelists/ljs_audio_text_val_filelist.txt.cleaned",
        "text_cleaners": ["english_cleaners2"],
        "max_wav_value": 32768.0,
        "sampling_rate": 22050,
        "filter_length": 1024,
        "hop_length": 256,
        "win_length": 1024,
        "n_mel_channels": 80,
        "mel_fmin": 0.0,
        "mel_fmax": null,
        "add_blank": true,
        "n_speakers": 0,
        "cleaned_text": true
    },
    "model": {
        "inter_channels": 192,
        "hidden_channels": 192,
        "filter_channels": 768,
        "n_heads": 2,
        "n_layers": 6,
        "kernel_size": 3,
        "p_dropout": 0.1,
        "resblock": "1",
        "resblock_kernel_sizes": [3, 7, 11],
        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        "upsample_rates": [8, 8, 2, 2],
        "upsample_initial_channel": 512,
        "upsample_kernel_sizes": [16, 16, 4, 4],
        "n_layers_q": 3,
        "use_spectral_norm": false
    }
}

with open('config.json', 'w') as f:
    json.dump(config_json, f, indent=2)

hps = utils.get_hparams_from_file("config.json")

## 3. Data Preparation

Here, you need to prepare your dataset. The VITS model expects a specific filelist format: `path/to/audio.wav|text transcription`.

For this example, we'll download the LJSpeech dataset.

In [None]:
# Download and extract LJSpeech
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjvf LJSpeech-1.1.tar.bz2
!mkdir -p filelists

# Create filelists
!wget https://raw.githubusercontent.com/NVIDIA/tacotron2/master/filelists/ljs_audio_text_train_filelist.txt
!wget https://raw.githubusercontent.com/NVIDIA/tacotron2/master/filelists/ljs_audio_text_val_filelist.txt
!wget https://raw.githubusercontent.com/NVIDIA/tacotron2/master/filelists/ljs_audio_text_test_filelist.txt

# Clean the text
!python preprocess.py --text_index 1 --filelists ljs_audio_text_train_filelist.txt ljs_audio_text_val_filelist.txt --text_cleaners english_cleaners2

# Move filelists to the correct directory
!mv ljs_audio_text_train_filelist.txt.cleaned filelists/
!mv ljs_audio_text_val_filelist.txt.cleaned filelists/

## 4. Training

Now we will set up the training components and start the training process.

In [None]:
torch.manual_seed(hps.train.seed)
torch.backends.cudnn.benchmark = True

# Create logger
logger = utils.get_logger(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)

# Create dataset and dataloader
train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
collate_fn = TextAudioCollate()
train_loader = DataLoader(train_dataset, batch_size=hps.train.batch_size, num_workers=2, shuffle=True, pin_memory=True, collate_fn=collate_fn)

# Create models
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).to(device)

net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)

# Create optimizers
optim_g = torch.optim.AdamW(net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
optim_d = torch.optim.AdamW(net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)

# Schedulers
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay)

# Scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(enabled=hps.train.fp16_run)

# Training loop
for epoch in range(1, hps.train.epochs + 1):
    net_g.train()
    net_d.train()
    for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(train_loader):
        x, x_lengths = x.to(device), x_lengths.to(device)
        spec, spec_lengths = spec.to(device), spec_lengths.to(device)
        y, y_lengths = y.to(device), y_lengths.to(device)

        with torch.cuda.amp.autocast(enabled=hps.train.fp16_run):
            y_hat, l_length, attn, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths)

            mel = spec_to_mel_torch(
                spec,
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.mel_fmin,
                hps.data.mel_fmax)
            y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
            y_hat_mel = mel_spectrogram_torch(
                y_hat.squeeze(1),
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.hop_length,
                hps.data.win_length,
                hps.data.mel_fmin,
                hps.data.mel_fmax
            )

            y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size)

            # Discriminator
            y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
            with torch.cuda.amp.autocast(enabled=False):
                loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
                loss_disc_all = loss_disc

        optim_d.zero_grad()
        scaler.scale(loss_disc_all).backward()
        scaler.unscale_(optim_d)
        grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
        scaler.step(optim_d)

        with torch.cuda.amp.autocast(enabled=hps.train.fp16_run):
            # Generator
            y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
            with torch.cuda.amp.autocast(enabled=False):
                loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
                loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
                loss_fm = feature_loss(fmap_r, fmap_g)
                loss_gen, losses_gen = generator_loss(y_d_hat_g)
                loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl

        optim_g.zero_grad()
        scaler.scale(loss_gen_all).backward()
        scaler.unscale_(optim_g)
        grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
        scaler.step(optim_g)
        scaler.update()

        if (batch_idx % hps.train.log_interval == 0):
            logger.info(f'Epoch: {epoch}, Batch: {batch_idx}, Gen_loss: {loss_gen_all}, Disc_loss: {loss_disc_all}')

    scheduler_g.step()
    scheduler_d.step()

    if (epoch % hps.train.eval_interval == 0):
        utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(epoch)))
        utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(epoch)))

## 5. Inference

Once the model is trained, you can use it to synthesize speech from text.

In [None]:
def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

# Load the trained generator model
# Make sure to replace G_XXXX.pth with your trained model checkpoint
model_path = "./logs/G_10000.pth" # Change this to your model path
hps_inf = utils.get_hparams_from_file("./config.json")
net_g_inf = SynthesizerTrn(
    len(symbols),
    hps_inf.data.filter_length // 2 + 1,
    hps_inf.train.segment_size // hps_inf.data.hop_length,
    **hps_inf.model).to(device)
_ = net_g_inf.eval()
_ = utils.load_checkpoint(model_path, net_g_inf, None)

# Text to synthesize
text = "Hello world, this is a test of the VITS model."
stn_tst = get_text(text, hps_inf)

with torch.no_grad():
    x_tst = stn_tst.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
    audio = net_g_inf.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()

# Display and save the audio
print("Generated Audio:")
ipd.display(ipd.Audio(audio, rate=hps_inf.data.sampling_rate, normalize=False))
librosa.output.write_wav("./generated_audio.wav", audio, hps_inf.data.sampling_rate)