In [None]:
# ===== COMPLETE VITS TTS TRAINING SCRIPT - FIXED VERSION =====
import os
import sys
import torch
import json
import time
import pandas as pd
from pathlib import Path
import torchaudio

if hasattr(torchaudio, "set_audio_backend"):
    torchaudio.set_audio_backend("soundfile")
    
print("=" * 70)
print("VITS TTS TRAINING - FIXED VERSION")
print("=" * 70)

# ===== STEP 1: SETUP PATHS =====
print("\nüìÅ Setting up paths...")

# Define base directories
BASE = r"C:\Users\ReticleX\Pictures\nepali_tts"  # Update this to your actual path
OUTPUT = os.path.join(BASE, "vits_output")

# Create output directory
os.makedirs(OUTPUT, exist_ok=True)

print(f"‚úÖ Base directory: {BASE}")
print(f"‚úÖ Output directory: {OUTPUT}")

VITS TTS TRAINING - FIXED VERSION

üìÅ Setting up paths...
‚úÖ Base directory: C:\Users\ReticleX\Pictures\nepali_tts
‚úÖ Output directory: C:\Users\ReticleX\Pictures\nepali_tts\vits_output


In [None]:
# ===== STEP 2: IMPORTS =====
print("\nüì¶ Importing modules...")
try:
    from TTS.config.shared_configs import BaseDatasetConfig, BaseAudioConfig
    from TTS.tts.configs.vits_config import VitsConfig
    from TTS.tts.models.vits import Vits
    from TTS.tts.utils.text.tokenizer import TTSTokenizer
    from TTS.tts.utils.text.characters import Graphemes
    from TTS.utils.audio import AudioProcessor
    from TTS.tts.datasets import load_tts_samples
    from trainer import Trainer, TrainerArgs
    print("‚úÖ All imports successful")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("\nüí° Install missing packages:")
    print("   pip install TTS")
    print("   pip install trainer")
    sys.exit(1)


üì¶ Importing modules...
‚úÖ All imports successful


