In [None]:
!pip install coqui-tts -U -q
!pip install install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

In [None]:
import os
import torch
import logging
from datetime import datetime
from pathlib import Path
import traceback
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.tts.configs.shared_configs import (
    BaseDatasetConfig,
    CharactersConfig,
)
from TTS.tts.datasets import load_tts_samples

In [None]:
### Vits takes singificatly longer time to train than style-tts2 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 104
EPOCHS = 1000
MAX_SAMPLES = 5000
VAL_SPLIT = 0.1
LOG_FILE = "training.log"
male = False
pretrained = True

# Setup

In [None]:

pretrained_path = ""
if pretrained:
    pretrained_path = "./output/vits/"
if male:
    meta_file = "./content/drive/MyDrive/mono/metadata_male.txt"
    root_path = "/content/drive/MyDrive/mono"
else:
    meta_file = "/content/drive/MyDrive/mono/metadata_female.txt"
    root_path = "/content/drive/MyDrive/mono"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(LOG_FILE),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('bangla_tts_training')


# Format Dataset

In [None]:
def formatter(root_path, meta_file, **kwargs):
    """
    Normalizes the LJSpeech meta data file to TTS format.
    Validates characters in each line and reports problematic entries.
    """
    txt_file = meta_file
    items = []
    speaker_name = "ljspeech"
    skipped_lines = 0

    logger.info(f"Reading metadata from {meta_file}")
    with open(txt_file, "r", encoding="utf-8") as ttf:
        for line_num, line in enumerate(ttf, 1):
            try:
                cols = line.split("|")
                if len(cols) < 2:
                    logger.warning(f"Line {line_num} has invalid format: {line}")
                    skipped_lines += 1
                    continue

                wav_file = os.path.join(root_path, "wav", cols[0] + ".wav")
                if not os.path.exists(wav_file):
                    logger.warning(f"Line {line_num}: WAV file not found: {wav_file}")
                    skipped_lines += 1
                    continue

                text = cols[1].strip()
                if not text:
                    logger.warning(f"Line {line_num}: Empty text for file {cols[0]}")
                    skipped_lines += 1
                    continue

                items.append(
                    {
                        "text": text,
                        "audio_file": wav_file,
                        "speaker_name": speaker_name,
                        "root_path": root_path,
                    }
                )
            except Exception as e:
                logger.warning(f"Line {line_num}: Error processing line: {e}")
                skipped_lines += 1

    if skipped_lines > 0:
        logger.warning(f"Skipped {skipped_lines} lines in metadata file due to errors")

    logger.info(f"Successfully loaded {len(items)} items from metadata")
    return items


# Create Dataset

In [None]:
output_path = "./output"
os.makedirs(output_path, exist_ok=True)
logger.info(f"Output path: {output_path}")

dataset_config = BaseDatasetConfig(
    meta_file_train=meta_file,
    path=os.path.join(root_path, "")
)
logger.info(f"Dataset config: {dataset_config}")

eval_split_size = VAL_SPLIT
eval_split_max_size = int(MAX_SAMPLES * VAL_SPLIT)
logger.info(f"Using validation split: {eval_split_size}, max validation samples: {eval_split_max_size}")

logger.info(f"Loading samples from {meta_file}")
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    formatter=formatter,
    eval_split=True,
    eval_split_size=eval_split_size,
    eval_split_max_size=eval_split_max_size
)


if len(train_samples) > MAX_SAMPLES - len(eval_samples):
    logger.info(f"Limiting training samples from {len(train_samples)} to {MAX_SAMPLES - len(eval_samples)}")
    train_samples = train_samples[:MAX_SAMPLES - len(eval_samples)]

logger.info(f"Training samples: {len(train_samples)}, Validation samples: {len(eval_samples)}")
logger.debug(f"Sample example: {train_samples[0]}")
logger.info("Analyzing training data for character coverage...")
all_train_texts = [sample['text'] for sample in train_samples + eval_samples]
unique_chars = set()
for text in all_train_texts:
    unique_chars.update(text)

# Configure

In [None]:
audio_config = VitsAudioConfig(
    sample_rate=22050,
    win_length=1024,
    hop_length=256,
    num_mels=80,
    mel_fmin=0,
    mel_fmax=None,
)
logger.info("Audio configuration created")

bangla_chars = "অআইঈউঊঋঌএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহড়ঢ়য়ৎংঃঁ০১২৩৪৫৬৭৮৯ািীুূৃৄেৈোৌ্ৗৢৣ—"
additional_chars = "$%&''\"`\"„‘’"

if male:
    logger.info("Using male character configuration with comprehensive Bangla set and common punctuation")
    characters_config = CharactersConfig(
        pad="<PAD>",
        eos="।",  # '<EOS>', #'।',
        bos="<BOS>",  # None,
        blank="<BLNK>",
        phonemes=None,
        characters=bangla_chars + "''\u200c\u200d" + additional_chars,
        punctuations="!,.:;?()- |",
    )
else:
    logger.info("Using female character configuration with comprehensive Bangla set and common punctuation")
    characters_config = CharactersConfig(
        pad="<PAD>",
        eos="।",  # '<EOS>', #'।',
        bos="<BOS>",  # None,
        blank="<BLNK>",
        phonemes=None,
        characters=bangla_chars + "''\u200c\u200d" + additional_chars,
        punctuations="!,.:;?()- |",
    )

