# Train with Original PercePiano Code

Use the ORIGINAL PercePiano model implementation directly.
This ensures we're using their exact architecture.

## Goal
- Use original `VirtuosoNetMultiLevel` from PercePiano
- Use original `HanEncoder` from PercePiano  
- Adapt data loading for our preprocessed data
- Target: R2 = 0.397 (Paper SOTA)

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Clone PercePiano and install dependencies
import os
import sys
from pathlib import Path

# Clone PercePiano repository
PERCEPIANO_ROOT = Path('/tmp/PercePiano')
if not PERCEPIANO_ROOT.exists():
    print("Cloning PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git /tmp/PercePiano
else:
    print(f"PercePiano already present at {PERCEPIANO_ROOT}")

PERCEPIANO_PATH = PERCEPIANO_ROOT / 'virtuoso' / 'virtuoso'

# Install dependencies - pin numpy<2.0 for PercePiano compatibility
# (numpy.lib.arraysetops was removed in numpy 2.0)
!pip install "numpy<2.0" omegaconf tqdm scikit-learn --quiet

# IMPORTANT: Add virtuoso path FIRST, then pyScoreParser
# This ensures virtuoso/utils.py is found before pyScoreParser/utils.py
sys.path.insert(0, str(PERCEPIANO_PATH / 'pyScoreParser'))
sys.path.insert(0, str(PERCEPIANO_PATH))

import numpy as np
print(f"\nnumpy version: {np.__version__}")
print(f"PercePiano path: {PERCEPIANO_PATH}")
print(f"Python path order:")
print(f"  1. {sys.path[0]}")
print(f"  2. {sys.path[1]}")

## Step 2: Import Original PercePiano Model

In [None]:
# Import original PercePiano components
from model_m2pf import VirtuosoNetMultiLevel, VirtuosoNetSingle
from omegaconf import OmegaConf
import yaml

print("Successfully imported original PercePiano models!")
print(f"  VirtuosoNetMultiLevel: {VirtuosoNetMultiLevel}")
print(f"  VirtuosoNetSingle: {VirtuosoNetSingle}")

In [None]:
# Load SOTA config
CONFIG_PATH = PERCEPIANO_PATH.parent / 'ymls' / 'shared' / 'label19' / 'han_measnote_nomask_bigger256.yml'

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

net_param = OmegaConf.create(config['nn_params'])

# Override input_size to match our 84-dimension features
net_param.input_size = 84

print("SOTA Configuration loaded:")
print(f"  Config: {CONFIG_PATH.name}")
print(f"  score_encoder: {net_param.score_encoder_name}")
print(f"  hidden_size: {net_param.encoder.size}")
print(f"  note_layers: {net_param.note.layer}")
print(f"  voice_layers: {net_param.voice.layer}")
print(f"  beat_layers: {net_param.beat.layer}")
print(f"  measure_layers: {net_param.measure.layer}")
print(f"  attention_heads: {net_param.num_attention_head}")
print(f"  dropout: {net_param.drop_out}")
print(f"  input_size: {net_param.input_size} (our 84-dim features)")

## Step 3: Setup Data

In [None]:
import subprocess
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_original')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("rclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("\nDownloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

## Step 4: Create Data Stats (required by PercePiano model)

In [None]:
import pickle
import numpy as np

# Load our stat file
stat_path = DATA_ROOT / 'stat.pkl'
if stat_path.exists():
    with open(stat_path, 'rb') as f:
        our_stats = pickle.load(f)
    print("Loaded our stats")
    print(f"Keys: {list(our_stats.keys())}")
else:
    print("No stat.pkl found - will create minimal stats")
    our_stats = {}

# Create data_stats in format expected by PercePiano
# The MixEmbedder only needs key_to_dim to exist (doesn't use it unless use_continuos_feature_only=True)
# We provide an empty dict which is sufficient
data_stats = {
    'key_to_dim': {'input': {}},  # Empty dict - MixEmbedder will use net_param.input_size directly
    'stats': our_stats,
    'graph_keys': [],
}

print(f"\ndata_stats created with keys: {list(data_stats.keys())}")
print(f"key_to_dim: {data_stats['key_to_dim']}")

## Step 5: Create Dataset Adapter

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_sequence
import pickle
import json
from pathlib import Path

class PercePianoDataset(Dataset):
    """Adapter to load our preprocessed data for original PercePiano model."""
    
    def __init__(self, data_dir, split='train', max_notes=5000):
        self.data_dir = Path(data_dir) / split
        self.max_notes = max_notes
        
        # Load all pkl files
        self.files = sorted([f for f in self.data_dir.glob('*.pkl') if f.name != 'stat.pkl'])
        print(f"Loaded {len(self.files)} samples from {split}")
        
        # Load label file
        label_path = Path(data_dir).parent / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
        if not label_path.exists():
            # Try PercePiano location
            label_path = PERCEPIANO_PATH.parent.parent / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
        
        if label_path.exists():
            with open(label_path) as f:
                self.label_map = json.load(f)
            print(f"Loaded {len(self.label_map)} labels")
        else:
            self.label_map = None
            print("No external label file - using embedded labels")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        with open(self.files[idx], 'rb') as f:
            data = pickle.load(f)
        
        # Get input features
        x = torch.tensor(data['input'], dtype=torch.float32)
        
        # Truncate if needed
        if len(x) > self.max_notes:
            x = x[:self.max_notes]
        
        # Get note locations
        note_locations = {
            'beat': torch.tensor(data['note_location']['beat'][:len(x)], dtype=torch.long),
            'measure': torch.tensor(data['note_location']['measure'][:len(x)], dtype=torch.long),
            'voice': torch.tensor(data['note_location']['voice'][:len(x)], dtype=torch.long),
            'section': torch.tensor(data['note_location']['section'][:len(x)], dtype=torch.long),
        }
        
        # Get labels (19 dimensions)
        if 'labels' in data:
            labels = torch.tensor(data['labels'][:19], dtype=torch.float32)
        else:
            labels = torch.zeros(19, dtype=torch.float32)
        
        return x, note_locations, labels


def collate_fn(batch):
    """Collate batch for PercePiano model (pack sequences)."""
    xs, note_locs, labels = zip(*batch)
    
    # Pack sequences (sorted by length, descending)
    lengths = [len(x) for x in xs]
    sorted_idx = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
    
    xs_sorted = [xs[i] for i in sorted_idx]
    batch_x = pack_sequence(xs_sorted, enforce_sorted=True)
    
    # Pad note locations
    note_locations = {
        'beat': pad_sequence([note_locs[i]['beat'] for i in sorted_idx], batch_first=True),
        'measure': pad_sequence([note_locs[i]['measure'] for i in sorted_idx], batch_first=True),
        'voice': pad_sequence([note_locs[i]['voice'] for i in sorted_idx], batch_first=True),
        'section': pad_sequence([note_locs[i]['section'] for i in sorted_idx], batch_first=True),
    }
    
    # Stack labels
    labels_batch = torch.stack([labels[i] for i in sorted_idx])
    
    return batch_x, note_locations, labels_batch


# Test dataset
train_ds = PercePianoDataset(DATA_ROOT, 'train')
val_ds = PercePianoDataset(DATA_ROOT, 'val')

print(f"\nTrain samples: {len(train_ds)}")
print(f"Val samples: {len(val_ds)}")

## Step 6: Initialize Original PercePiano Model

In [None]:
# Verify input size matches our data
sample_x, _, _ = train_ds[0]
actual_input_size = sample_x.shape[1]
print(f"Data input size: {actual_input_size}")
print(f"Config input size: {net_param.input_size}")

if actual_input_size != net_param.input_size:
    print(f"Updating config input_size from {net_param.input_size} to {actual_input_size}")
    net_param.input_size = actual_input_size

# Set graph_keys (required by model)
net_param.graph_keys = []

# Create the model using original PercePiano code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VirtuosoNetMultiLevel(net_param, data_stats, multi_level="total_note_cat")
model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"\nModel: VirtuosoNetMultiLevel")
print(f"Parameters: {n_params:,}")
print(f"Device: {device}")

