# 02: Symbolic Encoder Training (S1, S2, S3)

Train three symbolic encoder strategies on Thunder Compute:
- **S1**: BERT-style Transformer on REMI tokens (TransformerSymbolicEncoder)
- **S2**: GNN on score graph (GNNSymbolicEncoder)
- **S3**: 1D-CNN + Transformer on continuous features (ContinuousSymbolicEncoder)

Each model has a pretrain stage (self-supervised) and finetune stage (supervised ranking).

---

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
!git clone https://github.com/Jai-Dhiman/crescendAI.git /workspace/crescendai
%cd /workspace/crescendai/model

!curl -LsSf https://astral.sh/uv/install.sh | sh
!uv sync

!rclone sync gdrive:crescendai_data/model_improvement/data ./data --progress

from pathlib import Path

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
    CHECKPOINT_DIR = Path('/workspace/crescendai/model/checkpoints/model_improvement')
else:
    DATA_DIR = Path('../data')
    CHECKPOINT_DIR = Path('../checkpoints/model_improvement')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
import json
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader

sys.path.insert(0, 'src')

from model_improvement.symbolic_encoders import (
    TransformerSymbolicEncoder,
    GNNSymbolicEncoder,
    ContinuousSymbolicEncoder,
)
from model_improvement.tokenizer import PianoTokenizer, extract_continuous_features
from model_improvement.data import MIDIPretrainingDataset, PairedPerformanceDataset, multi_task_collate_fn

## 2. Load Data

In [None]:
cache_dir = DATA_DIR / 'percepiano_cache'

with open(cache_dir / 'labels.json') as f:
    labels = json.load(f)

with open(cache_dir / 'folds.json') as f:
    folds = json.load(f)

with open(cache_dir / 'piece_mapping.json') as f:
    piece_to_keys = json.load(f)

midi_dir = DATA_DIR / 'percepiano_midi'

print(f'Loaded {len(labels)} labeled segments')
print(f'Folds: {len(folds)}')
for i, fold in enumerate(folds):
    print(f'  Fold {i}: {len(fold["train"])} train, {len(fold["val"])} val')

## 3. Prepare Symbolic Inputs

In [None]:
print('Tokenizing MIDI files for S1...')
tokenizer = PianoTokenizer(max_seq_len=2048)

token_sequences = {}
for key in labels:
    midi_path = midi_dir / f'{key}.mid'
    if midi_path.exists():
        tokens = tokenizer.encode(midi_path)
        token_sequences[key] = tokens

print(f'Tokenized {len(token_sequences)} MIDI files')
print(f'Vocab size: {tokenizer.vocab_size}')
print(f'Average sequence length: {np.mean([len(t) for t in token_sequences.values()]):.0f}')

In [None]:
print('Extracting continuous features for S3...')

continuous_features = {}
for key in labels:
    midi_path = midi_dir / f'{key}.mid'
    if midi_path.exists():
        features = extract_continuous_features(midi_path, frame_rate=50)
        continuous_features[key] = torch.from_numpy(features).float()

print(f'Extracted features for {len(continuous_features)} files')
if continuous_features:
    sample = next(iter(continuous_features.values()))
    print(f'Feature shape: {sample.shape} (T, D={sample.shape[1]})')

## 4. Training Utilities

In [None]:
def upload_checkpoint(local_path, remote_subdir):
    """Sync checkpoint to Google Drive."""
    remote = f'gdrive:crescendai_data/model_improvement/checkpoints/{remote_subdir}'
    subprocess.run(['rclone', 'copy', str(local_path), remote, '--progress'], check=True)


def train_model(model, train_loader, val_loader, model_name, fold_idx, max_epochs=200, monitor='val_loss'):
    """Train a model with standard callbacks."""
    ckpt_dir = CHECKPOINT_DIR / model_name / f'fold_{fold_idx}'
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    callbacks = [
        ModelCheckpoint(
            dirpath=str(ckpt_dir),
            filename='{epoch}-{' + monitor + ':.4f}',
            monitor=monitor,
            mode='min',
            save_top_k=1,
        ),
        EarlyStopping(
            monitor=monitor,
            patience=20,
            mode='min',
        ),
    ]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='auto',
        devices=1,
        callbacks=callbacks,
        enable_progress_bar=True,
        log_every_n_steps=10,
        deterministic=True,
    )

    trainer.fit(model, train_loader, val_loader)

    upload_checkpoint(ckpt_dir, f'{model_name}/fold_{fold_idx}')

    return trainer

