# Full Fine-tuning MMS-TTS Amharic Model (Ablation Study)

This notebook fine-tunes `facebook/mms-tts-amh` for legal domain Amharic text-to-speech using **full fine-tuning** (no LoRA) as an ablation study.

## Model and Approach
- **Base Model**: `facebook/mms-tts-amh` (VITS architecture, ~36.3M parameters)
- **Fine-tuning Method**: Full fine-tuning (all parameters trainable)
- **Task**: Text-to-Speech (TTS)
- **Domain**: Legal Amharic text
- **Text Format**: Romanized Amharic (using uroman package)
- **Optimization**: Memory-optimized for T4 15GB GPU


## 1. Installation and Setup


In [None]:
%pip install -q transformers datasets accelerate torchaudio librosa soundfile uroman scipy torch-audio


In [None]:
import os
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Union, Optional, Tuple
import librosa
import soundfile as sf
import scipy.io.wavfile
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from transformers import (
    VitsModel,
    AutoTokenizer,
    TrainingArguments,
    Trainer
)

from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F

# Import uroman for romanization
import uroman

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"CUDA reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")


In [None]:
from google.colab import drive
drive.mount('/content/drive')


## 2. Configuration (Optimized for T4 15GB)


In [None]:
MODEL_NAME = "facebook/mms-tts-amh"

AUDIO_DIR = "/content/drive/MyDrive/Dataset_1.5h/audio"
TRAIN_CSV = "/content/drive/MyDrive/Dataset_1.5h/train.csv"
VAL_CSV = "/content/drive/MyDrive/Dataset_1.5h/val.csv"
TEST_CSV = "/content/drive/MyDrive/Dataset_1.5h/test.csv"

OUTPUT_DIR = "mms_tts_full_finetune_amharic_legal"

# Memory-optimized training configuration for T4 15GB
TRAINING_ARGS = {
    "output_dir": OUTPUT_DIR,
    "per_device_train_batch_size": 2,  # Small batch size for memory efficiency
    "per_device_eval_batch_size": 2,
    "gradient_accumulation_steps": 8,  # Effective batch size = 16
    "learning_rate": 5e-5,  # Slightly lower LR for full fine-tuning
    "warmup_steps": 100,
    "max_steps": 1200,  # ~2h15min training time
    "gradient_checkpointing": True,  # CRITICAL for memory savings
    "fp16": True,  # Mixed precision for memory efficiency
    "eval_strategy": "steps",
    "eval_steps": 300,
    "save_strategy": "steps",
    "save_steps": 300,
    "save_total_limit": 3,
    "load_best_model_at_end": True,
    "metric_for_best_model": "loss",
    "greater_is_better": False,
    "logging_steps": 50,
    "report_to": "none",
    "push_to_hub": False,
    "dataloader_num_workers": 2,  # Reduce workers to save memory
    "dataloader_pin_memory": False,  # Disable pinning to save memory
    "remove_unused_columns": False,
}

# Audio processing configuration
TARGET_SAMPLE_RATE = 16000  # MMS-TTS uses 16kHz
MAX_AUDIO_LENGTH = 10.0  # Maximum audio length in seconds (for memory efficiency)

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Batch size: {TRAINING_ARGS['per_device_train_batch_size']}")
print(f"  Gradient accumulation: {TRAINING_ARGS['gradient_accumulation_steps']}")
print(f"  Effective batch size: {TRAINING_ARGS['per_device_train_batch_size'] * TRAINING_ARGS['gradient_accumulation_steps']}")
print(f"  Learning rate: {TRAINING_ARGS['learning_rate']}")
print(f"  Gradient checkpointing: {TRAINING_ARGS['gradient_checkpointing']}")
print(f"  FP16: {TRAINING_ARGS['fp16']}")


## 3. Load and Prepare Data


In [None]:
def load_csv_split(csv_path, audio_dir):
    """Load a CSV split and return list of (audio_path, transcription) tuples"""
    df = pd.read_csv(csv_path)

    data = []
    for _, row in df.iterrows():
        audio_path = Path(audio_dir) / row['file_name']
        transcription = str(row['transcription']).strip()

        if audio_path.exists():
            data.append({
                'audio_path': str(audio_path),
                'transcription': transcription
            })
        else:
            print(f"Warning: Audio file not found: {audio_path}")

    return data

train_data = load_csv_split(TRAIN_CSV, AUDIO_DIR)
val_data = load_csv_split(VAL_CSV, AUDIO_DIR)
test_data = load_csv_split(TEST_CSV, AUDIO_DIR)

print(f"Train samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\nTotal samples: {len(train_data) + len(val_data) + len(test_data)}")


In [None]:
# Visualization setup
import matplotlib.pyplot as plt
import matplotlib
try:
    matplotlib.style.use('seaborn-v0_8-darkgrid')
