# 01: Audio Encoder Training (A1, A2, A3)

Train three MuQ domain adaptation strategies on Thunder Compute:
- **A1**: MuQ + LoRA multi-task (MuQLoRAModel)
- **A2**: Staged domain adaptation (MuQStagedModel) - Stage 1: self-supervised, Stage 2: supervised
- **A3**: Full unfreeze with gradual layer unfreezing (MuQFullUnfreezeModel)

All models use pre-extracted MuQ embeddings and PercePiano labels with 4-fold cross-validation.

---

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
!git clone https://github.com/Jai-Dhiman/crescendAI.git /workspace/crescendai
%cd /workspace/crescendai/model

!curl -LsSf https://astral.sh/uv/install.sh | sh
!uv sync

!rclone sync gdrive:crescendai_data/model_improvement/data ./data --progress

from pathlib import Path

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
    CHECKPOINT_DIR = Path('/workspace/crescendai/model/checkpoints/model_improvement')
else:
    DATA_DIR = Path('../data')
    CHECKPOINT_DIR = Path('../checkpoints/model_improvement')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
import json
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader

sys.path.insert(0, 'src')

from model_improvement.audio_encoders import MuQLoRAModel, MuQStagedModel, MuQFullUnfreezeModel
from model_improvement.data import PairedPerformanceDataset, AugmentedEmbeddingDataset, multi_task_collate_fn
from model_improvement.augmentation import AudioAugmentor

## 2. Load Data

In [None]:
cache_dir = DATA_DIR / 'percepiano_cache'

with open(cache_dir / 'labels.json') as f:
    labels = json.load(f)

embeddings = torch.load(cache_dir / 'muq_embeddings.pt', map_location='cpu', weights_only=True)

print(f'Loaded {len(labels)} labeled segments')
print(f'Embedding keys: {len(embeddings)}')

In [None]:
with open(cache_dir / 'folds.json') as f:
    folds = json.load(f)

print(f'Number of folds: {len(folds)}')
for i, fold in enumerate(folds):
    print(f'  Fold {i}: {len(fold["train"])} train, {len(fold["val"])} val')

In [None]:
# Build piece mapping from labels metadata
# Each label key encodes piece info; group recordings by piece
with open(cache_dir / 'piece_mapping.json') as f:
    piece_to_keys = json.load(f)

print(f'Pieces with multiple performances: {len(piece_to_keys)}')
total_keys = sum(len(v) for v in piece_to_keys.values())
print(f'Total recordings in piece mapping: {total_keys}')

## 3. Training Utilities

In [None]:
def upload_checkpoint(local_path, remote_subdir):
    """Sync checkpoint to Google Drive after results are finalized."""
    remote = f'gdrive:crescendai_data/model_improvement/checkpoints/{remote_subdir}'
    subprocess.run(['rclone', 'copy', str(local_path), remote, '--progress'], check=True)


def build_dataloaders(fold, labels, piece_to_keys, embeddings, cache_dir, batch_size=16, include_augmented=False):
    """Build train/val DataLoaders for a given fold."""
    train_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=fold['train'],
    )

    val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=fold['val'],
    )

    if include_augmented:
        train_ds = AugmentedEmbeddingDataset(train_ds)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        collate_fn=multi_task_collate_fn, num_workers=2, pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        collate_fn=multi_task_collate_fn, num_workers=2, pin_memory=True,
    )
    return train_loader, val_loader


def train_model(model, train_loader, val_loader, model_name, fold_idx, max_epochs=200):
    """Train a model with standard callbacks and return trainer."""
    ckpt_dir = CHECKPOINT_DIR / model_name / f'fold_{fold_idx}'
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    callbacks = [
        ModelCheckpoint(
            dirpath=str(ckpt_dir),
            filename='{epoch}-{val_loss:.4f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
        ),
        EarlyStopping(
            monitor='val_loss',
            patience=20,
            mode='min',
        ),
    ]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='auto',
        devices=1,
        callbacks=callbacks,
        enable_progress_bar=True,
        log_every_n_steps=10,
        deterministic=True,
    )

    trainer.fit(model, train_loader, val_loader)

    # Upload checkpoint to Google Drive
    upload_checkpoint(ckpt_dir, f'{model_name}/fold_{fold_idx}')

    return trainer

## 4. Train A1: MuQ + LoRA