## 5. Train S1: Transformer on REMI Tokens

In [None]:
S1_CONFIG = {
    'vocab_size': tokenizer.vocab_size + 1,  # +1 for mask token
    'd_model': 512,
    'nhead': 8,
    'num_layers': 6,
    'hidden_dim': 512,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS = 50
FINETUNE_EPOCHS = 150

print('Training S1: Transformer on REMI Tokens')
print('=' * 50)

s1_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Masked Language Modeling ---
    print('  Pretrain: Masked token prediction')

    train_tokens = [token_sequences[k] for k in fold['train'] if k in token_sequences]
    val_tokens = [token_sequences[k] for k in fold['val'] if k in token_sequences]

    pretrain_ds = MIDIPretrainingDataset(
        token_sequences=train_tokens,
        max_seq_len=2048,
        mask_prob=0.15,
        vocab_size=tokenizer.vocab_size,
    )
    pretrain_val_ds = MIDIPretrainingDataset(
        token_sequences=val_tokens,
        max_seq_len=2048,
        mask_prob=0.15,
        vocab_size=tokenizer.vocab_size,
    )

    pretrain_loader = DataLoader(pretrain_ds, batch_size=8, shuffle=True, num_workers=2)
    pretrain_val_loader = DataLoader(pretrain_val_ds, batch_size=8, shuffle=False, num_workers=2)

    model = TransformerSymbolicEncoder(**S1_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S1_pretrain', fold_idx,
        max_epochs=PRETRAIN_EPOCHS,
        monitor='val_mlm_loss',
    )

    # --- Finetune: Pairwise Ranking ---
    print('  Finetune: Pairwise ranking')
    model.stage = 'finetune'

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in token_sequences],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in token_sequences],
    )

    def symbolic_collate(batch):
        """Custom collate that adds tokenized MIDI to paired data."""
        collated = {}
        max_len = 2048

        ids_a, ids_b, masks_a, masks_b = [], [], [], []
        labels_a_list, labels_b_list = [], []
        piece_ids_a, piece_ids_b = [], []

        for item in batch:
            key_a, key_b = item['key_a'], item['key_b']
            if key_a not in token_sequences or key_b not in token_sequences:
                continue

            tok_a = token_sequences[key_a][:max_len]
            tok_b = token_sequences[key_b][:max_len]

            # Pad
            pad_a = tok_a + [0] * (max_len - len(tok_a))
            pad_b = tok_b + [0] * (max_len - len(tok_b))

            ids_a.append(torch.tensor(pad_a, dtype=torch.long))
            ids_b.append(torch.tensor(pad_b, dtype=torch.long))

            mask_a = torch.zeros(max_len, dtype=torch.bool)
            mask_a[:len(tok_a)] = True
            masks_a.append(mask_a)

            mask_b = torch.zeros(max_len, dtype=torch.bool)
            mask_b[:len(tok_b)] = True
            masks_b.append(mask_b)

            labels_a_list.append(item['labels_a'])
            labels_b_list.append(item['labels_b'])
            piece_ids_a.append(item['piece_id'])
            piece_ids_b.append(item['piece_id'])

        if not ids_a:
            return None

        collated['input_ids_a'] = torch.stack(ids_a)
        collated['input_ids_b'] = torch.stack(ids_b)
        collated['mask_a'] = torch.stack(masks_a)
        collated['mask_b'] = torch.stack(masks_b)
        collated['labels_a'] = torch.stack(labels_a_list)
        collated['labels_b'] = torch.stack(labels_b_list)
        collated['piece_ids_a'] = torch.tensor(piece_ids_a)
        collated['piece_ids_b'] = torch.tensor(piece_ids_b)
        return collated

    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=symbolic_collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=symbolic_collate, num_workers=2)

    trainer_ft = train_model(model, finetune_loader, finetune_val_loader, 'S1', fold_idx, max_epochs=FINETUNE_EPOCHS)
    s1_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 6. Train S2: GNN on Score Graph