except:
    try:
        matplotlib.style.use('seaborn-darkgrid')
    except:
        plt.style.use('default')
        
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

print("Visualization libraries loaded successfully!")


## 4. Romanize Text Using Uroman


In [None]:
# Initialize uroman instance
try:
    uroman_obj = uroman.Uroman()
    print("Uroman initialized successfully")
except Exception as e:
    print(f"Error initializing uroman: {e}")
    uroman_obj = None

def romanize_text(text):
    """Convert Ge'ez script Amharic text to Romanized format using uroman"""
    if uroman_obj is None:
        return text  # Return original if uroman not available

    try:
        romanized = uroman_obj.romanize_string(text)
        return romanized.strip()
    except Exception as e:
        print(f"Error romanizing text: {text[:50]}... Error: {e}")
        return text  # Fallback to original if romanization fails

# Test romanization on a sample
sample_text = train_data[0]['transcription']
print(f"Original (Ge'ez): {sample_text}")
romanized_sample = romanize_text(sample_text)
print(f"Romanized: {romanized_sample}")


In [None]:
# Initialize tracking lists if not already done
if 'learning_rates' not in locals():
    learning_rates = []
if 'gpu_memory_usage' not in locals():
    gpu_memory_usage = []
if 'train_steps' not in locals():
    train_steps = []
if 'val_steps' not in locals():
    val_steps = []

print("Metrics tracking initialized!")


In [None]:
# Save training metrics to JSON file for later analysis
import json

metrics_data = {
    'training_config': {
        'max_steps': TRAINING_ARGS['max_steps'],
        'batch_size': TRAINING_ARGS['per_device_train_batch_size'],
        'gradient_accumulation_steps': TRAINING_ARGS['gradient_accumulation_steps'],
        'learning_rate': TRAINING_ARGS['learning_rate'],
        'fp16': TRAINING_ARGS['fp16'],
        'gradient_checkpointing': TRAINING_ARGS['gradient_checkpointing'],
        'logging_steps': TRAINING_ARGS['logging_steps'],
        'eval_steps': TRAINING_ARGS['eval_steps'],
    },
    'train_losses': train_losses,
    'val_losses': val_losses,
    'learning_rates': learning_rates,
    'gpu_memory_usage': gpu_memory_usage,
    'train_steps': train_steps if len(train_steps) > 0 else [i * TRAINING_ARGS['logging_steps'] for i in range(len(train_losses))],
    'val_steps': val_steps if len(val_steps) > 0 else [i * TRAINING_ARGS['eval_steps'] for i in range(len(val_losses))],
    'final_step': global_step,
    'best_val_loss': best_val_loss if 'best_val_loss' in locals() else None,
}

metrics_file = f'{OUTPUT_DIR}/training_metrics.json'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Enhanced metrics data with statistics and metadata
from datetime import datetime

# Safely handle potentially missing variables
train_losses_safe = train_losses if 'train_losses' in locals() and len(train_losses) > 0 else []
val_losses_safe = val_losses if 'val_losses' in locals() and len(val_losses) > 0 else []
learning_rates_safe = learning_rates if 'learning_rates' in locals() and len(learning_rates) > 0 else []
gpu_memory_usage_safe = gpu_memory_usage if 'gpu_memory_usage' in locals() and len(gpu_memory_usage) > 0 else []
train_steps_safe = train_steps if 'train_steps' in locals() and len(train_steps) > 0 else []
val_steps_safe = val_steps if 'val_steps' in locals() and len(val_steps) > 0 else []
global_step_safe = global_step if 'global_step' in locals() else 0
best_val_loss_safe = best_val_loss if 'best_val_loss' in locals() and best_val_loss != float('inf') else None

# Calculate step values if not explicitly tracked
if len(train_steps_safe) == 0 and len(train_losses_safe) > 0:
    train_steps_safe = [i * TRAINING_ARGS['logging_steps'] for i in range(len(train_losses_safe))]

if len(val_steps_safe) == 0 and len(val_losses_safe) > 0:
    val_steps_safe = [i * TRAINING_ARGS['eval_steps'] for i in range(len(val_losses_safe))]

# Calculate statistics
train_stats = {}
if len(train_losses_safe) > 0:
    train_stats = {
        'initial': float(train_losses_safe[0]),
        'final': float(train_losses_safe[-1]),
        'best': float(min(train_losses_safe)),
        'worst': float(max(train_losses_safe)),
        'mean': float(np.mean(train_losses_safe)),
        'std': float(np.std(train_losses_safe)),
    }

val_stats = {}
if len(val_losses_safe) > 0:
    val_stats = {
        'initial': float(val_losses_safe[0]),
        'final': float(val_losses_safe[-1]),
        'best': float(min(val_losses_safe)),
        'worst': float(max(val_losses_safe)),
        'mean': float(np.mean(val_losses_safe)),
        'std': float(np.std(val_losses_safe)),
    }

