# 00: Multi-Dataset Data Preparation

One-time setup: download external MIDI datasets, validate, preprocess, and cache for pretraining.

**Datasets:**
- MAESTRO v3 (~1,276 MIDI files, ~81MB)
- ATEPP v1.2 (~11,674 MIDI files, ~212MB)
- ASAP (~1,067 performances, if already cached)
- PercePiano (~1,202 MIDI files, already in data/)

**Output:** `data/pretrain_cache/` containing tokenized sequences, score graphs, and continuous features.

Run this notebook once locally (M4 Mac) or on cloud, then sync to GDrive.

---

## 1. Setup

In [None]:
import subprocess
import sys
import os

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    subprocess.run(['bash', '-c', 'curl -fsSL https://rclone.org/install.sh | sudo bash'], check=True)
    subprocess.run(['git', 'clone', 'https://github.com/Jai-Dhiman/crescendAI.git', '/workspace/crescendai'], check=True)
    os.chdir('/workspace/crescendai/model')
    subprocess.run(['bash', '-c', 'curl -LsSf https://astral.sh/uv/install.sh | sh'], check=True)
    subprocess.run(['uv', 'sync'], check=True)
    subprocess.run(['rclone', 'sync', 'gdrive:crescendai_data/model_improvement/data', './data', '--progress'], check=True)

In [None]:
from pathlib import Path

if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
else:
    DATA_DIR = Path('../data')

sys.path.insert(0, 'src' if IS_REMOTE else '../../model/src')

from model_improvement.datasets import (
    load_all_midi_files,
    load_maestro_midi_files,
    load_atepp_midi_files,
    load_asap_midi_files,
    load_percepiano_midi_files,
)
from model_improvement.preprocessing import (
    preprocess_all,
    preprocess_tokens,
    preprocess_graphs,
    preprocess_continuous_features,
    merge_graph_shards,
    merge_feature_shards,
)

print(f'DATA_DIR: {DATA_DIR.resolve()}')
print(f'Exists: {DATA_DIR.exists()}')

## 2. Download External Datasets

In [None]:
import zipfile

maestro_dir = DATA_DIR / 'maestro_cache'

if maestro_dir.exists() and any(maestro_dir.glob('**/*.midi')):
    print(f'MAESTRO already downloaded at {maestro_dir}')
    print(f'  MIDI files: {len(list(maestro_dir.glob("**/*.midi")))}')
else:
    print('Downloading MAESTRO v3 MIDI-only (~81MB)...')
    maestro_zip = DATA_DIR / 'maestro-v3.0.0-midi.zip'
    subprocess.run([
        'curl', '-L', '-o', str(maestro_zip),
        'https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip'
    ], check=True)

    print('Extracting...')
    with zipfile.ZipFile(maestro_zip, 'r') as zf:
        zf.extractall(DATA_DIR)

    # MAESTRO extracts to maestro-v3.0.0/; rename to maestro_cache
    extracted_dir = DATA_DIR / 'maestro-v3.0.0'
    if extracted_dir.exists():
        if maestro_dir.exists():
            import shutil
            shutil.rmtree(maestro_dir)
        extracted_dir.rename(maestro_dir)

    maestro_zip.unlink()
    print(f'MAESTRO extracted to {maestro_dir}')
    print(f'  MIDI files: {len(list(maestro_dir.glob("**/*.midi")))}')

In [None]:
atepp_dir = DATA_DIR / 'atepp_cache'

if atepp_dir.exists() and any(atepp_dir.glob('**/*.mid')):
    print(f'ATEPP already downloaded at {atepp_dir}')
    print(f'  MIDI files: {len(list(atepp_dir.glob("**/*.mid")))}')
else:
    print('Downloading ATEPP v1.2 (~212MB)...')
    atepp_zip = DATA_DIR / 'ATEPP-1.2.zip'
    subprocess.run([
        'curl', '-L', '-o', str(atepp_zip),
        'https://zenodo.org/records/14997880/files/ATEPP-1.2.zip'
    ], check=True)

    print('Extracting...')
    with zipfile.ZipFile(atepp_zip, 'r') as zf:
        zf.extractall(DATA_DIR)

    # ATEPP extracts to ATEPP-1.2/ or similar; rename to atepp_cache
    for candidate in ['ATEPP-1.2', 'ATEPP']:
        extracted_dir = DATA_DIR / candidate
        if extracted_dir.exists():
            if atepp_dir.exists():
                import shutil
                shutil.rmtree(atepp_dir)
            extracted_dir.rename(atepp_dir)
            break
    else:
        # If extraction created the target directly
        if not atepp_dir.exists():
            raise FileNotFoundError(
                f'ATEPP extraction did not produce expected directory. '
                f'Contents of {DATA_DIR}: {list(DATA_DIR.iterdir())}'
            )

    atepp_zip.unlink()
    print(f'ATEPP extracted to {atepp_dir}')
    print(f'  MIDI files: {len(list(atepp_dir.glob("**/*.mid")))}')