In [None]:
# ===== STEP 3: DEFINE FIX METADATA FUNCTION =====
def fix_metadata_extensions(metadata_path, audio_dir):
    """Fix metadata to add .wav extension to filenames"""
    try:
        # Check what files actually exist in the audio directory
        actual_files = [f for f in os.listdir(audio_dir) if f.endswith('.wav')]
        print(f"   Found {len(actual_files)} .wav files in {audio_dir}")
        
        # Read original metadata
        with open(metadata_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        fixed_lines = []
        missing_ext = 0
        with_ext = 0
        
        for line in lines:
            parts = line.strip().split('|')
            if len(parts) >= 1:
                filename = parts[0].strip()
                text = parts[1] if len(parts) > 1 else ""
                
                # Check if filename already has .wav extension
                if not filename.lower().endswith('.wav'):
                    # Try adding .wav extension
                    filename_with_ext = filename + '.wav'
                    missing_ext += 1
                else:
                    filename_with_ext = filename
                    with_ext += 1
                
                # Check if file exists
                audio_path = os.path.join(audio_dir, filename_with_ext)
                if os.path.exists(audio_path):
                    fixed_lines.append(f"{filename_with_ext}|{text}")
                else:
                    # Try without extension (in case it's already in filename but not .wav)
                    if not filename.lower().endswith('.wav'):
                        # Check if it exists without any change
                        audio_path_orig = os.path.join(audio_dir, filename)
                        if os.path.exists(audio_path_orig):
                            fixed_lines.append(f"{filename}|{text}")
                        else:
                            print(f"   Warning: File not found: {filename} or {filename_with_ext}")
                            # Skip this entry
                            continue
                    else:
                        print(f"   Warning: File not found: {filename_with_ext}")
                        # Skip this entry
                        continue
        
        # Save fixed metadata
        fixed_path = metadata_path.replace('.csv', '_fixed_wav.csv')
        with open(fixed_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(fixed_lines))
        
        print(f"‚úÖ Fixed metadata saved to: {fixed_path}")
        print(f"   Original entries: {len(lines)}")
        print(f"   Fixed entries: {len(fixed_lines)}")
        print(f"   Added .wav extension to: {missing_ext} files")
        print(f"   Already had .wav: {with_ext} files")
        
        # Show first few fixed entries
        print("\nüìÑ First 3 fixed entries:")
        for i in range(min(3, len(fixed_lines))):
            filename, text = fixed_lines[i].split('|', 1)
            print(f"   {i+1}. {filename} | {text[:30]}...")
        
        return fixed_path
        
    except Exception as e:
        print(f"‚ùå Error fixing metadata: {e}")
        import traceback
        traceback.print_exc()
        return None

def verify_audio_files(metadata_path, audio_dir):
    """Final verification that all files in metadata exist"""
    try:
        with open(metadata_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        valid = 0
        missing = []
        
        print(f"   Checking {len(lines)} entries...")
        
        for i, line in enumerate(lines):
            parts = line.strip().split('|')
            if len(parts) >= 1:
                filename = parts[0].strip()
                audio_path = os.path.join(audio_dir, filename)
                
                if os.path.exists(audio_path):
                    valid += 1
                else:
                    missing.append((i+1, filename))
                
                # Show progress every 500 files
                if (i + 1) % 500 == 0:
                    print(f"   Checked {i+1}/{len(lines)} files...")
        
        print(f"\nüìä Verification results:")
        print(f"   Total entries: {len(lines)}")
        print(f"   Valid files: {valid}")
        print(f"   Missing files: {len(missing)}")
        
        if missing:
            print(f"\n‚ùå Missing files detected!")
            print(f"   First 5 missing:")
            for i, (line_num, filename) in enumerate(missing[:5]):
                print(f"      Line {line_num}: {filename}")
            
            # Check if files exist without .wav extension
            print(f"\nüîç Checking for files without .wav extension...")
            found_without_ext = 0
            for line_num, filename in missing[:10]:
                # Remove .wav extension if present
                base_name = filename
                if filename.lower().endswith('.wav'):
                    base_name = filename[:-4]
                
                # Check if file exists without .wav
                possible_paths = [
                    os.path.join(audio_dir, base_name),
                    os.path.join(audio_dir, base_name + '.WAV'),  # uppercase
                    os.path.join(audio_dir, base_name + '.Wav'),  # mixed case
                ]
                
                for path in possible_paths:
                    if os.path.exists(path):
                        print(f"   Found: {os.path.basename(path)} (different case/extension)")
                        found_without_ext += 1
                        break
            
            if found_without_ext > 0:
                print(f"\n‚ö†Ô∏è Found {found_without_ext} files with different extensions/case")
                print("üí° Run the fix_metadata_extensions function again")
            
            return False
        else:
            print("‚úÖ All audio files found!")
            return True
            
    except Exception as e:
        print(f"‚ùå Error during verification: {e}")
        return False

In [None]:
# ===== STEP 4: FIX METADATA - ADD .WAV EXTENSION =====
print("\nüîç Fixing metadata file extensions...")

# Define paths
train_audio_dir = os.path.join(BASE, "dataset", "ljspeech_train", "wavs")
train_metadata = os.path.join(BASE, "dataset", "ljspeech_train", "metadata_fixed.csv")  # Use your fixed metadata

print(f"üìÅ Audio directory: {train_audio_dir}")
print(f"üìÑ Metadata file: {train_metadata}")

# Check if audio directory exists
if not os.path.exists(train_audio_dir):
    print(f"‚ùå Audio directory not found: {train_audio_dir}")
    print("üí° Creating directory...")
    os.makedirs(train_audio_dir, exist_ok=True)
    print("‚úÖ Created audio directory")
    print("üí° Please add your .wav files to this directory")
    sys.exit(1)

# Fix metadata extensions
fixed_metadata = fix_metadata_extensions(train_metadata, train_audio_dir)
if not fixed_metadata:
    print("‚ùå Failed to fix metadata")
    sys.exit(1)


üîç Fixing metadata file extensions...
üìÅ Audio directory: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\wavs
üìÑ Metadata file: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\metadata_fixed.csv
   Found 6082 .wav files in C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\wavs
‚úÖ Fixed metadata saved to: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\metadata_fixed_fixed_wav.csv
   Original entries: 6082
   Fixed entries: 6082
   Added .wav extension to: 6082 files
   Already had .wav: 0 files

üìÑ First 3 fixed entries:
   1. 97c06f8b39.wav | ‡§≠‡•å‡§§‡§ø‡§ï ‡§µ‡§ø‡§ú‡•ç‡§û‡§æ‡§®‡§≤‡•á ‡§ï‡•Å‡§®‡•à...
   2. 48e8538fbd.wav | ‡§Æ‡§®‡•ç‡§¶‡§ø‡§∞ ‡§≤‡§ø‡§ö‡•ç‡§õ‡§µ‡§ø ‡§ï‡§≤‡§æ...
   3. 06c5b9782c.wav | ‡§®‡§ø‡§Æ‡•ç‡§§‡§ø ‡§¶‡§æ‡§® ‡§ó‡§∞‡•á...


In [None]:
# ===== STEP 5: CREATE NEPALI CHARACTER SET =====
print("\nüìù Creating Nepali character set...")

nepali_vocab = []

# Vowels
vowels = ['‡§Ö', '‡§Ü', '‡§á', '‡§à', '‡§â', '‡§ä', '‡§ã', '‡§è', '‡§ê', '‡§ì', '‡§î']
nepali_vocab.extend(vowels)

# Consonants
consonants = [
    '‡§ï', '‡§ñ', '‡§ó', '‡§ò', '‡§ô',
    '‡§ö', '‡§õ', '‡§ú', '‡§ù', '‡§û',
    '‡§ü', '‡§†', '‡§°', '‡§¢', '‡§£',
    '‡§§', '‡§•', '‡§¶', '‡§ß', '‡§®',
    '‡§™', '‡§´', '‡§¨', '‡§≠', '‡§Æ',
    '‡§Ø', '‡§∞', '‡§≤', '‡§µ', '‡§∂', '‡§∑', '‡§∏', '‡§π'
]
nepali_vocab.extend(consonants)

# Vowel signs
vowel_signs = ['‡§æ', '‡§ø', '‡•Ä', '‡•Å', '‡•Ç', '‡•É', '‡•á', '‡•à', '‡•ã', '‡•å', '‡•ç']
nepali_vocab.extend(vowel_signs)

# Diacritics
diacritics = ['‡§Ç', '‡§É', '‡§Å']
nepali_vocab.extend(diacritics)

# Nepali digits
digits = ['‡•¶', '‡•ß', '‡•®', '‡•©', '‡•™', '‡•´', '‡•¨', '‡•≠', '‡•Æ', '‡•Ø']
nepali_vocab.extend(digits)

# Latin alphabet and numbers
latin = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
nepali_vocab.extend(latin)

# Common punctuation
common_punct = list(" !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~‡•§")
nepali_vocab.extend(common_punct)

# Remove duplicates and sort
nepali_vocab = sorted(set(nepali_vocab))

print(f"‚úÖ Character set ready ({len(nepali_vocab)} characters)")

# Create Graphemes object
chars_obj = Graphemes(
    characters=nepali_vocab,
    punctuations="‡•§!?,.:; -\"",
    pad="_",
    eos="~",
    bos="^",
    blank="#",
)

# Create tokenizer
tokenizer = TTSTokenizer(
    use_phonemes=False,
    characters=chars_obj,
    add_blank=True,
)

print(f"‚úÖ Tokenizer ready (vocab: {len(tokenizer.characters.characters)})")

# Test tokenizer
test_text = "‡§®‡§Æ‡§∏‡•ç‡§§‡•á"
test_ids = tokenizer.text_to_ids(test_text)
test_decoded = tokenizer.ids_to_text(test_ids)
print(f"   Test: '{test_text}' ‚Üí {len(test_ids)} tokens ‚Üí '{test_decoded}'")


üìù Creating Nepali character set...
‚úÖ Character set ready (164 characters)
‚úÖ Tokenizer ready (vocab: 164)
   Test: '‡§®‡§Æ‡§∏‡•ç‡§§‡•á' ‚Üí 13 tokens ‚Üí '#‡§®#‡§Æ#‡§∏#‡•ç#‡§§#‡•á#'


In [None]:
# ===== STEP 6: DATASET CONFIGURATION =====
print("\nüìä Setting up dataset...")

# Use the fixed metadata with .wav extensions
dataset_config = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train=fixed_metadata,  # Use the fixed metadata variable
    meta_file_val=os.path.join(BASE, "dataset", "ljspeech_val", "metadata.csv"),
    path=train_audio_dir,  # This should be the wavs directory
    language="ne",
)

print(f"‚úÖ Dataset path: {dataset_config.path}")
print(f"‚úÖ Metadata file: {dataset_config.meta_file_train}")


üìä Setting up dataset...
‚úÖ Dataset path: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\wavs
‚úÖ Metadata file: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\metadata_fixed_fixed_wav.csv


In [None]:
# ===== STEP 7: VERIFY AUDIO FILES EXIST =====
print("\nüîç Final verification of audio files...")

# Verify files
if not verify_audio_files(dataset_config.meta_file_train, dataset_config.path):
    print("\n‚ùå Audio file verification failed!")
    print("\nüí° Solutions:")
    print("   1. Make sure .wav files are in the correct directory")
    print("   2. Check file extensions (should be .wav)")
    print("   3. Check filename case sensitivity")
    
    # List actual files in directory
    print(f"\nüìÅ Actual files in {dataset_config.path}:")
    actual_files = os.listdir(dataset_config.path)
    wav_files = [f for f in actual_files if f.lower().endswith('.wav')]
    print(f"   Total files: {len(actual_files)}")
    print(f"   .wav files: {len(wav_files)}")
    
    if wav_files:
        print(f"   First 5 .wav files:")
        for f in wav_files[:5]:
            print(f"      {f}")
    
    # Ask to continue anyway
    continue_anyway = input("\n‚ùì Continue anyway? (y/n): ").strip().lower()
    if continue_anyway != 'y':
        sys.exit(1)


üîç Final verification of audio files...
   Checking 100 entries...

üìä Verification results:
   Total entries: 100
   Valid files: 100
   Missing files: 0
‚úÖ All audio files found!


In [None]:
# ===== STEP 8: DEBUG AND FIX DATASET LOADING =====
print("\nüìÇ Debugging dataset loading...")

# First, let's manually parse the metadata to see what's happening
print(f"üîç Manually parsing: {dataset_config.meta_file_train}")

with open(dataset_config.meta_file_train, 'r', encoding='utf-8') as f:
    lines = f.readlines()

print(f"üìÑ Total lines: {len(lines)}")

# Parse first 5 lines manually
print("\nüìã Manual parsing of first 5 lines:")
for i in range(min(5, len(lines))):
    line = lines[i].strip()
    print(f"\nLine {i+1}: '{line}'")
    
    # Split by pipe
    parts = line.split('|')
    print(f"  Parts: {len(parts)}")
    for j, part in enumerate(parts):
        print(f"    Part {j}: '{part}'")
    
    # Check for empty parts or weird characters
    if len(parts) >= 2:
        text = parts[1]
        print(f"  Text analysis:")
        print(f"    Length: {len(text)}")
        print(f"    Is empty: {text.strip() == ''}")
        print(f"    First 10 chars ascii: {[ord(c) for c in text[:10]]}")
        newline = '/n'
        print(f"    Contains newline: {newline in text}")

print("\n" + "="*50)
print("üîß CREATING PROPER LJSPEECH METADATA")
print("="*50)

# Create a properly formatted LJSpeech metadata file
proper_meta_path = os.path.join(BASE, "dataset", "ljspeech_train", "proper_metadata.csv")

print(f"\nüìù Creating proper metadata at: {proper_meta_path}")

# Get actual .wav files
wav_files = [f for f in os.listdir(dataset_config.path) if f.lower().endswith('.wav')]
print(f"üìÅ Found {len(wav_files)} .wav files")

# Create proper metadata
proper_lines = []
for i, wav_file in enumerate(wav_files[:100]):  # Use first 100
    # Create proper LJSpeech format: filename|text|normalized_text
    # For testing, use simple text
    text = f"Test text number {i+1} for Nepali TTS"
    proper_lines.append(f"{wav_file}|{text}|{text}")  # Note: THREE parts with pipe

# Write the file
with open(proper_meta_path, 'w', encoding='utf-8', newline='\n') as f:
    f.write('\n'.join(proper_lines))

print(f"‚úÖ Created proper metadata with {len(proper_lines)} entries")

# Verify the file
print("\nüîç Verifying proper metadata:")
with open(proper_meta_path, 'r', encoding='utf-8') as f:
    verify_lines = f.readlines()

print(f"üìÑ First 3 entries:")
for i in range(min(3, len(verify_lines))):
    line = verify_lines[i].strip()
    print(f"  {i+1}. {line}")

# Update dataset config
print(f"\nüîÑ Updating dataset config to use proper metadata")
dataset_config.meta_file_train = proper_meta_path

# ===== TRY CUSTOM DATASET LOADING =====
print("\n" + "="*50)
print("üîÑ TRYING CUSTOM DATASET LOADING")
print("="*50)

# Instead of using load_tts_samples, let's create samples manually
print("\nüß™ Creating samples manually...")

def create_manual_samples(metadata_path, audio_dir):
    """Manually create samples to bypass formatter issues"""
    samples = []
    
    with open(metadata_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        parts = line.split('|')
        if len(parts) >= 2:
            filename = parts[0].strip()
            text = parts[1].strip()
            
            # Skip if filename or text is empty
            if not filename or not text:
                continue
                
            # Fix double .wav extension if present
            if filename.endswith('.wav.wav'):
                filename = filename[:-4]  # Remove one .wav
            
            audio_path = os.path.join(audio_dir, filename)
            
            if os.path.exists(audio_path):
                sample = {
                    'audio_file': audio_path,
                    'text': text,
                    'speaker_name': 'nepali_speaker',
                    'language': 'ne'
                }
                samples.append(sample)
            else:
                print(f"‚ö†Ô∏è File not found: {audio_path}")
    
    return samples

# Create manual samples
train_samples = create_manual_samples(proper_meta_path, dataset_config.path)
eval_samples = train_samples[:20]  # Use first 20 as validation

print(f"‚úÖ Manual samples created:")
print(f"   Training samples: {len(train_samples)}")
print(f"   Validation samples: {len(eval_samples)}")

if train_samples:
    sample = train_samples[0]
    print(f"\nüìÑ Sample from manual creation:")
    print(f"   Audio file: {os.path.basename(sample['audio_file'])}")
    print(f"   Text: '{sample['text']}'")
    print(f"   Text length: {len(sample['text'])}")
    print(f"   File exists: {os.path.exists(sample['audio_file'])}")
    
    # Test tokenizer on the sample text
    print(f"\nüîç Testing tokenizer on sample text:")
    test_ids = tokenizer.text_to_ids(sample['text'])
    print(f"   Token IDs: {len(test_ids)} tokens")
    print(f"   Decoded: '{tokenizer.ids_to_text(test_ids)}'")

# ===== OVERRIDE THE DATASET LOADING =====
print("\n" + "="*50)
print("üöÄ OVERRIDING DATASET CONFIGURATION")
print("="*50)

# Since LJSpeech formatter seems problematic, let's try a different approach
# We'll modify the dataset config to use a custom formatter

# Update the dataset config to use our manual samples
print("\nüìä Using manually created samples for training")

# Skip the load_tts_samples step since we already have samples
print("‚úÖ Bypassing load_tts_samples function")

# Continue with the rest of the training setup...
print("\n‚û°Ô∏è Continuing with model creation and training...")

# ===== CONTINUE WITH THE REST OF THE CODE =====
# (Keep all the code after STEP 8 the same, starting with Step 9: Audio Configuration)


üìÇ Debugging dataset loading...
üîç Manually parsing: C:\Users\ReticleX\Pictures\nepali_tts\dataset\ljspeech_train\proper_metadata.csv
üìÑ Total lines: 100

üìã Manual parsing of first 5 lines:

Line 1: '0009e14baa.wav|Test text number 1 for Nepali TTS|Test text number 1 for Nepali TTS'
  Parts: 3
    Part 0: '0009e14baa.wav'
    Part 1: 'Test text number 1 for Nepali TTS'
    Part 2: 'Test text number 1 for Nepali TTS'
  Text analysis:
    Length: 33
    Is empty: False
    First 10 chars ascii: [84, 101, 115, 116, 32, 116, 101, 120, 116, 32]
    Contains newline: False

Line 2: '00157312d0.wav|Test text number 2 for Nepali TTS|Test text number 2 for Nepali TTS'
  Parts: 3
    Part 0: '00157312d0.wav'
    Part 1: 'Test text number 2 for Nepali TTS'
    Part 2: 'Test text number 2 for Nepali TTS'
  Text analysis:
    Length: 33
    Is empty: False
    First 10 chars ascii: [84, 101, 115, 116, 32, 116, 101, 120, 116, 32]
    Contains newline: False

Line 3: '00168daa2b.wav|Test te

In [None]:
# ===== STEP 9: AUDIO CONFIGURATION =====
print("\nüéµ Creating audio configuration...")

audio_config = BaseAudioConfig(
    sample_rate=22050,
    hop_length=256,
    win_length=1024,
    fft_size=1024,
    num_mels=80,
    mel_fmin=0.0,
    mel_fmax=8000.0,
)

# Create audio processor from config
ap = AudioProcessor.init_from_config(audio_config)
print(f"‚úÖ Audio processor: {ap.sample_rate} Hz")

# ===== STEP 10: VITS CONFIGURATION =====
print("\n‚öôÔ∏è VITS configuration...")

config = VitsConfig(
    output_path=OUTPUT,
    run_name=f"nepali_vits_{time.strftime('%Y%m%d_%H%M%S')}",
)

# Set attributes
config.datasets = [dataset_config]
config.audio = audio_config

# Training parameters
config.batch_size = 4
config.eval_batch_size = 2
config.num_loader_workers = 0  # Set to 0 to avoid multiprocessing issues on Windows
config.num_eval_loader_workers = 0
config.epochs = 50  # Reduced for testing

# Text processing
config.text_cleaner = "basic_cleaners"
config.use_phonemes = False
config.add_blank = True
config.characters = None
config.num_chars = len(tokenizer.characters.characters)

# Optimizer
config.optimizer = "AdamW"
config.optimizer_params = {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}
config.lr = 2e-4
config.lr_scheduler = "ExponentialLR"
config.lr_scheduler_params = {"gamma": 0.999875}

# Training monitoring
config.print_step = 25
config.save_step = 500  # Reduced for testing
config.save_n_checkpoints = 3
config.run_eval = True

# Test sentences
config.test_sentences = [
    "‡§®‡§Æ‡§∏‡•ç‡§§‡•á",
    "‡§ß‡§®‡•ç‡§Ø‡§µ‡§æ‡§¶",
    "‡§ï‡•á ‡§õ",
    "‡§Æ ‡§®‡•á‡§™‡§æ‡§≤‡•Ä ‡§π‡•Å‡§Å"
]

print(f"‚úÖ Config created:")
print(f"   Run name: {config.run_name}")
print(f"   Batch size: {config.batch_size}")
print(f"   Learning rate: {config.lr}")
print(f"   Characters: {config.num_chars}")
print(f"   Epochs: {config.epochs}")


üéµ Creating audio configuration...
 > Setting up Audio Processor...
 | > sample_rate:22050
 | > resample:False
 | > num_mels:80
 | > log_func:np.log10
 | > min_level_db:-100
 | > frame_shift_ms:None
 | > frame_length_ms:None
 | > ref_level_db:20
 | > fft_size:1024
 | > power:1.5
 | > preemphasis:0.0
 | > griffin_lim_iters:60
 | > signal_norm:True
 | > symmetric_norm:True
 | > mel_fmin:0
 | > mel_fmax:8000.0
 | > pitch_fmin:1.0
 | > pitch_fmax:640.0
 | > spec_gain:20.0
 | > stft_pad_mode:reflect
 | > max_norm:4.0
 | > clip_norm:True
 | > do_trim_silence:True
 | > trim_db:45
 | > do_sound_norm:False
 | > do_amp_to_db_linear:True
 | > do_amp_to_db_mel:True
 | > do_rms_norm:False
 | > db_level:None
 | > stats_path:None
 | > base:10
 | > hop_length:256
 | > win_length:1024
‚úÖ Audio processor: 22050 Hz

‚öôÔ∏è VITS configuration...
‚úÖ Config created:
   Run name: nepali_vits_20251215_154211
   Batch size: 4
   Learning rate: 0.0002
   Characters: 164
   Epochs: 50


In [None]:
# ===== STEP 11: CREATE MODEL =====
print("\nüß† Creating VITS model...")

model = Vits(
    config=config,
    ap=ap,
    tokenizer=tokenizer,
    speaker_manager=None,
)

# Move to GPU if available
if torch.cuda.is_available():
    print("üéÆ Moving model to GPU...")
    model.cuda()
    print(f"‚úÖ Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("üíª Using CPU (training will be slower)")

print(f"‚úÖ Model ready ({sum(p.numel() for p in model.parameters()):,} parameters)")


üß† Creating VITS model...
üíª Using CPU (training will be slower)
‚úÖ Model ready (83,068,204 parameters)


In [None]:
# ===== STEP 12: CREATE TRAINER WITH MANUAL SAMPLES =====
print("\nüë®‚Äçüè´ Creating trainer with manual samples...")

trainer_args = TrainerArgs(
    continue_path=None,
    restore_path=None,
    use_ddp=False,  # Disable distributed training for simplicity
)

# Create trainer with our manually created samples
trainer = Trainer(
    trainer_args,
    config,
    OUTPUT,
    model=model,
    train_samples=train_samples,  # Use our manually created samples
    eval_samples=eval_samples,    # Use our manually created samples
)

print("‚úÖ Trainer ready!")

 > Training Environment:
 | > Backend: Torch
 | > Mixed precision: False
 | > Precision: float32
 | > Num. of CPUs: 16
 | > Num. of Torch Threads: 12
 | > Torch seed: 54321
 | > Torch CUDNN: True
 | > Torch CUDNN deterministic: False
 | > Torch CUDNN benchmark: False
 | > Torch TF32 MatMul: False



üë®‚Äçüè´ Creating trainer with manual samples...


 > Start Tensorboard: tensorboard --logdir=C:\Users\ReticleX\Pictures\nepali_tts\vits_output\nepali_vits_20251215_154211-December-15-2025_03+43PM-0000000

 > Model has 83068204 parameters


‚úÖ Trainer ready!


In [None]:
# ===== STEP 13: CREATE CUSTOM DATA LOADER (FIXED) =====
print("\nüîç Creating custom data loader...")

from torch.utils.data import Dataset, DataLoader
import numpy as np

class NepaliTTSDataset(Dataset):
    """Custom dataset for Nepali TTS"""
    def __init__(self, samples, tokenizer, ap, max_text_length=200, max_audio_length=100000):
        self.samples = samples
        self.tokenizer = tokenizer
        self.ap = ap
        self.max_text_length = max_text_length
        self.max_audio_length = max_audio_length
        
        # Filter samples that are too long
        self.filtered_samples = []
        for sample in samples:
            # Tokenize to check text length
            token_ids = tokenizer.text_to_ids(sample['text'])
            if len(token_ids) <= max_text_length:
                # Check audio length
                try:
                    audio_info = os.stat(sample['audio_file'])
                    # Approximate audio length from file size (16-bit mono at 22050 Hz)
                    approx_length = audio_info.st_size / 2  # 2 bytes per sample for 16-bit
                    if approx_length <= max_audio_length:
                        self.filtered_samples.append(sample)
                except:
                    continue
        
        print(f"   Original samples: {len(samples)}")
        print(f"   Filtered samples: {len(self.filtered_samples)}")
        
    def __len__(self):
        return len(self.filtered_samples)
    
    def __getitem__(self, idx):
        sample = self.filtered_samples[idx]
        
        # Tokenize text
        text = sample['text']
        token_ids = self.tokenizer.text_to_ids(text)
        
        # Load and process audio
        waveform = self.ap.load_wav(sample['audio_file'])
        
        # Convert to tensor
        token_ids = torch.LongTensor(token_ids)
        waveform = torch.FloatTensor(waveform)
        
        return {
            'text': text,
            'token_ids': token_ids,
            'text_lengths': len(token_ids),
            'waveform': waveform,
            'waveform_lengths': len(waveform),
            'speaker_id': 0,  # Single speaker
            'audio_file': sample['audio_file'],
        }

def collate_fn(batch):
    """Custom collate function for variable length sequences"""
    # Get max lengths in this batch
    max_text_len = max(item['text_lengths'] for item in batch)
    max_audio_len = max(item['waveform_lengths'] for item in batch)
    
    batch_size = len(batch)
    
    # Initialize padded tensors
    text_inputs = torch.zeros(batch_size, max_text_len, dtype=torch.long)
    text_lengths = torch.zeros(batch_size, dtype=torch.long)
    waveforms = torch.zeros(batch_size, max_audio_len, dtype=torch.float32)
    waveform_lengths = torch.zeros(batch_size, dtype=torch.long)
    speaker_ids = torch.zeros(batch_size, dtype=torch.long)
    audio_names = []
    raw_texts = []
    
    # Fill tensors
    for i, item in enumerate(batch):
        text_len = item['text_lengths']
        audio_len = item['waveform_lengths']
        
        # Text
        text_inputs[i, :text_len] = item['token_ids']
        text_lengths[i] = text_len
        raw_texts.append(item['text'])
        
        # Audio
        waveforms[i, :audio_len] = item['waveform']
        waveform_lengths[i] = audio_len
        
        # Speaker (dummy)
        speaker_ids[i] = item['speaker_id']
        
        # Audio name
        audio_names.append(item['audio_file'])
    
    # Add channel dimension for audio (VITS expects [B, 1, T])
    waveforms = waveforms.unsqueeze(1)
    
    return {
        'text_input': text_inputs,
        'text_lengths': text_lengths,
        'waveform': waveforms,
        'waveform_lengths': waveform_lengths,
        'speaker_ids': speaker_ids,
        'raw_text': raw_texts,
        'audio_unique_name': audio_names,
    }

# Create datasets
print("üìä Creating datasets...")
train_dataset = NepaliTTSDataset(train_samples, tokenizer, ap)
eval_dataset = NepaliTTSDataset(eval_samples, tokenizer, ap)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0,  # Set to 0 for Windows compatibility
)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=config.eval_batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0,
)

print(f"‚úÖ Data loaders created:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Eval batches: {len(eval_loader)}")

# Test one batch
print("\nüß™ Testing one batch...")
for batch in train_loader:
    print(f"‚úÖ Batch loaded successfully!")
    print(f"   Batch size: {batch['text_input'].shape[0]}")
    print(f"   Text shape: {batch['text_input'].shape}")
    print(f"   Text lengths: {batch['text_lengths']}")
    print(f"   Audio shape: {batch['waveform'].shape}")
    print(f"   Audio lengths: {batch['waveform_lengths']}")
    
    # Test model forward pass
    print("\nüß™ Testing model forward pass...")
    try:
        # Move batch to GPU if available
        if torch.cuda.is_available():
            batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        # Get the correct forward method signature by checking the source
        # Let's try different signatures
        try:
            # Try signature 1: Standard VITS forward
            outputs = model.forward(
                x=batch['text_input'],
                x_lengths=batch['text_lengths'],
                y=batch['waveform'],
                y_lengths=batch['waveform_lengths'],
                sid=batch['speaker_ids']
            )
            print(f"‚úÖ Forward pass successful with standard signature!")
        except Exception as e1:
            print(f"   Standard signature failed: {e1}")
            
            # Try signature 2: Without speaker IDs
            outputs = model.forward(
                x=batch['text_input'],
                x_lengths=batch['text_lengths'],
                y=batch['waveform'],
                y_lengths=batch['waveform_lengths'],
            )
            print(f"‚úÖ Forward pass successful without speaker IDs!")
        
        print(f"   Output keys: {list(outputs.keys())}")
        
        # Check for losses
        if hasattr(outputs, 'keys'):
            for key in outputs.keys():
                if 'loss' in key:
                    print(f"   {key}: {outputs[key].item():.4f}")
        elif isinstance(outputs, dict):
            for key, value in outputs.items():
                if 'loss' in key:
                    print(f"   {key}: {value.item():.4f}")
        
    except Exception as e:
        print(f"‚ùå Forward pass failed: {e}")
        
        # Debug: Check model forward method signature
        print("\nüîç Checking model forward method signature...")
        import inspect
        try:
            sig = inspect.signature(model.forward)
            print(f"   Forward signature: {sig}")
            
            # Show docstring
            if model.forward.__doc__:
                print(f"   Docstring: {model.forward.__doc__[:200]}...")
        except:
            print("   Could not inspect signature")
        
        # Try one more time with minimal arguments
        print("\nüîÑ Trying minimal forward pass...")
        try:
            outputs = model.forward(
                batch['text_input'],
                batch['text_lengths'],
                batch['waveform'],
                batch['waveform_lengths'],
            )
            print(f"‚úÖ Minimal forward pass successful!")
        except Exception as e2:
            print(f"‚ùå Minimal forward also failed: {e2}")
    
    break

# ===== STEP 14: CREATE CUSTOM TRAINER WRAPPER (FIXED) =====
print("\nüîß Creating custom training loop...")

# Let's check the actual forward method by looking at the source
print("üîç Inspecting VITS model structure...")

# Try to get a batch and see what happens
test_batch = next(iter(train_loader))
if torch.cuda.is_available():
    test_batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}