## Step 7: Training Loop

In [None]:
from sklearn.metrics import r2_score
from tqdm import tqdm
import time

# Hyperparameters (matching paper)
BATCH_SIZE = 8
LR = 2.5e-5
WEIGHT_DECAY = 1e-5
MAX_EPOCHS = 200
PATIENCE = 40
GRAD_CLIP = 2.0

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=collate_fn, num_workers=0)

# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.98)
criterion = torch.nn.MSELoss()
sigmoid = torch.nn.Sigmoid()

print(f"Training config:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  LR: {LR}")
print(f"  Max epochs: {MAX_EPOCHS}")
print(f"  Patience: {PATIENCE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch_x, note_locations, labels in loader:
        batch_x = batch_x.to(device)
        note_locations = {k: v.to(device) for k, v in note_locations.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass (returns tuple for multi-level)
        outputs = model(batch_x, None, None, note_locations)
        logits = outputs[-1]  # Last level (total_note_cat)
        preds = sigmoid(logits)
        
        loss = criterion(preds, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch_x, note_locations, labels in loader:
            batch_x = batch_x.to(device)
            note_locations = {k: v.to(device) for k, v in note_locations.items()}
            labels = labels.to(device)
            
            outputs = model(batch_x, None, None, note_locations)
            logits = outputs[-1]
            preds = sigmoid(logits)
            
            loss = criterion(preds, labels)
            total_loss += loss.item()
            
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    import numpy as np
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    r2 = r2_score(all_labels, all_preds)
    
    return total_loss / len(loader), r2

In [None]:
# Training loop
print("="*70)
print("TRAINING WITH ORIGINAL PERCEPIANO MODEL")
print("="*70)
print(f"\nTarget: R2 = 0.397 (Paper SOTA)")
print("="*70 + "\n")

best_r2 = -float('inf')
best_epoch = 0
patience_counter = 0

for epoch in range(MAX_EPOCHS):
    start = time.time()
    
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_r2 = validate(model, val_loader, criterion, device)
    
    elapsed = time.time() - start
    
    # Check for best
    is_best = val_r2 > best_r2
    if is_best:
        best_r2 = val_r2
        best_epoch = epoch
        patience_counter = 0
        # Save best model
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'r2': val_r2,
            'optimizer': optimizer.state_dict(),
        }, CHECKPOINT_ROOT / 'best.pt')
    else:
        patience_counter += 1
    
    # Log
    best_marker = " *best*" if is_best else ""
    print(f"Epoch {epoch:3d} | train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | "
          f"val_r2: {val_r2:+.4f} | time: {elapsed:.1f}s{best_marker}")
    
    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch} (patience={PATIENCE})")
        break