In [None]:
S2_CONFIG = {
    'node_features': 6,
    'hidden_dim': 512,
    'num_layers': 4,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS_S2 = 50
FINETUNE_EPOCHS_S2 = 150

print('Training S2: GNN on Score Graph')
print('=' * 50)
print('Note: S2 requires graph-structured MIDI data (node features + edge indices).')
print('Graph construction from MIDI is dataset-specific and must be prepared separately.')
print('See model_improvement/data.py for graph construction utilities.')

# Graph data loading would look like:
# from torch_geometric.data import Data, Batch
# from torch_geometric.loader import DataLoader as PyGDataLoader
#
# For each MIDI file, build a graph:
#   x = [pitch, velocity, onset, duration, pedal, voice] per note (node features)
#   edge_index = temporal adjacency + harmonic intervals + voice grouping
#   batch_vec = graph batch assignment
#
# S2 pretrain: link prediction (pos_edges, neg_edges)
# S2 finetune: pairwise ranking on graph embeddings

# Placeholder: S2 training follows same fold loop pattern as S1
# but uses PyG DataLoader and graph-structured batches.
# Uncomment and adapt once graph data is prepared.

s2_trainers = []
print('\nS2 training requires graph-structured data preparation.')
print('Skipping until MIDI-to-graph pipeline is implemented.')

## 7. Train S3: Continuous Feature Encoder

In [None]:
S3_CONFIG = {
    'input_channels': 5,  # pitch, velocity, density, pedal, IOI
    'hidden_dim': 512,
    'num_labels': 19,
    'max_epochs': 200,
}

PRETRAIN_EPOCHS_S3 = 50
FINETUNE_EPOCHS_S3 = 150

print('Training S3: Continuous Feature Encoder')
print('=' * 50)

s3_trainers = []
for fold_idx, fold in enumerate(folds):
    print(f'\nFold {fold_idx}/{len(folds)-1}')

    # --- Pretrain: Contrastive with Gumbel codebook ---
    print('  Pretrain: Contrastive + codebook quantization')

    train_keys = [k for k in fold['train'] if k in continuous_features]
    val_keys = [k for k in fold['val'] if k in continuous_features]

    # Build pretrain datasets with masked features
    class ContinuousPretrainDataset(torch.utils.data.Dataset):
        def __init__(self, keys, features_dict, max_len=2000, mask_prob=0.15):
            self.keys = keys
            self.features = features_dict
            self.max_len = max_len
            self.mask_prob = mask_prob

        def __len__(self):
            return len(self.keys)

        def __getitem__(self, idx):
            key = self.keys[idx]
            feat = self.features[key]  # [T, C]
            T, C = feat.shape

            # Truncate/pad to max_len
            if T > self.max_len:
                feat = feat[:self.max_len]
                T = self.max_len

            mask = torch.ones(self.max_len, dtype=torch.bool)
            if T < self.max_len:
                padding = torch.zeros(self.max_len - T, C)
                feat = torch.cat([feat, padding], dim=0)
                mask[T:] = False

            # Create masked version
            masked_feat = feat.clone()
            masked_positions = torch.zeros(self.max_len, dtype=torch.bool)
            rand = torch.rand(self.max_len)
            mask_candidates = mask.clone()
            masked_positions = mask_candidates & (rand < self.mask_prob)
            masked_feat[masked_positions] = 0.0

            return {
                'features': feat,
                'mask': mask,
                'masked_features': masked_feat,
                'masked_positions': masked_positions,
            }

    pretrain_ds = ContinuousPretrainDataset(train_keys, continuous_features)
    pretrain_val_ds = ContinuousPretrainDataset(val_keys, continuous_features)

    pretrain_loader = DataLoader(pretrain_ds, batch_size=8, shuffle=True, num_workers=2)
    pretrain_val_loader = DataLoader(pretrain_val_ds, batch_size=8, shuffle=False, num_workers=2)

    model = ContinuousSymbolicEncoder(**S3_CONFIG, stage='pretrain')
    trainer_pt = train_model(
        model, pretrain_loader, pretrain_val_loader,
        'S3_pretrain', fold_idx,
        max_epochs=PRETRAIN_EPOCHS_S3,
        monitor='val_contrastive_loss',
    )

    # --- Finetune: Pairwise ranking ---
    print('  Finetune: Pairwise ranking')
    model.stage = 'finetune'

    def continuous_collate(batch):
        """Custom collate for continuous feature pairwise data."""
        max_len = 2000
        feats_a, feats_b = [], []
        masks_a, masks_b = [], []
        labels_a_list, labels_b_list = [], []
        piece_ids_a, piece_ids_b = [], []

        for item in batch:
            key_a, key_b = item['key_a'], item['key_b']
            if key_a not in continuous_features or key_b not in continuous_features:
                continue

            def pad_feat(feat):
                T, C = feat.shape
                if T > max_len:
                    feat = feat[:max_len]
                    T = max_len
                m = torch.ones(max_len, dtype=torch.bool)
                if T < max_len:
                    padding = torch.zeros(max_len - T, C)
                    feat = torch.cat([feat, padding], dim=0)
                    m[T:] = False
                return feat, m

            fa, ma = pad_feat(continuous_features[key_a])
            fb, mb = pad_feat(continuous_features[key_b])

            feats_a.append(fa)
            feats_b.append(fb)
            masks_a.append(ma)
            masks_b.append(mb)
            labels_a_list.append(item['labels_a'])
            labels_b_list.append(item['labels_b'])
            piece_ids_a.append(item['piece_id'])
            piece_ids_b.append(item['piece_id'])

        if not feats_a:
            return None

        return {
            'features_a': torch.stack(feats_a),
            'features_b': torch.stack(feats_b),
            'mask_a': torch.stack(masks_a),
            'mask_b': torch.stack(masks_b),
            'labels_a': torch.stack(labels_a_list),
            'labels_b': torch.stack(labels_b_list),
            'piece_ids_a': torch.tensor(piece_ids_a),
            'piece_ids_b': torch.tensor(piece_ids_b),
        }

    finetune_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=[k for k in fold['train'] if k in continuous_features],
    )
    finetune_val_ds = PairedPerformanceDataset(
        cache_dir=cache_dir,
        labels=labels,
        piece_to_keys=piece_to_keys,
        keys=[k for k in fold['val'] if k in continuous_features],
    )

    finetune_loader = DataLoader(finetune_ds, batch_size=8, shuffle=True, collate_fn=continuous_collate, num_workers=2)
    finetune_val_loader = DataLoader(finetune_val_ds, batch_size=8, shuffle=False, collate_fn=continuous_collate, num_workers=2)

    trainer_ft = train_model(model, finetune_loader, finetune_val_loader, 'S3', fold_idx, max_epochs=FINETUNE_EPOCHS_S3)
    s3_trainers.append(trainer_ft)

    best_val = trainer_ft.callback_metrics.get('val_loss', float('inf'))
    best_acc = trainer_ft.callback_metrics.get('val_pairwise_acc', 0.0)
    print(f'  Best val_loss={best_val:.4f}, val_pairwise_acc={best_acc:.4f}')

