# Echo-TTS Rap Style Fine-Tuning

This notebook implements LoRA-based fine-tuning for Echo-TTS to adapt the model to **rapping style** while preserving voice cloning capabilities.

**Dataset:**
- 225 rap acapellas (unprocessed)
- Requires preprocessing: segmentation, transcription

**Requirements:**
- GPU with 16GB+ VRAM (T4/A100 on Colab)

**What you'll get:**
- Fine-tuned model that generates rap-style speech
- LoRA checkpoint (~50-100MB) that can be loaded on top of base model
- Preserved voice cloning - can still use any speaker reference

## 1. Setup & Dependencies

In [None]:
# Clone the echo-tts repository
!rm -rf /content/echo-tts
!git clone https://github.com/CoreBedtime/echo-tts.git 2>/dev/null || echo "Repo already exists"

# Add to Python path
import sys
sys.path.insert(0, '/content/echo-tts')

# Change working directory
import os
os.chdir('/content/echo-tts')

# Install dependencies
!pip install -q torch torchaudio safetensors huggingface-hub einops
!pip install -q openai-whisper  # For automatic transcription
!pip install -q torchcodec  # For audio decoding

print(f"Working directory: {os.getcwd()}")

In [None]:
!pip install torch torchaudio safetensors

In [None]:
# Mount Google Drive for saving checkpoints (Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

In [None]:
import os
import sys
from pathlib import Path

import torch
import torchaudio
from IPython.display import Audio, display

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
import torch

# =============================================================================
# PATHS - Update these!
# =============================================================================

# Directory containing your training audio files
AUDIO_DIR = "/content/dataset/"  # Put your rap acapellas here

# Where to save checkpoints
OUTPUT_DIR = "./checkpoints/" 

# =============================================================================
# LoRA CONFIGURATION (optimized for 225 samples)
# =============================================================================

LORA_RANK = 32          # Higher rank for larger dataset (more expressiveness)
LORA_ALPHA = 32.0       # Scaling factor, typically equal to rank
LORA_DROPOUT = 0.05     # Lower dropout - more data means less regularization needed

# Which modules to train (default preserves voice cloning path)
# Speaker path (wk_speaker, wv_speaker) is NOT trained to preserve cloning
TARGET_MODULES = [
    # Main decoder attention (style)
    "blocks.*.attention.wq",
    "blocks.*.attention.wk",
    "blocks.*.attention.wv",
    "blocks.*.attention.wo",
    # Text cross-attention (text-to-audio mapping)
    "blocks.*.attention.wk_text",
    "blocks.*.attention.wv_text",
    # MLP layers (feature transformation)
    "blocks.*.mlp.w1",
    "blocks.*.mlp.w2",
    "blocks.*.mlp.w3",
]

# =============================================================================
# TRAINING CONFIGURATION (optimized for 225 samples)
# =============================================================================

LEARNING_RATE = 1e-4     # Slightly higher LR for larger dataset
NUM_EPOCHS = 10          # Fewer epochs needed with more data
BATCH_SIZE = 1           # Batch size (1 for memory efficiency)
GRADIENT_ACCUMULATION = 8  # Larger effective batch for stability
MAX_GRAD_NORM = 1.0      # Gradient clipping
WARMUP_STEPS = 100       # More warmup steps for larger dataset

# Audio settings
MAX_LATENT_LENGTH = 640  # Max ~30 seconds (reduce to 320 if OOM)
SEGMENT_DURATION = 25.0  # Split long audio into ~25 second chunks
MIN_SEGMENT_DURATION = 5.0  # Minimum segment length to keep

# Whisper model for transcription - LARGE-V3 for best rap lyrics accuracy!
WHISPER_MODEL = "large-v3"  # Options: tiny, base, small, medium, large-v3

# Parallel transcription settings
NUM_TRANSCRIPTION_WORKERS = 4  # Number of parallel workers (adjust based on CPU/GPU)
TRANSCRIPTION_BATCH_SIZE = 8   # Files per batch for progress updates

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16   # Use bfloat16 for training