print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"\nBest R2: {best_r2:+.4f} (epoch {best_epoch})")
print(f"Target: R2 = 0.397 (Paper SOTA)")

if best_r2 >= 0.35:
    print("[SUCCESS] Approaching SOTA!")
elif best_r2 >= 0.30:
    print("[GOOD] Strong performance")
elif best_r2 >= 0.20:
    print("[PARTIAL] Reasonable but below target")
else:
    print("[ISSUE] Below expected - investigate")

## Step 8: Analysis

In [None]:
# Per-dimension R2 analysis
import numpy as np
from sklearn.metrics import r2_score

DIMENSIONS = [
    'timing', 'articulation_length', 'articulation_touch',
    'pedal_amount', 'pedal_clarity', 'timbre_variety', 'timbre_depth',
    'timbre_brightness', 'timbre_loudness', 'sophistication',
    'dynamic_range', 'tempo', 'space', 'balance', 'drama',
    'mood_valence', 'mood_energy', 'mood_imagination', 'interpretation'
]

# Get all predictions
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_x, note_locations, labels in val_loader:
        batch_x = batch_x.to(device)
        note_locations = {k: v.to(device) for k, v in note_locations.items()}
        
        outputs = model(batch_x, None, None, note_locations)
        preds = sigmoid(outputs[-1])
        
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.numpy())

all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

print("="*60)
print("PER-DIMENSION R2")
print("="*60)
print(f"\n{'Dimension':<25} {'R2':>10}")
print("-"*35)

dim_r2s = []
for i, dim in enumerate(DIMENSIONS):
    if i < all_preds.shape[1]:
        r2 = r2_score(all_labels[:, i], all_preds[:, i])
        dim_r2s.append((dim, r2))

# Sort by R2
dim_r2s.sort(key=lambda x: x[1], reverse=True)
for dim, r2 in dim_r2s:
    status = "[OK]" if r2 >= 0.2 else "[LOW]" if r2 >= 0 else "[NEG]"
    print(f"{dim:<25} {r2:>+10.4f} {status}")

positive = sum(1 for _, r2 in dim_r2s if r2 > 0)
print(f"\nPositive R2: {positive}/{len(dim_r2s)}")