# Sinhala XTTS-v2 Fine-tuning - Corrected Notebook
## üá±üá∞ Complete End-to-End Pipeline with Dataset Column Mapping

**Corrected for:**
- Dataset: https://www.kaggle.com/datasets/amalshaf/sinhala-tts-dataset
- CSV format: `audio_file_path | transcript | speaker_id`
- Repository: https://github.com/amalshafernando/XTTSv2-sinhala (sinhala-tokenization branch)

**This notebook will:**
1. ‚úÖ Clone your repository
2. ‚úÖ Mount and validate dataset
3. ‚úÖ Prepare dataset with correct column mapping
4. ‚úÖ Download XTTS-v2 model
5. ‚úÖ Extend vocabulary (15,000 Sinhala tokens)
6. ‚úÖ Fine-tune GPT for Sinhala
7. ‚úÖ Generate sample speech

**Time:** ~5-9 hours (mostly training)

## Cell 1: Setup & Install Dependencies

In [None]:
print("\n" + "="*80)
print("PHASE 1: ENVIRONMENT SETUP")
print("="*80)

print("\nüì¶ Installing PyTorch with CUDA support...")
!pip install -q torch==2.1.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

print("üì¶ Installing TTS framework...")
!pip install -q TTS>=0.22.0 transformers>=4.30.0 tokenizers>=0.13.0

print("üì¶ Installing dependencies...")
!pip install -q pandas numpy tqdm pyyaml regex

print("\n‚úÖ Verifying installation...")
import torch
import torchaudio
from TTS.utils.manage import ModelManager

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("\n" + "="*80)
print("‚úÖ PHASE 1 COMPLETE")
print("="*80)

## Cell 2: Clone Repository

In [None]:
import os
import subprocess
import sys

print("\n" + "="*80)
print("PHASE 2: CLONE REPOSITORY")
print("="*80)

repo_url = "https://github.com/amalshafernando/XTTSv2-sinhala.git"
branch = "sinhala-tokenization"
repo_path = "/kaggle/working/XTTSv2-sinhala"

print(f"\nüì• Cloning repository...")
print(f"   URL: {repo_url}")
print(f"   Branch: {branch}")

if not os.path.exists(repo_path):
    result = subprocess.run([
        "git", "clone",
        "-b", branch,
        repo_url,
        repo_path
    ], capture_output=True, text=True)
    
    if result.returncode == 0:
        print(f"‚úÖ Repository cloned successfully")
    else:
        print(f"‚ùå Clone failed: {result.stderr}")
        raise RuntimeError("Failed to clone repository")
else:
    print(f"‚úÖ Repository already exists")

# Verify scripts
print(f"\nüìã Verifying scripts...")
scripts = [
    'extend_vocab_sinhala.py',
    'train_gpt_xtts.py',
    'config_sinhala.py',
    'prepare_dataset_sinhala.py',
    'inference_sinhala.py'
]

for script in scripts:
    script_path = os.path.join(repo_path, script)
    if os.path.exists(script_path):
        size_kb = os.path.getsize(script_path) / 1024
        print(f"   ‚úÖ {script} ({size_kb:.1f} KB)")
    else:
        print(f"   ‚ö†Ô∏è  {script} not found")

# Add to path
sys.path.insert(0, repo_path)

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 2 COMPLETE")
print("="*80)

## Cell 3: Setup Configuration

In [None]:
import os
from pathlib import Path

print("\n" + "="*80)
print("PHASE 3: SETUP CONFIGURATION")
print("="*80)

# Paths
DATASET_INPUT = "/kaggle/input/sinhala-tts-dataset"
WORKING_DIR = "/kaggle/working"
DATASET_PROCESSED = os.path.join(WORKING_DIR, "datasets")
CHECKPOINTS_DIR = os.path.join(WORKING_DIR, "checkpoints")
XTTS_MODEL_DIR = os.path.join(CHECKPOINTS_DIR, "XTTS_v2.0_original_model_files")
OUTPUT_DIR = os.path.join(WORKING_DIR, "output")

# Create directories
print(f"\nüìÅ Creating directories...")
for dir_path in [DATASET_PROCESSED, CHECKPOINTS_DIR, XTTS_MODEL_DIR, OUTPUT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

# Configuration
LANGUAGE_CODE = "si"
VOCAB_SIZE = 15000
NUM_EPOCHS = 5
BATCH_SIZE = 8
GRAD_ACCUM = 4
LEARNING_RATE = 5e-6
SAVE_STEP = 50000

print("\nüìã Configuration:")
print(f"   Language: {LANGUAGE_CODE} (Sinhala)")
print(f"   Vocab size: {VOCAB_SIZE:,} tokens")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Learning rate: {LEARNING_RATE}")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 3 COMPLETE")
print("="*80)

