# Fine-tuning MMS-TTS Amharic Model with LoRA

This notebook fine-tunes `facebook/mms-tts-amh` for legal domain Amharic text-to-speech using LoRA (Low-Rank Adaptation) for efficient training on Colab free tier.

## Model and Approach
- **Base Model**: `facebook/mms-tts-amh` (VITS architecture, ~36.3M parameters)
- **Fine-tuning Method**: LoRA (via PEFT library)
- **Task**: Text-to-Speech (TTS)
- **Domain**: Legal Amharic text
- **Text Format**: Romanized Amharic (using uroman package)


## 1. Installation and Setup


In [1]:
%pip install -q transformers datasets accelerate peft torchaudio librosa soundfile uroman scipy


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/930.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━[0m [32m614.4/930.7 kB[0m [31m18.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.7/930.7 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Union, Optional
import librosa
import soundfile as sf
import scipy.io.wavfile
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from transformers import (
    VitsModel,
    AutoTokenizer,
    TrainingArguments,
    Trainer
)

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType
)

from datasets import Dataset, DatasetDict

# Import uroman for romanization
import uroman

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: Tesla T4
CUDA memory: 15.83 GB


In [3]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


## 2. Configuration


In [12]:
MODEL_NAME = "facebook/mms-tts-amh"

AUDIO_DIR = "/content/drive/MyDrive/Dataset_1.5h/audio"
TRAIN_CSV = "/content/drive/MyDrive/Dataset_1.5h/train.csv"
VAL_CSV = "/content/drive/MyDrive/Dataset_1.5h/val.csv"
TEST_CSV = "/content/drive/MyDrive/Dataset_1.5h/test.csv"

OUTPUT_DIR = "mms_tts_lora_amharic_legal"

# LoRA Configuration
# Note: VITS models may have limited LoRA support - we'll adapt if needed
LORA_CONFIG = {
    "r": 8,
    "lora_alpha": 32,
    "target_modules": ["q_proj", "v_proj", "k_proj", "out_proj"],  # May need adjustment for VITS
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": "FEATURE_EXTRACTION"
}

TRAINING_ARGS = {
    "output_dir": OUTPUT_DIR,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "learning_rate": 1e-4,
    "warmup_steps": 100,
    "max_steps": 1200,  # ~2h15min training time
    "gradient_checkpointing": True,
    "fp16": True,
    "eval_strategy": "steps",
    "eval_steps":1500,
    "save_strategy": "steps",
    "save_steps": 1500,
    "save_total_limit": 3,
    "load_best_model_at_end": True,
    "metric_for_best_model": "loss",
    "greater_is_better": False,
    "logging_steps": 50,
    "report_to": "none",
    "push_to_hub": False
}

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  LoRA rank (r): {LORA_CONFIG['r']}")
print(f"  LoRA alpha: {LORA_CONFIG['lora_alpha']}")


Configuration:
  Model: facebook/mms-tts-amh
  Output directory: mms_tts_lora_amharic_legal
  LoRA rank (r): 8
  LoRA alpha: 32


## 3. Load and Prepare Data


In [13]:
def load_csv_split(csv_path, audio_dir):
    """Load a CSV split and return list of (audio_path, transcription) tuples"""
    df = pd.read_csv(csv_path)

    data = []
    for _, row in df.iterrows():
        audio_path = Path(audio_dir) / row['file_name']
        transcription = str(row['transcription']).strip()

        if audio_path.exists():
            data.append({
                'audio_path': str(audio_path),
                'transcription': transcription
            })
        else:
            print(f"Warning: Audio file not found: {audio_path}")

    return data

train_data = load_csv_split(TRAIN_CSV, AUDIO_DIR)
val_data = load_csv_split(VAL_CSV, AUDIO_DIR)
test_data = load_csv_split(TEST_CSV, AUDIO_DIR)

print(f"Train samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\nTotal samples: {len(train_data) + len(val_data) + len(test_data)}")


Train samples: 302
Validation samples: 37
Test samples: 39

Total samples: 378


## 4. Romanize Text Using Uroman


In [14]:
# Initialize uroman instance
try:
    uroman_obj = uroman.Uroman()
    print("Uroman initialized successfully")
except Exception as e:
    print(f"Error initializing uroman: {e}")
    uroman_obj = None

def romanize_text(text):
    """Convert Ge'ez script Amharic text to Romanized format using uroman"""
    if uroman_obj is None:
        return text  # Return original if uroman not available

    try:
        # Uroman.romanize_string() is the correct method
        romanized = uroman_obj.romanize_string(text)
        return romanized.strip()
    except Exception as e:
        print(f"Error romanizing text: {text[:50]}... Error: {e}")
        return text  # Fallback to original if romanization fails

# Test romanization on a sample
sample_text = train_data[0]['transcription']
print(f"Original (Ge'ez): {sample_text}")
romanized_sample = romanize_text(sample_text)
print(f"Romanized: {romanized_sample}")


Uroman initialized successfully
Original (Ge'ez): የፌደራል የከተማ ልማት ሚኒስቴር የከተማ ፕላን መመሪያዎች የሚከተሉትን ያካትታሉ አንደኛ የመሬት አጠቃቀም እቅድ ሁለተኛ የከተማ መሰረተ ልማት ሶስተኛ የመኖሪያ ቤቶች ልማት አራተኛ የኢንዱስትሪ ዞኖች እና አምስተኛ የአረንጓዴ ቦታዎች።
Romanized: yafeedaraale yakatamaa lemaate miniseteere yakatamaa pelaane mamariyaawoche yamikatalutene yaakaatetaalu anedanyaa yamareete ataqaaqame eqede hulatanyaa yakatamaa masarata lemaate sosetanyaa yamanoriyaa beetoche lemaate araatanyaa yaineduseteri zonoche enaa amesetanyaa yaaranegwaadee botaawoche.


In [15]:
# Romanize all transcriptions
print("Romanizing training data...")
for item in tqdm(train_data, desc="Train"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("Romanizing validation data...")
for item in tqdm(val_data, desc="Val"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("Romanizing test data...")
for item in tqdm(test_data, desc="Test"):
    item['transcription_romanized'] = romanize_text(item['transcription'])

print("\nRomanization complete!")
print(f"Sample - Original: {train_data[0]['transcription']}")
print(f"Sample - Romanized: {train_data[0]['transcription_romanized']}")


Romanizing training data...


Train: 100%|██████████| 302/302 [00:01<00:00, 190.78it/s]


Romanizing validation data...


Val: 100%|██████████| 37/37 [00:00<00:00, 254.36it/s]


Romanizing test data...


Test: 100%|██████████| 39/39 [00:00<00:00, 225.25it/s]


Romanization complete!
Sample - Original: የፌደራል የከተማ ልማት ሚኒስቴር የከተማ ፕላን መመሪያዎች የሚከተሉትን ያካትታሉ አንደኛ የመሬት አጠቃቀም እቅድ ሁለተኛ የከተማ መሰረተ ልማት ሶስተኛ የመኖሪያ ቤቶች ልማት አራተኛ የኢንዱስትሪ ዞኖች እና አምስተኛ የአረንጓዴ ቦታዎች።
Sample - Romanized: yafeedaraale yakatamaa lemaate miniseteere yakatamaa pelaane mamariyaawoche yamikatalutene yaakaatetaalu anedanyaa yamareete ataqaaqame eqede hulatanyaa yakatamaa masarata lemaate sosetanyaa yamanoriyaa beetoche lemaate araatanyaa yaineduseteri zonoche enaa amesetanyaa yaaranegwaadee botaawoche.





## 5. Load Model and Tokenizer


In [16]:
print(f"Loading model: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = VitsModel.from_pretrained(MODEL_NAME)

print(f"Model loaded successfully!")
print(f"Sampling rate: {model.config.sampling_rate} Hz")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

if torch.cuda.is_available():
    model = model.to("cuda")
    print("Model moved to CUDA")


Loading model: facebook/mms-tts-amh...


tokenizer_config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/47.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/145M [00:00<?, ?B/s]

Model loaded successfully!
Sampling rate: 16000 Hz
Model parameters: 36.28M
Model moved to CUDA


## 6. Test Inference (Before Training)

Test the model with romanized text to verify it works correctly.


In [19]:
def synthesize_speech(model, tokenizer, text_romanized, output_path):
    """Synthesize speech from romanized text"""
    inputs = tokenizer(text_romanized, return_tensors="pt")

    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

    with torch.no_grad():
        output = model(**inputs).waveform

    # Move tensor to CPU before converting to numpy
    if isinstance(output, torch.Tensor):
        output = output.cpu()

    # Save audio
    scipy.io.wavfile.write(
        output_path,
        rate=model.config.sampling_rate,
        data=output.numpy().T
    )

    return output_path

# Test with sample text (after romanization)
if len(train_data) > 0 and 'transcription_romanized' in train_data[0]:
    test_text_romanized = train_data[0]['transcription_romanized']
    test_output = "test_synthesized_before_training.wav"

    print(f"Testing synthesis with: {test_text_romanized}")
    synthesize_speech(model, tokenizer, test_text_romanized, test_output)
    print(f"Test audio saved to: {test_output}")
else:
    print("Please run romanization cells first!")


Testing synthesis with: yafeedaraale yakatamaa lemaate miniseteere yakatamaa pelaane mamariyaawoche yamikatalutene yaakaatetaalu anedanyaa yamareete ataqaaqame eqede hulatanyaa yakatamaa masarata lemaate sosetanyaa yamanoriyaa beetoche lemaate araatanyaa yaineduseteri zonoche enaa amesetanyaa yaaranegwaadee botaawoche.
Test audio saved to: test_synthesized_before_training.wav


## 7. Apply LoRA (if supported)

**Important Note**: VITS models use a different architecture than transformers. LoRA support may be limited. We'll attempt to apply it, but may need to use standard fine-tuning if LoRA is not compatible.


In [None]:
# Attempt to apply LoRA
# Note: This may fail for VITS models - if so, we'll need to use standard fine-tuning
try:
    # Check available modules in the model
    print("Checking model structure for LoRA-compatible modules...")
    target_modules_found = []
    for name, module in model.named_modules():
        module_type = type(module).__name__
        if any(target in name.lower() for target in ["linear", "proj", "dense"]):
            if "Linear" in module_type or "Conv" in module_type:
                target_modules_found.append(name)

    print(f"Found {len(target_modules_found)} potential target modules")
    if len(target_modules_found) > 0:
        print("Sample modules:", target_modules_found[:5])

    # Try to configure LoRA with found modules
    # VITS models may not have standard transformer modules
    # We may need to identify the correct target modules
    lora_config = LoraConfig(
        r=LORA_CONFIG["r"],
        lora_alpha=LORA_CONFIG["lora_alpha"],
        target_modules=LORA_CONFIG["target_modules"],  # Try standard modules first
        lora_dropout=LORA_CONFIG["lora_dropout"],
        bias=LORA_CONFIG["bias"],
        task_type=TaskType.FEATURE_EXTRACTION
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    print("\nLoRA adapters applied successfully!")

except Exception as e:
    print(f"LoRA application failed: {e}")
    print("\nNote: VITS models may not support standard LoRA.")
    print("We may need to use standard fine-tuning instead.")
    print("Continuing with standard model for now...")

# Remove any existing wrapper code (safety check)
# If you have any wrapper code, remove it first before running this

# Check if LoRA was successfully applied
from peft import PeftModel

if isinstance(model, PeftModel):
    print("Model is a PeftModel")
    print(f"Base model type: {type(model.base_model)}")
    print(f"Has base_model.forward: {hasattr(model.base_model, 'forward')}")

    # Try to check if we can access the underlying VITS model
    if hasattr(model.base_model, 'model'):
        print(f"Base model has 'model' attribute: {type(model.base_model.model)}")
else:
    print("Model is NOT a PeftModel - LoRA may not have been applied")

# Fix: Patch PEFT's forward to filter out inputs_embeds (VitsModel doesn't accept it)
# This is necessary because PEFT includes inputs_embeds in its forward signature,
# but VitsModel doesn't accept it
if isinstance(model, PeftModel):
    import functools
    
    # Store the original forward method
    original_peft_forward = model.forward
    
    @functools.wraps(original_peft_forward)
    def patched_peft_forward(*args, **kwargs):
        # Remove inputs_embeds if present (VitsModel doesn't accept it)
        kwargs.pop('inputs_embeds', None)
        # Also remove task_ids which PEFT might pass
        kwargs.pop('task_ids', None)
        # Filter to only valid VitsModel arguments
        valid_kwargs = {k: v for k, v in kwargs.items() 
                       if k in ['input_ids', 'attention_mask', 'speaker_id', 
                               'output_attentions', 'output_hidden_states', 'return_dict', 'labels']}
        # Call the original forward with filtered kwargs
        return original_peft_forward(*args, **valid_kwargs)
    
    # Replace the forward method
    model.forward = patched_peft_forward
    
    # Also patch the base_model's forward to catch any internal calls
    if hasattr(model, 'base_model'):
        # Helper function to create a patched forward
        def create_patched_forward(original_forward):
            @functools.wraps(original_forward)
            def patched_forward(*args, **kwargs):
                kwargs.pop('inputs_embeds', None)
                kwargs.pop('task_ids', None)
                valid_kwargs = {k: v for k, v in kwargs.items() 
                               if k in ['input_ids', 'attention_mask', 'speaker_id', 
                                       'output_attentions', 'output_hidden_states', 'return_dict', 'labels']}
                return original_forward(*args, **valid_kwargs)
            return patched_forward
        
        # Recursively patch nested base_models
        current = model.base_model
        depth = 0
        max_depth = 5  # Safety limit
        
        while depth < max_depth:
            if hasattr(current, 'forward'):
                current.forward = create_patched_forward(current.forward)
            
            # Move to next level
            if hasattr(current, 'base_model'):
                current = current.base_model
            elif hasattr(current, 'model'):
                current = current.model
            else:
                break
            depth += 1
    
    print("✓ Patched PEFT model forward to filter out unsupported arguments (inputs_embeds, etc.)")

Checking model structure for LoRA-compatible modules...
Found 108 potential target modules
Sample modules: ['base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.text_encoder.encoder.layers.0.attention.k_proj', 'base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.text_encoder.encoder.layers.0.attention.k_proj.base_layer', 'base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.text_encoder.encoder.layers.0.attention.k_proj.lora_A.default', 'base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.text_encoder.encoder.layers.0.attention.k_proj.lora_B.default', 'base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.text_encoder.encoder.layers.0.attention.v_proj']
trainable params: 73,728 || all params: 36,356,400 || trainable%: 0.2028

LoRA adapters applied successfully!
Model is a PeftModel
Base model type: <class 'peft.tuners.lora.mod

## 8. Training Setup

**Important**: VITS models require custom training loops as they don't use the standard HuggingFace Trainer interface. This section provides a basic structure. For actual training, you may need to implement a custom training loop.

**Alternative Approach**: If LoRA doesn't work, consider:
1. Standard fine-tuning (full model fine-tuning)
2. Using a different TTS model that supports LoRA better
3. Implementing custom LoRA for VITS architecture


In [36]:
# Custom training loop for VITS model
# Note: This is a simplified training loop - VITS training is complex
# You may need to adapt this based on your specific requirements

from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn

# Prepare datasets for training
train_dataset_list = [
    {'text': item['transcription_romanized'], 'audio_path': item['audio_path']}
    for item in train_data
]

val_dataset_list = [
    {'text': item['transcription_romanized'], 'audio_path': item['audio_path']}
    for item in val_data
]

def load_audio(audio_path, target_sr=16000):
    """Load and resample audio"""
    audio, sr = librosa.load(audio_path, sr=target_sr)
    return audio

print(f"Prepared {len(train_dataset_list)} training samples")
print(f"Prepared {len(val_dataset_list)} validation samples")
print("\nNOTE: Full VITS training requires complex loss functions (mel-spectrogram loss, etc.)")
print("This is a simplified structure. For production use, implement full VITS training objectives.")


Prepared 302 training samples
Prepared 37 validation samples

NOTE: Full VITS training requires complex loss functions (mel-spectrogram loss, etc.)
This is a simplified structure. For production use, implement full VITS training objectives.


In [None]:
# Memory optimization setup before training
import os

# Set PyTorch CUDA memory allocation config to reduce fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear any existing CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Print current memory usage
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print("GPU Memory Status:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Total: {total:.2f} GB")
    print(f"  Free: {total - reserved:.2f} GB")
    print()
    
    # Check model memory usage
    if hasattr(model, 'parameters'):
        model_params = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9
        print(f"Model parameters memory: ~{model_params:.2f} GB (FP32)")
        if TRAINING_ARGS.get("fp16", False):
            print(f"  With FP16: ~{model_params / 2:.2f} GB")
    print()


## 9. Train Model

**Note**: This is a simplified training approach. Full VITS training requires complex loss functions. For production, you may need to implement proper VITS training objectives.


In [None]:
# Simplified training loop with MEMORY OPTIMIZATION
# WARNING: This is a basic structure - full VITS training is more complex

# Enable gradient checkpointing to save memory
if TRAINING_ARGS.get("gradient_checkpointing", False):
    try:
        if hasattr(model, 'gradient_checkpointing_enable'):
            model.gradient_checkpointing_enable()
            print("✓ Gradient checkpointing enabled")
        elif hasattr(model, 'base_model') and hasattr(model.base_model, 'gradient_checkpointing_enable'):
            model.base_model.gradient_checkpointing_enable()
            print("✓ Gradient checkpointing enabled (via base_model)")
    except Exception as e:
        print(f"⚠ Could not enable gradient checkpointing: {e}")

# Enable FP16 mixed precision training
use_fp16 = TRAINING_ARGS.get("fp16", False)
if use_fp16:
    from torch.cuda.amp import autocast, GradScaler
    scaler = GradScaler()
    print("✓ FP16 mixed precision enabled")
else:
    scaler = None

model.train()
optimizer = AdamW(model.parameters(), lr=TRAINING_ARGS["learning_rate"])
scheduler = CosineAnnealingLR(optimizer, T_max=TRAINING_ARGS["max_steps"])

# Reduce batch size if memory is tight (T4 15GB)
batch_size = TRAINING_ARGS["per_device_train_batch_size"]
if batch_size > 2:
    print(f"⚠ Reducing batch size from {batch_size} to 2 for T4 GPU memory constraints")
    batch_size = 2

gradient_accumulation_steps = TRAINING_ARGS["gradient_accumulation_steps"]
max_steps = TRAINING_ARGS["max_steps"]

print("Starting training...")
print(f"Training steps: {max_steps}")
print(f"Learning rate: {TRAINING_ARGS['learning_rate']}")
print(f"Batch size: {batch_size}")
print(f"Gradient accumulation: {gradient_accumulation_steps}")
print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
print(f"FP16: {use_fp16}")
print()

# Clear CUDA cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"Initial GPU memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print()

step = 0
epoch = 0
accumulated_loss = 0.0

while step < max_steps:
    epoch_loss = 0.0
    num_batches = 0

    # Simple batching (you may want to implement proper DataLoader)
    for i in tqdm(range(0, len(train_dataset_list), batch_size), desc=f"Epoch {epoch+1}"):
        batch = train_dataset_list[i:i+batch_size]

        # Clear gradients at the start of each batch
        optimizer.zero_grad()

        # Simplified forward pass (actual VITS training needs proper loss)
        try:
            # Tokenize batch text
            texts = [item['text'] for item in batch]
            inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

            if torch.cuda.is_available():
                inputs = {k: v.to("cuda") for k, v in inputs.items()}

            # Forward pass with FP16 if enabled
            # Filter out inputs_embeds which VitsModel doesn't accept
            valid_inputs = {k: v for k, v in inputs.items() if k in ['input_ids', 'attention_mask']}
            
            # Use autocast for FP16
            if use_fp16:
                with autocast():
                    outputs = model(**valid_inputs)
            else:
                outputs = model(**valid_inputs)

            # PLACEHOLDER: Compute loss (VITS needs complex loss function)
            # For now, we'll skip actual loss computation
            # In real training, you'd compute mel-spectrogram loss, duration loss, etc.

            # Dummy loss for demonstration (REPLACE with actual VITS loss)
            # Use a simple tensor that doesn't require keeping outputs in memory
            loss = torch.tensor(0.0, requires_grad=True, device=inputs['input_ids'].device)

            # Delete outputs immediately to free memory
            del outputs
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            # Backward pass with FP16 scaling
            if use_fp16:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            # Delete loss and inputs to free memory
            loss_value = loss.item()
            del loss, inputs, valid_inputs
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            accumulated_loss += loss_value
            num_batches += 1

            # Update weights after gradient accumulation
            if (step + 1) % gradient_accumulation_steps == 0:
                if use_fp16:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                
                scheduler.step()
                step += 1

                if step % TRAINING_ARGS["logging_steps"] == 0:
                    avg_loss = accumulated_loss / (TRAINING_ARGS["logging_steps"] * gradient_accumulation_steps)
                    mem_used = torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0
                    mem_total = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0
                    print(f"Step {step}/{max_steps} - Loss: {avg_loss:.4f} - GPU Memory: {mem_used:.2f}/{mem_total:.2f} GB")
                    accumulated_loss = 0.0
                    
                    # Periodic memory cleanup
                    torch.cuda.empty_cache()

                if step >= max_steps:
                    break

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"\n⚠ CUDA OOM at step {step}. Clearing cache and reducing batch size...")
                torch.cuda.empty_cache()
                # Try with smaller batch
                if batch_size > 1:
                    batch_size = 1
                    print(f"Reduced batch size to {batch_size}")
                else:
                    print("Batch size already at minimum. Skipping this batch.")
                    continue
            else:
                print(f"Error in training step: {e}")
            # Clear memory on any error
            torch.cuda.empty_cache()
            continue
        except Exception as e:
            print(f"Error in training step: {e}")
            # Clear memory on any error
            torch.cuda.empty_cache()
            continue

    epoch += 1

    if step >= max_steps:
        break

print("\nTraining completed!")

# Final memory cleanup
torch.cuda.empty_cache()

# Save final model
final_model_path = f"{OUTPUT_DIR}_final"
print(f"\nSaving model to: {final_model_path}")
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"Model saved successfully to: {final_model_path}")
print("\nNOTE: This was a simplified training loop.")
print("For production use, implement proper VITS training objectives.")


Starting training...
Training steps: 1200
Learning rate: 0.0001
Batch size: 4
Gradient accumulation: 4



Epoch 1:  16%|█▌        | 12/76 [00:00<00:00, 116.90it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 1:  32%|███▏      | 24/76 [00:00<00:00, 110.51it/s]

Error in training step: maximum recursion depth exceeded


Epoch 1:  47%|████▋     | 36/76 [00:00<00:00, 104.82it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 1:  62%|██████▏   | 47/76 [00:00<00:00, 106.04it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 1:  76%|███████▋  | 58/76 [00:00<00:00, 104.82it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 1: 100%|██████████| 76/76 [00:00<00:00, 102.72it/s]


Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 2:  16%|█▌        | 12/76 [00:00<00:00, 118.25it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 2:  32%|███▏      | 24/76 [00:00<00:00, 108.49it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 2:  46%|████▌     | 35/76 [00:00<00:00, 103.32it/s]

Error in training step: maximum recursion depth exceeded


Epoch 2:  61%|██████    | 46/76 [00:00<00:00, 105.16it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 2:  75%|███████▌  | 57/76 [00:00<00:00, 104.41it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 2:  89%|████████▉ | 68/76 [00:00<00:00, 104.68it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 2: 100%|██████████| 76/76 [00:01<00:00, 67.55it/s] 


Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 3:  16%|█▌        | 12/76 [00:00<00:00, 112.64it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 3:  46%|████▌     | 35/76 [00:00<00:00, 103.79it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 3:  75%|███████▌  | 57/76 [00:00<00:00, 103.84it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 3: 100%|██████████| 76/76 [00:00<00:00, 103.94it/s]


Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:   0%|          | 0/76 [00:00<?, ?it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  13%|█▎        | 10/76 [00:00<00:00, 96.83it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  28%|██▊       | 21/76 [00:00<00:00, 103.01it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  42%|████▏     | 32/76 [00:00<00:00, 98.35it/s] 

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  57%|█████▋    | 43/76 [00:00<00:00, 101.25it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  71%|███████   | 54/76 [00:00<00:00, 98.36it/s] 

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4:  86%|████████▌ | 65/76 [00:00<00:00, 100.16it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 4: 100%|██████████| 76/76 [00:00<00:00, 98.48it/s]


Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 5:   0%|          | 0/76 [00:00<?, ?it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 5:  22%|██▏       | 17/76 [00:00<00:01, 35.88it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 5:  50%|█████     | 38/76 [00:00<00:00, 66.51it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 5:  79%|███████▉  | 60/76 [00:00<00:00, 85.73it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 5: 100%|██████████| 76/76 [00:01<00:00, 66.51it/s]


Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 6:  16%|█▌        | 12/76 [00:00<00:00, 114.47it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum

Epoch 6:  46%|████▌     | 35/76 [00:00<00:00, 99.69it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded


Epoch 6:  57%|█████▋    | 43/76 [00:00<00:00, 98.84it/s]

Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded
Error in training step: maximum recursion depth exceeded





KeyboardInterrupt: 

## 10. Zip Model and Copy to Google Drive


In [None]:
import shutil
import zipfile
from pathlib import Path

# Zip the final model directory
final_model_path = f"{OUTPUT_DIR}_final"
zip_filename = f"{final_model_path}.zip"

print(f"Creating zip file: {zip_filename}...")

with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
    for file_path in Path(final_model_path).rglob("*"):
        if file_path.is_file():
            # Get relative path for archive
            arcname = file_path.relative_to(final_model_path)
            zipf.write(file_path, arcname)
            print(f"  Added: {arcname}")

print(f"\nZip file created: {zip_filename}")

# Copy to Google Drive
drive_dest = f"/content/drive/MyDrive/{zip_filename}"
shutil.copy2(zip_filename, drive_dest)

print(f"Model zip file copied to Google Drive: {drive_dest}")
print(f"\nFile size: {Path(zip_filename).stat().st_size / (1024*1024):.2f} MB")


## 11. Test Inference (After Training)

Test the fine-tuned model with romanized text.


In [None]:
# Load the fine-tuned model for testing
final_model_path = f"{OUTPUT_DIR}_final"

print(f"Loading fine-tuned model from: {final_model_path}")
test_model = VitsModel.from_pretrained(final_model_path)
test_tokenizer = AutoTokenizer.from_pretrained(final_model_path)

if torch.cuda.is_available():
    test_model = test_model.to("cuda")
    test_model.eval()

# Test with sample text
if len(test_data) > 0 and 'transcription_romanized' in test_data[0]:
    test_text_romanized = test_data[0]['transcription_romanized']
    test_output = "test_synthesized_after_training.wav"

    print(f"\nTesting synthesis with: {test_text_romanized}")
    synthesize_speech(test_model, test_tokenizer, test_text_romanized, test_output)
    print(f"Test audio saved to: {test_output}")
else:
    print("Please run romanization cells first!")