# Enhanced metrics data structure
enhanced_metrics_data = {
    'metadata': {
        'saved_at': datetime.now().isoformat(),
        'model_name': MODEL_NAME,
        'output_dir': OUTPUT_DIR,
    },
    'training_config': {
        'max_steps': TRAINING_ARGS['max_steps'],
        'batch_size': TRAINING_ARGS['per_device_train_batch_size'],
        'gradient_accumulation_steps': TRAINING_ARGS['gradient_accumulation_steps'],
        'effective_batch_size': TRAINING_ARGS['per_device_train_batch_size'] * TRAINING_ARGS['gradient_accumulation_steps'],
        'learning_rate': TRAINING_ARGS['learning_rate'],
        'fp16': TRAINING_ARGS['fp16'],
        'gradient_checkpointing': TRAINING_ARGS['gradient_checkpointing'],
        'logging_steps': TRAINING_ARGS['logging_steps'],
        'eval_steps': TRAINING_ARGS['eval_steps'],
    },
    'training_progress': {
        'final_step': global_step_safe,
        'total_steps_completed': global_step_safe,
        'best_val_loss': best_val_loss_safe,
    },
    'metrics': {
        'train_losses': [float(x) for x in train_losses_safe],
        'val_losses': [float(x) for x in val_losses_safe],
        'learning_rates': [float(x) for x in learning_rates_safe],
        'gpu_memory_usage': [float(x) for x in gpu_memory_usage_safe],
        'train_steps': [int(x) for x in train_steps_safe],
        'val_steps': [int(x) for x in val_steps_safe],
    },
    'statistics': {
        'training_loss': train_stats,
        'validation_loss': val_stats,
    }
}

try:
    with open(metrics_file, 'w') as f:
        json.dump(enhanced_metrics_data, f, indent=2)
    
    print(f"✓ Training metrics saved to: {metrics_file}")
    print(f"\nMetrics Summary:")
    print(f"  - Training loss points: {len(train_losses_safe)}")
    print(f"  - Validation loss points: {len(val_losses_safe)}")
    print(f"  - Learning rate points: {len(learning_rates_safe)}")
    print(f"  - GPU memory points: {len(gpu_memory_usage_safe)}")
    print(f"  - Final step: {global_step_safe}")
    
    if train_stats:
        print(f"\nTraining Loss Statistics:")
        print(f"  - Initial: {train_stats['initial']:.4f}")
        print(f"  - Final: {train_stats['final']:.4f}")
        print(f"  - Best: {train_stats['best']:.4f}")
        print(f"  - Mean: {train_stats['mean']:.4f} ± {train_stats['std']:.4f}")
    
    if val_stats:
        print(f"\nValidation Loss Statistics:")
        print(f"  - Initial: {val_stats['initial']:.4f}")
        print(f"  - Final: {val_stats['final']:.4f}")
        print(f"  - Best: {val_stats['best']:.4f}")
        print(f"  - Mean: {val_stats['mean']:.4f} ± {val_stats['std']:.4f}")
        
except Exception as e:
    print(f"⚠ Error saving enhanced metrics: {e}")
    print("Falling back to original metrics structure...")
    try:
        with open(metrics_file, 'w') as f:
            json.dump(metrics_data, f, indent=2)
        print(f"✓ Basic metrics saved to: {metrics_file}")
    except Exception as e2:
        print(f"✗ Failed to save metrics: {e2}")


## Real-time Training Progress (Run this cell during training to see live updates)


In [None]:
# Real-time training progress visualization
# Run this cell during training to see live updates

