# XTTS Fine-tuning Pipeline for Bangladeshi Bangla TTS

This notebook provides a complete pipeline for fine-tuning XTTS (Cross-lingual Text-To-Speech) model for Bangladeshi Bangla language.

## Requirements:
- NVIDIA GPU with sufficient VRAM (A100 recommended)
- Bangla speech dataset with metadata
- Reference audio files for speaker conditioning

## Pipeline Overview:
1. **Setup Environment** - Clone repository and install dependencies
2. **Download Checkpoints** - Get pretrained XTTS model
3. **Extend Vocabulary** - Add Bangla characters to vocabulary
4. **Train DVAE** - Train the variational autoencoder component
5. **Fix Compatibility** - Update torch.load for newer PyTorch versions
6. **Train GPT** - Train the GPT component for text conditioning
7. **Load Fine-tuned Model** - Load the trained model for inference
8. **Generate Speech** - Synthesize Bangla speech with fine-tuned model

---

## 1. Setup Environment

Clone the XTTS fine-tuning repository and install required dependencies.

In [None]:
# Clone the XTTS fine-tuning repository
!git clone https://github.com/nguyenhoanganh2002/XTTSv2-Finetuning-for-New-Languages.git
%cd XTTSv2-Finetuning-for-New-Languages

# Install required packages
!pip install -r requirements.txt

In [None]:
# Check the contents of the repository
%ls

## 2. Download Pretrained Checkpoints

Download the pretrained XTTS model checkpoints that will be used as the starting point for fine-tuning.

In [None]:
# Download pretrained XTTS checkpoints
!python download_checkpoint.py --output_path checkpoints/

In [None]:
# Verify checkpoint download
%ls

## 3. Extend Vocabulary for Bangla

Extend the model's vocabulary to include Bangla characters and phonemes from the training dataset.

In [None]:
# Extend vocabulary configuration for Bangla language
!python extend_vocab_config.py \
    --output_path=checkpoints/ \
    --metadata_path /kaggle/input/dataset-custom/dataset_custom/metadata_train.csv \
    --language bn \
    --extended_vocab_size 2000

## 4. Train DVAE (Variational Autoencoder)

Train the DVAE component of XTTS which handles audio encoding and decoding.

In [None]:
# Train DVAE component
!CUDA_VISIBLE_DEVICES=0,1 python train_dvae_xtts.py \
    --output_path=checkpoints/ \
    --train_csv_path=/kaggle/input/dataset-custom/dataset_custom/metadata_train.csv \
    --eval_csv_path=/kaggle/input/dataset-custom/dataset_custom/metadata_eval.csv \
    --language="bn" \
    --num_epochs=15 \
    --batch_size=64 \
    --lr=1e-6

## 5. Fix PyTorch Compatibility

Update the torch.load function to include `weights_only=False` parameter for compatibility with newer PyTorch versions.

In [None]:
# Fix torch.load compatibility issue
file_path = "/kaggle/working/XTTSv2-Finetuning-for-New-Languages/TTS/utils/io.py"

# Read the current contents
with open(file_path, 'r') as file:
    content = file.read()

# Modify the torch.load line to include weights_only=False
new_content = content.replace(
    'return torch.load(f, map_location=map_location, **kwargs)',
    'return torch.load(f, map_location=map_location, weights_only=False, **kwargs)'
)

# Write the modified content back
with open(file_path, 'w') as file:
    file.write(new_content)

print("✅ PyTorch compatibility fix applied successfully!")

In [None]:
# Verify the changes (optional - shows the modified file content)
with open(file_path, 'r') as file:
    modified_content = file.read()
    
# Show only the relevant part
lines = modified_content.split('\n')
for i, line in enumerate(lines):
    if 'torch.load' in line:
        print(f"Line {i+1}: {line.strip()}")

## 6. Train GPT Component

Train the GPT component of XTTS which handles text conditioning and cross-lingual capabilities.

In [None]:
# Train GPT component
!CUDA_VISIBLE_DEVICES=0 python train_gpt_xtts.py \
    --output_path checkpoints/ \
    --metadatas /kaggle/input/dataset-custom/dataset_custom/metadata_train.csv,/kaggle/input/dataset-custom/dataset_custom/metadata_eval.csv,bn \
    --num_epochs 55 \
    --batch_size 4 \
    --grad_acumm 4 \
    --max_text_length 400 \
    --max_audio_length 330750 \
    --weight_decay 1e-2 \
    --lr 1e-6 \
    --save_step 5000000

## 7. Load Fine-tuned XTTS Model

Load the fine-tuned XTTS model for inference and speech synthesis.

In [None]:
import torch
import torchaudio
from tqdm import tqdm
from underthesea import sent_tokenize

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

# Device configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"🖥️  Using device: {device}")

