# Moonbeam Quickstart (Google Colab, free GPU)

This notebook runs **end-to-end inference** with the pretrained **Moonbeam 309M** checkpoint and writes `out.mid` with **no dataset and no finetuning**.

## 1) Runtime setup
In Colab: `Runtime -> Change runtime type -> GPU`.

In [None]:
import os
import subprocess

print(subprocess.check_output(["nvidia-smi"], text=True))

## 2) Clone repo and install dependencies (exact README commands)

In [None]:
import os

REPO_URL = "https://github.com/guozixunnicolas/Moonbeam-MIDI-Foundation-Model.git"
REPO_DIR = "/content/Moonbeam-MIDI-Foundation-Model"

if not os.path.isdir(REPO_DIR):
    !git clone $REPO_URL $REPO_DIR
else:
    print("Repo already exists; syncing to latest origin/main (fallback origin/master).")
    !git -C $REPO_DIR fetch origin
    !git -C $REPO_DIR reset --hard origin/main || git -C $REPO_DIR reset --hard origin/master

%cd /content/Moonbeam-MIDI-Foundation-Model
!pip install .
!pip install src/llama_recipes/transformers_minimal/.


## 3) Download pretrained checkpoint from Hugging Face

In [None]:
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="guozixunnicolas/moonbeam-midi-foundation-model",
    filename="moonbeam_309M.pt",
    local_dir="checkpoints/pretrained",
)
print("checkpoint:", ckpt_path)

## 4) Resolve config + tokenizer paths used by repo
- Model config: `src/llama_recipes/configs/model_config.json`
- Tokenizer: search for `tokenizer.model` in repo, fallback to benchmark tokenizer path.

In [None]:
from pathlib import Path
import json
import subprocess
import torch

repo_root = Path.cwd()
primary_model_config_path = repo_root / "src/llama_recipes/configs/model_config.json"
small_model_config_path = repo_root / "src/llama_recipes/configs/model_config_small.json"
assert primary_model_config_path.exists(), f"Missing model config: {primary_model_config_path}"

# Search for tokenizer.model in repo.
search = subprocess.run(
    ["bash", "-lc", "rg --files | rg 'tokenizer\.model$'"],
    cwd=repo_root,
    text=True,
    capture_output=True,
    check=False,
)
found = [line.strip() for line in search.stdout.splitlines() if line.strip()]
print("tokenizer.model candidates:", found)

if found:
    tokenizer_path = repo_root / found[0]
else:
    tokenizer_path = repo_root / "recipes/benchmarks/inference_throughput/tokenizer/tokenizer.model"

assert tokenizer_path.exists(), f"Missing tokenizer file: {tokenizer_path}"

# Detect which config matches checkpoint tensor shapes (309M checkpoint expects *_small config).
resolved_model_config_path = primary_model_config_path
if 'ckpt_path' in globals() and small_model_config_path.exists():
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    state = checkpoint.get('model_state_dict') if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
    if isinstance(state, dict):
        norm_key = 'model.norm.weight'
        ckpt_hidden_size = state.get(norm_key).shape[0] if norm_key in state else None
        if ckpt_hidden_size is not None:
            with open(primary_model_config_path) as f:
                primary_hidden = json.load(f).get('hidden_size')
            with open(small_model_config_path) as f:
                small_hidden = json.load(f).get('hidden_size')
            if ckpt_hidden_size == small_hidden and ckpt_hidden_size != primary_hidden:
                resolved_model_config_path = small_model_config_path

print("using model_config_path:", resolved_model_config_path)
print("using tokenizer_path:", tokenizer_path)


## 5) Add dataset-free inference entrypoint (SOS-only prompt)
This avoids the existing CSV + `.npy` prompt requirement.

In [None]:
# Ensure we are in the cloned repo (some Colab workflows can change cwd).
%cd /content/Moonbeam-MIDI-Foundation-Model

from pathlib import Path
import textwrap

