# Train with Original PercePiano Code

Use the ORIGINAL PercePiano model with properly preprocessed data.

## Data
- Preprocessed using their `m2pf_dataset_compositionfold.py`
- 101-dimension features with proper `key_to_dim` mapping
- 4-fold CV structure (fold0-fold3)

## Target: R2 = 0.397 (Paper SOTA)

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Clone PercePiano and install dependencies
import os
import sys
from pathlib import Path
from types import ModuleType

# Clone PercePiano repository
PERCEPIANO_ROOT = Path('/tmp/PercePiano')
if not PERCEPIANO_ROOT.exists():
    print("Cloning PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git /tmp/PercePiano
else:
    print(f"PercePiano already present at {PERCEPIANO_ROOT}")

PERCEPIANO_PATH = PERCEPIANO_ROOT / 'virtuoso' / 'virtuoso'

# Install dependencies (keep numpy 2.0, we'll patch compatibility)
!pip install omegaconf tqdm --quiet

# Patch numpy 2.0 compatibility for PercePiano
# PercePiano imports 'from numpy.lib.arraysetops import isin' which was removed in numpy 2.0
import numpy as np
if not hasattr(np.lib, 'arraysetops'):
    arraysetops = ModuleType('numpy.lib.arraysetops')
    arraysetops.isin = np.isin
    sys.modules['numpy.lib.arraysetops'] = arraysetops
    np.lib.arraysetops = arraysetops
    print("Patched numpy.lib.arraysetops for numpy 2.0 compatibility")

# Add to Python path (virtuoso first, then pyScoreParser)
sys.path.insert(0, str(PERCEPIANO_PATH / 'pyScoreParser'))
sys.path.insert(0, str(PERCEPIANO_PATH))

print(f"\nnumpy version: {np.__version__}")
print(f"PercePiano path: {PERCEPIANO_PATH}")

## Step 2: Download Data

In [None]:
import subprocess
from pathlib import Path

# Paths - using original preprocessed data
DATA_ROOT = Path('/tmp/percepiano_original')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_original')
LABEL_ROOT = Path('/tmp/percepiano_labels')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_original'
GDRIVE_LABEL_PATH = 'gdrive:crescendai_data/percepiano_labels'

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)
LABEL_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

print("rclone 'gdrive' remote: CONFIGURED")