#test sentences
test_sentences = [
    "হয়,হয়ে,ওয়া,হয়েছ,হয়েছে,দিয়ে,যায়,দায়,নিশ্চয়,আয়,ভয়,নয়,আয়াত,নিয়ে,হয়েছে,দিয়েছ,রয়ে,রয়েছ,রয়েছে।",
    "দেয়,দেওয়া,বিষয়,হয়,হওয়া,সম্প্রদায়,সময়,হয়েছি,দিয়েছি,হয়,হয়েছিল,বিষয়ে,নয়,কিয়াম,ইয়া,দেয়া,দিয়েছে,আয়াতে,দয়া।",
    "ইয়াহুদ,নয়,ব্যয়,ইয়াহুদী,নেওয়া,উভয়ে,যায়,হয়েছিল,প্রয়োজন।",
]

# VITS model configuration
run_name = f"vits_{datetime.now().strftime('%b_%d')}"
logger.info(f"Creating model configuration with run name: {run_name}")
config = VitsConfig(
    audio=audio_config,
    run_name=run_name,
    use_speaker_embedding=True,
    batch_size=BATCH_SIZE,
    eval_batch_size=BATCH_SIZE,
    batch_group_size=0,
    num_loader_workers=4,
    num_eval_loader_workers=4,
    run_eval=True,
    test_delay_epochs=-1,
    epochs=EPOCHS,
    text_cleaner="multilingual_cleaners",
    use_phonemes=False,
    phoneme_language="bn",
    compute_input_seq_cache=True,
    print_step=50,
    print_eval=True,
    mixed_precision=True,
    output_path=output_path,
    datasets=[dataset_config],
    characters=characters_config,
    save_step=1000,
    cudnn_benchmark=True,
    cudnn_deterministic=True,
    eval_split_size=eval_split_size,
    eval_split_max_size=eval_split_max_size,
    test_sentences=test_sentences,
)

config_path = os.path.join(output_path, "config.json")
config.save_json(config_path)
logger.info(f"Configuration saved to {config_path}")

logger.info("Initializing audio processor")
ap = AudioProcessor.init_from_config(config)
logger.info(f"Audio processor resample: {ap.resample}")

logger.info("Initializing tokenizer")
tokenizer, config = TTSTokenizer.init_from_config(config)

logger.info(f"Creating VITS model with device: {device}")
model = Vits(config, ap, tokenizer, speaker_manager=None)
model = model.to(device)

logger.info("Setting up trainer")
trainer_args = TrainerArgs(continue_path=pretrained_path)
trainer = Trainer(
    trainer_args,
    config,
    output_path,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)

# Train

In [None]:
logger.info("Starting training")
try:
        trainer.fit()
        logger.info("Training completed successfully")
except Exception as e:
        logger.error(f"Training failed with error: {str(e)}")
        logger.error(traceback.format_exc())

# Inference

In [None]:
import torch
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from IPython.display import Audio, display
import numpy as np

config_path = "/content/output/vits/config.json" 
config = VitsConfig()
config.load_json(config_path)
ap = AudioProcessor.init_from_config(config)
tokenizer, config = TTSTokenizer.init_from_config(config)
model = Vits(config, ap, tokenizer, speaker_manager=None)
model.load_checkpoint(config, checkpoint_path="/content/output/vits/best_model_10082.pth", eval=True)
model.eval()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

text ="""একটি ছোট্ট গ্রামে বাস করত রমেশ নামের এক কৃষক। সে খুবই গরিব ছিল, কিন্তু তার বুদ্ধি ছিল অসাধারণ। একদিন রমেশের জমিতে একটি বড়ো কুমড়া ফলল। এত বড়ো কুমড়া সে আগে কখনো দেখেনি!

রমেশ ভাবল, "এটি রাজাকে উপহার দিলে যদি কিছু পুরস্কার পাই!" তাই সে কুমড়াটি নিয়ে রাজপ্রাসাদে গেল।

রাজা খুবই খুশি হলেন এবং বললেন, "তোমাকে আমি পুরস্কার দেব। তুমি কী চাও?"

রমেশ বলল, "মহারাজ, আমি কিছু চাই না। আপনি খুশি থাকলেই আমি খুশি।"

রাজা তার বিনয় দেখে খুশি হয়ে তাকে অনেক স্বর্ণমুদ্রা দিলেন।

এ ঘটনা শুনে গ্রামের এক ধনী লোভী ব্যবসায়ী ভাবল, "যদি কুমড়া উপহার দিয়ে এত পুরস্কার পাওয়া যায়, তাহলে আমি রাজাকে ঘোড়া উপহার দেব, নিশ্চয়ই অনেক কিছু পাব!"

সে এক দুর্লভ ঘোড়া কিনে রাজাকে উপহার দিল।

রাজা তখন হেসে বললেন, "তুমি আমাকে এত সুন্দর ঘোড়া উপহার দিয়েছো, আমি তোমাকে সেই কৃষকের দেওয়া বড় কুমড়াটি উপহার দিচ্ছি!"

ধনী ব্যবসায়ী হতভম্ভ হয়ে গেল। সে বুঝতে পারল, শুধু লোভ করলে সবকিছু পাওয়া যায় না, বুদ্ধি আর বিনয়ই প্রকৃত সম্পদ।"""

with torch.no_grad():
    inputs = tokenizer.text_to_ids(text)
    inputs = torch.LongTensor(inputs).unsqueeze(0).to(device)
    output = model.inference(inputs)


print(f"Type of output: {type(output)}")
print(f"Output keys: {output.keys() if isinstance(output, dict) else None}")


if isinstance(output, dict) and "model_outputs" in output:
    output_wav = output["model_outputs"]
else:
    output_wav = output 

if isinstance(output_wav, torch.Tensor):
    output_wav = output_wav.cpu().numpy()

if output_wav.ndim > 1:
    output_wav = output_wav.squeeze()


ap.save_wav(output_wav, "output.wav")

print("Audio saved to output.wav")

display(Audio(output_wav, rate=22050))