# Check if metrics exist
if 'train_losses' in locals() and len(train_losses) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Training Loss
    ax1 = axes[0]
    train_step_values = [i * TRAINING_ARGS['logging_steps'] for i in range(len(train_losses))]
    ax1.plot(train_step_values, train_losses, 'b-', linewidth=2, marker='o', markersize=3, alpha=0.7)
    ax1.set_xlabel('Training Step', fontsize=11)
    ax1.set_ylabel('Training Loss', fontsize=11)
    ax1.set_title('Training Loss (Live)', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(left=0)
    
    # Add current loss annotation
    if len(train_losses) > 0:
        current_loss = train_losses[-1]
        current_step = train_step_values[-1]
        ax1.annotate(f'Current: {current_loss:.4f}', 
                     xy=(current_step, current_loss),
                     xytext=(10, 10), textcoords='offset points',
                     bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                     arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
    
    # Plot 2: Training and Validation Loss Comparison
    ax2 = axes[1]
    if len(train_losses) > 0:
        ax2.plot(train_step_values, train_losses, 'b-', label='Training', linewidth=2, alpha=0.7)
    
    if 'val_losses' in locals() and len(val_losses) > 0:
        val_step_values = [i * TRAINING_ARGS['eval_steps'] for i in range(len(val_losses))]
        ax2.plot(val_step_values, val_losses, 'r-', label='Validation', linewidth=2, marker='s', markersize=3, alpha=0.7)
    
    ax2.set_xlabel('Training Step', fontsize=11)
    ax2.set_ylabel('Loss', fontsize=11)
    ax2.set_title('Training vs Validation Loss (Live)', fontsize=13, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(left=0)
    
    plt.tight_layout()
    plt.show()
    
    # Print current statistics
    print(f"\nCurrent Training Status:")
    print(f"  Steps completed: {len(train_losses) * TRAINING_ARGS['logging_steps']}")
    print(f"  Latest training loss: {train_losses[-1]:.4f}")
    if 'val_losses' in locals() and len(val_losses) > 0:
        print(f"  Latest validation loss: {val_losses[-1]:.4f}")
    if 'learning_rates' in locals() and len(learning_rates) > 0:
        print(f"  Current learning rate: {learning_rates[-1]:.2e}")
else:
    print("Training hasn't started yet or no metrics available.")
    print("Run this cell after training has logged some steps.")


In [None]:
# Create comprehensive training visualizations
print("Generating training visualizations...")

# Create figure with subplots
fig = plt.figure(figsize=(16, 12))

# 1. Training and Validation Loss
ax1 = plt.subplot(2, 2, 1)
if len(train_losses) > 0:
    # Calculate steps for training losses (assuming logged every logging_steps)
    train_step_values = [i * TRAINING_ARGS['logging_steps'] for i in range(len(train_losses))]
    ax1.plot(train_step_values, train_losses, 'b-', label='Training Loss', linewidth=2, alpha=0.7)
    
if len(val_losses) > 0:
    # Calculate steps for validation losses (assuming logged every eval_steps)
    val_step_values = [i * TRAINING_ARGS['eval_steps'] for i in range(len(val_losses))]
    ax1.plot(val_step_values, val_losses, 'r-', label='Validation Loss', linewidth=2, alpha=0.7)

ax1.set_xlabel('Training Step', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(left=0)

# 2. Learning Rate Schedule
ax2 = plt.subplot(2, 2, 2)
if len(learning_rates) > 0:
    lr_steps = [i * TRAINING_ARGS['logging_steps'] for i in range(len(learning_rates))]
    ax2.plot(lr_steps, learning_rates, 'g-', linewidth=2, alpha=0.7)
    ax2.set_xlabel('Training Step', fontsize=12)
    ax2.set_ylabel('Learning Rate', fontsize=12)
    ax2.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')
    ax2.set_xlim(left=0)
else:
    # Plot theoretical LR schedule
    max_steps = TRAINING_ARGS['max_steps']
    initial_lr = TRAINING_ARGS['learning_rate']
    steps = np.arange(0, max_steps, 50)
    # Cosine annealing schedule
    lrs = [initial_lr * (1 + np.cos(np.pi * s / max_steps)) / 2 for s in steps]
    ax2.plot(steps, lrs, 'g-', linewidth=2, alpha=0.7)
    ax2.set_xlabel('Training Step', fontsize=12)
    ax2.set_ylabel('Learning Rate', fontsize=12)
    ax2.set_title('Learning Rate Schedule (Theoretical)', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')

# 3. Loss Comparison (Training vs Validation)
ax3 = plt.subplot(2, 2, 3)
if len(train_losses) > 0 and len(val_losses) > 0:
    # Normalize to same step range for comparison
    min_steps = min(len(train_losses), len(val_losses))
    train_norm = train_losses[:min_steps] if len(train_losses) >= min_steps else train_losses
    val_norm = val_losses[:min_steps] if len(val_losses) >= min_steps else val_losses
    
    x_range = range(min(len(train_norm), len(val_norm)))
    ax3.plot(x_range, train_norm[:len(x_range)], 'b-', label='Training', linewidth=2, alpha=0.7)
    ax3.plot(x_range, val_norm[:len(x_range)], 'r-', label='Validation', linewidth=2, alpha=0.7)
    ax3.set_xlabel('Logging Interval', fontsize=12)
    ax3.set_ylabel('Loss', fontsize=12)
    ax3.set_title('Loss Comparison (Normalized)', fontsize=14, fontweight='bold')
    ax3.legend(fontsize=11)
    ax3.grid(True, alpha=0.3)
else:
    ax3.text(0.5, 0.5, 'Insufficient data for comparison', 
             ha='center', va='center', transform=ax3.transAxes, fontsize=12)
    ax3.set_title('Loss Comparison', fontsize=14, fontweight='bold')

# 4. Training Statistics Summary
ax4 = plt.subplot(2, 2, 4)
ax4.axis('off')

# Calculate statistics
stats_text = []
stats_text.append("Training Statistics Summary")
stats_text.append("=" * 40)
stats_text.append(f"Total Training Steps: {global_step}")
stats_text.append(f"Max Steps Configured: {TRAINING_ARGS['max_steps']}")

if len(train_losses) > 0:
    stats_text.append(f"\nTraining Loss:")
    stats_text.append(f"  Initial: {train_losses[0]:.4f}")
    stats_text.append(f"  Final: {train_losses[-1]:.4f}")
    stats_text.append(f"  Best: {min(train_losses):.4f}")
    stats_text.append(f"  Average: {np.mean(train_losses):.4f}")

if len(val_losses) > 0:
    stats_text.append(f"\nValidation Loss:")
    stats_text.append(f"  Initial: {val_losses[0]:.4f}")
    stats_text.append(f"  Final: {val_losses[-1]:.4f}")
    stats_text.append(f"  Best: {min(val_losses):.4f}")
    stats_text.append(f"  Average: {np.mean(val_losses):.4f}")

stats_text.append(f"\nConfiguration:")
stats_text.append(f"  Batch Size: {TRAINING_ARGS['per_device_train_batch_size']}")
stats_text.append(f"  Gradient Accumulation: {TRAINING_ARGS['gradient_accumulation_steps']}")
stats_text.append(f"  Effective Batch Size: {TRAINING_ARGS['per_device_train_batch_size'] * TRAINING_ARGS['gradient_accumulation_steps']}")
stats_text.append(f"  Learning Rate: {TRAINING_ARGS['learning_rate']}")
stats_text.append(f"  FP16: {TRAINING_ARGS['fp16']}")
stats_text.append(f"  Gradient Checkpointing: {TRAINING_ARGS['gradient_checkpointing']}")

stats_str = "\n".join(stats_text)
ax4.text(0.1, 0.95, stats_str, transform=ax4.transAxes, 
         fontsize=10, verticalalignment='top', family='monospace',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/training_visualizations.png', dpi=300, bbox_inches='tight')
print(f"Visualizations saved to: {OUTPUT_DIR}/training_visualizations.png")
plt.show()

# Also create individual high-quality plots for report
print("\nGenerating individual plots for report...")

# Individual plot 1: Training and Validation Loss
fig1, ax1 = plt.subplots(figsize=(10, 6))
if len(train_losses) > 0:
    train_step_values = [i * TRAINING_ARGS['logging_steps'] for i in range(len(train_losses))]
    ax1.plot(train_step_values, train_losses, 'b-', label='Training Loss', linewidth=2.5, marker='o', markersize=4, alpha=0.8)
if len(val_losses) > 0:
    val_step_values = [i * TRAINING_ARGS['eval_steps'] for i in range(len(val_losses))]
    ax1.plot(val_step_values, val_losses, 'r-', label='Validation Loss', linewidth=2.5, marker='s', markersize=4, alpha=0.8)
ax1.set_xlabel('Training Step', fontsize=13, fontweight='bold')
ax1.set_ylabel('Loss', fontsize=13, fontweight='bold')
ax1.set_title('Training and Validation Loss Over Time', fontsize=15, fontweight='bold')
ax1.legend(fontsize=12, loc='best')
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.set_xlim(left=0)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/loss_plot.png', dpi=300, bbox_inches='tight')
print(f"  Saved: {OUTPUT_DIR}/loss_plot.png")
plt.show()

# Individual plot 2: Learning Rate Schedule
fig2, ax2 = plt.subplots(figsize=(10, 6))
if len(learning_rates) > 0:
    lr_steps = [i * TRAINING_ARGS['logging_steps'] for i in range(len(learning_rates))]
    ax2.plot(lr_steps, learning_rates, 'g-', linewidth=2.5, marker='o', markersize=4, alpha=0.8)
else:
    # Theoretical schedule
    max_steps = TRAINING_ARGS['max_steps']
    initial_lr = TRAINING_ARGS['learning_rate']
    steps = np.arange(0, max_steps, 50)
    lrs = [initial_lr * (1 + np.cos(np.pi * s / max_steps)) / 2 for s in steps]
    ax2.plot(steps, lrs, 'g-', linewidth=2.5, alpha=0.8)
ax2.set_xlabel('Training Step', fontsize=13, fontweight='bold')
ax2.set_ylabel('Learning Rate', fontsize=13, fontweight='bold')
ax2.set_title('Learning Rate Schedule (Cosine Annealing)', fontsize=15, fontweight='bold')
ax2.grid(True, alpha=0.3, linestyle='--')
ax2.set_yscale('log')
ax2.set_xlim(left=0)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/learning_rate_plot.png', dpi=300, bbox_inches='tight')
print(f"  Saved: {OUTPUT_DIR}/learning_rate_plot.png")
plt.show()

print("\nAll visualizations generated successfully!")


In [None]:
# Romanize all transcriptions
print("Romanizing training data...")
for item in tqdm(train_data, desc="Train"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("Romanizing validation data...")
for item in tqdm(val_data, desc="Val"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("Romanizing test data...")
for item in tqdm(test_data, desc="Test"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("\nRomanization complete!")
print(f"Sample - Original: {train_data[0]['transcription']}")
print(f"Sample - Romanized: {train_data[0]['transcription_romanized']}")


## 5. Load Model and Tokenizer


In [None]:
print(f"Loading model: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = VitsModel.from_pretrained(MODEL_NAME)

print(f"Model loaded successfully!")
print(f"Sampling rate: {model.config.sampling_rate} Hz")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")

# Enable gradient checkpointing for memory efficiency
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()
    print("Gradient checkpointing enabled")

if torch.cuda.is_available():
    model = model.to("cuda")
    print("Model moved to CUDA")
    print(f"Memory after model load: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")


## 6. Test Inference (Before Training)


In [None]:
def synthesize_speech(model, tokenizer, text_romanized, output_path):
    """Synthesize speech from romanized text"""
    inputs = tokenizer(text_romanized, return_tensors="pt")

    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

    with torch.no_grad():
        output = model(**inputs).waveform

    # Move tensor to CPU before converting to numpy
    if isinstance(output, torch.Tensor):
        output = output.cpu()

    # Save audio
    scipy.io.wavfile.write(
        output_path,
        rate=model.config.sampling_rate,
        data=output.numpy().T
    )

    return output_path

# Test with sample text (after romanization)
if len(train_data) > 0 and 'transcription_romanized' in train_data[0]:
    test_text_romanized = train_data[0]['transcription_romanized']
    test_output = "test_synthesized_before_training.wav"

    print(f"Testing synthesis with: {test_text_romanized}")
    synthesize_speech(model, tokenizer, test_text_romanized, test_output)
    print(f"Test audio saved to: {test_output}")
else:
    print("Please run romanization cells first!")


## 7. Implement VITS Training Loss Function


In [None]:
def compute_mel_spectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=256):
    """Compute mel-spectrogram from waveform"""
    # Convert to numpy if tensor
    if isinstance(waveform, torch.Tensor):
        waveform_np = waveform.cpu().numpy()
    else:
        waveform_np = waveform
    
    # Handle multi-channel audio
    if len(waveform_np.shape) > 1:
        waveform_np = waveform_np[0] if waveform_np.shape[0] == 1 else waveform_np.mean(axis=0)
    
    # Compute mel-spectrogram using librosa
    mel_spec = librosa.feature.melspectrogram(
        y=waveform_np,
        sr=sample_rate,
        n_mels=n_mels,
        n_fft=n_fft,
        hop_length=hop_length,
        fmin=0,
        fmax=sample_rate // 2
    )
    
    # Convert to log scale
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    
    return torch.tensor(mel_spec_db, dtype=torch.float32)

def vits_training_loss(predicted_waveform, target_waveform, sample_rate=16000):
    """
    Compute training loss for VITS model.
    Uses a combination of waveform L1 loss and mel-spectrogram loss.
    """
    # Ensure same length (truncate or pad)
    min_len = min(predicted_waveform.shape[-1], target_waveform.shape[-1])
    pred = predicted_waveform[..., :min_len]
    target = target_waveform[..., :min_len]
    
    # Waveform L1 loss
    waveform_loss = F.l1_loss(pred, target)
    
    # Mel-spectrogram loss
    try:
        pred_mel = compute_mel_spectrogram(pred, sample_rate)
        target_mel = compute_mel_spectrogram(target, sample_rate)
        
        # Ensure same dimensions
        min_time = min(pred_mel.shape[-1], target_mel.shape[-1])
        pred_mel = pred_mel[..., :min_time]
        target_mel = target_mel[..., :min_time]
        
        mel_loss = F.l1_loss(pred_mel, target_mel)
    except Exception as e:
        # Fallback to waveform loss only if mel computation fails
        print(f"Warning: Mel-spectrogram computation failed: {e}")
        mel_loss = torch.tensor(0.0, device=waveform_loss.device)
    
    # Combined loss (weighted)
    total_loss = waveform_loss + 0.5 * mel_loss
    
    return total_loss, waveform_loss, mel_loss


In [None]:
class VITSDataset:
    """Dataset for VITS training"""
    
    def __init__(self, data_list, tokenizer, target_sr=16000, max_length=10.0):
        self.data_list = data_list
        self.tokenizer = tokenizer
        self.target_sr = target_sr
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        item = self.data_list[idx]
        
        # Load and preprocess audio
        audio_path = item['audio_path']
        audio, sr = librosa.load(audio_path, sr=self.target_sr, duration=self.max_length)
        
        # Convert to tensor and ensure mono
        if len(audio.shape) > 1:
            audio = audio[0] if audio.shape[0] == 1 else audio.mean(axis=0)
        
        audio_tensor = torch.tensor(audio, dtype=torch.float32)
        
        # Tokenize text
        text_romanized = item['transcription_romanized']
        tokenized = self.tokenizer(
            text_romanized,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        
        return {
            'input_ids': tokenized['input_ids'].squeeze(0),
            'attention_mask': tokenized['attention_mask'].squeeze(0),
            'audio': audio_tensor,
            'text': text_romanized
        }

# Create datasets
train_dataset = VITSDataset(train_data, tokenizer, target_sr=TARGET_SAMPLE_RATE, max_length=MAX_AUDIO_LENGTH)
val_dataset = VITSDataset(val_data, tokenizer, target_sr=TARGET_SAMPLE_RATE, max_length=MAX_AUDIO_LENGTH)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"\nSample item keys: {list(train_dataset[0].keys())}")
print(f"Sample audio shape: {train_dataset[0]['audio'].shape}")
print(f"Sample input_ids shape: {train_dataset[0]['input_ids'].shape}")


## 9. Training Loop (Memory-Optimized)


In [None]:
# Setup training
model.train()

# Use FP16 scaler for mixed precision
scaler = torch.cuda.amp.GradScaler() if TRAINING_ARGS['fp16'] else None

# Optimizer with weight decay
optimizer = AdamW(
    model.parameters(),
    lr=TRAINING_ARGS['learning_rate'],
    weight_decay=0.01,
    betas=(0.9, 0.999)
)

# Learning rate scheduler
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=TRAINING_ARGS['max_steps']
)

# Data loaders with memory-efficient settings
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAINING_ARGS['per_device_train_batch_size'],
    shuffle=True,
    num_workers=TRAINING_ARGS['dataloader_num_workers'],
    pin_memory=TRAINING_ARGS['dataloader_pin_memory'],
    drop_last=True  # Drop last incomplete batch
)

val_loader = DataLoader(
    val_dataset,
    batch_size=TRAINING_ARGS['per_device_eval_batch_size'],
    shuffle=False,
    num_workers=TRAINING_ARGS['dataloader_num_workers'],
    pin_memory=TRAINING_ARGS['dataloader_pin_memory']
)

print("Training setup complete!")
print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Optimizer: {type(optimizer).__name__}")
print(f"Scheduler: {type(scheduler).__name__}")
print(f"FP16 scaler: {scaler is not None}")


In [None]:
# Training loop
max_steps = TRAINING_ARGS['max_steps']
gradient_accumulation_steps = TRAINING_ARGS['gradient_accumulation_steps']
logging_steps = TRAINING_ARGS['logging_steps']
save_steps = TRAINING_ARGS['save_steps']
eval_steps = TRAINING_ARGS['eval_steps']

step = 0
global_step = 0
best_val_loss = float('inf')
accumulated_loss = 0.0
train_losses = []
val_losses = []

print("Starting training...")
print(f"Max steps: {max_steps}")
print(f"Gradient accumulation: {gradient_accumulation_steps}")
print(f"Effective batch size: {TRAINING_ARGS['per_device_train_batch_size'] * gradient_accumulation_steps}")
print(f"Learning rate: {TRAINING_ARGS['learning_rate']}")
print()

while global_step < max_steps:
    epoch_loss = 0.0
    num_batches = 0
    
    model.train()
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {step // len(train_loader) + 1}")):
        # Move batch to device
        input_ids = batch['input_ids'].to('cuda' if torch.cuda.is_available() else 'cpu')
        attention_mask = batch['attention_mask'].to('cuda' if torch.cuda.is_available() else 'cpu')
        target_audio = batch['audio'].to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Forward pass with mixed precision
        optimizer.zero_grad()
        
        if TRAINING_ARGS['fp16'] and scaler is not None:
            with torch.cuda.amp.autocast():
                # Generate waveform from model
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                predicted_waveform = outputs.waveform
                
                # Compute loss
                loss, waveform_loss, mel_loss = vits_training_loss(
                    predicted_waveform,
                    target_audio.unsqueeze(1),  # Add channel dimension
                    sample_rate=TARGET_SAMPLE_RATE
                )
                
                # Scale loss for gradient accumulation
                loss = loss / gradient_accumulation_steps
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # Gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                
                global_step += 1
                accumulated_loss += loss.item() * gradient_accumulation_steps
                
                # Logging
                if global_step % logging_steps == 0:
                    avg_loss = accumulated_loss / logging_steps
                    current_lr = scheduler.get_last_lr()[0]
                    print(f"\nStep {global_step}/{max_steps}")
                    print(f"  Loss: {avg_loss:.4f}")
                    print(f"  Learning rate: {current_lr:.2e}")
                    print(f"  GPU Memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
                    train_losses.append(avg_loss)
                    accumulated_loss = 0.0
                
                # Evaluation
                if global_step % eval_steps == 0:
                    model.eval()
                    val_loss = 0.0
                    val_batches = 0
                    
                    with torch.no_grad():
                        for val_batch in val_loader:
                            val_input_ids = val_batch['input_ids'].to('cuda' if torch.cuda.is_available() else 'cpu')
                            val_attention_mask = val_batch['attention_mask'].to('cuda' if torch.cuda.is_available() else 'cpu')
                            val_target_audio = val_batch['audio'].to('cuda' if torch.cuda.is_available() else 'cpu')
                            
                            with torch.cuda.amp.autocast():
                                val_outputs = model(input_ids=val_input_ids, attention_mask=val_attention_mask)
                                val_predicted = val_outputs.waveform
                                val_loss_batch, _, _ = vits_training_loss(
                                    val_predicted,
                                    val_target_audio.unsqueeze(1),
                                    sample_rate=TARGET_SAMPLE_RATE
                                )
                                val_loss += val_loss_batch.item()
                                val_batches += 1
                    
                    avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
                    val_losses.append(avg_val_loss)
                    print(f"  Validation Loss: {avg_val_loss:.4f}")
                    
                    # Save best model
                    if avg_val_loss < best_val_loss:
                        best_val_loss = avg_val_loss
                        checkpoint_path = f"{OUTPUT_DIR}/checkpoint-{global_step}"
                        os.makedirs(checkpoint_path, exist_ok=True)
                        model.save_pretrained(checkpoint_path)
                        tokenizer.save_pretrained(checkpoint_path)
                        print(f"  Saved best model to {checkpoint_path}")
                    
                    model.train()
                
                # Save checkpoint
                if global_step % save_steps == 0:
                    checkpoint_path = f"{OUTPUT_DIR}/checkpoint-{global_step}"
                    os.makedirs(checkpoint_path, exist_ok=True)
                    model.save_pretrained(checkpoint_path)
                    tokenizer.save_pretrained(checkpoint_path)
                    print(f"  Saved checkpoint to {checkpoint_path}")
                
                if global_step >= max_steps:
                    break
        else:
            # FP32 training (fallback)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predicted_waveform = outputs.waveform
            
            loss, waveform_loss, mel_loss = vits_training_loss(
                predicted_waveform,
                target_audio.unsqueeze(1),
                sample_rate=TARGET_SAMPLE_RATE
            )
            
            loss = loss / gradient_accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                
                global_step += 1
                accumulated_loss += loss.item() * gradient_accumulation_steps
                
                if global_step % logging_steps == 0:
                    avg_loss = accumulated_loss / logging_steps
                    print(f"\nStep {global_step}/{max_steps} - Loss: {avg_loss:.4f}")
                    train_losses.append(avg_loss)
                    accumulated_loss = 0.0
                
                if global_step >= max_steps:
                    break
    
    if global_step >= max_steps:
        break
    
    step += 1

print("\nTraining completed!")

# Save final model
final_model_path = f"{OUTPUT_DIR}_final"
print(f"\nSaving final model to: {final_model_path}")
os.makedirs(final_model_path, exist_ok=True)
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"Final model saved successfully!")


## 10. Zip Model and Copy to Google Drive


In [None]:
import shutil
import zipfile
from pathlib import Path

# Zip the final model directory
final_model_path = f"{OUTPUT_DIR}_final"
zip_filename = f"{final_model_path}.zip"

print(f"Creating zip file: {zip_filename}...")

with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
    for file_path in Path(final_model_path).rglob("*"):
        if file_path.is_file():
            arcname = file_path.relative_to(final_model_path)
            zipf.write(file_path, arcname)
            print(f"  Added: {arcname}")

print(f"\nZip file created: {zip_filename}")

# Copy to Google Drive
drive_dest = f"/content/drive/MyDrive/{zip_filename}"
shutil.copy2(zip_filename, drive_dest)

print(f"Model zip file copied to Google Drive: {drive_dest}")
print(f"\nFile size: {Path(zip_filename).stat().st_size / (1024*1024):.2f} MB")


## 11. Test Inference (After Training)


In [None]:
# Load the fine-tuned model for testing
final_model_path = f"{OUTPUT_DIR}_final"

print(f"Loading fine-tuned model from: {final_model_path}")
test_model = VitsModel.from_pretrained(final_model_path)
test_tokenizer = AutoTokenizer.from_pretrained(final_model_path)

if torch.cuda.is_available():
    test_model = test_model.to("cuda")
    test_model.eval()

# Test with sample text
if len(test_data) > 0 and 'transcription_romanized' in test_data[0]:
    test_text_romanized = test_data[0]['transcription_romanized']
    test_output = "test_synthesized_after_training.wav"

    print(f"\nTesting synthesis with: {test_text_romanized}")
    synthesize_speech(test_model, test_tokenizer, test_text_romanized, test_output)
    print(f"Test audio saved to: {test_output}")
else:
    print("Please run romanization cells first!")