# Try different forward signatures
print("\nüß™ Testing different forward signatures...")

# Signature 1: Direct call
try:
    output1 = model(
        test_batch['text_input'],
        test_batch['text_lengths'],
        test_batch['waveform'],
        test_batch['waveform_lengths'],
        test_batch['speaker_ids']
    )
    print("‚úÖ Signature 1 (model() with 5 args) works!")
    forward_signature = 1
except Exception as e:
    print(f"‚ùå Signature 1 failed: {e}")

# Signature 2: Without speaker IDs
try:
    output2 = model(
        test_batch['text_input'],
        test_batch['text_lengths'],
        test_batch['waveform'],
        test_batch['waveform_lengths'],
    )
    print("‚úÖ Signature 2 (model() with 4 args) works!")
    forward_signature = 2
except Exception as e:
    print(f"‚ùå Signature 2 failed: {e}")

# Signature 3: Using model.forward_train
try:
    if hasattr(model, 'forward_train'):
        output3 = model.forward_train(
            test_batch['text_input'],
            test_batch['text_lengths'],
            test_batch['waveform'],
            test_batch['waveform_lengths'],
            test_batch['speaker_ids']
        )
        print("‚úÖ Signature 3 (forward_train) works!")
        forward_signature = 3
    else:
        print("‚ÑπÔ∏è model.forward_train not available")