# Model paths (update these paths according to your setup)
xtts_checkpoint = "/kaggle/input/pretrain-weights/3-dataset/best_model.pth"
xtts_config = "/kaggle/input/pretrain-weights/3-dataset/config.json"
xtts_vocab = "/kaggle/input/pretrain-weights/3-dataset/vocab.json"

print(f"📋 Loading config from: {xtts_config}")
print(f"🤖 Loading model from: {xtts_checkpoint}")
print(f"📚 Loading vocab from: {xtts_vocab}")

In [None]:
# Load model configuration and initialize XTTS
config = XttsConfig()
config.load_json(xtts_config)

# Initialize XTTS model
XTTS_MODEL = Xtts.init_from_config(config)

# Load checkpoint
XTTS_MODEL.load_checkpoint(
    config, 
    checkpoint_path=xtts_checkpoint, 
    vocab_path=xtts_vocab, 
    use_deepspeed=False
)

# Move model to device
XTTS_MODEL.to(device)

print("✅ Model loaded successfully!")

## 8. Prepare Speaker Conditioning

Load reference audio and extract speaker embeddings for voice conditioning.

In [None]:
# Speaker conditioning setup
speaker_audio_file = "/kaggle/input/dataset-custom/ref1.wav"
lang = "bn"  # Bangla language code

print(f"🎤 Loading speaker reference from: {speaker_audio_file}")

# Extract conditioning latents from reference audio
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
    audio_path=speaker_audio_file,
    gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
    max_ref_length=XTTS_MODEL.config.max_ref_len,
    sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)

print("✅ Speaker conditioning prepared successfully!")
print(f"📊 GPT conditioning shape: {gpt_cond_latent.shape}")
print(f"📊 Speaker embedding shape: {speaker_embedding.shape}")

## 9. Text-to-Speech Synthesis

Generate speech from Bangla text using the fine-tuned XTTS model.

In [None]:
# Test text 1: Professional context
tts_text = "লিভের আবেদন করার সময় Brain station এর কর্মীদেরকে তাদের ছুটির কারণ স্পষ্টভাবে বলতে হবে"

print(f"🎯 Synthesizing text: {tts_text}")
print(f"📝 Text length: {len(tts_text)} characters")

In [None]:
# Test text 2: Weather report
tts_text = "ঢাকায় একুশ ডিগ্রি সেলসিয়াস তাপমাত্রায় বাতাসের সাথে বৃষ্টিপাত আছে"

print(f"🎯 Synthesizing text: {tts_text}")
print(f"📝 Text length: {len(tts_text)} characters")

In [None]:
# Tokenize text into sentences for better synthesis
tts_texts = sent_tokenize(tts_text)
print(f"📄 Split into {len(tts_texts)} sentences:")
for i, sentence in enumerate(tts_texts, 1):
    print(f"  {i}. {sentence}")

# Generate speech for each sentence
wav_chunks = []
print("\n🎵 Generating speech...")

for i, text in enumerate(tqdm(tts_texts, desc="Synthesizing")):
    wav_chunk = XTTS_MODEL.inference(
        text=text,
        language=lang,
        gpt_cond_latent=gpt_cond_latent,
        speaker_embedding=speaker_embedding,
        temperature=0.1,          # More natural but still stable
        length_penalty=1.0,
        repetition_penalty=10.0,   # Enough to stop loops, not distort speech
        top_k=20,                 # More phoneme variety
        top_p=0.9,                # Allow smoother probability distribution
    )
    wav_chunks.append(torch.tensor(wav_chunk["wav"]))
    print(f"  ✅ Sentence {i+1} synthesized ({len(wav_chunk['wav'])} samples)")

# Concatenate all audio chunks
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0).cpu()
print(f"\n🎵 Final audio shape: {out_wav.shape}")
print(f"⏱️  Duration: {out_wav.shape[1] / 24000:.2f} seconds")

## 10. Save Generated Audio

Save the synthesized speech to an audio file.

In [None]:
import torch
import torchaudio

# Ensure tensor format
if not isinstance(out_wav, torch.Tensor):
    out_wav = torch.from_numpy(out_wav)

# Normalize audio to prevent clipping
out_wav = out_wav / out_wav.abs().max()

# Save audio file
output_filename = "bangla_xtts_output.wav"
torchaudio.save(
    output_filename,
    out_wav,
    sample_rate=24000,
    encoding="PCM_F",
    bits_per_sample=32
)

print(f"🎵 Audio saved as: {output_filename}")
print(f"📊 Sample rate: 24000 Hz")
print(f"📊 Duration: {out_wav.shape[1] / 24000:.2f} seconds")
print(f"📊 File size: ~{out_wav.shape[1] * 4 / 1024:.1f} KB")

## 11. Audio Analysis and Visualization (Optional)

Analyze the generated audio and create visualizations.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create audio waveform visualization
plt.figure(figsize=(15, 8))