entrypoint = Path("recipes/inference/custom_music_generation/unconditional_from_scratch.py")
entrypoint.parent.mkdir(parents=True, exist_ok=True)
entrypoint.write_text(textwrap.dedent('''from pathlib import Path
from typing import Optional

import fire
import torch
from transformers import LlamaConfig, LlamaForCausalLM

from generation import MusicLlama
from llama_recipes.datasets.music_tokenizer import MusicTokenizer


def _normalize_checkpoint_state_dict(checkpoint: dict) -> dict:
    if not isinstance(checkpoint, dict):
        raise ValueError("Checkpoint must be a dict-like object.")

    if "model_state_dict" in checkpoint and isinstance(checkpoint["model_state_dict"], dict):
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint and isinstance(checkpoint["state_dict"], dict):
        state_dict = checkpoint["state_dict"]
    elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint

    normalized = {}
    for k, v in state_dict.items():
        normalized[k[7:] if k.startswith("module.") else k] = v
    return normalized


def _resolve_model_config_path(ckpt_path: str, model_config_path: str) -> str:
    model_config = Path(model_config_path)
    small_config = model_config.with_name("model_config_small.json")
    if not small_config.exists():
        return str(model_config)

    try:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    except TypeError:
        checkpoint = torch.load(ckpt_path, map_location="cpu")

    state_dict = _normalize_checkpoint_state_dict(checkpoint)
    ckpt_hidden_size = state_dict.get("model.norm.weight").shape[0] if "model.norm.weight" in state_dict else None
    if ckpt_hidden_size is None:
        return str(model_config)

    cfg_hidden = LlamaConfig.from_pretrained(str(model_config)).hidden_size
    small_hidden = LlamaConfig.from_pretrained(str(small_config)).hidden_size
    if ckpt_hidden_size == small_hidden and ckpt_hidden_size != cfg_hidden:
        print(f"[info] Switching model config to checkpoint-compatible file: {small_config}")
        return str(small_config)
    return str(model_config)


def _build_music_llama(
    ckpt_path: str,
    model_config_path: str,
    seed: int,
) -> MusicLlama:
    model_config_path = _resolve_model_config_path(ckpt_path, model_config_path)
    config = LlamaConfig.from_pretrained(model_config_path)
    model = LlamaForCausalLM(config)

    try:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    except TypeError:
        checkpoint = torch.load(ckpt_path, map_location="cpu")

    state_dict = _normalize_checkpoint_state_dict(checkpoint)
    model_state = model.state_dict()
    filtered_state = {
        k: v for k, v in state_dict.items() if k in model_state and getattr(v, "shape", None) == model_state[k].shape
    }
    skipped = len(state_dict) - len(filtered_state)
    missing, unexpected = model.load_state_dict(filtered_state, strict=False)
    print(f"[info] Loaded keys: {len(filtered_state)} | skipped: {skipped} | missing: {len(missing)} | unexpected: {len(unexpected)}")

    if torch.cuda.is_available():
        model = model.to("cuda")
    model.eval()

    tokenizer = MusicTokenizer(
        timeshift_vocab_size=config.onset_vocab_size,
        dur_vocab_size=config.dur_vocab_size,
        octave_vocab_size=config.octave_vocab_size,
        pitch_class_vocab_size=config.pitch_class_vocab_size,
        instrument_vocab_size=config.instrument_vocab_size,
        velocity_vocab_size=config.velocity_vocab_size,
    )

    if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        model = model.to(torch.bfloat16)

    return MusicLlama(model, tokenizer, config)


def main(
    ckpt_path: str,
    model_config_path: str = "src/llama_recipes/configs/model_config.json",
    tokenizer_path: str = "recipes/benchmarks/inference_throughput/tokenizer/tokenizer.model",
    output_midi_path: str = "out.mid",
    temperature: float = 0.9,
    top_p: float = 0.95,
    max_seq_len: int = 512,
    max_gen_len: int = 256,
    seed: int = 42,
    finetuned_PEFT_weight_path: Optional[str] = None,
) -> None:
    """Generate a MIDI file from an SOS-only prompt (no dataset required)."""
    del tokenizer_path, max_seq_len, finetuned_PEFT_weight_path

    torch.manual_seed(seed)
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA GPU is required for this Colab quickstart. In Colab, set Runtime -> Change runtime type -> GPU.")
    torch.cuda.manual_seed_all(seed)

    generator = _build_music_llama(
        ckpt_path=ckpt_path,
        model_config_path=model_config_path,
        seed=seed,
    )

    sos_prompt = [generator.tokenizer.sos_token_compound]
    result = generator.music_completion(
        prompt_tokens=[sos_prompt],
        temperature=temperature,
        top_p=top_p,
        max_gen_len=max_gen_len,
    )[0]

    output_path = Path(output_midi_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    result["generation"]["content"].save(str(output_path))
    print(f"Saved MIDI to: {output_path.resolve()}")


if __name__ == "__main__":
    fire.Fire(main)
'''))
print(f"Wrote/updated entrypoint: {entrypoint}")
assert "_build_music_llama" in entrypoint.read_text(), "Entrypoint write failed; missing expected loader helper."

