In [None]:
# Launch the main Gradio app with LoRA support and public sharing
# This runs the gradio_app.py file which has the full UI and all features

print("\n" + "="*60)
print("üöÄ Launching Echo-TTS Gradio Interface...")
print("="*60)
print("\nThe Gradio app will:")
print("  - Load the base Echo-TTS model")
print("  - Allow loading your LoRA checkpoint via path input")
print("  - Generate a public shareable URL")
print("\nTo use your fine-tuned LoRA:")
print(f"  1. Enter this path in the LoRA section: {OUTPUT_DIR}lora_best.pt")
print("  2. Click Generate to test with rap style!\n")

# Set environment variable to enable public sharing
import os
os.environ["GRADIO_SHARE"] = "true"

# Run the gradio app
!python gradio_app.py

# Echo-TTS Controllable Rap Style Fine-Tuning

This notebook implements LoRA-based fine-tuning for Echo-TTS with **controllable rhythm/timing** - enabling you to separate voice identity from delivery style.

**New Feature: Controllable Rhythm/Timing**
- **Speaker Reference**: Controls WHO it sounds like (voice identity)
- **Content Reference**: Controls HOW it's delivered (rhythm, pacing, flow)
- Train the model to follow timing patterns from any audio!

**Dataset:**
- 173+ rap acapellas with transcriptions
- Requires: `dataset/transcriptions.json`

**Requirements:**
- GPU with 16GB+ VRAM (T4/A100 on Colab)

**What you'll get:**
- Fine-tuned model with controllable rhythm transfer
- LoRA checkpoint (~50-100MB) loadable on base model
- Preserved voice cloning - use any speaker reference

## 1. Setup & Dependencies

In [1]:
# Clone the echo-tts repository
!rm -rf /content/echo-tts
!git clone https://github.com/CoreBedtime/echo-tts.git 2>/dev/null || echo "Repo already exists"

# Add to Python path
import sys
sys.path.insert(0, '/content/echo-tts')

# Change working directory
import os
os.chdir('/content/echo-tts')

# Install dependencies
!pip install -q torch torchaudio safetensors huggingface-hub einops
!pip install -q transformers accelerate  # For Parakeet transcription (5-10x faster!)
!pip install -q torchcodec  # For audio decoding