## 3. Validate Sources

In [None]:
entries = load_all_midi_files(DATA_DIR)

# Count by source
from collections import Counter
source_counts = Counter(e.source for e in entries)

print(f'\nValidation:')
for source, count in sorted(source_counts.items()):
    print(f'  {source}: {count}')
print(f'  Total: {len(entries)}')

# Sanity checks
if 'maestro' in source_counts:
    assert source_counts['maestro'] > 1000, f'Expected >1000 MAESTRO files, got {source_counts["maestro"]}'
if 'atepp' in source_counts:
    assert source_counts['atepp'] > 5000, f'Expected >5000 ATEPP files, got {source_counts["atepp"]}'

print('\nAll source validations passed.')

## 4. Preprocess

Each stage skips automatically if the final output file already exists.
Graph and feature pipelines use shard-based processing to keep memory bounded (~200 entries per shard).
Legacy `.partial` checkpoints are migrated to shard format automatically.

In [None]:
pretrain_cache = DATA_DIR / 'pretrain_cache'
pretrain_cache.mkdir(parents=True, exist_ok=True)

# 4a. Tokenization (skips if all_tokens.pt exists)
tokens = preprocess_tokens(entries, pretrain_cache / 'tokens' / 'all_tokens.pt')

# 4b. Graph building -- shard-based, resumes from partial shards
graphs, hetero = preprocess_graphs(
    entries,
    pretrain_cache / 'graphs' / 'all_graphs.pt',
    pretrain_cache / 'graphs' / 'all_hetero_graphs.pt',
)

# 4c. Continuous features -- shard-based
features = preprocess_continuous_features(
    entries, pretrain_cache / 'features' / 'all_features.pt'
)

print(f'\nTokens: {len(tokens)}, Graphs: {len(graphs)}, Hetero: {len(hetero)}, Features: {len(features)}')

## 5. Upload to GDrive

In [None]:
print('Uploading pretrain_cache to GDrive...')
subprocess.run([
    'rclone', 'sync',
    str(pretrain_cache),
    'gdrive:crescendai_data/model_improvement/data/pretrain_cache',
    '--progress',
], check=True)

# Also sync the raw MIDI dirs so other machines can skip downloads
for subdir in ['maestro_cache', 'atepp_cache']:
    src = DATA_DIR / subdir
    if src.exists():
        print(f'Uploading {subdir}...')
        subprocess.run([
            'rclone', 'sync',
            str(src),
            f'gdrive:crescendai_data/model_improvement/data/{subdir}',
            '--progress',
        ], check=True)

print('Upload complete.')

## 6. Spot-check

In [None]:
import torch
import random

tokens = torch.load(pretrain_cache / 'tokens' / 'all_tokens.pt', map_location='cpu', weights_only=False)
graphs = torch.load(pretrain_cache / 'graphs' / 'all_graphs.pt', map_location='cpu', weights_only=False)
features = torch.load(pretrain_cache / 'features' / 'all_features.pt', map_location='cpu', weights_only=False)

print(f'Tokens: {len(tokens)} entries')
print(f'Graphs: {len(graphs)} entries')
print(f'Features: {len(features)} entries')

# Sample 5 per source
for source in ['asap', 'maestro', 'atepp', 'percepiano']:
    source_keys = [k for k in tokens if k.startswith(f'{source}__')]
    if not source_keys:
        print(f'\n{source}: no entries')
        continue

    sample_keys = random.sample(source_keys, min(5, len(source_keys)))
    print(f'\n{source} samples ({len(source_keys)} total):')
    for key in sample_keys:
        tok_len = len(tokens[key]) if key in tokens else 'MISSING'
        node_count = graphs[key].x.shape[0] if key in graphs else 'MISSING'
        feat_shape = tuple(features[key].shape) if key in features else 'MISSING'
        print(f'  {key}: tokens={tok_len}, nodes={node_count}, features={feat_shape}')