# Moonbeam LoRA Style Training (Colab)
Train a style LoRA (e.g., Chopin/Liszt/Bach) from uploaded MIDI files, then export a zip for `Moonbeam_Quickstart.ipynb`.


## 1) Runtime setup
Use **GPU** runtime.


In [None]:
import torch
print('CUDA available:', torch.cuda.is_available())
if not torch.cuda.is_available():
    raise RuntimeError('Please switch Colab runtime to GPU.')


## 2) Clone repo and install dependencies (exact README commands)


In [None]:
import os
from pathlib import Path

repo_dir = Path('/content/Moonbeam-MIDI-Foundation-Model')
if repo_dir.exists():
    os.chdir(repo_dir)
    os.system('git fetch origin --prune')
    if os.system('git reset --hard origin/main') != 0:
        os.system('git reset --hard origin/master')
else:
    os.system('git clone https://github.com/guozixunnicolas/Moonbeam-MIDI-Foundation-Model /content/Moonbeam-MIDI-Foundation-Model')

%cd /content/Moonbeam-MIDI-Foundation-Model
!pip install .
!pip install src/llama_recipes/transformers_minimal/.
!pip install huggingface_hub pandas mido


## 3) Download pretrained checkpoint


In [None]:
from huggingface_hub import hf_hub_download

CKPT_FILENAME = 'moonbeam_309M.pt'  #@param ['moonbeam_309M.pt', 'moonbeam_839M.pt']
ckpt_path = hf_hub_download(
    repo_id='guozixunnicolas/moonbeam-midi-foundation-model',
    filename=CKPT_FILENAME,
)
print('Checkpoint:', ckpt_path)


## 4) Upload style MIDI zip and preprocess (absolute /content paths)


In [None]:
from pathlib import Path
from google.colab import files
import zipfile
import random
import json
import numpy as np
import pandas as pd

from transformers import LlamaConfig
from llama_recipes.datasets.music_tokenizer import MusicTokenizer

STYLE_NAME = 'chopin'  #@param {type:"string"}
TRAIN_RATIO = 0.9  #@param {type:"number"}

BASE_DATA_DIR = Path('/content/processed_datasets/unconditional')
STYLE_DATA_DIR = BASE_DATA_DIR / STYLE_NAME
RAW_DIR = STYLE_DATA_DIR / 'raw_midis'
PROCESSED_DIR = STYLE_DATA_DIR / 'processed'
TRAIN_JSON_DIR = STYLE_DATA_DIR / 'train'
TEST_JSON_DIR = STYLE_DATA_DIR / 'test'
SPLIT_CSV = STYLE_DATA_DIR / f'{STYLE_NAME}_split.csv'

for d in [RAW_DIR, PROCESSED_DIR, TRAIN_JSON_DIR, TEST_JSON_DIR]:
    d.mkdir(parents=True, exist_ok=True)

uploaded = files.upload()
if not uploaded:
    raise RuntimeError('Please upload a zip containing MIDI files.')
zip_name = next(iter(uploaded.keys()))

with zipfile.ZipFile(zip_name, 'r') as zf:
    zf.extractall(RAW_DIR)

midi_files = sorted([p for p in RAW_DIR.rglob('*') if p.suffix.lower() in {'.mid', '.midi'}])
if not midi_files:
    raise RuntimeError('No MIDI files found in uploaded zip.')

cfg = LlamaConfig.from_pretrained('/content/Moonbeam-MIDI-Foundation-Model/src/llama_recipes/configs/model_config.json')
tokenizer = MusicTokenizer(
    timeshift_vocab_size=cfg.onset_vocab_size,
    dur_vocab_size=cfg.dur_vocab_size,
    octave_vocab_size=cfg.octave_vocab_size,
    pitch_class_vocab_size=cfg.pitch_class_vocab_size,
    instrument_vocab_size=cfg.instrument_vocab_size,
    velocity_vocab_size=cfg.velocity_vocab_size,
)

rows = []
valid = 0
for idx, midi_path in enumerate(midi_files):
    try:
        tokens = tokenizer.midi_to_compound(str(midi_path))
        arr = np.asarray(tokens, dtype=np.int16)
        if arr.size == 0:
            continue
        out_name = f'{STYLE_NAME}_{idx:05d}.npy'
        np.save(PROCESSED_DIR / out_name, arr)
        rows.append({'file_base_name': out_name, 'token_length': int(len(tokens))})
        valid += 1
    except Exception as e:
        print(f'[skip] {midi_path.name}: {e}')

if len(rows) < 2:
    raise RuntimeError('Need at least 2 valid MIDI files after preprocessing.')