# Validation split
VAL_SPLIT = 0.1  # 10% for validation (~22 samples)

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(AUDIO_DIR, exist_ok=True)

print(f"Audio directory: {AUDIO_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")
print(f"\nTraining config for 225 samples:")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"\nTranscription config:")
print(f"  Whisper model: {WHISPER_MODEL}")
print(f"  Parallel workers: {NUM_TRANSCRIPTION_WORKERS}")
print(f"  This will be MUCH faster than sequential!")

## 3. Load Base Models

In [None]:
from inference import (
    load_model_from_hf,
    load_fish_ae_from_hf,
    load_pca_state_from_hf,
    ae_decode,
    get_speaker_latent_and_mask,
    get_text_input_ids_and_mask,
    sample_euler_cfg_independent_guidances,
    sample_pipeline,
    load_audio,
)
from functools import partial

print("Loading EchoDiT model...")
model = load_model_from_hf(
    device=DEVICE,
    dtype=DTYPE,
    compile=False,  # Don't compile for training
    delete_blockwise_modules=True,  # Save memory
)
print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M parameters")

print("\nLoading Fish-S1-DAC autoencoder...")
fish_ae = load_fish_ae_from_hf(
    device=DEVICE,
    dtype=torch.float32,  # AE needs float32 for quality
)
print("Autoencoder loaded")

print("\nLoading PCA state...")
pca_state = load_pca_state_from_hf(device=DEVICE)
print("PCA state loaded")

print("\n" + "="*50)
print("All models loaded successfully!")
print("="*50)

## 4. Data Preparation

With 225 rap acapellas, we need to:
1. List and validate all audio files
2. Segment long tracks into training-sized chunks (~25 seconds)
3. Transcribe using Whisper (medium model for better rap lyrics)
4. Create train/validation split
5. Pre-encode all audio to latents (cached for fast training)

In [None]:
import os
from pathlib import Path

# List all audio files
audio_files = []
for ext in [".mp3", ".wav", ".flac", ".m4a", ".ogg"]:
    audio_files.extend(Path(AUDIO_DIR).glob(f"**/*{ext}"))

print(f"Found {len(audio_files)} audio files in {AUDIO_DIR}")

if len(audio_files) == 0:
    print("\n⚠️  No audio files found!")
    print(f"Please add your rap acapella files to: {AUDIO_DIR}")
else:
    print(f"\nSample files:")
    for f in list(audio_files)[:5]:
        print(f"  {f.name}")

In [None]:
import os
from pathlib import Path

# Transcribe all audio files using Whisper LARGE-V3 in PARALLEL
# This will be MUCH faster than sequential processing!

from train_utils import transcribe_audio_files_parallel
import json

TRANSCRIPTION_CACHE = os.path.join(OUTPUT_DIR, "transcriptions.json")

# Check if we have cached transcriptions
if os.path.exists(TRANSCRIPTION_CACHE):
    print(f"Loading cached transcriptions from {TRANSCRIPTION_CACHE}")
    with open(TRANSCRIPTION_CACHE, "r") as f:
        transcriptions = json.load(f)
    print(f"Loaded {len(transcriptions)} cached transcriptions")
    
    # Find files that still need transcription
    cached_paths = set(transcriptions.keys())
    files_to_transcribe = [f for f in audio_files if str(f) not in cached_paths]
    print(f"Files still needing transcription: {len(files_to_transcribe)}")
else:
    transcriptions = {}
    files_to_transcribe = audio_files

