# 02: Symbolic Encoder Training (S1, S2, S3)

Train three symbolic encoder strategies on Thunder Compute:
- **S1**: BERT-style Transformer on REMI tokens (TransformerSymbolicEncoder)
- **S2**: GNN on score graph (GNNSymbolicEncoder)
- **S3**: 1D-CNN + Transformer on continuous features (ContinuousSymbolicEncoder)

Each model has a pretrain stage (self-supervised) and finetune stage (supervised ranking).

---

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
!git clone https://github.com/Jai-Dhiman/crescendAI.git /workspace/crescendai
%cd /workspace/crescendai/model

!curl -LsSf https://astral.sh/uv/install.sh | sh
!uv sync

!rclone sync gdrive:crescendai_data/model_improvement/data ./data --progress

from pathlib import Path

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
    CHECKPOINT_DIR = Path('/workspace/crescendai/model/checkpoints/model_improvement')
else:
    DATA_DIR = Path('../data')
    CHECKPOINT_DIR = Path('../checkpoints/model_improvement')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
import json
import numpy as np
import torch
import pytorch_lightning as pl
from functools import partial
from torch.utils.data import DataLoader

sys.path.insert(0, 'src')

from model_improvement.symbolic_encoders import (
    TransformerSymbolicEncoder,
    GNNSymbolicEncoder,
    GNNHeteroSymbolicEncoder,
    ContinuousSymbolicEncoder,
)
from model_improvement.tokenizer import PianoTokenizer, extract_continuous_features
from model_improvement.data import (
    MIDIPretrainingDataset,
    PairedPerformanceDataset,
    ScoreGraphPretrainingDataset,
    ContinuousPretrainDataset,
    HeteroPretrainDataset,
    graph_pair_collate_fn,
    symbolic_collate_fn,
    continuous_collate_fn,
    hetero_graph_collate_fn,
    multi_task_collate_fn,
)
from model_improvement.graph import midi_to_graph, midi_to_hetero_graph
from model_improvement.training import train_model, upload_checkpoint

## 2. Load Data

In [None]:
cache_dir = DATA_DIR / 'percepiano_cache'

with open(cache_dir / 'labels.json') as f:
    labels = json.load(f)

with open(cache_dir / 'folds.json') as f:
    folds = json.load(f)

with open(cache_dir / 'piece_mapping.json') as f:
    piece_to_keys = json.load(f)

midi_dir = DATA_DIR / 'percepiano_midi'

print(f'Loaded {len(labels)} labeled segments')
print(f'Folds: {len(folds)}')
for i, fold in enumerate(folds):
    print(f'  Fold {i}: {len(fold["train"])} train, {len(fold["val"])} val')

## 3. Prepare Symbolic Inputs

In [None]:
print('Tokenizing MIDI files for S1...')
tokenizer = PianoTokenizer(max_seq_len=2048)

token_sequences = {}
for key in labels:
    midi_path = midi_dir / f'{key}.mid'
    if midi_path.exists():
        tokens = tokenizer.encode(midi_path)
        token_sequences[key] = tokens

print(f'Tokenized {len(token_sequences)} MIDI files')
print(f'Vocab size: {tokenizer.vocab_size}')
print(f'Average sequence length: {np.mean([len(t) for t in token_sequences.values()]):.0f}')

In [None]:
print('Extracting continuous features for S3...')

continuous_features = {}
for key in labels:
    midi_path = midi_dir / f'{key}.mid'
    if midi_path.exists():
        features = extract_continuous_features(midi_path, frame_rate=50)
        continuous_features[key] = torch.from_numpy(features).float()

print(f'Extracted features for {len(continuous_features)} files')
if continuous_features:
    sample = next(iter(continuous_features.values()))
    print(f'Feature shape: {sample.shape} (T, D={sample.shape[1]})')

In [None]:
print('Building score graphs for S2...')

score_graphs = {}
hetero_graphs = {}
for key in labels:
    midi_path = midi_dir / f'{key}.mid'
    if midi_path.exists():
        try:
            score_graphs[key] = midi_to_graph(midi_path)
            hetero_graphs[key] = midi_to_hetero_graph(midi_path)
        except ValueError:
            continue

print(f'Built graphs for {len(score_graphs)} MIDI files')
if score_graphs:
    sample = next(iter(score_graphs.values()))
    print(f'Sample graph: {sample.x.shape[0]} nodes, {sample.edge_index.shape[1]} edges')
    print(f'Node features: {sample.x.shape[1]}')
    edge_types = sample.edge_type.unique().tolist()
    type_names = {0: 'onset', 1: 'during', 2: 'follow', 3: 'silence'}
    for t in edge_types:
        count = (sample.edge_type == t).sum().item()
        print(f'  {type_names.get(t, t)}: {count} edges')