In [None]:
A1_CONFIG = {
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'temperature': 0.07,
    'lambda_contrastive': 0.3,
    'lambda_regression': 0.5,
    'lambda_invariance': 0.1,
    'max_epochs': 200,
    'use_pretrained_muq': False,  # Using pre-extracted embeddings
}

print('Training A1: MuQ + LoRA Multi-Task')
print('=' * 50)

a1_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    model = MuQLoRAModel(**A1_CONFIG)
    train_loader, val_loader = build_dataloaders(
        fold, labels, piece_to_keys, embeddings, cache_dir,
        batch_size=16,
    )

    trainer = train_model(model, train_loader, val_loader, 'A1', fold_idx, max_epochs=A1_CONFIG['max_epochs'])
    a1_trainers.append(trainer)

    best_val = trainer.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 5. Train A2: Staged Domain Adaptation

In [None]:
A2_CONFIG = {
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'temperature': 0.07,
    'lambda_contrastive': 0.3,
    'lambda_regression': 0.5,
    'lambda_invariance': 0.5,
    'max_epochs': 200,
    'use_pretrained_muq': False,
}

STAGE1_EPOCHS = 50
STAGE2_EPOCHS = 150

print('Training A2: Staged Domain Adaptation')
print('=' * 50)

a2_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Stage 1: Self-supervised ---
    print('  Stage 1: Self-supervised (contrastive + invariance)')
    model = MuQStagedModel(**A2_CONFIG, stage='self_supervised')

    train_loader_ss, val_loader_ss = build_dataloaders(
        fold, labels, piece_to_keys, embeddings, cache_dir,
        batch_size=16, include_augmented=True,
    )

    trainer_s1 = train_model(
        model, train_loader_ss, val_loader_ss,
        'A2_stage1', fold_idx, max_epochs=STAGE1_EPOCHS,
    )

    # --- Stage 2: Supervised ---
    print('  Stage 2: Supervised (ranking + regression)')
    model.switch_to_supervised()

    train_loader, val_loader = build_dataloaders(
        fold, labels, piece_to_keys, embeddings, cache_dir,
        batch_size=16,
    )

    trainer_s2 = train_model(
        model, train_loader, val_loader,
        'A2', fold_idx, max_epochs=STAGE2_EPOCHS,
    )
    a2_trainers.append(trainer_s2)

    best_val = trainer_s2.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_s2.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 6. Train A3: Full Unfreeze

In [None]:
A3_CONFIG = {
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'temperature': 0.07,
    'lambda_contrastive': 0.3,
    'lambda_regression': 0.5,
    'max_epochs': 200,
    'use_pretrained_muq': False,
    'unfreeze_schedule': {0: [12], 10: [11], 20: [10], 30: [9]},
    'lr_decay_factor': 0.8,
}

print('Training A3: Full Unfreeze with Gradual Layer Unfreezing')
print('=' * 50)

a3_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    model = MuQFullUnfreezeModel(**A3_CONFIG)
    train_loader, val_loader = build_dataloaders(
        fold, labels, piece_to_keys, embeddings, cache_dir,
        batch_size=16,
    )

    trainer = train_model(model, train_loader, val_loader, 'A3', fold_idx, max_epochs=A3_CONFIG['max_epochs'])
    a3_trainers.append(trainer)

    best_val = trainer.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 7. Training Summary

In [None]:
print('Training Summary')
print('=' * 60)
print(f'{"Model":<10} {"Fold":<6} {"Val Loss":<12} {"Pairwise Acc":<14}')
print('-' * 60)

for name, trainers in [('A1', a1_trainers), ('A2', a2_trainers), ('A3', a3_trainers)]:
    for fold_idx, trainer in enumerate(trainers):
        val_loss = trainer.callback_metrics.get('val_loss', float('nan'))
        val_acc = trainer.callback_metrics.get('val_pairwise_acc', float('nan'))
        print(f'{name:<10} {fold_idx:<6} {val_loss:<12.4f} {val_acc:<14.4f}')

print('\nCheckpoints saved to:', CHECKPOINT_DIR)
print('Checkpoints synced to Google Drive via rclone')

## 8. Upload Final Results

In [None]:
# Ensure all checkpoints are synced
for model_name in ['A1', 'A2', 'A3']:
    local = CHECKPOINT_DIR / model_name
    if local.exists():
        upload_checkpoint(local, model_name)
        print(f'Uploaded {model_name} checkpoints')

print('\nAll audio encoder training complete.')
print('Run 03_audio_comparison.ipynb to evaluate and compare results.')