if len(files_to_transcribe) > 0:
    print(f"\n{'='*60}")
    print(f"PARALLEL TRANSCRIPTION with Whisper {WHISPER_MODEL}")
    print(f"{'='*60}")
    print(f"Files to transcribe: {len(files_to_transcribe)}")
    print(f"Workers: {NUM_TRANSCRIPTION_WORKERS}")
    print(f"\nEstimated time:")
    print(f"  Sequential (old): ~{len(files_to_transcribe) * 0.5:.0f} minutes")
    print(f"  Parallel (new): ~{len(files_to_transcribe) * 0.5 / NUM_TRANSCRIPTION_WORKERS:.0f} minutes")
    print(f"  Speedup: {NUM_TRANSCRIPTION_WORKERS}x faster!\n")
    
    # Use parallel transcription (MUCH faster!)
    batch_transcriptions = transcribe_audio_files_parallel(
        audio_paths=[str(f) for f in files_to_transcribe],
        model_name=WHISPER_MODEL,
        language="en",
        num_workers=NUM_TRANSCRIPTION_WORKERS,
        batch_size=TRANSCRIPTION_BATCH_SIZE,
    )
    
    transcriptions.update(batch_transcriptions)
    
    # Save progress
    with open(TRANSCRIPTION_CACHE, "w") as f:
        json.dump(transcriptions, f, indent=2)
    
    print(f"\nSaved to: {TRANSCRIPTION_CACHE}")

print(f"\n{'='*60}")
print(f"Transcription complete: {len(transcriptions)} files")
print(f"{'='*60}")

# Show a few samples
print(f"\nSample transcriptions:")
for i, (path, text) in enumerate(list(transcriptions.items())[:3]):
    filename = Path(path).name
    print(f"\n{filename}:")
    print(f"  {text[:150]}{'...' if len(text) > 150 else ''}")

In [None]:
# Optional: Edit transcriptions manually if Whisper made mistakes
# Uncomment and modify as needed:

# transcriptions["/path/to/file.mp3"] = "[S1] Your corrected transcription here."

# Tips for transcriptions:
# - Start with [S1] for single speaker
# - Use commas for pauses
# - Exclamation marks increase expressiveness
# - Keep punctuation natural

In [None]:
# Create training dataset with train/val split
from train_utils import (
    TrainingSample,
    EchoTTSDataset,
    collate_fn,
    segment_audio,
    load_audio_tensor,
)
from torch.utils.data import DataLoader
import random

# Create training samples
all_samples = []
for path, text in transcriptions.items():
    if text and len(text.strip()) > 10:  # Filter out empty/very short transcriptions
        all_samples.append(TrainingSample(
            audio_path=path,
            text=text,
            speaker_audio_path=None,  # Use same audio as speaker reference
        ))

print(f"Created {len(all_samples)} training samples (filtered {len(transcriptions) - len(all_samples)} empty)")

# Shuffle and split into train/val
random.seed(42)
random.shuffle(all_samples)

val_size = max(1, int(len(all_samples) * VAL_SPLIT))
train_samples = all_samples[val_size:]
val_samples = all_samples[:val_size]

print(f"Train samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")

# Create datasets
print("\nCreating training dataset and encoding audio to latents...")
print("This will take a few minutes for 225 files...\n")

train_dataset = EchoTTSDataset(
    samples=train_samples,
    fish_ae=fish_ae,
    pca_state=pca_state,
    device=DEVICE,
    max_latent_length=MAX_LATENT_LENGTH,
    cache_latents=True,
)

print("\nCreating validation dataset...")
val_dataset = EchoTTSDataset(
    samples=val_samples,
    fish_ae=fish_ae,
    pca_state=pca_state,
    device=DEVICE,
    max_latent_length=MAX_LATENT_LENGTH,
    cache_latents=True,
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0,
)

print(f"\n{'='*50}")
print(f"Dataset ready!")
print(f"  Train: {len(train_dataset)} samples, {len(train_dataloader)} batches/epoch")
print(f"  Val: {len(val_dataset)} samples, {len(val_dataloader)} batches")
print(f"{'='*50}")

## 5. Apply LoRA to Model

In [None]:
from lora import (
    apply_lora_to_model,
    count_parameters,
    get_lora_params,
    save_lora_checkpoint,
    load_lora_checkpoint,
)

# Apply LoRA adapters to the model
print("Applying LoRA adapters...")
print(f"  Rank: {LORA_RANK}")
print(f"  Alpha: {LORA_ALPHA}")
print(f"  Dropout: {LORA_DROPOUT}")
print(f"  Target modules: {len(TARGET_MODULES)} patterns")

model, lora_modules = apply_lora_to_model(
    model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
    dropout=LORA_DROPOUT,
    target_modules=TARGET_MODULES,
)