## 4. Training Utilities

In [None]:
# train_model and upload_checkpoint imported from model_improvement.training
# train_model now takes checkpoint_dir as an explicit parameter
print(f'Checkpoint dir: {CHECKPOINT_DIR}')

## 5. Train S1: Transformer on REMI Tokens

In [None]:
S1_CONFIG = {
    'vocab_size': tokenizer.vocab_size + 1,  # +1 for mask token
    'd_model': 512,
    'nhead': 8,
    'num_layers': 6,
    'hidden_dim': 512,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS = 50
FINETUNE_EPOCHS = 150

print('Training S1: Transformer on REMI Tokens')
print('=' * 50)

s1_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Masked Language Modeling ---
    print('  Pretrain: Masked token prediction')

    train_tokens = [token_sequences[k] for k in fold['train'] if k in token_sequences]
    val_tokens = [token_sequences[k] for k in fold['val'] if k in token_sequences]

    pretrain_ds = MIDIPretrainingDataset(
        token_sequences=train_tokens,
        max_seq_len=2048,
        mask_prob=0.15,
        vocab_size=tokenizer.vocab_size,
    )
    pretrain_val_ds = MIDIPretrainingDataset(
        token_sequences=val_tokens,
        max_seq_len=2048,
        mask_prob=0.15,
        vocab_size=tokenizer.vocab_size,
    )

    pretrain_loader = DataLoader(pretrain_ds, batch_size=8, shuffle=True, num_workers=2)
    pretrain_val_loader = DataLoader(pretrain_val_ds, batch_size=8, shuffle=False, num_workers=2)

    model = TransformerSymbolicEncoder(**S1_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S1_pretrain', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=PRETRAIN_EPOCHS, monitor='val_mlm_loss',
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )

    # --- Finetune: Pairwise Ranking ---
    print('  Finetune: Pairwise ranking')
    model.stage = 'finetune'

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in token_sequences],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in token_sequences],
    )

    collate = partial(symbolic_collate_fn, token_sequences=token_sequences)
    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=collate, num_workers=2)

    trainer_ft = train_model(
        model, finetune_loader, finetune_val_loader,
        'S1', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=FINETUNE_EPOCHS,
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )
    s1_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 6. Train S2: GNN on Score Graph

In [None]:
S2_CONFIG = {
    'node_features': 6,
    'hidden_dim': 512,
    'num_layers': 4,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS_S2 = 50
FINETUNE_EPOCHS_S2 = 150

print('Training S2: GNN on Score Graph')
print('=' * 50)

from torch_geometric.loader import DataLoader as PyGDataLoader

s2_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Link Prediction ---
    print('  Pretrain: Link prediction')

    train_graphs = [score_graphs[k] for k in fold['train'] if k in score_graphs]
    val_graphs = [score_graphs[k] for k in fold['val'] if k in score_graphs]

    pretrain_ds = ScoreGraphPretrainingDataset(train_graphs, mask_fraction=0.15)
    pretrain_val_ds = ScoreGraphPretrainingDataset(val_graphs, mask_fraction=0.15)

    pretrain_loader = PyGDataLoader(pretrain_ds, batch_size=8, shuffle=True, num_workers=2)
    pretrain_val_loader = PyGDataLoader(pretrain_val_ds, batch_size=8, shuffle=False, num_workers=2)

    model = GNNSymbolicEncoder(**S2_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S2_pretrain', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=PRETRAIN_EPOCHS_S2, monitor='val_link_loss',
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )

    # --- Finetune: Pairwise Ranking ---
    print('  Finetune: Pairwise ranking')
    model.stage = 'finetune'

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in score_graphs],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in score_graphs],
    )

    graph_collate = partial(graph_pair_collate_fn, graphs=score_graphs)
    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=graph_collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=graph_collate, num_workers=2)

    trainer_ft = train_model(
        model, finetune_loader, finetune_val_loader,
        'S2', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=FINETUNE_EPOCHS_S2,
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )
    s2_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 6b. Train S2-hetero: Heterogeneous GNN on Score Graph