## Cell 4: Validate Dataset (CORRECTED FOR YOUR CSV FORMAT)

In [None]:
import os
import pandas as pd

print("\n" + "="*80)
print("PHASE 4: VALIDATE DATASET")
print("="*80)

print(f"\n[1/4] Checking dataset at {DATASET_INPUT}")
if not os.path.exists(DATASET_INPUT):
    print(f"‚ùå Dataset not found!")
    raise FileNotFoundError(f"Dataset path not found: {DATASET_INPUT}")

print(f"‚úÖ Dataset found")

# List contents
print(f"\n[2/4] Dataset contents:")
for item in os.listdir(DATASET_INPUT):
    item_path = os.path.join(DATASET_INPUT, item)
    if os.path.isdir(item_path):
        count = len(os.listdir(item_path))
        print(f"   üìÅ {item}/ ({count} items)")
    else:
        print(f"   üìÑ {item}")

# Find CSV files
print(f"\n[3/4] Validating CSV files:")
train_csv = None
eval_csv = None

# Look for standard filenames
for fname in os.listdir(DATASET_INPUT):
    if 'train' in fname.lower() and fname.endswith('.csv'):
        train_csv = os.path.join(DATASET_INPUT, fname)
    elif 'eval' in fname.lower() or 'val' in fname.lower() and fname.endswith('.csv'):
        eval_csv = os.path.join(DATASET_INPUT, fname)
    elif 'test' in fname.lower() and fname.endswith('.csv'):
        eval_csv = os.path.join(DATASET_INPUT, fname)

if train_csv and eval_csv:
    print(f"   ‚úÖ Found train: {os.path.basename(train_csv)}")
    print(f"   ‚úÖ Found eval: {os.path.basename(eval_csv)}")
    
    # Validate CSV format
    print(f"\n[4/4] Validating CSV format:")
    
    try:
        df_train = pd.read_csv(train_csv, sep=',')
        df_eval = pd.read_csv(eval_csv, sep=',')
        
        print(f"   Train CSV columns: {list(df_train.columns)}")
        print(f"   Train CSV rows: {len(df_train)}")
        print(f"   Eval CSV rows: {len(df_eval)}")
        
        # Check for expected columns (with flexible matching)
        train_cols = [c.lower().strip() for c in df_train.columns]
        print(f"\n   Column mapping:")
        
        if 'audio_file_path' in train_cols:
            print(f"   ‚úÖ audio_file_path column found")
        if 'transcript' in train_cols:
            print(f"   ‚úÖ transcript column found")
        if 'speaker_id' in train_cols:
            print(f"   ‚úÖ speaker_id column found")
        
        print(f"\n   Sample row:")
        print(f"   {df_train.iloc[0].to_dict()}")
        
        print(f"\n‚úÖ Dataset validation PASSED")
        
    except Exception as e:
        print(f"   ‚ùå Error reading CSV: {str(e)}")
        raise
else:
    print(f"   ‚ùå Could not find train/eval CSV files")
    print(f"   Please ensure dataset has train and eval CSV files")
    raise FileNotFoundError("CSV files not found")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 4 COMPLETE")
print("="*80)

## Cell 5: Download XTTS-v2 Model

In [None]:
from TTS.utils.manage import ModelManager
import os

print("\n" + "="*80)
print("PHASE 5: DOWNLOAD XTTS-v2 MODEL")
print("="*80)

files_to_download = [
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth", "mel_stats.pth"),
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth", "dvae.pth"),
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", "vocab.json"),
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth", "model.pth"),
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json", "config.json"),
    ("https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/speakers_xtts.pth", "speakers_xtts.pth"),
]

print(f"\nDownloading {len(files_to_download)} model files...")

for idx, (url, filename) in enumerate(files_to_download, 1):
    filepath = os.path.join(XTTS_MODEL_DIR, filename)
    
    if os.path.exists(filepath):
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"[{idx}/{len(files_to_download)}] ‚úÖ {filename} ({size_mb:.1f} MB)")
    else:
        print(f"[{idx}/{len(files_to_download)}] üì• Downloading {filename}...")
        try:
            ModelManager._download_model_files([url], XTTS_MODEL_DIR, progress_bar=True)
            size_mb = os.path.getsize(filepath) / (1024 * 1024)
            print(f"       ‚úÖ Done ({size_mb:.1f} MB)")
        except Exception as e:
            print(f"       ‚ö†Ô∏è  Error: {str(e)}")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 5 COMPLETE")