print(f"Working directory: {os.getcwd()}")

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.1/2.1 MB[0m [31m95.1 MB/s[0m eta [36m0:00:00[0m
[?25hWorking directory: /content/echo-tts


In [None]:
# Mount Google Drive for saving checkpoints (Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

In [2]:
import os
import sys
from pathlib import Path

import torch
import torchaudio
from IPython.display import Audio, display

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
VRAM: 42.5 GB


## 2. Configuration

In [None]:
import torch

# =============================================================================
# PATHS - Update these!
# =============================================================================

# Directory containing your training audio files
AUDIO_DIR = "/content/dataset/"  # Put your rap acapellas here

# Where to save checkpoints
OUTPUT_DIR = "./checkpoints/" 

# =============================================================================
# LoRA CONFIGURATION (optimized for controllable rhythm training)
# =============================================================================

LORA_RANK = 32          # Higher rank for larger dataset (more expressiveness)
LORA_ALPHA = 32.0       # Scaling factor, typically equal to rank
LORA_DROPOUT = 0.05     # Lower dropout - more data means less regularization needed

# Which modules to train (includes latent path for controllable rhythm!)
# Speaker path (wk_speaker, wv_speaker) is NOT trained to preserve voice cloning
TARGET_MODULES = [
    # Main decoder attention (style)
    "blocks.*.attention.wq",
    "blocks.*.attention.wk",
    "blocks.*.attention.wv",
    "blocks.*.attention.wo",
    # Text cross-attention (text-to-audio mapping)
    "blocks.*.attention.wk_text",
    "blocks.*.attention.wv_text",
    # Latent cross-attention (CONTROLLABLE RHYTHM/TIMING!)
    "blocks.*.attention.wk_latent",
    "blocks.*.attention.wv_latent",
    # MLP layers (feature transformation)
    "blocks.*.mlp.w1",
    "blocks.*.mlp.w2",
    "blocks.*.mlp.w3",
]

# =============================================================================
# CONTROLLABLE RHYTHM CONFIGURATION
# =============================================================================

USE_CONTENT_CONDITIONING = True  # Enable rhythm/timing training
# When True, the model learns to follow timing from content latent
# This enables: same text + different rhythm reference = different delivery!

# =============================================================================
# TRAINING CONFIGURATION (optimized for 173+ samples)
# =============================================================================

LEARNING_RATE = 1e-4     # Slightly higher LR for larger dataset
NUM_EPOCHS = 10          # Fewer epochs needed with more data
BATCH_SIZE = 1           # Batch size (1 for memory efficiency)
GRADIENT_ACCUMULATION = 8  # Larger effective batch for stability
MAX_GRAD_NORM = 1.0      # Gradient clipping
WARMUP_STEPS = 100       # More warmup steps for larger dataset

# Audio settings
MAX_LATENT_LENGTH = 640  # Max ~30 seconds (reduce to 320 if OOM)
SEGMENT_DURATION = 25.0  # Split long audio into ~25 second chunks
MIN_SEGMENT_DURATION = 5.0  # Minimum segment length to keep

# Transcription model - for new files without transcriptions
TRANSCRIPTION_MODEL = "distil-whisper/distil-large-v3"
TRANSCRIPTION_BATCH_SIZE = 8

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16   # Use bfloat16 for training

# Validation split
VAL_SPLIT = 0.1  # 10% for validation

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(AUDIO_DIR, exist_ok=True)

print(f"Audio directory: {AUDIO_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")
print(f"\nControllable Rhythm Training: {'ENABLED' if USE_CONTENT_CONDITIONING else 'DISABLED'}")
print(f"\nTraining config:")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"\nLoRA targets: {len(TARGET_MODULES)} module patterns")
print(f"  (includes wk_latent, wv_latent for rhythm control)")

## 3. Load Base Models

In [None]:
from inference import (
    load_model_from_hf,
    load_fish_ae_from_hf,
    load_pca_state_from_hf,
    ae_decode,
    get_speaker_latent_and_mask,
    get_text_input_ids_and_mask,
    get_content_latent,
    sample_euler_cfg_independent_guidances,
    sample_pipeline,
    load_audio,
)
from functools import partial

print("Loading EchoDiT model...")
print("NOTE: Keeping latent_encoder for controllable rhythm/timing!")
model = load_model_from_hf(
    device=DEVICE,
    dtype=DTYPE,
    compile=False,  # Don't compile for training
    delete_blockwise_modules=False,  # KEEP latent_encoder for rhythm control!
)
print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M parameters")

print("\nLoading Fish-S1-DAC autoencoder...")
fish_ae = load_fish_ae_from_hf(
    device=DEVICE,
    dtype=torch.float32,  # AE needs float32 for quality
)
print("Autoencoder loaded")

print("\nLoading PCA state...")
pca_state = load_pca_state_from_hf(device=DEVICE)
print("PCA state loaded")

print("\n" + "="*50)
print("All models loaded successfully!")
print("Latent encoder available for controllable rhythm!")
print("="*50)

## 4. Data Preparation

With 225 rap acapellas, we need to:
1. List and validate all audio files
2. Transcribe FULL SONGS using **Parakeet** (5-10x faster than Whisper!)
   - Long songs are chunked into 25-second segments
   - Each segment gets its own precise transcription
   - This allows training on the ENTIRE song, not just intros!
3. Create train/validation split (from all segments)
4. Pre-encode all audio segments to latents (cached for fast training)

In [5]:
import os
from pathlib import Path

# List all audio files
audio_files = []
for ext in [".mp3", ".wav", ".flac", ".m4a", ".ogg"]:
    audio_files.extend(Path(AUDIO_DIR).glob(f"**/*{ext}"))

print(f"Found {len(audio_files)} audio files in {AUDIO_DIR}")

if len(audio_files) == 0:
    print("\n‚ö†Ô∏è  No audio files found!")
    print(f"Please add your rap acapella files to: {AUDIO_DIR}")
else:
    print(f"\nSample files:")
    for f in list(audio_files)[:5]:
        print(f"  {f.name}")

Found 0 audio files in /content/dataset/

‚ö†Ô∏è  No audio files found!
Please add your rap acapella files to: /content/dataset/


In [None]:
import os
from pathlib import Path

# Transcribe all audio files using NVIDIA Parakeet (5-10x FASTER than Whisper!)
# NOW WITH FULL SONG SUPPORT - chunks long audio and transcribes each segment!

from train_utils import transcribe_audio_files_parakeet
import json

TRANSCRIPTION_CACHE = os.path.join(OUTPUT_DIR, "transcriptions.json")

# IMPORTANT: Delete old transcriptions.json to re-transcribe with full song support!
# The old version only transcribed the first 30 seconds (intros)
# The new version chunks the FULL song and transcribes each 25s segment

# Check if we have cached transcriptions
if os.path.exists(TRANSCRIPTION_CACHE):
    print(f"Loading cached transcriptions from {TRANSCRIPTION_CACHE}")
    with open(TRANSCRIPTION_CACHE, "r") as f:
        transcriptions = json.load(f)
    print(f"Loaded {len(transcriptions)} cached transcriptions (segments)")
    print(f"NOTE: If these are old transcriptions (only intros), delete the file and re-run this cell!")
    
    # Find base files that need transcription (not counting segments)
    base_files_transcribed = set()
    for key in transcriptions.keys():
        base_path = key.split("#segment_")[0]
        base_files_transcribed.add(base_path)
    
    files_to_transcribe = [f for f in audio_files if str(f) not in base_files_transcribed]
    print(f"Base files still needing transcription: {len(files_to_transcribe)}")
else:
    transcriptions = {}
    files_to_transcribe = audio_files

if len(files_to_transcribe) > 0:
    print(f"\n{'='*60}")
    print(f"FULL SONG TRANSCRIPTION WITH PARAKEET")
    print(f"{'='*60}")
    print(f"Files to transcribe: {len(files_to_transcribe)}")
    print(f"\nEach song will be:")
    print(f"  - Chunked into 25-second segments (10% overlap)")
    print(f"  - Each segment transcribed separately")
    print(f"  - Result: FULL song coverage, not just intros!")
    print(f"\nEstimated time:")
    print(f"  Whisper large-v3: ~30-45 minutes")
    print(f"  Parakeet: ~10-15 minutes")
    print(f"  Speedup: 3-5x faster!\n")
    
    # Use Parakeet with chunking
    batch_transcriptions = transcribe_audio_files_parakeet(
        audio_paths=[str(f) for f in files_to_transcribe],
        model_name=TRANSCRIPTION_MODEL,
        batch_size=TRANSCRIPTION_BATCH_SIZE,
        chunk_duration=25.0,  # 25 second chunks
        overlap=0.1,  # 10% overlap
    )
    
    transcriptions.update(batch_transcriptions)
    
    # Save progress
    with open(TRANSCRIPTION_CACHE, "w") as f:
        json.dump(transcriptions, f, indent=2)
    
    print(f"\nSaved to: {TRANSCRIPTION_CACHE}")

print(f"\n{'='*60}")
print(f"Transcription complete: {len(transcriptions)} segments")
print(f"{'='*60}")

# Count base files
base_files = set()
for key in transcriptions.keys():
    base_path = key.split("#segment_")[0]
    base_files.add(base_path)
print(f"Total files transcribed: {len(base_files)}")
print(f"Total segments created: {len(transcriptions)}")
print(f"Average segments per file: {len(transcriptions) / max(len(base_files), 1):.1f}")

# Show a few samples
print(f"\nSample transcriptions (first 3 segments):")
for i, (path, text) in enumerate(list(transcriptions.items())[:3]):
    if "#segment_" in path:
        base_name = Path(path.split("#segment_")[0]).name
        segment_num = path.split("#segment_")[1]
        print(f"\n{base_name} [segment {segment_num}]:")
    else:
        print(f"\n{Path(path).name}:")
    print(f"  {text[:150]}{'...' if len(text) > 150 else ''}")

In [None]:
# Optional: Edit transcriptions manually if needed
# Uncomment and modify as needed:

# transcriptions["/path/to/file.mp3"] = "[S1] Your corrected transcription here."

# Tips for transcriptions:
# - Start with [S1] for single speaker
# - Use commas for pauses
# - Exclamation marks increase expressiveness
# - Keep punctuation natural

In [None]:
# Create training dataset with train/val split
from train_utils import (
    TrainingSample,
    EchoTTSDataset,
    collate_fn,
)
from torch.utils.data import DataLoader
import random

# Create training samples from transcriptions
all_samples = []
for path, text in transcriptions.items():
    if text and len(text.strip()) > 10:  # Filter out empty/very short transcriptions
        all_samples.append(TrainingSample(
            audio_path=path,
            text=text,
            speaker_audio_path=None,  # Use same audio as speaker reference
        ))

print(f"Created {len(all_samples)} training samples (filtered {len(transcriptions) - len(all_samples)} empty)")

if len(all_samples) == 0:
    print("\n‚ö†Ô∏è  ERROR: No valid training samples!")
    print("Please check that your audio files were transcribed correctly.")
else:
    # Shuffle and split into train/val
    random.seed(42)
    random.shuffle(all_samples)

    val_size = max(1, int(len(all_samples) * VAL_SPLIT))
    train_samples = all_samples[val_size:]
    val_samples = all_samples[:val_size]

    print(f"Train samples: {len(train_samples)}")
    print(f"Validation samples: {len(val_samples)}")

    # Create datasets
    print("\nCreating training dataset and encoding audio to latents...")
    print("This will take a few minutes...\n")

    train_dataset = EchoTTSDataset(
        samples=train_samples,
        fish_ae=fish_ae,
        pca_state=pca_state,
        device=DEVICE,
        max_latent_length=MAX_LATENT_LENGTH,
        cache_latents=True,
    )

    print("\nCreating validation dataset...")
    val_dataset = EchoTTSDataset(
        samples=val_samples,
        fish_ae=fish_ae,
        pca_state=pca_state,
        device=DEVICE,
        max_latent_length=MAX_LATENT_LENGTH,
        cache_latents=True,
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
    )

    print(f"\n{'='*50}")
    print(f"Dataset ready!")
    print(f"  Train: {len(train_dataset)} samples, {len(train_dataloader)} batches/epoch")
    print(f"  Val: {len(val_dataset)} samples, {len(val_dataloader)} batches")
    print(f"{'='*50}")

## 5. Apply LoRA to Model

In [None]:
from lora import (
    apply_lora_to_model,
    count_parameters,
    get_lora_params,
    save_lora_checkpoint,
    load_lora_checkpoint,
)

# Apply LoRA adapters to the model
print("Applying LoRA adapters...")
print(f"  Rank: {LORA_RANK}")
print(f"  Alpha: {LORA_ALPHA}")
print(f"  Dropout: {LORA_DROPOUT}")
print(f"  Target modules: {len(TARGET_MODULES)} patterns")

model, lora_modules = apply_lora_to_model(
    model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
    dropout=LORA_DROPOUT,
    target_modules=TARGET_MODULES,
)

# Count parameters
total_params, trainable_params = count_parameters(model)
print(f"\nParameter counts:")
print(f"  Total: {total_params / 1e6:.1f}M")
print(f"  Trainable (LoRA): {trainable_params / 1e6:.2f}M ({100 * trainable_params / total_params:.2f}%)")
print(f"  LoRA modules applied: {len(lora_modules)}")

## 6. Training

In [None]:
from train_utils import train_epoch, get_cosine_schedule_with_warmup, training_step

# Setup optimizer (only LoRA params)
lora_params = get_lora_params(model)
optimizer = torch.optim.AdamW(
    lora_params,
    lr=LEARNING_RATE,
    weight_decay=0.01,
    betas=(0.9, 0.999),
)

# Learning rate scheduler
num_training_steps = len(train_dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(WARMUP_STEPS, num_training_steps // 10),
    num_training_steps=num_training_steps,
)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler()

print("Training setup:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Training steps: {num_training_steps}")
print(f"  Warmup steps: {min(WARMUP_STEPS, num_training_steps // 10)}")
print(f"  Train batches/epoch: {len(train_dataloader)}")
print(f"  Val batches: {len(val_dataloader)}")
print(f"\n  Controllable Rhythm: {'ENABLED' if USE_CONTENT_CONDITIONING else 'DISABLED'}")

# Validation function
@torch.no_grad()
def validate(model, val_dataloader, device, use_content_conditioning=True):
    """Validation loop using train_utils.training_step"""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    for batch in val_dataloader:
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            loss = training_step(model, batch, device, use_content_conditioning=use_content_conditioning)
        
        # Skip NaN losses in validation too
        if not torch.isnan(loss) and not torch.isinf(loss):
            total_loss += loss.item()
            num_batches += 1
    
    model.train()
    
    if num_batches == 0:
        return float('nan')
    
    return total_loss / num_batches

In [None]:
# Training loop with validation and controllable rhythm
print("\n" + "="*50)
print("Starting training with CONTROLLABLE RHYTHM...")
print("="*50 + "\n")

history = {"train_loss": [], "val_loss": [], "epoch": [], "lr": []}
best_val_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    # Train one epoch with content conditioning for rhythm learning
    train_loss = train_epoch(
        model=model,
        dataloader=train_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=DEVICE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        max_grad_norm=MAX_GRAD_NORM,
        scaler=scaler,
        use_content_conditioning=USE_CONTENT_CONDITIONING,  # Controllable rhythm!
    )
    
    # Validate
    val_loss = validate(model, val_dataloader, DEVICE, use_content_conditioning=USE_CONTENT_CONDITIONING)
    
    # Get current LR
    current_lr = scheduler.get_last_lr()[0]
    
    # Record history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["epoch"].append(epoch + 1)
    history["lr"].append(current_lr)
    
    # Print progress
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Train: {train_loss:.4f} - Val: {val_loss:.4f} - LR: {current_lr:.2e}")
    
    # Save best checkpoint (based on validation loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_lora_checkpoint(
            model,
            os.path.join(OUTPUT_DIR, "lora_best.pt"),
            config={
                "rank": LORA_RANK,
                "alpha": LORA_ALPHA,
                "dropout": LORA_DROPOUT,
                "target_modules": TARGET_MODULES,
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "use_content_conditioning": USE_CONTENT_CONDITIONING,
            }
        )
        print(f"  -> Saved best checkpoint (val_loss: {best_val_loss:.4f})")
    
    # Periodic checkpoint every 2 epochs
    if (epoch + 1) % 2 == 0:
        save_lora_checkpoint(
            model,
            os.path.join(OUTPUT_DIR, f"lora_epoch_{epoch + 1}.pt"),
            config={
                "rank": LORA_RANK,
                "alpha": LORA_ALPHA,
                "dropout": LORA_DROPOUT,
                "target_modules": TARGET_MODULES,
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "use_content_conditioning": USE_CONTENT_CONDITIONING,
            }
        )

# Save final checkpoint
save_lora_checkpoint(
    model,
    os.path.join(OUTPUT_DIR, "lora_final.pt"),
    config={
        "rank": LORA_RANK,
        "alpha": LORA_ALPHA,
        "dropout": LORA_DROPOUT,
        "target_modules": TARGET_MODULES,
        "epoch": NUM_EPOCHS,
        "train_loss": history["train_loss"][-1],
        "val_loss": history["val_loss"][-1],
        "use_content_conditioning": USE_CONTENT_CONDITIONING,
    }
)

print("\n" + "="*50)
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Checkpoints saved to: {OUTPUT_DIR}")
print(f"Controllable rhythm: {'ENABLED' if USE_CONTENT_CONDITIONING else 'DISABLED'}")
print("="*50)

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Training and validation loss
axes[0].plot(history["epoch"], history["train_loss"], 'b-', linewidth=2, label='Train')
axes[0].plot(history["epoch"], history["val_loss"], 'r--', linewidth=2, label='Val')
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training vs Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Training loss only (zoomed)
axes[1].plot(history["epoch"], history["train_loss"], 'b-', linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].set_title("Training Loss")
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(history["epoch"], history["lr"], 'g-', linewidth=2)
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Learning Rate")
axes[2].set_title("Learning Rate Schedule")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "training_curves.png"), dpi=150)
plt.show()

# Print summary
print(f"\nTraining Summary:")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history['val_loss'][-1]:.4f}")
print(f"  Best val loss: {best_val_loss:.4f}")

## 7. Evaluation & Inference

Generate samples with your fine-tuned model and compare to the base model.

In [None]:
# Set model to eval mode
model.eval()

# Test prompts - rap style!
TEST_PROMPTS = [
    "[S1] Yeah, I'm spitting fire on the mic tonight, gonna show them what I got, rising to the top!",
    "[S1] The rhythm flows through me like water, every beat hits harder, I'm a natural born starter!",
    "[S1] Check it out, I'm the one they've been waiting for, coming through the door, ready to explore!",
    "[S1] Money on my mind, grind never stops, from the bottom to the top, watch me drop!",
    "[S1] Real recognize real, that's the deal, keep it trill, got the skill to make you feel!",
]

# Use a random training file as speaker reference
import random
SPEAKER_AUDIO_PATH = random.choice(audio_files) if audio_files else None

print(f"Using speaker reference: {SPEAKER_AUDIO_PATH}")
print(f"\nWill generate {len(TEST_PROMPTS)} samples...")

In [None]:
# Generate samples with fine-tuned model - NOW WITH CONTROLLABLE RHYTHM!
@torch.inference_mode()
def generate_audio(model, text, speaker_audio_path=None, content_audio_path=None, seed=0):
    """
    Generate audio using the (fine-tuned) model.
    
    Args:
        model: Fine-tuned EchoDiT model
        text: Text to speak
        speaker_audio_path: Audio for voice identity (WHO it sounds like)
        content_audio_path: Audio for rhythm/timing (HOW it's delivered) - NEW!
        seed: Random seed for reproducibility
    """
    
    # Load speaker audio (for voice identity)
    if speaker_audio_path:
        speaker_audio = load_audio(speaker_audio_path)
    else:
        speaker_audio = None
    
    # Load content audio (for rhythm/timing) - CONTROLLABLE!
    if content_audio_path:
        content_audio = load_audio(content_audio_path)
    else:
        content_audio = None
    
    # Create sample function with content_latent support
    sample_fn = partial(
        sample_euler_cfg_independent_guidances,
        num_steps=40,
        cfg_scale_text=3.0,
        cfg_scale_speaker=8.0,
        cfg_min_t=0.5,
        cfg_max_t=1.0,
        truncation_factor=0.8,
        rescale_k=None,
        rescale_sigma=None,
        speaker_kv_scale=None,
        speaker_kv_max_layers=None,
        speaker_kv_min_t=None,
        sequence_length=640,
    )
    
    # Generate with both speaker and content references
    audio_out, normalized_text = sample_pipeline(
        model=model,
        fish_ae=fish_ae,
        pca_state=pca_state,
        sample_fn=sample_fn,
        text_prompt=text,
        speaker_audio=speaker_audio,
        rng_seed=seed,
        content_audio=content_audio,  # Controllable rhythm!
    )
    
    return audio_out[0].cpu(), normalized_text

# Generate and play samples
print("Generating samples with fine-tuned model...\n")
print("Using CONTROLLABLE RHYTHM - speaker for voice, content for timing!\n")

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"Prompt {i + 1}: {prompt}")
    
    # Use same audio for both speaker and content (for testing)
    # In practice, you can use DIFFERENT audio for content to transfer rhythm!
    audio, _ = generate_audio(
        model,
        prompt,
        speaker_audio_path=str(SPEAKER_AUDIO_PATH) if SPEAKER_AUDIO_PATH else None,
        content_audio_path=str(SPEAKER_AUDIO_PATH) if SPEAKER_AUDIO_PATH else None,  # Can be different!
        seed=i,
    )
    
    # Save audio
    output_path = os.path.join(OUTPUT_DIR, f"sample_{i + 1}.wav")
    torchaudio.save(output_path, audio.unsqueeze(0), 44100)
    print(f"Saved to: {output_path}")
    
    # Play audio
    display(Audio(audio.numpy(), rate=44100))
    print()

## 8. Load Checkpoint for Later Use

Use this section to load a saved LoRA checkpoint onto a fresh model.

In [None]:
# Example: Load a saved LoRA checkpoint
# Uncomment and run this cell to load a checkpoint

# CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "lora_best.pt")
# 
# # Load fresh base model
# model_fresh = load_model_from_hf(
#     device=DEVICE,
#     dtype=DTYPE,
#     compile=False,
#     delete_blockwise_modules=True,
# )
# 
# # Load checkpoint to get config
# checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
# config = checkpoint["config"]
# 
# # Apply LoRA with saved config
# model_fresh, _ = apply_lora_to_model(
#     model_fresh,
#     rank=config["rank"],
#     alpha=config["alpha"],
#     dropout=0.0,  # No dropout for inference
#     target_modules=config["target_modules"],
# )
# 
# # Load LoRA weights
# load_lora_checkpoint(model_fresh, CHECKPOINT_PATH, device=DEVICE)
# model_fresh.eval()
# 
# print(f"Loaded checkpoint from epoch {config['epoch']} (val_loss: {config['val_loss']:.4f})")

## 9. Tips & Notes for Controllable Rap Training

### NEW: Controllable Rhythm/Timing
This training enables **separate control** over:
- **Speaker Reference** ‚Üí Voice identity (timbre, pitch, who it sounds like)
- **Content Reference** ‚Üí Rhythm/timing (pacing, flow, cadence)

**Use Cases:**
- Same lyrics with Eminem's flow vs Drake's flow
- Transfer a specific verse's rhythm to any voice
- Keep your voice but rap like someone else

### What We Configured
- **LoRA rank 32**: Higher rank captures more style nuance
- **Content conditioning**: Enabled for rhythm learning
- **Latent path training**: `wk_latent`, `wv_latent` included in targets
- **Parakeet transcription**: 5-10x faster than Whisper

### Training Expectations
With 173+ rap acapellas:
- **Training time**: ~2-4 hours on T4, ~1-2 hours on A100
- **Expected final loss**: ~0.05-0.15 (lower is better)
- **Checkpoint size**: ~80-100MB

### If Results Sound Off
1. **Rhythm not transferring**: Make sure `USE_CONTENT_CONDITIONING = True`
2. **Too monotone**: Increase `cfg_scale_text` to 4-5 during inference
3. **Voice doesn't match**: Try different speaker reference audio
4. **Gibberish output**: Check if transcriptions were accurate

### How to Use at Inference
```python
# Same text, DIFFERENT rhythm sources:
audio1, _ = generate_audio(
    model,
    text="Your rap lyrics here",
    speaker_audio_path="your_voice.wav",      # WHO it sounds like
    content_audio_path="fast_flow_ref.wav",   # HOW it's delivered (fast)
)

audio2, _ = generate_audio(
    model,
    text="Your rap lyrics here",
    speaker_audio_path="your_voice.wav",      # Same voice
    content_audio_path="slow_flow_ref.wav",   # Different rhythm (slow)
)
```

### Common Issues

**Out of Memory (OOM)**:
- Reduce `MAX_LATENT_LENGTH` to 320 (15 seconds max)
- Use A100 GPU on Colab instead of T4
- Reduce `LORA_RANK` to 16

**Loss not decreasing**:
- Check transcriptions are accurate
- Ensure you have enough training data

**Validation loss increasing (overfitting)**:
- Increase `LORA_DROPOUT` to 0.1
- Reduce `NUM_EPOCHS`
- Reduce `LORA_RANK` to 16

### Voice Cloning Still Works!
The speaker path (wk_speaker, wv_speaker) was kept frozen, so you can:
- Use ANY speaker reference audio at inference time
- The rap style AND rhythm control transfer to any voice
- Original voice cloning quality is preserved

## 10. Interactive Testing with Gradio

Launch a Gradio interface to test your fine-tuned rap model interactively!