except Exception as e:
    print(f"‚ùå Signature 3 failed: {e}")

# Based on which signature works, create the trainer
class SimpleTrainer:
    def __init__(self, model, train_loader, eval_loader, config, output_dir):
        self.model = model
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.config = config
        self.output_dir = output_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Move model to device
        self.model.to(self.device)
        
        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.lr,
            betas=config.optimizer_params.get('betas', [0.8, 0.99]),
            eps=config.optimizer_params.get('eps', 1e-9),
            weight_decay=config.optimizer_params.get('weight_decay', 0.01)
        )
        
        # Setup scheduler
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer,
            gamma=config.lr_scheduler_params.get('gamma', 0.999875)
        )
        
        # Training state
        self.current_epoch = 0
        self.global_step = 0
        self.best_loss = float('inf')
        
        # Determine which forward signature to use
        self.forward_signature = 2  # Default to 4-argument version
        
        print(f"‚úÖ SimpleTrainer initialized")
        print(f"   Device: {self.device}")
        print(f"   Forward signature: {self.forward_signature}")
        print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        
    def forward_model(self, batch):
        """Forward pass with the correct signature"""
        if self.forward_signature == 1:
            return self.model(
                batch['text_input'],
                batch['text_lengths'],
                batch['waveform'],
                batch['waveform_lengths'],
                batch['speaker_ids']
            )
        else:  # signature 2
            return self.model(
                batch['text_input'],
                batch['text_lengths'],
                batch['waveform'],
                batch['waveform_lengths']
            )
    
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(self.train_loader):
            # Move batch to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.forward_model(batch)
            
            # Get loss - VITS usually returns a dict or tuple
            if isinstance(outputs, dict):
                if 'loss' in outputs:
                    loss = outputs['loss']
                else:
                    # Try to find any loss in the dict
                    for key in outputs:
                        if 'loss' in key:
                            loss = outputs[key]
                            break
                    else:
                        loss = outputs[list(outputs.keys())[0]]  # First item
            elif isinstance(outputs, tuple):
                # Usually first element is loss
                loss = outputs[0]
            else:
                # Assume outputs is the loss tensor
                loss = outputs
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Optimizer step
            self.optimizer.step()
            
            # Update statistics
            total_loss += loss.item()
            num_batches += 1
            self.global_step += 1
            
            # Print progress
            if (batch_idx + 1) % self.config.print_step == 0:
                avg_loss = total_loss / num_batches
                print(f"   Step {self.global_step}: Loss = {avg_loss:.4f}")
                
                # Save checkpoint periodically
                if self.global_step % self.config.save_step == 0:
                    self.save_checkpoint(f"checkpoint_{self.global_step}")
        
        return total_loss / max(num_batches, 1)
    
    def evaluate(self):
        """Evaluate on validation set"""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in self.eval_loader:
                # Move batch to device
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Forward pass
                outputs = self.forward_model(batch)
                
                # Get loss
                if isinstance(outputs, dict):
                    if 'loss' in outputs:
                        loss = outputs['loss']
                    else:
                        for key in outputs:
                            if 'loss' in key:
                                loss = outputs[key]
                                break
                        else:
                            loss = outputs[list(outputs.keys())[0]]
                elif isinstance(outputs, tuple):
                    loss = outputs[0]
                else:
                    loss = outputs
                    
                total_loss += loss.item()
                num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def save_checkpoint(self, name):
        """Save checkpoint"""
        checkpoint_path = os.path.join(self.output_dir, f"{name}.pth")
        
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'epoch': self.current_epoch,
            'global_step': self.global_step,
            'best_loss': self.best_loss,
            'config': self.config.__dict__,
        }
        
        torch.save(checkpoint, checkpoint_path)
        print(f"   üíæ Checkpoint saved: {checkpoint_path}")
    
    def fit(self, epochs):
        """Main training loop"""
        print(f"\nüöÄ Starting training for {epochs} epochs...")
        print(f"   Training samples: {len(self.train_loader.dataset)}")
        print(f"   Validation samples: {len(self.eval_loader.dataset)}")
        print(f"   Batch size: {self.config.batch_size}")
        print(f"   Learning rate: {self.config.lr}")
        
        for epoch in range(epochs):
            self.current_epoch = epoch
            print(f"\nüìà Epoch {epoch + 1}/{epochs}")
            
            # Train
            train_loss = self.train_epoch()
            print(f"   Train Loss: {train_loss:.4f}")
            
            # Evaluate
            val_loss = self.evaluate()
            print(f"   Val Loss: {val_loss:.4f}")
            
            # Update learning rate
            self.scheduler.step()
            
            # Save best model
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self.save_checkpoint("best_model")
            
            # Save periodic checkpoint
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(f"epoch_{epoch + 1}")
        
        print(f"\nüéâ Training completed!")
        self.save_checkpoint("final_model")

