### Install packages and download models

In [None]:
%%shell
git clone https://github.com/yl4579/StyleTTS2.git
cd StyleTTS2
pip install Soundfile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git
sudo apt-get install espeak-ng
git clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS
mv StyleTTS2-LibriTTS/Models .

### Download dataset

In [None]:
%cd StyleTTS2


In [None]:
# ==========================================
# = Code Cell 3: Load Voice Dataset =
# ==========================================

# We will download and extract the dataset, then move only the 'female/mono' folder
# to the desired location and remove unnecessary files.
!rm -rf Data

!kaggle datasets download -d mobassir/comprehensive-bangla-tts --unzip -p Data

# Move the 'female/mono' folder to 'Data/raw' and remove other files
!mkdir -p Data
!mv Data/iitm_bangla_tts/comprehensive_bangla_tts/female/mono/* Data/

# Rename wav directory to wavs
!mv Data/wav Data/wavs

# Clean up unnecessary files
!rm -rf Data/comprehensive_bangla_tts_weights
!rm -rf Data/comprehensive_bangla_tts
!rm -rf Data/iitm_bangla_tts
!rm -rf Data/vits_m_phoneme
!rm Data/license.pdf
!rm Data/txt.done.data


# The final directory structure should now look like this:
# Data/
#   ├── wavs/
#   └── metadata_female.txt

In [None]:
# Clean up old files

!rm Data/OOD_texts.txt
!rm Data/main_list.txt
!rm Data/train_list.txt
!rm Data/val_list.txt

In [None]:
import os
import random
from phonemizer import phonemize

metadata_path = "Data/metadata_female.txt"
main_list_path = "Data/main_list.txt"
train_list_path = "Data/train_list.txt"
val_list_path = "Data/val_list.txt"
ood_file_path = "Data/OOD_texts.txt"
sample_count = 5000

def phonemize_bengali(text):
    """
    Convert Bengali text to phonemes using eSpeak-ng.
    """
    try:
        return phonemize(
            text,
            language='bn',
            backend='espeak',
            strip=True,
            preserve_punctuation=True,
            with_stress=True
        )
    except Exception as e:
        print(f"Error phonemizing text '{text}': {str(e)}")
        return None

if os.path.exists(metadata_path):
    with open(metadata_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    formatted_lines = []
    for line in lines:
        parts = line.strip().split('|')
        if len(parts) >= 2:
            wav_id = parts[0]
            transcription = parts[1]
            relative_path = f"Data/wavs/{wav_id}.wav"
            formatted_lines.append(f"{relative_path}|{transcription}")

    with open(main_list_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(formatted_lines))

    print(f"Created {main_list_path} with {len(formatted_lines)} entries")

    root_wavs = "Data/wavs"
    broken_files = []
    missing_files = []

    for line in formatted_lines:
        parts = line.split("|")
        if len(parts) < 2:
            continue
        wav_path = parts[0].strip()

        if not os.path.exists(wav_path):
            missing_files.append(wav_path)
            continue
        try:
            import soundfile as sf
            data, sr = sf.read(wav_path)
            if len(data) == 0:
                broken_files.append((wav_path, "Empty/0 Samples"))
        except Exception as e:
            broken_files.append((wav_path, str(e)))

    print("Missing files:", len(missing_files))
    for m in missing_files[:10]:
        print("  -", m)

    print("Defective/unreadable files:", len(broken_files))
    for b in broken_files[:10]:
        print("  -", b[0], "| Error:", b[1])

    if not os.path.exists(ood_file_path):
        if len(formatted_lines) >= sample_count:
            sampled_lines = random.sample(formatted_lines, sample_count)
            phonemized_lines = []
            for line in sampled_lines:
                parts = line.split('|')
                if len(parts) != 2:
                    continue
                wav_path, text = parts
                phonemes = phonemize_bengali(text)
                if phonemes:
                    phonemized_lines.append(f"{wav_path}|{phonemes}|0")

            with open(ood_file_path, 'w', encoding='utf-8') as f:
                f.write('\n'.join(phonemized_lines))

            print(f"Created {ood_file_path} with {len(phonemized_lines)} phonemized entries")
        else:
            print(f"Warning: Not enough entries to sample {sample_count} items")

    if os.path.exists(ood_file_path):
        with open(ood_file_path, 'r', encoding='utf-8') as f:
            ood_lines = f.readlines()
        ood_lines = [line.strip() for line in ood_lines if line.strip()]
        total_ood = len(ood_lines)
        train_samples = int(total_ood * 0.8)
        eval_samples = total_ood - train_samples

        train_lines = []
        val_lines = []
        for line in ood_lines[:train_samples]:
            parts = line.split('|')
            if len(parts) == 3:
                filename = os.path.basename(parts[0])
                train_lines.append(f"{filename}|{parts[1]}|0")
        for line in ood_lines[train_samples:train_samples + eval_samples]:
            parts = line.split('|')
            if len(parts) == 3:
                filename = os.path.basename(parts[0])
                val_lines.append(f"{filename}|{parts[1]}|0")

        with open(train_list_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(train_lines))

        with open(val_list_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(val_lines))

        print(f"Using {ood_file_path} for training and validation")
        print(f"Created {train_list_path} with {len(train_lines)} entries")
        print(f"Created {val_list_path} with {len(val_lines)} entries")
    else:
        print(f"Error: OOD file {ood_file_path} not found")
else:
    print(f"Error: Metadata file {metadata_path} not found")

### Change the finetuning config

Depending on the GPU you got, you may want to change the bacth size, max audio length, epiochs and so on.

In [7]:
config_path = "Configs/config_ft.yml"
import yaml
config = yaml.safe_load(open(config_path))

In [13]:
config['data_params']['root_path'] = "Data/wavs"
config['data_params']['train'] = "Data/wavs"
config['batch_size'] = 8
config['epochs'] = 500
config['max_len'] = 128
config['pretrained_model'] = "Models/LibriTTS/epochs_2nd_00020.pth"
config['loss_params']['joint_epoch'] = 512



with open(config_path, 'w') as outfile:
  yaml.dump(config, outfile, default_flow_style=True)

### Start finetuning


In [None]:
!python train_finetune_accelerate.py --config_path ./Configs/config_ft.yml

### Test the model quality


In [None]:
import nltk
nltk.download('punkt_tab')

In [None]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import random
random.seed(0)

import numpy as np
np.random.seed(0)

# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize

from models import *
from utils import *
from text_utils import TextCleaner
textclenaer = TextCleaner()

%matplotlib inline

to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(path):
    wave, sr = librosa.load(path, sr=24000)
    audio, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        audio = librosa.resample(audio, sr, 24000)
    mel_tensor = preprocess(audio).to(device)

    with torch.no_grad():
        ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
        ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))

    return torch.cat([ref_s, ref_p], dim=1)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='bn', preserve_punctuation=True,  with_stress=True)

config = yaml.safe_load(open("Models/LJSpeech/config_ft.yml"))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

In [None]:
files = [f for f in os.listdir("Models/LJSpeech/") if f.endswith('.pth')]
sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

In [None]:
params_whole = torch.load("Models/LJSpeech/" + sorted_files[-1], map_location='cpu')
params = params_whole['net']

In [None]:
for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]
                new_state_dict[name] = v

            model[key].load_state_dict(new_state_dict, strict=False)
_ = [model[key].eval() for key in model]

In [None]:
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

In [None]:
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
    text = text.strip()
    ps = global_phonemizer.phonemize([text])
    ps = word_tokenize(ps[0])
    ps = ' '.join(ps)
    tokens = textclenaer(ps)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        text_mask = length_to_mask(input_lengths).to(device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

        s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
                                          embedding=bert_dur,
                                          embedding_scale=embedding_scale,
                                            features=ref_s,
                                             num_steps=diffusion_steps).squeeze(1)


        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        ref = alpha * ref + (1 - alpha)  * ref_s[:, :128]
        s = beta * s + (1 - beta)  * ref_s[:, 128:]

        d = model.predictor.text_encoder(d_en,
                                         s, input_lengths, text_mask)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)

        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new

        out = model.decoder(asr,
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


    return out.squeeze().cpu().numpy()[..., :-50] 

### Synthesize speech

In [None]:
text = '''হ্যালো আপনি কেমন আছেন? সব ভাল তো? আসুন গল্প করি।'''

In [None]:
# get a random reference in the training set, note that it doesn't matter which one you use
path = "/content/StyleTTS2/Data/wavs/train_bengalifemale_00117.wav"
# this style vector ref_s can be saved as a parameter together with the model weights
ref_s = compute_style(path)

In [None]:
start = time.time()
wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)
rtf = (time.time() - start) / (len(wav) / 24000)
print(f"RTF = {rtf:5f}")
import IPython.display as ipd
display(ipd.Audio(wav, rate=24000, normalize=False))