# Count parameters
total_params, trainable_params = count_parameters(model)
print(f"\nParameter counts:")
print(f"  Total: {total_params / 1e6:.1f}M")
print(f"  Trainable (LoRA): {trainable_params / 1e6:.2f}M ({100 * trainable_params / total_params:.2f}%)")
print(f"  LoRA modules applied: {len(lora_modules)}")

## 6. Training

In [None]:
from train_utils import train_epoch, get_cosine_schedule_with_warmup, training_step

# Setup optimizer (only LoRA params)
lora_params = get_lora_params(model)
optimizer = torch.optim.AdamW(
    lora_params,
    lr=LEARNING_RATE,
    weight_decay=0.01,
    betas=(0.9, 0.999),
)

# Learning rate scheduler
num_training_steps = len(train_dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(WARMUP_STEPS, num_training_steps // 10),
    num_training_steps=num_training_steps,
)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler()

print("Training setup:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Training steps: {num_training_steps}")
print(f"  Warmup steps: {min(WARMUP_STEPS, num_training_steps // 10)}")
print(f"  Train batches/epoch: {len(train_dataloader)}")
print(f"  Val batches: {len(val_dataloader)}")

# Validation function
@torch.no_grad()
def validate(model, val_dataloader, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    for batch in val_dataloader:
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            loss = training_step(model, batch, device)
        
        # Skip NaN losses in validation too
        if not torch.isnan(loss) and not torch.isinf(loss):
            total_loss += loss.item()
            num_batches += 1
    
    model.train()
    
    if num_batches == 0:
        return float('nan')
    
    return total_loss / num_batches

In [None]:
# Debug: Test a single training step to diagnose NaN issues
print("Running diagnostic training step...")

# Get a single batch
test_batch = next(iter(train_dataloader))

# Check batch data
print("\nBatch data statistics:")
print(f"  Latent shape: {test_batch['latent'].shape}")
print(f"  Latent range: [{test_batch['latent'].min():.4f}, {test_batch['latent'].max():.4f}]")
print(f"  Latent has NaN: {torch.isnan(test_batch['latent']).any()}")
print(f"  Speaker latent shape: {test_batch['speaker_latent'].shape}")
print(f"  Speaker latent range: [{test_batch['speaker_latent'].min():.4f}, {test_batch['speaker_latent'].max():.4f}]")
print(f"  Speaker latent has NaN: {torch.isnan(test_batch['speaker_latent']).any()}")
print(f"  Texts: {test_batch['text']}")

# Try forward pass
try:
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        loss = training_step(model, test_batch, DEVICE)
    print(f"\nTest loss: {loss.item():.4f}")
    print("✓ Forward pass successful!")
    
    # Try backward pass
    loss.backward()
    print("✓ Backward pass successful!")
    
    # Check gradients
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            grad_norms.append((name, grad_norm))
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print(f"  ⚠️  NaN/Inf gradient in {name}")
    
    if grad_norms:
        # Sort by gradient magnitude
        grad_norms.sort(key=lambda x: x[1], reverse=True)
        print(f"\nTop 5 gradient magnitudes:")
        for name, norm in grad_norms[:5]:
            print(f"  {name}: {norm:.4f}")
    
    optimizer.zero_grad()
    
except Exception as e:
    print(f"\n❌ Error during test step: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Training loop with validation
print("\n" + "="*50)
print("Starting training...")
print("="*50 + "\n")

history = {"train_loss": [], "val_loss": [], "epoch": [], "lr": []}
best_val_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    # Train one epoch (scheduler is now passed and called inside train_epoch)
    train_loss = train_epoch(
        model=model,
        dataloader=train_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,  # Pass scheduler to training loop
        device=DEVICE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        max_grad_norm=MAX_GRAD_NORM,
        scaler=scaler,
    )
    
    # Validate
    val_loss = validate(model, val_dataloader, DEVICE)
    
    # Get current LR (scheduler is stepped inside train_epoch now)
    current_lr = scheduler.get_last_lr()[0]
    
    # Record history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["epoch"].append(epoch + 1)
    history["lr"].append(current_lr)
    
    # Print progress
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Train: {train_loss:.4f} - Val: {val_loss:.4f} - LR: {current_lr:.2e}")
    
    # Save best checkpoint (based on validation loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_lora_checkpoint(
            model,
            os.path.join(OUTPUT_DIR, "lora_best.pt"),
            config={
                "rank": LORA_RANK,
                "alpha": LORA_ALPHA,
                "dropout": LORA_DROPOUT,
                "target_modules": TARGET_MODULES,
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
            }
        )
        print(f"  -> Saved best checkpoint (val_loss: {best_val_loss:.4f})")
    
    # Periodic checkpoint every 2 epochs
    if (epoch + 1) % 2 == 0:
        save_lora_checkpoint(
            model,
            os.path.join(OUTPUT_DIR, f"lora_epoch_{epoch + 1}.pt"),
            config={
                "rank": LORA_RANK,
                "alpha": LORA_ALPHA,
                "dropout": LORA_DROPOUT,
                "target_modules": TARGET_MODULES,
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
            }
        )

# Save final checkpoint
save_lora_checkpoint(
    model,
    os.path.join(OUTPUT_DIR, "lora_final.pt"),
    config={
        "rank": LORA_RANK,
        "alpha": LORA_ALPHA,
        "dropout": LORA_DROPOUT,
        "target_modules": TARGET_MODULES,
        "epoch": NUM_EPOCHS,
        "train_loss": history["train_loss"][-1],
        "val_loss": history["val_loss"][-1],
    }
)

print("\n" + "="*50)
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Checkpoints saved to: {OUTPUT_DIR}")
print("="*50)

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Training and validation loss
axes[0].plot(history["epoch"], history["train_loss"], 'b-', linewidth=2, label='Train')
axes[0].plot(history["epoch"], history["val_loss"], 'r--', linewidth=2, label='Val')
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training vs Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Training loss only (zoomed)
axes[1].plot(history["epoch"], history["train_loss"], 'b-', linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].set_title("Training Loss")
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(history["epoch"], history["lr"], 'g-', linewidth=2)
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Learning Rate")
axes[2].set_title("Learning Rate Schedule")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "training_curves.png"), dpi=150)
plt.show()