# Test forward pass one more time to confirm
print("\nüß™ Final forward pass test...")
try:
    test_batch = next(iter(train_loader))
    if torch.cuda.is_available():
        test_batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}
    
    # Try the 4-argument version (most likely)
    outputs = model(
        test_batch['text_input'],
        test_batch['text_lengths'],
        test_batch['waveform'],
        test_batch['waveform_lengths']
    )
    
    print(f"‚úÖ Forward test successful!")
    
    if isinstance(outputs, dict):
        print(f"   Output type: dict with keys: {list(outputs.keys())}")
        for key, value in outputs.items():
            if hasattr(value, 'shape'):
                print(f"   {key}: shape {value.shape}")
            elif torch.is_tensor(value):
                print(f"   {key}: tensor with value {value.item():.4f}")
    elif isinstance(outputs, tuple):
        print(f"   Output type: tuple with {len(outputs)} elements")
        for i, value in enumerate(outputs):
            if hasattr(value, 'shape'):
                print(f"   [{i}]: shape {value.shape}")
            elif torch.is_tensor(value):
                print(f"   [{i}]: tensor with value {value.item():.4f}")
    
    # Create simple trainer
    print("\nüë®‚Äçüè´ Creating SimpleTrainer...")
    simple_trainer = SimpleTrainer(
        model=model,
        train_loader=train_loader,
        eval_loader=eval_loader,
        config=config,
        output_dir=os.path.join(OUTPUT, config.run_name)
    )
    