# Download preprocessed data
print("\nDownloading original PercePiano preprocessed data...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Download label files
print("\nDownloading label files...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_LABEL_PATH, str(LABEL_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)
for fold in range(4):
    fold_path = DATA_ROOT / f'fold{fold}'
    if fold_path.exists():
        for split in ['train', 'valid', 'test']:
            split_path = fold_path / split
            if split_path.exists():
                count = len([f for f in split_path.glob('*.pkl') if f.name != 'stat.pkl'])
                print(f"  fold{fold}/{split}: {count} samples")

# Verify labels
label_file = LABEL_ROOT / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
if label_file.exists():
    print(f"\nLabel file: {label_file.name} [OK]")
else:
    raise RuntimeError(f"Label file not found: {label_file}")

## Step 3: Load Data Stats

In [None]:
import pickle

# Use fold0 for training
FOLD = 0
FOLD_PATH = DATA_ROOT / f'fold{FOLD}'

# Load stats from train split
stat_path = FOLD_PATH / 'train' / 'stat.pkl'
with open(stat_path, 'rb') as f:
    data_stats = pickle.load(f)

print(f"Loaded stats from fold{FOLD}")
print(f"Keys: {list(data_stats.keys())}")
print(f"Input keys: {len(data_stats.get('input_keys', []))} features")
print(f"key_to_dim['input']: {len(data_stats.get('key_to_dim', {}).get('input', {}))} entries")

# Get input dimension from key_to_dim
input_key_to_dim = data_stats.get('key_to_dim', {}).get('input', {})
if input_key_to_dim:
    max_dim = max(v[1] for v in input_key_to_dim.values())
    print(f"\nInput dimension: {max_dim}")

## Step 4: Import Original PercePiano Model

In [None]:
# Import original PercePiano components
from model_m2pf import VirtuosoNetMultiLevel, VirtuosoNetSingle
from omegaconf import OmegaConf
import yaml

print("Successfully imported original PercePiano models!")
print(f"  VirtuosoNetMultiLevel: {VirtuosoNetMultiLevel}")

In [None]:
# Load SOTA config
CONFIG_PATH = PERCEPIANO_PATH.parent / 'ymls' / 'shared' / 'label19' / 'han_measnote_nomask_bigger256.yml'

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

net_param = OmegaConf.create(config['nn_params'])

# Override input_size to match our data
net_param.input_size = max_dim
net_param.graph_keys = []

print("SOTA Configuration:")
print(f"  input_size: {net_param.input_size}")
print(f"  hidden_size: {net_param.encoder.size}")
print(f"  layers: note={net_param.note.layer}, voice={net_param.voice.layer}, beat={net_param.beat.layer}, measure={net_param.measure.layer}")
print(f"  attention_heads: {net_param.num_attention_head}")
print(f"  dropout: {net_param.drop_out}")

## Step 5: Create Dataset

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_sequence
import pickle
import json
import re
from pathlib import Path

# Load perceptual labels from JSON
label_file = LABEL_ROOT / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
with open(label_file) as f:
    PERCEPTUAL_LABELS = json.load(f)
print(f"Loaded {len(PERCEPTUAL_LABELS)} perceptual labels")


def extract_label_key(filename):
    """Extract label key from pkl filename.
    
    Example: all_2rounds_Beethoven_WoO80_thema_8bars_11_1.mid.pkl
          -> Beethoven_WoO80_thema_8bars_11_1
    """
    # Remove prefix and suffix
    name = filename.replace('.pkl', '').replace('.mid', '')
    # Remove 'all_2rounds_' prefix if present
    if name.startswith('all_2rounds_'):
        name = name[len('all_2rounds_'):]
    return name


class PercePianoDataset(Dataset):
    """Load original PercePiano preprocessed data with perceptual labels."""
    
    def __init__(self, data_dir, split='train', max_notes=5000):
        self.data_dir = Path(data_dir) / split
        self.max_notes = max_notes
        
        # Load all pkl files that have matching labels
        all_files = sorted([f for f in self.data_dir.glob('*.pkl') if f.name != 'stat.pkl'])
        
        self.files = []
        self.labels_cache = {}
        missing = 0
        
        for f in all_files:
            key = extract_label_key(f.name)
            if key in PERCEPTUAL_LABELS:
                self.files.append(f)
                # Labels: first 19 values (20th is pianist ID)
                self.labels_cache[f.name] = PERCEPTUAL_LABELS[key][:19]
            else:
                missing += 1
        
        print(f"Loaded {len(self.files)} samples from {split} ({missing} missing labels)")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        with open(self.files[idx], 'rb') as f:
            data = pickle.load(f)
        
        # Get input features
        x = torch.tensor(data['input'], dtype=torch.float32)
        
        # Truncate if needed
        if len(x) > self.max_notes:
            x = x[:self.max_notes]
        
        # Get note locations
        note_locations = {
            'beat': torch.tensor(data['note_location']['beat'][:len(x)], dtype=torch.long),
            'measure': torch.tensor(data['note_location']['measure'][:len(x)], dtype=torch.long),
            'voice': torch.tensor(data['note_location']['voice'][:len(x)], dtype=torch.long),
            'section': torch.tensor(data['note_location']['section'][:len(x)], dtype=torch.long),
        }
        
        # Get perceptual labels (19 dimensions, normalized 0-1)
        labels = torch.tensor(self.labels_cache[self.files[idx].name], dtype=torch.float32)
        
        return x, note_locations, labels


def collate_fn(batch):
    """Collate batch for PercePiano model."""
    xs, note_locs, labels = zip(*batch)
    
    # Sort by length (descending) for packing
    lengths = [len(x) for x in xs]
    sorted_idx = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
    
    xs_sorted = [xs[i] for i in sorted_idx]
    batch_x = pack_sequence(xs_sorted, enforce_sorted=True)
    
    # Pad note locations
    note_locations = {
        'beat': pad_sequence([note_locs[i]['beat'] for i in sorted_idx], batch_first=True),
        'measure': pad_sequence([note_locs[i]['measure'] for i in sorted_idx], batch_first=True),
        'voice': pad_sequence([note_locs[i]['voice'] for i in sorted_idx], batch_first=True),
        'section': pad_sequence([note_locs[i]['section'] for i in sorted_idx], batch_first=True),
    }
    
    labels_batch = torch.stack([labels[i] for i in sorted_idx])
    
    return batch_x, note_locations, labels_batch


# Create datasets
train_ds = PercePianoDataset(FOLD_PATH, 'train')
val_ds = PercePianoDataset(FOLD_PATH, 'valid')

print(f"\nTrain samples: {len(train_ds)}")
print(f"Val samples: {len(val_ds)}")

## Step 6: Initialize Model

In [None]:
# Verify input size
sample_x, _, _ = train_ds[0]
actual_input_size = sample_x.shape[1]
print(f"Data input size: {actual_input_size}")
print(f"Config input size: {net_param.input_size}")

if actual_input_size != net_param.input_size:
    print(f"Updating config input_size to {actual_input_size}")
    net_param.input_size = actual_input_size

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VirtuosoNetMultiLevel(net_param, data_stats, multi_level="total_note_cat")
model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"\nModel: VirtuosoNetMultiLevel")
print(f"Parameters: {n_params:,}")
print(f"Device: {device}")

## Step 7: Training

In [None]:
import time

def r2_score(y_true, y_pred):
    """Calculate R2 score manually (avoids scikit-learn numpy 2.0 dependency)."""
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    if ss_tot == 0:
        return 0.0
    return 1 - (ss_res / ss_tot)

def r2_score_multioutput(y_true, y_pred):
    """Calculate R2 score for multi-output (average across dimensions)."""
    n_outputs = y_true.shape[1] if y_true.ndim > 1 else 1
    if n_outputs == 1:
        return r2_score(y_true.ravel(), y_pred.ravel())
    r2s = []
    for i in range(n_outputs):
        r2s.append(r2_score(y_true[:, i], y_pred[:, i]))
    return np.mean(r2s)

# Hyperparameters (matching paper)
BATCH_SIZE = 8
LR = 2.5e-5
WEIGHT_DECAY = 1e-5
MAX_EPOCHS = 200
PATIENCE = 20  # Reduced from 40 to save compute time
GRAD_CLIP = 2.0

# Data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=collate_fn, num_workers=0)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.98)
criterion = torch.nn.MSELoss()
sigmoid = torch.nn.Sigmoid()

print(f"Training config:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  LR: {LR}")
print(f"  Patience: {PATIENCE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch_x, note_locations, labels in loader:
        batch_x = batch_x.to(device)
        note_locations = {k: v.to(device) for k, v in note_locations.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(batch_x, None, None, note_locations)
        logits = outputs[-1]
        preds = sigmoid(logits)
        
        loss = criterion(preds, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch_x, note_locations, labels in loader:
            batch_x = batch_x.to(device)
            note_locations = {k: v.to(device) for k, v in note_locations.items()}
            labels = labels.to(device)
            
            outputs = model(batch_x, None, None, note_locations)
            preds = sigmoid(outputs[-1])
            
            loss = criterion(preds, labels)
            total_loss += loss.item()
            
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    r2 = r2_score_multioutput(all_labels, all_preds)
    
    return total_loss / len(loader), r2

In [None]:
# Training loop
print("="*70)
print("TRAINING WITH ORIGINAL PERCEPIANO")
print("="*70)
print(f"Fold: {FOLD}")
print(f"Target: R2 = 0.397 (Paper SOTA)")
print("="*70 + "\n")

best_r2 = -float('inf')
best_epoch = 0
patience_counter = 0

for epoch in range(MAX_EPOCHS):
    start = time.time()
    
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_r2 = validate(model, val_loader, criterion, device)
    
    elapsed = time.time() - start
    
    is_best = val_r2 > best_r2
    if is_best:
        best_r2 = val_r2
        best_epoch = epoch
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'r2': val_r2,
        }, CHECKPOINT_ROOT / 'best.pt')
    else:
        patience_counter += 1
    
    marker = " *best*" if is_best else ""
    print(f"Epoch {epoch:3d} | train: {train_loss:.4f} | val: {val_loss:.4f} | r2: {val_r2:+.4f} | {elapsed:.1f}s{marker}")
    
    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch}")
        break

print("\n" + "="*70)
print(f"Best R2: {best_r2:+.4f} (epoch {best_epoch})")
print(f"Target: R2 = 0.397")
print("="*70)

## Step 8: Analysis

In [None]:
# Per-dimension R2
DIMENSIONS = [
    'timing', 'articulation_length', 'articulation_touch',
    'pedal_amount', 'pedal_clarity', 'timbre_variety', 'timbre_depth',
    'timbre_brightness', 'timbre_loudness', 'sophistication',
    'dynamic_range', 'tempo', 'space', 'balance', 'drama',
    'mood_valence', 'mood_energy', 'mood_imagination', 'interpretation'
]

model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for batch_x, note_locations, labels in val_loader:
        batch_x = batch_x.to(device)
        note_locations = {k: v.to(device) for k, v in note_locations.items()}
        
        outputs = model(batch_x, None, None, note_locations)
        preds = sigmoid(outputs[-1])
        
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.numpy())

all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

print("PER-DIMENSION R2")
print("="*40)
dim_r2s = []
for i, dim in enumerate(DIMENSIONS):
    if i < all_preds.shape[1]:
        r2 = r2_score(all_labels[:, i], all_preds[:, i])
        dim_r2s.append((dim, r2))

dim_r2s.sort(key=lambda x: x[1], reverse=True)
for dim, r2 in dim_r2s:
    status = "[OK]" if r2 >= 0.2 else "[LOW]" if r2 >= 0 else "[NEG]"
    print(f"{dim:<25} {r2:>+.4f} {status}")

print(f"\nPositive R2: {sum(1 for _, r2 in dim_r2s if r2 > 0)}/{len(dim_r2s)}")