# Print summary
print(f"\nTraining Summary:")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history['val_loss'][-1]:.4f}")
print(f"  Best val loss: {best_val_loss:.4f}")

## 7. Evaluation & Inference

Generate samples with your fine-tuned model and compare to the base model.

In [None]:
# Set model to eval mode
model.eval()

# Test prompts - rap style!
TEST_PROMPTS = [
    "[S1] Yeah, I'm spitting fire on the mic tonight, gonna show them what I got, rising to the top!",
    "[S1] The rhythm flows through me like water, every beat hits harder, I'm a natural born starter!",
    "[S1] Check it out, I'm the one they've been waiting for, coming through the door, ready to explore!",
    "[S1] Money on my mind, grind never stops, from the bottom to the top, watch me drop!",
    "[S1] Real recognize real, that's the deal, keep it trill, got the skill to make you feel!",
]

# Use a random training file as speaker reference
SPEAKER_AUDIO_PATH = random.choice(audio_files) if audio_files else None

print(f"Using speaker reference: {SPEAKER_AUDIO_PATH}")
print(f"\nWill generate {len(TEST_PROMPTS)} samples...")

In [None]:
# Generate samples with fine-tuned model
@torch.inference_mode()
def generate_audio(model, text, speaker_audio_path=None, seed=0):
    """Generate audio using the (fine-tuned) model."""
    
    # Load speaker audio
    if speaker_audio_path:
        speaker_audio = load_audio(speaker_audio_path)
    else:
        speaker_audio = None
    
    # Create sample function
    sample_fn = partial(
        sample_euler_cfg_independent_guidances,
        num_steps=40,
        cfg_scale_text=3.0,
        cfg_scale_speaker=8.0,
        cfg_min_t=0.5,
        cfg_max_t=1.0,
        truncation_factor=0.8,
        rescale_k=None,
        rescale_sigma=None,
        speaker_kv_scale=None,
        speaker_kv_max_layers=None,
        speaker_kv_min_t=None,
        sequence_length=640,
    )
    
    # Generate
    audio_out, normalized_text = sample_pipeline(
        model=model,
        fish_ae=fish_ae,
        pca_state=pca_state,
        sample_fn=sample_fn,
        text_prompt=text,
        speaker_audio=speaker_audio,
        rng_seed=seed,
    )
    
    return audio_out[0].cpu(), normalized_text