except Exception as e:
    print(f"‚ùå Forward test failed: {e}")
    print("\nüí° Manual workaround: Let's check the actual VITS source code...")
    
    # Try to import and check
    import inspect
    print("üîç Model class source location:")
    print(f"   {model.__class__.__module__}.{model.__class__.__name__}")
    
    # Show available methods
    print("\nüìã Available methods in model:")
    methods = [m for m in dir(model) if not m.startswith('_')]
    for method in methods[:10]:  # Show first 10
        print(f"   - {method}")
    
    # Check if there's a train_step method
    if hasattr(model, 'train_step'):
        print("\n‚úÖ Found train_step method! Using that...")
    else:
        print("\n‚ùì No train_step found. Let's try to call the model directly...")


üîç Creating custom data loader...
üìä Creating datasets...
   Original samples: 100
   Filtered samples: 55
   Original samples: 20
   Filtered samples: 8
‚úÖ Data loaders created:
   Train batches: 14
   Eval batches: 4

üß™ Testing one batch...
‚úÖ Batch loaded successfully!
   Batch size: 4
   Text shape: torch.Size([4, 69])
   Text lengths: tensor([69, 69, 69, 67])
   Audio shape: torch.Size([4, 1, 94976])
   Audio lengths: tensor([87760, 55255, 56890, 94976])