random.seed(42)
random.shuffle(rows)
split_idx = max(1, min(len(rows)-1, int(len(rows) * float(TRAIN_RATIO))))
for i, r in enumerate(rows):
    r['split'] = 'train' if i < split_idx else 'test'

pd.DataFrame([{'file_base_name': r['file_base_name'], 'split': r['split']} for r in rows]).to_csv(SPLIT_CSV, index=False)

for r in rows:
    target_dir = TRAIN_JSON_DIR if r['split'] == 'train' else TEST_JSON_DIR
    meta = {
        'file_base_name': r['file_base_name'],
        'token_length': r['token_length'],
        'split': r['split'],
    }
    (target_dir / f"{Path(r['file_base_name']).stem}.json").write_text(json.dumps(meta))

print('Valid preprocessed files:', valid)
print('Processed dataset dir:', STYLE_DATA_DIR.resolve())
print('Split CSV:', SPLIT_CSV.resolve())


## 4b) Dataset token-length analyzer


In [None]:
from pathlib import Path
import json
import numpy as np

STYLE_NAME = 'chopin'  #@param {type:"string"}
train_json_dir = Path(f'/content/processed_datasets/unconditional/{STYLE_NAME}/train')

if not train_json_dir.exists():
    print(f'[info] Dataset folder missing: {train_json_dir}')
else:
    json_files = sorted(train_json_dir.glob('*.json'))
    if not json_files:
        print(f'[info] No JSON files found under: {train_json_dir}')
    else:
        lengths = []
        for p in json_files:
            try:
                d = json.loads(p.read_text())
                tl = d.get('token_length')
                if tl is None and isinstance(d.get('tokens'), list):
                    tl = len(d['tokens'])
                if tl is not None:
                    lengths.append(int(tl))
            except Exception as e:
                print(f'[skip] {p.name}: {e}')

        if not lengths:
            print('[info] No valid token lengths found in JSON files.')
        else:
            arr = np.array(lengths)
            p50, p75, p90, p95 = np.percentile(arr, [50, 75, 90, 95])
            print(f'total files: {len(arr)}')
            print(f'min/max/avg: {arr.min()} / {arr.max()} / {arr.mean():.2f}')
            print(f'p50/p75/p90/p95: {p50:.1f} / {p75:.1f} / {p90:.1f} / {p95:.1f}')

            if p95 < 1000:
                rec = 1024
            elif p95 < 2000:
                rec = 2048
            else:
                rec = 4096
            print(f'Recommended context_length: {rec}')


## 5) Run LoRA finetuning (with sanity checks, absolute paths, and clear status)


In [None]:
import subprocess
from pathlib import Path

STYLE_NAME = 'chopin'  #@param {type:"string"}
NUM_EPOCHS = 5  #@param {type:"integer"}
BATCH_SIZE = 1  #@param {type:"integer"}
LR = 0.0003  #@param {type:"number"}
CONTEXT_LENGTH = 1024  #@param {type:"integer"}

style_root = Path(f'/content/processed_datasets/unconditional/{STYLE_NAME}')
split_csv = style_root / f'{STYLE_NAME}_split.csv'
out_dir = Path(f'/content/checkpoints/finetuned_checkpoints/{STYLE_NAME}_lora')
out_dir.mkdir(parents=True, exist_ok=True)

if not style_root.exists():
    raise FileNotFoundError(f'Missing dataset dir: {style_root}')
if not (style_root / 'processed').exists():
    raise FileNotFoundError(f"Missing processed dir: {style_root / 'processed'}")
if not split_csv.exists():
    raise FileNotFoundError(f'Missing split CSV: {split_csv}')

args = [
    'torchrun', '--nnodes', '1', '--nproc_per_node', '1',
    '/content/Moonbeam-MIDI-Foundation-Model/recipes/finetuning/real_finetuning_uncon_gen.py',
    '--lr', str(LR),
    '--val_batch_size', '1',
    '--run_validation', 'True',
    '--validation_interval', '20',
    '--save_metrics', 'True',
    '--dist_checkpoint_root_folder', '/content/checkpoints/finetuned_checkpoints',
    '--dist_checkpoint_folder', f'{STYLE_NAME}_lora_ddp',
    '--trained_checkpoint_path', str(ckpt_path),
    '--ddp_config.pure_bf16', 'True',
    '--enable_ddp', 'True',
    '--use_peft', 'True',
    '--peft_method', 'lora',
    '--quantization', 'False',
    '--model_name', f'moonbeam_{STYLE_NAME}',
    '--dataset', 'lakhmidi_dataset',
    '--lakhmidi_dataset.data_dir', str(style_root),
    '--lakhmidi_dataset.csv_file', str(split_csv),
    '--output_dir', str(out_dir),
    '--batch_size_training', str(BATCH_SIZE),
    '--context_length', str(CONTEXT_LENGTH),
    '--num_epochs', str(NUM_EPOCHS),
    '--use_wandb', 'False',
    '--gamma', '0.99',
]