print("="*80)

## Cell 6: Prepare Dataset with Column Mapping (CORRECTED)

In [None]:
import os
import pandas as pd
import shutil

print("\n" + "="*80)
print("PHASE 6: PREPARE DATASET WITH CORRECT COLUMN MAPPING")
print("="*80)

print(f"\nDataset format: audio_file_path | transcript | speaker_id")
print(f"Output format: audio_file | text | speaker_name")

# Read CSVs
print(f"\n[1/3] Reading CSV files...")
df_train = pd.read_csv(train_csv, sep=',')
df_eval = pd.read_csv(eval_csv, sep=',')

print(f"   Train: {len(df_train)} samples")
print(f"   Eval: {len(df_eval)} samples")

# Create output directories
print(f"\n[2/3] Creating output structure...")
wavs_dir = os.path.join(DATASET_PROCESSED, "wavs")
os.makedirs(wavs_dir, exist_ok=True)

# Copy audio files and create metadata
print(f"\n[3/3] Processing data...")

def process_dataset(df, split_name):
    """Process dataset with column mapping"""
    metadata = []
    audio_source_dir = os.path.join(DATASET_INPUT, "wavs")
    
    if not os.path.exists(audio_source_dir):
        print(f"   ‚ö†Ô∏è  Audio directory not found: {audio_source_dir}")
        print(f"   Looking in dataset input...")
        # Try to find wavs directory
        for root, dirs, files in os.walk(DATASET_INPUT):
            if 'wavs' in dirs:
                audio_source_dir = os.path.join(root, 'wavs')
                print(f"   Found at: {audio_source_dir}")
                break
    
    processed = 0
    skipped = 0
    
    for idx, row in df.iterrows():
        try:
            # Map columns
            audio_file = str(row['audio_file_path']).strip()
            text = str(row['transcript']).strip()
            speaker = str(row['speaker_id']).strip()
            
            # Source audio path
            src_audio = os.path.join(audio_source_dir, audio_file)
            
            if os.path.exists(src_audio):
                # Copy to output
                dst_audio = os.path.join(wavs_dir, os.path.basename(audio_file))
                if not os.path.exists(dst_audio):
                    shutil.copy(src_audio, dst_audio)
                
                # Create metadata entry
                metadata.append({
                    'audio_file': os.path.join('wavs', os.path.basename(audio_file)),
                    'text': text,
                    'speaker': speaker
                })
                processed += 1
            else:
                skipped += 1
        except Exception as e:
            print(f"      Error processing row {idx}: {str(e)}")
            skipped += 1
    
    print(f"   {split_name}: {processed} processed, {skipped} skipped")
    return metadata

# Process train and eval
train_metadata = process_dataset(df_train, "Train")
eval_metadata = process_dataset(df_eval, "Eval")

# Save metadata files
print(f"\n   Saving metadata...")
train_meta_path = os.path.join(DATASET_PROCESSED, "metadata_train.csv")
eval_meta_path = os.path.join(DATASET_PROCESSED, "metadata_eval.csv")

df_train_out = pd.DataFrame(train_metadata)
df_eval_out = pd.DataFrame(eval_metadata)

# Save in pipe-delimited format (expected by training scripts)
df_train_out.to_csv(train_meta_path, sep='|', header=False, index=False)
df_eval_out.to_csv(eval_meta_path, sep='|', header=False, index=False)

print(f"   ‚úÖ metadata_train.csv: {len(train_metadata)} rows")
print(f"   ‚úÖ metadata_eval.csv: {len(eval_metadata)} rows")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 6 COMPLETE")
print("="*80)

## Cell 7: Extend Vocabulary (ByteLevel BPE)

In [None]:
import subprocess
import sys

print("\n" + "="*80)
print("PHASE 7: EXTEND VOCABULARY FOR SINHALA")
print("="*80)

vocab_script = os.path.join(repo_path, "extend_vocab_sinhala.py")
train_metadata_path = train_meta_path

print(f"\nüìù Extending vocabulary...")
print(f"   Metadata: {train_metadata_path}")
print(f"   Output: {XTTS_MODEL_DIR}")
print(f"   Vocab size: {VOCAB_SIZE:,} tokens")
print(f"   Method: ByteLevel BPE")