## 8. Training Summary

In [None]:
print('Symbolic Encoder Training Summary')
print('=' * 60)
print(f'{"Model":<10} {"Fold":<6} {"Val Loss":<12} {"Pairwise Acc":<14}')
print('-' * 60)

for name, trainers in [('S1', s1_trainers), ('S3', s3_trainers)]:
    for fold_idx, trainer in enumerate(trainers):
        val_loss = trainer.callback_metrics.get('val_loss', float('nan'))
        val_acc = trainer.callback_metrics.get('val_pairwise_acc', float('nan'))
        print(f'{name:<10} {fold_idx:<6} {val_loss:<12.4f} {val_acc:<14.4f}')

if not s2_trainers:
    print('S2       (skipped - graph data not yet prepared)')

print('\nCheckpoints saved to:', CHECKPOINT_DIR)
print('Checkpoints synced to Google Drive via rclone')

## 9. Upload Final Results

In [None]:
for model_name in ['S1', 'S1_pretrain', 'S2', 'S3', 'S3_pretrain']:
    local = CHECKPOINT_DIR / model_name
    if local.exists():
        upload_checkpoint(local, model_name)
        print(f'Uploaded {model_name} checkpoints')

print('\nAll symbolic encoder training complete.')
print('Run 04_symbolic_comparison.ipynb to evaluate and compare results.')