üß™ Testing model forward pass...
   Standard signature failed: Vits.forward() got an unexpected keyword argument 'sid'
‚ùå Forward pass failed: Vits.forward() missing 1 required positional argument: 'waveform'

üîç Checking model forward method signature...
   Forward signature: (x: <built-in method tensor of type object at 0x00007FFF9029E8C0>, x_lengths: <built-in method tensor of type object at 0x00007FFF9029E8C0>, y: <built-in method tensor of type object at 0x00007FFF9029E8C0>, y_lengths: <built-in method tensor 

In [None]:
# ===== STEP 14: CREATE CUSTOM TRAINER USING train_step METHOD =====
print("\nüîß Creating custom training loop using train_step method...")

# Test the train_step method
print("üß™ Testing train_step method...")

try:
    test_batch = next(iter(train_loader))
    if torch.cuda.is_available():
        test_batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}
    
    # Try train_step
    outputs = model.train_step(
        test_batch['text_input'],
        test_batch['text_lengths'],
        test_batch['waveform'],
        test_batch['waveform_lengths']
    )
    
    print(f"‚úÖ train_step successful!")
    
    if isinstance(outputs, dict):
        print(f"   Output type: dict")
        print(f"   Keys: {list(outputs.keys())}")
        
        # Show loss values
        for key, value in outputs.items():
            if 'loss' in key.lower() and torch.is_tensor(value):
                print(f"   {key}: {value.item():.4f}")
    else:
        print(f"   Output type: {type(outputs)}")
        
except Exception as e:
    print(f"‚ùå train_step failed: {e}")
    import traceback
    traceback.print_exc()