!python recipes/inference/custom_music_generation/unconditional_from_scratch.py --help


## 6) Generate MIDI from scratch (no dataset)

In [None]:
%cd /content/Moonbeam-MIDI-Foundation-Model

from pathlib import Path
import torch
from transformers import LlamaConfig, LlamaForCausalLM

import sys

# Avoid conflict with external `recipes` package (e.g., torchtune) by importing local generation.py directly.
sys.path.insert(0, str(Path("recipes/inference/custom_music_generation").resolve()))

from generation import MusicLlama
from llama_recipes.datasets.music_tokenizer import MusicTokenizer


def _normalize_checkpoint_state_dict(checkpoint: dict) -> dict:
    if not isinstance(checkpoint, dict):
        raise ValueError("Checkpoint must be a dict-like object.")
    if "model_state_dict" in checkpoint and isinstance(checkpoint["model_state_dict"], dict):
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint and isinstance(checkpoint["state_dict"], dict):
        state_dict = checkpoint["state_dict"]
    elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint
    return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}


if not torch.cuda.is_available():
    raise RuntimeError("CUDA GPU is required for this Colab quickstart. In Colab, set Runtime -> Change runtime type -> GPU.")

cfg_path = str(resolved_model_config_path)
config = LlamaConfig.from_pretrained(cfg_path)
model = LlamaForCausalLM(config)

try:
    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
except TypeError:
    checkpoint = torch.load(ckpt_path, map_location="cpu")

state_dict = _normalize_checkpoint_state_dict(checkpoint)
model_state = model.state_dict()
filtered_state = {k: v for k, v in state_dict.items() if k in model_state and getattr(v, "shape", None) == model_state[k].shape}
missing, unexpected = model.load_state_dict(filtered_state, strict=False)
print(f"[info] Loaded keys: {len(filtered_state)} | missing: {len(missing)} | unexpected: {len(unexpected)}")

model = model.to("cuda")
if torch.cuda.is_bf16_supported():
    model = model.to(torch.bfloat16)
model.eval()

tokenizer = MusicTokenizer(
    timeshift_vocab_size=config.onset_vocab_size,
    dur_vocab_size=config.dur_vocab_size,
    octave_vocab_size=config.octave_vocab_size,
    pitch_class_vocab_size=config.pitch_class_vocab_size,
    instrument_vocab_size=config.instrument_vocab_size,
    velocity_vocab_size=config.velocity_vocab_size,
)

generator = MusicLlama(model, tokenizer, config)
sos_prompt = [generator.tokenizer.sos_token_compound]
result = generator.music_completion(
    prompt_tokens=[sos_prompt],
    temperature=0.9,
    top_p=0.95,
    max_gen_len=256,
)[0]

out_path = Path("out.mid")
out_path.parent.mkdir(parents=True, exist_ok=True)
result["generation"]["content"].save(str(out_path))
print(f"Saved MIDI to: {out_path.resolve()}")


## 7) Verify output and (optional) render to audio preview

In [None]:
from pathlib import Path

out_path = Path("out.mid")
assert out_path.exists() and out_path.stat().st_size > 0, "out.mid was not created"
print("âœ… Generated:", out_path.resolve(), "size:", out_path.stat().st_size, "bytes")

In [None]:
# Optional audio preview if dependencies are available.
# If synthesis backends are unavailable in Colab, this cell may be skipped.

!pip install pretty_midi midi2audio

import pretty_midi
from IPython.display import Audio, display

midi = pretty_midi.PrettyMIDI("out.mid")
# Attempt software synthesis (requires fluidsynth backend in runtime)
audio = midi.synthesize(fs=16000)
display(Audio(audio, rate=16000))

## Notes on checkpoint compatibility
`MusicLlama.build()` in this fork was updated to accept multiple checkpoint layouts:
- `{"model_state_dict": ...}`
- `{"state_dict": ...}`
- `{"model": ...}`
- or a raw state dict

It also strips `module.` prefixes when present.