cmd = [
    sys.executable,
    vocab_script,
    "--metadata_path", train_metadata_path,
    "--output_path", XTTS_MODEL_DIR,
    "--language", LANGUAGE_CODE,
    "--vocab_size", str(VOCAB_SIZE)
]

print(f"\n[Running] Tokenization Training...\n")

try:
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)
    print(result.stdout)
    
    if result.returncode != 0:
        print(f"‚ùå Error:")
        print(result.stderr)
        raise RuntimeError("Vocabulary extension failed")
    else:
        print(f"‚úÖ Vocabulary extension SUCCESSFUL")
        
except subprocess.TimeoutExpired:
    print(f"‚ùå Timeout - vocabulary extension took too long")
    raise

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 7 COMPLETE")
print("="*80)

## Cell 8: GPT Fine-tuning (4-8 HOURS)

In [None]:
import subprocess
import sys

print("\n" + "="*80)
print("PHASE 8: GPT FINE-TUNING (MAIN TRAINING - 4-8 HOURS)")
print("="*80)

train_script = os.path.join(repo_path, "train_gpt_xtts.py")

print(f"\nüöÄ Starting GPT Fine-tuning...")
print(f"   Train data: {train_meta_path}")
print(f"   Eval data: {eval_meta_path}")
print(f"   Language: {LANGUAGE_CODE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"\n‚è±Ô∏è  This will take 4-8 hours...")
print(f"\n" + "="*80 + "\n")

# Build metadata string
metadata_string = f"{train_meta_path},{eval_meta_path},{LANGUAGE_CODE}"

cmd = [
    sys.executable,
    train_script,
    "--output_path", CHECKPOINTS_DIR,
    "--metadatas", metadata_string,
    "--num_epochs", str(NUM_EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--grad_acumm", str(GRAD_ACCUM),
    "--max_text_length", "400",
    "--max_audio_length", "330750",
    "--lr", str(LEARNING_RATE),
    "--weight_decay", "1e-2",
    "--save_step", str(SAVE_STEP)
]

try:
    result = subprocess.run(cmd, text=True)
    
    if result.returncode == 0:
        print(f"\n" + "="*80)
        print(f"‚úÖ TRAINING COMPLETED SUCCESSFULLY")
        print("="*80)
    else:
        print(f"\n‚ùå Training failed")
        
except KeyboardInterrupt:
    print(f"\n‚ö†Ô∏è  Training interrupted")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 8 COMPLETE")
print("="*80)

## Cell 9: Generate Sample Speech

In [None]:
import torch
import torchaudio
from TTS.tts.models.xtts import Xtts
from TTS.tts.configs.xtts_config import XttsConfig
import os

print("\n" + "="*80)
print("PHASE 9: GENERATE SAMPLE SPEECH")
print("="*80)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nüì± Using device: {device}")

# Find checkpoint
training_dir = os.path.join(CHECKPOINTS_DIR, "run", "training")
config_path = os.path.join(XTTS_MODEL_DIR, "config.json")
vocab_path = os.path.join(XTTS_MODEL_DIR, "vocab.json")

print(f"\nüîß Loading model...")

if os.path.exists(training_dir):
    checkpoints = [f for f in os.listdir(training_dir) if f.endswith('.pth')]
    
    if checkpoints and os.path.exists(config_path):
        checkpoints.sort()
        checkpoint_path = os.path.join(training_dir, checkpoints[-1])
        print(f"‚úÖ Found checkpoint: {checkpoints[-1]}")
        
        try:
            config = XttsConfig()
            config.load_json(config_path)
            
            model = Xtts.init_from_config(config)
            model.load_checkpoint(
                config,
                checkpoint_path=checkpoint_path,
                vocab_path=vocab_path,
                use_deepspeed=False
            )
            model.to(device)
            
            print(f"‚úÖ Model loaded successfully!")
            
            # Find reference audio
            wavs_dir = os.path.join(DATASET_PROCESSED, "wavs")
            audio_files = [f for f in os.listdir(wavs_dir) if f.endswith('.wav')] if os.path.exists(wavs_dir) else []
            
            if audio_files:
                reference_audio = os.path.join(wavs_dir, audio_files[0])
                print(f"üé§ Reference audio: {audio_files[0]}")
                
                # Test text
                test_text = "‡∂±‡∑í‡∂ª‡∂±‡∑ä‡∂≠‡∂ª‡∂∫‡∑í ‡∂â‡∂≠‡∑è ‡∑Ä‡∑ê‡∂Ø‡∂ú‡∂≠‡∑ä"
                print(f"\nüìù Test text: {test_text}")
                
                try:
                    print(f"\nüéµ Generating speech...")
                    
                    # Get speaker embedding
                    gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
                        audio_path=reference_audio,
                        gpt_cond_len=model.config.gpt_cond_len,
                        max_ref_length=model.config.max_ref_len,
                        sound_norm_refs=model.config.sound_norm_refs,
                    )
                    
                    # Generate speech
                    wav = model.inference(
                        text=test_text,
                        language="si",
                        gpt_cond_latent=gpt_cond_latent,
                        speaker_embedding=speaker_embedding,
                        temperature=0.7,
                        length_penalty=1.0,
                        repetition_penalty=5.0,
                        top_k=50,
                        top_p=0.85,
                    )
                    
                    # Save audio
                    output_file = os.path.join(OUTPUT_DIR, "sinhala_sample.wav")
                    torchaudio.save(
                        output_file,
                        torch.tensor(wav["wav"]).unsqueeze(0),
                        24000
                    )
                    
                    duration_sec = len(wav['wav']) / 24000
                    print(f"‚úÖ Speech generated!")
                    print(f"   File: {output_file}")
                    print(f"   Duration: {duration_sec:.2f} seconds")
                    
                except Exception as e:
                    print(f"‚ùå Generation error: {str(e)}")
            else:
                print(f"‚ö†Ô∏è No audio files found for reference")
        
        except Exception as e:
            print(f"‚ùå Error loading model: {str(e)}")
    else:
        print(f"‚ö†Ô∏è Checkpoint or config not found")
else:
    print(f"‚ö†Ô∏è Training directory not found")
    print(f"   Make sure Phase 8 completed successfully")

print(f"\n" + "="*80)
print(f"‚úÖ PHASE 9 COMPLETE")
print("="*80)

## Cell 10: Summary

In [None]:
print("\n\n" + "#"*80)
print("#" + " "*78 + "#")
print("#" + " "*15 + "‚úÖ SINHALA XTTS-v2 FINE-TUNING COMPLETE!" + " "*26 + "#")
print("#" + " "*78 + "#")
print("#"*80)

print("\n‚úÖ COMPLETED PHASES:")
phases = [
    "Environment Setup",
    "Clone Repository",
    "Setup Configuration",
    "Validate Dataset",
    "Download XTTS-v2 Model",
    "Prepare Dataset with Column Mapping",
    "Extend Vocabulary (15,000 Sinhala tokens)",
    "Fine-tune GPT",
    "Generate Sample Speech"
]

for i, phase in enumerate(phases, 1):
    print(f"   {i}. ‚úÖ {phase}")

print("\nüìä MODEL SPECIFICATIONS:")
print(f"   Language: Sinhala (‡∑É‡∑í‡∂Ç‡∑Ñ‡∂Ω)")
print(f"   Language Code: si")
print(f"   Tokenization: ByteLevel BPE")
print(f"   Vocabulary: 15,000 tokens")
print(f"   Training Epochs: {NUM_EPOCHS}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Learning Rate: {LEARNING_RATE}")

print("\nüìÅ OUTPUT FILES:")
print(f"   Model Checkpoint: {os.path.join(CHECKPOINTS_DIR, 'run/training/')}")
print(f"   Vocabulary: {os.path.join(XTTS_MODEL_DIR, 'vocab.json')}")
print(f"   Config: {os.path.join(XTTS_MODEL_DIR, 'config.json')}")
print(f"   Sample Output: {os.path.join(OUTPUT_DIR, 'sinhala_sample.wav')}")

print("\nüéâ KEY FIXES IN THIS NOTEBOOK:")
print(f"   ‚úÖ Corrected CSV column mapping")
print(f"      From: audio_file_path | transcript | speaker_id")
print(f"      To: audio_file | text | speaker_name")
print(f"   ‚úÖ Automatic audio file discovery")
print(f"   ‚úÖ Error handling for missing files")
print(f"   ‚úÖ Proper metadata formatting")

print("\n" + "#"*80)
print("#" + " "*78 + "#")
print("#" + " "*18 + "üéµ Your Sinhala TTS Model is Ready! üéµ" + " "*22 + "#")
print("#" + " "*78 + "#")
print("#"*80 + "\n")