print('Launching training command:\n' + ' '.join(args))
run = subprocess.run(args, capture_output=True, text=True)
print(run.stdout)
if run.stderr:
    print('--- STDERR ---')
    print(run.stderr)

combined = (run.stdout or '') + '\n' + (run.stderr or '')
unknown_hits = [line for line in combined.splitlines() if 'Warning: unknown parameter' in line]
if unknown_hits:
    raise RuntimeError('Unknown parameter warnings found:\n' + '\n'.join(unknown_hits))

if run.returncode != 0:
    raise RuntimeError(f'Training failed with return code {run.returncode}')

print('Training finished successfully.')
print('Output dir:', out_dir.resolve())


## 6) Package best LoRA adapter + print final training summary


In [None]:
from pathlib import Path
import json
import zipfile
from datetime import datetime
STYLE_NAME = 'chopin'  #@param {type:"string"}
out_dir = Path(f'/content/checkpoints/finetuned_checkpoints/{STYLE_NAME}_lora')
metrics_files = sorted(out_dir.glob('metrics_data_*.json'), key=lambda p: p.stat().st_mtime)
metrics_path = metrics_files[-1] if metrics_files else None
adapter_dirs = [d for d in [out_dir, *out_dir.rglob('*')] if d.is_dir() and (d / 'adapter_config.json').exists()]
if not adapter_dirs:
    raise RuntimeError(f'No adapter folders with adapter_config.json found in {out_dir}')
def parse_folder_key(p: Path):
    stem = p.name.replace('.safetensors','')
    try:
        ep, st = stem.split('-', 1)
        return (int(ep), int(st))
    except Exception:
        return (-1, -1)
adapter_dirs_sorted = sorted(adapter_dirs, key=lambda p: (parse_folder_key(p), p.stat().st_mtime))
selected = None
select_reason = ''
best_eval = None
last_completed_epoch = None
if metrics_path:
    metrics = json.loads(metrics_path.read_text())
    val_losses = [float(x) for x in metrics.get('val_step_loss', []) if x is not None]
    train_epoch_losses = metrics.get('train_epoch_loss', [])
    if train_epoch_losses:
        last_completed_epoch = len(train_epoch_losses)
    if val_losses:
        best_i = min(range(len(val_losses)), key=lambda i: val_losses[i])
        best_eval = val_losses[best_i]
        if best_i < len(adapter_dirs_sorted):
            selected = adapter_dirs_sorted[best_i]
            select_reason = f'best eval loss from metrics index {best_i} (loss={best_eval:.6f})'
if selected is None:
    selected = max(adapter_dirs, key=lambda p: p.stat().st_mtime)
    select_reason = 'most recent adapter folder (metrics missing or insufficient)'
required_files = ['adapter_model.safetensors', 'adapter_config.json']
for rf in required_files:
    if not (selected / rf).exists():
        raise RuntimeError(f'Selected adapter folder missing required file: {selected/rf}')
ts = datetime.now().strftime('%Y%m%d_%H-%M')
zip_path = Path('/content') / f'{STYLE_NAME}_{ts}_lora_adapter.zip'
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    for name in required_files:
        zf.write(selected / name, arcname=name)
    readme = selected / 'README.md'
    if readme.exists():
        zf.write(readme, arcname='README.md')
if last_completed_epoch is None:
    last_completed_epoch = parse_folder_key(max(adapter_dirs_sorted, key=lambda p: parse_folder_key(p)))[0]
print('=== Training Summary ===')
print('Selected adapter folder:', selected.resolve())
print('Selection reason:', select_reason)
print('Last completed epoch:', last_completed_epoch)
print('Best eval loss:', best_eval if best_eval is not None else 'N/A')
print('Metrics JSON path:', str(metrics_path.resolve()) if metrics_path else 'N/A (metrics file not found)')
print('Adapter zip path:', zip_path.resolve())
# Stopping reason heuristic
if metrics_path and best_eval is not None:
    print('Detected stop reason: normal completion (metrics/checkpoints present).')
else:
    print('Detected stop reason: unknown/partial run (check runtime logs for interruption).')