# Waveform plot
plt.subplot(2, 1, 1)
time_axis = np.linspace(0, out_wav.shape[1] / 24000, out_wav.shape[1])
plt.plot(time_axis, out_wav.squeeze().numpy())
plt.title('Generated Bangla Speech Waveform (XTTS Fine-tuned)', fontsize=14, fontweight='bold')
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)

# Spectrogram
plt.subplot(2, 1, 2)
plt.specgram(out_wav.squeeze().numpy(), Fs=24000, cmap='viridis')
plt.title('Spectrogram', fontsize=14, fontweight='bold')
plt.xlabel('Time (seconds)')
plt.ylabel('Frequency (Hz)')
plt.colorbar(label='Power (dB)')

plt.tight_layout()
plt.show()

# Audio statistics
print("📊 Audio Statistics:")
print(f"   • Duration: {out_wav.shape[1] / 24000:.2f} seconds")
print(f"   • Sample Rate: 24000 Hz")
print(f"   • Max Amplitude: {out_wav.abs().max():.4f}")
print(f"   • RMS: {torch.sqrt(torch.mean(out_wav**2)):.4f}")
print(f"   • Dynamic Range: {20 * torch.log10(out_wav.abs().max() / (torch.sqrt(torch.mean(out_wav**2)) + 1e-8)):.1f} dB")

## 12. Model Performance Evaluation

Evaluate the fine-tuned model performance and compare with baseline.

In [None]:
# Test multiple Bangla sentences for evaluation
test_sentences = [
    "আমি বাংলা ভাষায় কথা বলতে পারি।",
    "ঢাকা বাংলাদেশের রাজধানী।",
    "আজকের আবহাওয়া খুবই সুন্দর।",
    "শিক্ষা জাতির মেরুদণ্ড।",
    "প্রযুক্তি আমাদের জীবনকে সহজ করেছে।"
]

print("🧪 Testing model with multiple sentences...")
print(f"📝 Total test sentences: {len(test_sentences)}")

# Generate audio for each test sentence
for i, sentence in enumerate(test_sentences, 1):
    print(f"\n🎯 Test {i}: {sentence}")
    
    # Generate speech
    wav_result = XTTS_MODEL.inference(
        text=sentence,
        language="bn",
        gpt_cond_latent=gpt_cond_latent,
        speaker_embedding=speaker_embedding,
        temperature=0.1,
        length_penalty=1.0,
        repetition_penalty=10.0,
        top_k=20,
        top_p=0.9,
    )
    
    # Convert to tensor and normalize
    test_wav = torch.tensor(wav_result["wav"]).unsqueeze(0)
    test_wav = test_wav / test_wav.abs().max()
    
    # Save individual test file
    test_filename = f"test_{i}_bangla_xtts.wav"
    torchaudio.save(test_filename, test_wav, 24000)
    
    # Report statistics
    duration = test_wav.shape[1] / 24000
    print(f"   ✅ Generated: {duration:.2f}s, saved as {test_filename}")

print("\n🎉 Model evaluation completed!")
print("📁 All test audio files have been saved for manual evaluation.")

## 13. Summary and Next Steps

Summary of the XTTS fine-tuning pipeline results and recommendations for further improvements.

In [None]:
print("🎉 XTTS Fine-tuning Pipeline Completed Successfully!")
print("="*60)
print("\n📋 Pipeline Summary:")
print("   ✅ Environment setup and repository cloning")
print("   ✅ Pretrained checkpoint download")
print("   ✅ Bangla vocabulary extension (2000 tokens)")
print("   ✅ DVAE training (15 epochs)")
print("   ✅ PyTorch compatibility fixes")
print("   ✅ GPT training (55 epochs)")
print("   ✅ Model loading and inference setup")
print("   ✅ Speech synthesis and evaluation")

print("\n🎯 Model Capabilities:")
print("   • Cross-lingual text-to-speech synthesis")
print("   • Bangla language support with extended vocabulary")
print("   • Speaker voice cloning and conditioning")
print("   • High-quality 24kHz audio generation")
print("   • Sentence-level synthesis with natural prosody")

print("\n📊 Training Configuration:")
print(f"   • Language: Bangla (bn)")
print(f"   • Vocabulary Size: 2000 tokens")
print(f"   • DVAE Epochs: 15")
print(f"   • GPT Epochs: 55")
print(f"   • Sample Rate: 24000 Hz")
print(f"   • Device: {device}")

print("\n🚀 Next Steps:")
print("   1. Evaluate audio quality manually")
print("   2. Test with diverse Bangla text samples")
print("   3. Compare with baseline VITS model")
print("   4. Optimize inference parameters for better quality")
print("   5. Deploy model for production use")

print("\n💡 Recommendations:")
print("   • Use more training data for better generalization")
print("   • Experiment with different temperature settings")
print("   • Fine-tune repetition penalty for specific use cases")
print("   • Consider multi-speaker training for voice variety")

print("\n" + "="*60)
print("🎵 Fine-tuned XTTS model ready for Bangladeshi Bangla TTS!")