In [None]:
S2H_CONFIG = {
    'node_features': 6,
    'hidden_dim': 512,
    'num_layers': 3,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS_S2H = 50
FINETUNE_EPOCHS_S2H = 150

print('Training S2-hetero: Heterogeneous GNN on Score Graph')
print('=' * 50)

s2h_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Link Prediction on hetero graphs ---
    print('  Pretrain: Link prediction (heterogeneous)')

    pretrain_ds = HeteroPretrainDataset(fold['train'], score_graphs, hetero_graphs)
    pretrain_val_ds = HeteroPretrainDataset(fold['val'], score_graphs, hetero_graphs)

    pretrain_loader = DataLoader(pretrain_ds, batch_size=1, shuffle=True, num_workers=0)
    pretrain_val_loader = DataLoader(pretrain_val_ds, batch_size=1, shuffle=False, num_workers=0)

    model = GNNHeteroSymbolicEncoder(**S2H_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S2H_pretrain', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=PRETRAIN_EPOCHS_S2H, monitor='val_link_loss',
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )

    # --- Finetune: Pairwise ranking ---
    print('  Finetune: Pairwise ranking (heterogeneous)')
    model.stage = 'finetune'

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in hetero_graphs],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in hetero_graphs],
    )

    collate = partial(hetero_graph_collate_fn, hetero_graphs=hetero_graphs)
    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=collate, num_workers=2)

    trainer_ft = train_model(
        model, finetune_loader, finetune_val_loader,
        'S2H', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=FINETUNE_EPOCHS_S2H,
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )
    s2h_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 7. Train S3: Continuous Feature Encoder

In [None]:
S3_CONFIG = {
    'input_channels': 5,  # pitch, velocity, density, pedal, IOI
    'hidden_dim': 512,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS_S3 = 50
FINETUNE_EPOCHS_S3 = 150

print('Training S3: Continuous Feature Encoder')
print('=' * 50)

s3_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Contrastive with Gumbel codebook ---
    print('  Pretrain: Contrastive + codebook quantization')

    train_keys = [k for k in fold['train'] if k in continuous_features]
    val_keys = [k for k in fold['val'] if k in continuous_features]

    pretrain_ds = ContinuousPretrainDataset(train_keys, continuous_features)
    pretrain_val_ds = ContinuousPretrainDataset(val_keys, continuous_features)

    pretrain_loader = DataLoader(pretrain_ds, batch_size=8, shuffle=True, num_workers=2)
    pretrain_val_loader = DataLoader(pretrain_val_ds, batch_size=8, shuffle=False, num_workers=2)

    model = ContinuousSymbolicEncoder(**S3_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S3_pretrain', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=PRETRAIN_EPOCHS_S3, monitor='val_contrastive_loss',
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )

    # --- Finetune: Pairwise ranking ---
    print('  Finetune: Pairwise ranking')
    model.stage = 'finetune'

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in continuous_features],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir, labels=labels, piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in continuous_features],
    )

    collate = partial(continuous_collate_fn, features_dict=continuous_features)
    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=collate, num_workers=2)

    trainer_ft = train_model(
        model, finetune_loader, finetune_val_loader,
        'S3', fold_idx, checkpoint_dir=CHECKPOINT_DIR,
        max_epochs=FINETUNE_EPOCHS_S3,
        upload_remote='gdrive:crescendai_data/model_improvement/checkpoints',
    )
    s3_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 8. Training Summary

In [None]:
print('Symbolic Encoder Training Summary')
print('=' * 60)
print(f'{"Model":<12} {"Fold":<6} {"Val Loss":<12} {"Pairwise Acc":<14}')
print('-' * 60)

for name, trainers in [('S1', s1_trainers), ('S2', s2_trainers), ('S2-hetero', s2h_trainers), ('S3', s3_trainers)]:
    for fold_idx, trainer in enumerate(trainers):
        val_loss = trainer.callback_metrics.get('val_loss', float('nan'))
        val_acc = trainer.callback_metrics.get('val_pairwise_acc', float('nan'))
        print(f'{name:<12} {fold_idx:<6} {val_loss:<12.4f} {val_acc:<14.4f}')

print('\nCheckpoints saved to:', CHECKPOINT_DIR)
print('Checkpoints synced to Google Drive via rclone')

## 9. Upload Final Results

In [None]:
for model_name in ['S1', 'S1_pretrain', 'S2', 'S2_pretrain', 'S2H', 'S2H_pretrain', 'S3', 'S3_pretrain']:
    local = CHECKPOINT_DIR / model_name
    if local.exists():
        upload_checkpoint(local, model_name)
        print(f'Uploaded {model_name} checkpoints')

print('\nAll symbolic encoder training complete.')
print('Run 04_symbolic_comparison.ipynb to evaluate and compare results.')