# Create trainer that uses train_step
class VitsTrainer:
    def __init__(self, model, train_loader, eval_loader, config, output_dir, tokenizer, ap):
        self.model = model
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.config = config
        self.output_dir = output_dir
        self.tokenizer = tokenizer
        self.ap = ap
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Move model to device
        self.model.to(self.device)
        
        # Setup optimizer based on config
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.lr,
            betas=config.optimizer_params.get('betas', [0.8, 0.99]),
            eps=config.optimizer_params.get('eps', 1e-9),
            weight_decay=config.optimizer_params.get('weight_decay', 0.01)
        )
        
        # Setup scheduler
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer,
            gamma=config.lr_scheduler_params.get('gamma', 0.999875)
        )
        
        # Training state
        self.current_epoch = 0
        self.global_step = 0
        self.best_loss = float('inf')
        
        # Loss tracking
        self.train_losses = []
        self.val_losses = []
        
        print(f"‚úÖ VitsTrainer initialized")
        print(f"   Device: {self.device}")
        print(f"   Training samples: {len(train_loader.dataset)}")
        print(f"   Validation samples: {len(eval_loader.dataset)}")
        print(f"   Batch size: {config.batch_size}")
        
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(self.train_loader):
            # Move batch to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            # Forward pass using train_step
            self.optimizer.zero_grad()
            outputs = self.model.train_step(
                batch['text_input'],
                batch['text_lengths'],
                batch['waveform'],
                batch['waveform_lengths']
            )
            
            # Extract loss - train_step returns a dict
            if isinstance(outputs, dict):
                # VITS train_step returns dict with 'loss' and other losses
                if 'loss' in outputs:
                    loss = outputs['loss']
                else:
                    # Try to find the main loss
                    for key in outputs:
                        if 'loss' in key and not key.startswith('loss_'):
                            loss = outputs[key]
                            break
                    else:
                        # Use the first loss found
                        for key in outputs:
                            if 'loss' in key:
                                loss = outputs[key]
                                break
                        else:
                            raise ValueError("No loss found in train_step outputs")
            else:
                raise ValueError(f"Unexpected output type from train_step: {type(outputs)}")
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Optimizer step
            self.optimizer.step()
            
            # Update statistics
            total_loss += loss.item()
            num_batches += 1
            self.global_step += 1
            
            # Print progress
            if (batch_idx + 1) % self.config.print_step == 0:
                avg_loss = total_loss / num_batches
                current_lr = self.optimizer.param_groups[0]['lr']
                
                # Get detailed losses
                loss_details = ""
                if isinstance(outputs, dict):
                    for key, value in outputs.items():
                        if 'loss' in key and torch.is_tensor(value):
                            loss_details += f", {key}: {value.item():.4f}"
                
                print(f"   Step {self.global_step}: Loss = {avg_loss:.4f}, LR = {current_lr:.6f}{loss_details}")
                
                # Log to tensorboard if available
                if hasattr(self, 'writer'):
                    self.writer.add_scalar('train/loss', avg_loss, self.global_step)
                    for key, value in outputs.items():
                        if 'loss' in key and torch.is_tensor(value):
                            self.writer.add_scalar(f'train/{key}', value.item(), self.global_step)
                
                # Save checkpoint periodically
                if self.global_step % self.config.save_step == 0:
                    self.save_checkpoint(f"checkpoint_{self.global_step}")
        
        return total_loss / max(num_batches, 1)
    
    def evaluate(self):
        """Evaluate on validation set"""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in self.eval_loader:
                # Move batch to device
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model.train_step(
                    batch['text_input'],
                    batch['text_lengths'],
                    batch['waveform'],
                    batch['waveform_lengths']
                )
                
                # Extract loss
                if isinstance(outputs, dict):
                    if 'loss' in outputs:
                        loss = outputs['loss']
                    else:
                        for key in outputs:
                            if 'loss' in key:
                                loss = outputs[key]
                                break
                        else:
                            loss = outputs[list(outputs.keys())[0]]
                else:
                    loss = outputs
                    
                total_loss += loss.item()
                num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def save_checkpoint(self, name):
        """Save checkpoint"""
        checkpoint_path = os.path.join(self.output_dir, f"{name}.pth")
        
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'epoch': self.current_epoch,
            'global_step': self.global_step,
            'best_loss': self.best_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'config': self.config.__dict__,
            'tokenizer_config': {
                'vocab_size': len(self.tokenizer.characters.characters),
                'characters': self.tokenizer.characters.characters,
            },
            'audio_config': {
                'sample_rate': self.ap.sample_rate,
                'num_mels': self.ap.num_mels,
            }
        }
        
        torch.save(checkpoint, checkpoint_path)
        print(f"   üíæ Checkpoint saved: {checkpoint_path}")
        
        # Also save config separately
        config_path = os.path.join(self.output_dir, "config.json")
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(self.config.__dict__, f, indent=2, default=str)
    
    def fit(self, epochs):
        """Main training loop"""
        print(f"\nüöÄ Starting training for {epochs} epochs...")
        print(f"   Total steps per epoch: {len(self.train_loader)}")
        print(f"   Total training steps: {len(self.train_loader) * epochs}")
        print(f"   Learning rate: {self.config.lr}")
        print(f"   Output directory: {self.output_dir}")
        
        # Try to setup tensorboard
        try:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.output_dir)
            print(f"   TensorBoard: tensorboard --logdir={self.output_dir}")
        except:
            print("   TensorBoard not available, skipping...")
        
        start_time = time.time()
        
        for epoch in range(epochs):
            self.current_epoch = epoch
            epoch_start_time = time.time()
            
            print(f"\n{'='*60}")
            print(f"üìà Epoch {epoch + 1}/{epochs}")
            print(f"{'='*60}")
            
            # Train
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)
            
            # Evaluate
            val_loss = self.evaluate()
            self.val_losses.append(val_loss)
            
            epoch_time = time.time() - epoch_start_time
            total_time = time.time() - start_time
            
            print(f"\nüìä Epoch {epoch + 1} Summary:")
            print(f"   Train Loss: {train_loss:.4f}")
            print(f"   Val Loss:   {val_loss:.4f}")
            print(f"   Epoch Time: {epoch_time:.1f}s")
            print(f"   Total Time: {total_time:.1f}s")
            print(f"   Global Step: {self.global_step}")
            print(f"   Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Log to tensorboard
            if hasattr(self, 'writer'):
                self.writer.add_scalar('epoch/train_loss', train_loss, epoch)
                self.writer.add_scalar('epoch/val_loss', val_loss, epoch)
                self.writer.add_scalar('epoch/lr', self.optimizer.param_groups[0]['lr'], epoch)
            
            # Update learning rate
            self.scheduler.step()
            
            # Save best model
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self.save_checkpoint("best_model")
                print(f"   üèÜ New best model! Loss: {val_loss:.4f}")
            
            # Save periodic checkpoint
            if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
                self.save_checkpoint(f"epoch_{epoch + 1}")
            
            # Generate test samples every 10 epochs
            if (epoch + 1) % 10 == 0:
                self.generate_test_samples(epoch + 1)
        
        print(f"\n{'='*60}")
        print(f"üéâ TRAINING COMPLETED!")
        print(f"{'='*60}")
        print(f"   Total epochs: {epochs}")
        print(f"   Total steps: {self.global_step}")
        print(f"   Best validation loss: {self.best_loss:.4f}")
        print(f"   Total training time: {time.time() - start_time:.1f}s")
        
        # Save final model
        self.save_checkpoint("final_model")
        
        # Close tensorboard writer
        if hasattr(self, 'writer'):
            self.writer.close()
    
    def generate_test_samples(self, epoch):
        """Generate test audio samples"""
        print(f"\nüé§ Generating test samples for epoch {epoch}...")
        
        test_dir = os.path.join(self.output_dir, "test_samples", f"epoch_{epoch}")
        os.makedirs(test_dir, exist_ok=True)
        
        test_texts = [
            "‡§®‡§Æ‡§∏‡•ç‡§§‡•á",
            "‡§ß‡§®‡•ç‡§Ø‡§µ‡§æ‡§¶",
            "‡§ï‡•á ‡§õ",
            "‡§Æ ‡§®‡•á‡§™‡§æ‡§≤‡•Ä ‡§π‡•Å‡§Å"
        ]
        
        self.model.eval()
        
        for i, text in enumerate(test_texts):
            try:
                # Use model's inference method
                outputs = self.model.inference(text)
                audio = outputs["model_outputs"].squeeze().cpu().numpy()
                
                # Save audio
                import soundfile as sf
                test_file = os.path.join(test_dir, f"test_{i:02d}_{text[:10]}.wav".replace(' ', '_'))
                sf.write(test_file, audio, self.ap.sample_rate)
                
                print(f"   ‚úÖ '{text}' ‚Üí {test_file}")
                
            except Exception as e:
                print(f"   ‚ùå Failed to generate '{text}': {e}")
        
        self.model.train()

# Create trainer
print("\nüë®‚Äçüè´ Creating VitsTrainer...")
vits_trainer = VitsTrainer(
    model=model,
    train_loader=train_loader,
    eval_loader=eval_loader,
    config=config,
    output_dir=os.path.join(OUTPUT, config.run_name),
    tokenizer=tokenizer,
    ap=ap
)


üîß Creating custom training loop using train_step method...
üß™ Testing train_step method...
‚ùå train_step failed: name 'train_loader' is not defined

üë®‚Äçüè´ Creating VitsTrainer...


Traceback (most recent call last):
  File "C:\Users\ReticleX\AppData\Local\Temp\ipykernel_836\3278468428.py", line 8, in <module>
    test_batch = next(iter(train_loader))
NameError: name 'train_loader' is not defined


NameError: name 'model' is not defined