# 🎓 Voice AI Training System - Training & Fine-tuning
## Dataset Preparation, Preprocessing, and Model Training

This notebook will:
1. ✅ Download and prepare LJSpeech dataset
2. ✅ Preprocess audio (resample, normalize, trim)
3. ✅ Split data into train/validation sets
4. ✅ Select optimal model based on GPU memory
5. ✅ Train/fine-tune TTS model with progress tracking
6. ✅ Validate and save checkpoints to Google Drive

**⚠️ Important:** Ensure setup.ipynb has been run first!

## Step 1: Import Libraries and Setup Paths

In [None]:
import os
import sys
import torch
import torchaudio
import librosa
import soundfile as sf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Setup paths
BASE_DIR = "/content/voiceai"
DRIVE_DIR = "/content/drive/MyDrive/voiceai"
DATASET_DIR = f"{BASE_DIR}/dataset"
PROCESSED_DIR = f"{BASE_DIR}/processed"
CHECKPOINT_DIR = f"{DRIVE_DIR}/checkpoints"
OUTPUT_DIR = f"{DRIVE_DIR}/outputs"
LOG_DIR = f"{DRIVE_DIR}/logs"

# Create directories if they don't exist
for dir_path in [BASE_DIR, DATASET_DIR, PROCESSED_DIR, CHECKPOINT_DIR, OUTPUT_DIR, LOG_DIR]:
    os.makedirs(dir_path, exist_ok=True)

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Device: {device}")
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"🔧 GPU Memory: {gpu_memory:.1f} GB")
    print(f"🔧 GPU Name: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ No GPU detected! Training will be very slow.")

print(f"\n📂 Paths configured:")
print(f"  • Dataset: {DATASET_DIR}")
print(f"  • Processed: {PROCESSED_DIR}")
print(f"  • Checkpoints: {CHECKPOINT_DIR}")
print(f"  • Outputs: {OUTPUT_DIR}")

## Step 2: Download LJSpeech Dataset

In [None]:
import urllib.request
import tarfile
from pathlib import Path

# LJSpeech download URL
LJSPEECH_URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
LJSPEECH_PATH = f"{DATASET_DIR}/LJSpeech-1.1"

print("📥 Downloading LJSpeech dataset...")
print("="*60)

# Check if already downloaded
if os.path.exists(LJSPEECH_PATH) and os.path.exists(f"{LJSPEECH_PATH}/metadata.csv"):
    print("✅ LJSpeech dataset already exists!")
    print(f"   Path: {LJSPEECH_PATH}")
else:
    try:
        # Download dataset
        tar_path = f"{DATASET_DIR}/LJSpeech-1.1.tar.bz2"
        
        if not os.path.exists(tar_path):
            print("⏳ Downloading... This may take 5-10 minutes (2.6 GB)")
            
            def download_progress(block_num, block_size, total_size):
                downloaded = block_num * block_size
                percent = min(downloaded * 100 / total_size, 100)
                sys.stdout.write(f'\r  Progress: {percent:.1f}% ({downloaded/(1024**3):.2f} GB / {total_size/(1024**3):.2f} GB)')
                sys.stdout.flush()
            
            urllib.request.urlretrieve(LJSPEECH_URL, tar_path, download_progress)
            print("\n✅ Download complete!")
        
        # Extract dataset
        print("📦 Extracting dataset...")
        with tarfile.open(tar_path, 'r:bz2') as tar:
            tar.extractall(path=DATASET_DIR)
        
        print("✅ Extraction complete!")
        
        # Remove tar file to save space
        if os.path.exists(tar_path):
            os.remove(tar_path)
            print("✅ Cleaned up temporary files")
            
    except Exception as e:
        print(f"❌ Error downloading dataset: {e}")
        print("You can manually download from: https://keithito.com/LJ-Speech-Dataset/")
        raise

# Verify dataset
metadata_path = f"{LJSPEECH_PATH}/metadata.csv"
wavs_path = f"{LJSPEECH_PATH}/wavs"

if os.path.exists(metadata_path) and os.path.exists(wavs_path):
    num_wavs = len(list(Path(wavs_path).glob("*.wav")))
    print(f"\n✅ Dataset verified!")
    print(f"   • Metadata: {metadata_path}")
    print(f"   • Audio files: {num_wavs} wav files")
    print(f"   • Total size: ~2.6 GB")
else:
    print("❌ Dataset verification failed!")
    
print("="*60)

## Step 3: Load and Explore Dataset

In [None]:
import pandas as pd

print("📊 Loading dataset metadata...")
print("="*60)

# Load metadata
metadata_path = f"{LJSPEECH_PATH}/metadata.csv"
df = pd.read_csv(metadata_path, sep='|', header=None, names=['filename', 'transcript', 'normalized_transcript'])

print(f"✅ Loaded {len(df)} samples")
print(f"\n�� Dataset Statistics:")
print(f"   • Total samples: {len(df)}")
print(f"   • Columns: {list(df.columns)}")

# Display sample data
print(f"\n📝 Sample transcripts:")
print(df[['filename', 'transcript']].head())

# Check for missing values
print(f"\n🔍 Data Quality:")
print(f"   • Missing transcripts: {df['transcript'].isna().sum()}")
print(f"   • Missing filenames: {df['filename'].isna().sum()}")

# Calculate transcript lengths
df['transcript_length'] = df['transcript'].str.len()
print(f"\n📏 Transcript Length Statistics:")
print(f"   • Mean: {df['transcript_length'].mean():.1f} characters")
print(f"   • Min: {df['transcript_length'].min()} characters")
print(f"   • Max: {df['transcript_length'].max()} characters")

print("="*60)