# Generate and play samples
print("Generating samples with fine-tuned model...\n")

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"Prompt {i + 1}: {prompt}")
    
    audio, _ = generate_audio(
        model,
        prompt,
        speaker_audio_path=str(SPEAKER_AUDIO_PATH) if SPEAKER_AUDIO_PATH else None,
        seed=i,
    )
    
    # Save audio
    output_path = os.path.join(OUTPUT_DIR, f"sample_{i + 1}.wav")
    torchaudio.save(output_path, audio.unsqueeze(0), 44100)
    print(f"Saved to: {output_path}")
    
    # Play audio
    display(Audio(audio.numpy(), rate=44100))
    print()

## 8. Load Checkpoint for Later Use

Use this section to load a saved LoRA checkpoint onto a fresh model.

In [None]:
# Example: Load a saved LoRA checkpoint
# Uncomment and run this cell to load a checkpoint

# CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "lora_best.pt")
# 
# # Load fresh base model
# model_fresh = load_model_from_hf(
#     device=DEVICE,
#     dtype=DTYPE,
#     compile=False,
#     delete_blockwise_modules=True,
# )
# 
# # Load checkpoint to get config
# checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
# config = checkpoint["config"]
# 
# # Apply LoRA with saved config
# model_fresh, _ = apply_lora_to_model(
#     model_fresh,
#     rank=config["rank"],
#     alpha=config["alpha"],
#     dropout=0.0,  # No dropout for inference
#     target_modules=config["target_modules"],
# )
# 
# # Load LoRA weights
# load_lora_checkpoint(model_fresh, CHECKPOINT_PATH, device=DEVICE)
# model_fresh.eval()
# 
# print(f"Loaded checkpoint from epoch {config['epoch']} (loss: {config['loss']:.4f})")

## 9. Tips & Notes for Rap Training

### What We Configured for 225 Samples
- **LoRA rank 32**: Higher rank captures more style nuance with larger dataset
- **Lower dropout (0.05)**: Less regularization needed with more data
- **Higher learning rate (1e-4)**: Can train faster with more data
- **Fewer epochs (10)**: More data means fewer passes needed
- **Whisper medium**: Better accuracy for rap lyrics than base model

### Training Expectations
With 225 rap acapellas:
- **Training time**: ~2-4 hours on T4, ~1-2 hours on A100
- **Expected final loss**: ~0.05-0.15 (lower is better)
- **Checkpoint size**: ~80-100MB

### If Results Sound Off
1. **Too monotone**: Increase `cfg_scale_text` to 4-5 during inference
2. **Wrong rhythm**: The model learned general rap style, not specific flows
3. **Voice doesn't match**: Try different speaker reference audio
4. **Gibberish output**: Check if transcriptions were accurate

### Common Issues

**Out of Memory (OOM)**:
- Reduce `MAX_LATENT_LENGTH` to 320 (15 seconds max)
- Use A100 GPU on Colab instead of T4
- Reduce `LORA_RANK` to 16

**Loss not decreasing**:
- Check transcriptions are accurate (rap lyrics are hard!)
- Try `WHISPER_MODEL = "large-v3"` for better transcription

**Validation loss increasing (overfitting)**:
- Increase `LORA_DROPOUT` to 0.1
- Reduce `NUM_EPOCHS`
- Reduce `LORA_RANK` to 16

### Voice Cloning Still Works!
The speaker path (wk_speaker, wv_speaker) was kept frozen, so you can:
- Use ANY speaker reference audio at inference time
- The rap style transfers to any voice
- Original voice cloning quality is preserved