# Moonbeam Quickstart (Google Colab, free GPU)

This notebook runs **end-to-end inference** with the pretrained **Moonbeam 309M** checkpoint and writes `out.mid` with **no dataset and no finetuning**.

## 1) Runtime setup
In Colab: `Runtime -> Change runtime type -> GPU`.

In [None]:
import os
import subprocess

print(subprocess.check_output(["nvidia-smi"], text=True))

## 2) Clone repo and install dependencies (exact README commands)

In [None]:
import os

REPO_URL = "https://github.com/guozixunnicolas/Moonbeam-MIDI-Foundation-Model.git"
REPO_DIR = "/content/Moonbeam-MIDI-Foundation-Model"

if not os.path.isdir(REPO_DIR):
    !git clone $REPO_URL $REPO_DIR
else:
    print("Repo already exists; syncing to latest origin/main (fallback origin/master).")
    !git -C $REPO_DIR fetch origin
    !git -C $REPO_DIR reset --hard origin/main || git -C $REPO_DIR reset --hard origin/master

%cd /content/Moonbeam-MIDI-Foundation-Model
!pip install .
!pip install src/llama_recipes/transformers_minimal/.


## 3) Download pretrained checkpoint from Hugging Face

In [None]:
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="guozixunnicolas/moonbeam-midi-foundation-model",
    filename="moonbeam_309M.pt",
    local_dir="checkpoints/pretrained",
)
print("checkpoint:", ckpt_path)

## 3b) (Optional) Upload a LoRA adapter for generation
If you trained a LoRA separately, upload a `.zip` containing `adapter_config.json` + adapter weights.


In [None]:
# Optional LoRA upload controls
USE_UPLOADED_LORA = False  #@param {type:"boolean"}
UPLOADED_LORA_ZIP = ""  #@param {type:"string"}
FINETUNED_PEFT_WEIGHT_PATH = None

if USE_UPLOADED_LORA:
    from google.colab import files
    import zipfile
    from pathlib import Path

    if UPLOADED_LORA_ZIP.strip():
        zip_path = Path(UPLOADED_LORA_ZIP)
    else:
        uploaded = files.upload()
        if not uploaded:
            raise RuntimeError("No LoRA zip uploaded.")
        zip_path = Path(next(iter(uploaded.keys())))

    out_dir = Path("uploaded_lora")
    out_dir.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(out_dir)

    # find adapter root folder
    candidates = [d for d in [out_dir, *out_dir.rglob('*')] if d.is_dir() and (d / 'adapter_config.json').exists()]
    if not candidates:
        raise RuntimeError("Could not find adapter_config.json in uploaded LoRA zip.")
    FINETUNED_PEFT_WEIGHT_PATH = str(candidates[0].resolve())
    print(f"Using uploaded LoRA adapter: {FINETUNED_PEFT_WEIGHT_PATH}")
else:
    print("LoRA disabled. Set USE_UPLOADED_LORA=True to upload/apply adapter.")


## 4) Resolve config + tokenizer paths used by repo
- Model config: `src/llama_recipes/configs/model_config.json`
- Tokenizer: search for `tokenizer.model` in repo, fallback to benchmark tokenizer path.

In [None]:
from pathlib import Path
import json
import subprocess
import torch

repo_root = Path.cwd()
primary_model_config_path = repo_root / "src/llama_recipes/configs/model_config.json"
small_model_config_path = repo_root / "src/llama_recipes/configs/model_config_small.json"
assert primary_model_config_path.exists(), f"Missing model config: {primary_model_config_path}"

# Search for tokenizer.model in repo.
search = subprocess.run(
    ["bash", "-lc", "rg --files | rg 'tokenizer\.model$'"],
    cwd=repo_root,
    text=True,
    capture_output=True,
    check=False,
)
found = [line.strip() for line in search.stdout.splitlines() if line.strip()]
print("tokenizer.model candidates:", found)

if found:
    tokenizer_path = repo_root / found[0]
else:
    tokenizer_path = repo_root / "recipes/benchmarks/inference_throughput/tokenizer/tokenizer.model"

assert tokenizer_path.exists(), f"Missing tokenizer file: {tokenizer_path}"

# Detect which config matches checkpoint tensor shapes (309M checkpoint expects *_small config).
resolved_model_config_path = primary_model_config_path
if 'ckpt_path' in globals() and small_model_config_path.exists():
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    state = checkpoint.get('model_state_dict') if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
    if isinstance(state, dict):
        normalized_state = {k[7:] if k.startswith('module.') else k: v for k, v in state.items()}
        norm_key = 'model.norm.weight'
        ckpt_hidden_size = normalized_state.get(norm_key).shape[0] if norm_key in normalized_state else None
        if ckpt_hidden_size is not None:
            with open(primary_model_config_path) as f:
                primary_hidden = json.load(f).get('hidden_size')
            with open(small_model_config_path) as f:
                small_hidden = json.load(f).get('hidden_size')
            if ckpt_hidden_size == small_hidden and ckpt_hidden_size != primary_hidden:
                resolved_model_config_path = small_model_config_path

print("using model_config_path:", resolved_model_config_path)
print("using tokenizer_path:", tokenizer_path)


## 5) Add dataset-free inference entrypoint (SOS-only prompt)
This avoids the existing CSV + `.npy` prompt requirement.

In [None]:
# Ensure we are in the cloned repo (some Colab workflows can change cwd).
%cd /content/Moonbeam-MIDI-Foundation-Model

from pathlib import Path
import textwrap

entrypoint = Path("recipes/inference/custom_music_generation/unconditional_from_scratch.py")
generation_impl = Path("recipes/inference/custom_music_generation/generation.py")

# Some upstream/older clones do not include this helper entrypoint yet.
# If missing, bootstrap a compatible script so the notebook remains runnable.
if not entrypoint.exists():
    print(f"[info] Missing {entrypoint}; creating compatibility entrypoint.")
    entrypoint.parent.mkdir(parents=True, exist_ok=True)
    entrypoint.write_text(textwrap.dedent('''
from pathlib import Path
from typing import Optional

import fire
import torch
from transformers import LlamaConfig, LlamaForCausalLM

from generation import MusicLlama
from llama_recipes.datasets.music_tokenizer import MusicTokenizer


def _normalize_checkpoint_state_dict(checkpoint: dict) -> dict:
    if not isinstance(checkpoint, dict):
        raise ValueError("Checkpoint must be a dict-like object.")
    if "model_state_dict" in checkpoint and isinstance(checkpoint["model_state_dict"], dict):
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint and isinstance(checkpoint["state_dict"], dict):
        state_dict = checkpoint["state_dict"]
    elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint
    return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}


def _build_music_llama(ckpt_path: str, model_config_path: str) -> MusicLlama:
    config = LlamaConfig.from_pretrained(model_config_path)
    model = LlamaForCausalLM(config)
    try:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    except TypeError:
        checkpoint = torch.load(ckpt_path, map_location="cpu")
    state_dict = _normalize_checkpoint_state_dict(checkpoint)
    model_state = model.state_dict()
    missing = sorted(set(model_state.keys()) - set(state_dict.keys()))
    unexpected = sorted(set(state_dict.keys()) - set(model_state.keys()))
    shape_mismatch = sorted([k for k in (set(state_dict.keys()) & set(model_state.keys())) if getattr(state_dict[k], "shape", None) != model_state[k].shape])
    if missing or unexpected or shape_mismatch:
        raise RuntimeError(
            f"Checkpoint/config mismatch (strict mode): missing={len(missing)}, unexpected={len(unexpected)}, shape_mismatch={len(shape_mismatch)}"
        )
    model.load_state_dict(state_dict, strict=True)
    model = model.to("cuda")
    if torch.cuda.is_bf16_supported():
        model = model.to(torch.bfloat16)
    model.eval()
    tokenizer = MusicTokenizer(
        timeshift_vocab_size=config.onset_vocab_size,
        dur_vocab_size=config.dur_vocab_size,
        octave_vocab_size=config.octave_vocab_size,
        pitch_class_vocab_size=config.pitch_class_vocab_size,
        instrument_vocab_size=config.instrument_vocab_size,
        velocity_vocab_size=config.velocity_vocab_size,
    )
    return MusicLlama(model, tokenizer, config)


def main(
    ckpt_path: str,
    model_config_path: str = "src/llama_recipes/configs/model_config.json",
    output_midi_path: str = "out.mid",
    temperature: float = 1.0,
    top_p: float = 0.95,
    max_gen_len: int = 512,
    seed: int = 42,
    num_samples: int = 1,
) -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA GPU is required.")
    generator = _build_music_llama(ckpt_path=ckpt_path, model_config_path=model_config_path)
    output_path = Path(output_midi_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    for i in range(max(1, int(num_samples))):
        sample_seed = int(seed) + i
        torch.manual_seed(sample_seed)
        torch.cuda.manual_seed_all(sample_seed)
        out = generator.music_completion(
            prompt_tokens=[[generator.tokenizer.sos_token_compound]],
            temperature=float(temperature),
            top_p=float(top_p),
            max_gen_len=int(max_gen_len),
        )[0]["generation"]["tokens"]
        valid = []
        for row in out:
            if len(row) == 6 and all(isinstance(x, (int, float)) for x in row):
                valid.append([int(x) for x in row])
        if not valid:
            raise RuntimeError("No valid generated rows.")
        path = output_path if int(num_samples) == 1 else output_path.with_name(f"{output_path.stem}_{i+1}{output_path.suffix}")
        generator.tokenizer.compound_to_midi(valid).save(str(path))
        print(f"Saved {path.resolve()} with {len(valid)} tokens")


if __name__ == "__main__":
    fire.Fire(main)
'''))

assert entrypoint.exists(), f"Failed to create entrypoint: {entrypoint}"
assert generation_impl.exists(), f"Missing generation implementation: {generation_impl}"
print(f"Using repo entrypoint: {entrypoint.resolve()}")
print(f"Using generator implementation: {generation_impl.resolve()}")

# Auto-heal older copies in Colab if EOS handling is stale, then verify.
gen_src = generation_impl.read_text()
if "torch.any(eos_conditions_all_attr" in gen_src and "torch.all(eos_conditions_all_attr" not in gen_src:
    print("[info] Patching EOS handling in generation.py to require full EOS compound token...")
    gen_src = gen_src.replace("torch.any(eos_conditions_all_attr, dim = -1)", "torch.all(eos_conditions_all_attr, dim = -1)")
    generation_impl.write_text(gen_src)

gen_src = generation_impl.read_text()
assert "torch.all(eos_conditions_all_attr" in gen_src, "generation.py is missing full-EOS handling fix."

!python recipes/inference/custom_music_generation/unconditional_from_scratch.py --help


## 6) Recommended default flow: MIDI continuation (prompted generation)
Continuation is typically more structured and musically coherent than pure from-scratch sampling.


### Quality presets
Choose a preset to adjust sampling only (no weight changes).


In [None]:
QUALITY_PRESET = "balanced"  #@param ["conservative","balanced","creative"]
USE_RANDOM_SEED = False  #@param {type:"boolean"}
BASE_SEED_DEFAULT = 42

PRESETS = {
    "conservative": {"temperature": 0.85, "top_p": 0.90, "scratch_len": 384, "cont_len": 256},
    "balanced": {"temperature": 0.95, "top_p": 0.93, "scratch_len": 512, "cont_len": 384},
    "creative": {"temperature": 1.10, "top_p": 0.98, "scratch_len": 640, "cont_len": 512},
}
preset = PRESETS[QUALITY_PRESET]
if USE_RANDOM_SEED:
    import random
    BASE_SEED_DEFAULT = random.randint(0, 2**31 - 1)

print("Selected quality preset:")
print(f"  preset={QUALITY_PRESET} temp={preset['temperature']} top_p={preset['top_p']} scratch_len={preset['scratch_len']} cont_len={preset['cont_len']} base_seed={BASE_SEED_DEFAULT}")


### Optional: choose how many songs to generate
Set `NUM_GENERATIONS` to 1, 2, 3, 4, etc., and rerun cell 6.


In [None]:
# Optional advanced mode: from-scratch generation controls
ADVANCED_RUN_SCRATCH = False  #@param {type:"boolean"}
NUM_GENERATIONS = 1  #@param {type:"integer"}
BASE_SEED = int(globals().get("BASE_SEED_DEFAULT", 42))  #@param {type:"integer"}
VARIATION_OFFSET = 0  #@param {type:"integer"}
MAX_GEN_LEN = int(globals().get("preset", {}).get("scratch_len", 512))  #@param {type:"integer"}
TEMPERATURE = float(globals().get("preset", {}).get("temperature", 0.95))  #@param {type:"number"}
TOP_P = float(globals().get("preset", {}).get("top_p", 0.93))  #@param {type:"number"}
print(f"Scratch advanced mode -> enabled={ADVANCED_RUN_SCRATCH}, num={NUM_GENERATIONS}, base_seed={BASE_SEED}, max_gen_len={MAX_GEN_LEN}, temperature={TEMPERATURE}, top_p={TOP_P}")


In [None]:
%cd /content/Moonbeam-MIDI-Foundation-Model

from pathlib import Path
import gc, io, contextlib, sys
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from peft import PeftModel

sys.path.insert(0, str(Path("recipes/inference/custom_music_generation").resolve()))
from generation import MusicLlama
from llama_recipes.datasets.music_tokenizer import MusicTokenizer

if not hasattr(MusicTokenizer, "_orig_convert_from_language_tokens"):
    MusicTokenizer._orig_convert_from_language_tokens = MusicTokenizer.convert_from_language_tokens
def _convert_from_language_tokens_on_device(self, inp):
    out = MusicTokenizer._orig_convert_from_language_tokens(self, inp)
    return out.to(inp.device) if torch.is_tensor(inp) else out
MusicTokenizer.convert_from_language_tokens = _convert_from_language_tokens_on_device

def _normalize_checkpoint_state_dict(checkpoint: dict) -> dict:
    if not isinstance(checkpoint, dict):
        raise ValueError("Checkpoint must be dict-like.")
    if "model_state_dict" in checkpoint and isinstance(checkpoint["model_state_dict"], dict):
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint and isinstance(checkpoint["state_dict"], dict):
        state_dict = checkpoint["state_dict"]
    elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint
    return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}

def _strict_load_or_fail(model, state_dict):
    model_state = model.state_dict()
    missing = sorted(set(model_state.keys()) - set(state_dict.keys()))
    unexpected = sorted(set(state_dict.keys()) - set(model_state.keys()))
    shape_mismatch = sorted([
        k for k in (set(state_dict.keys()) & set(model_state.keys()))
        if getattr(state_dict[k], "shape", None) != model_state[k].shape
    ])
    if missing or unexpected or shape_mismatch:
        def _p(keys):
            return ", ".join(keys[:8]) + (" ..." if len(keys) > 8 else "")
        raise RuntimeError(
            "Checkpoint/config mismatch detected (strict mode). "
            f"missing={len(missing)} ({_p(missing)}), "
            f"unexpected={len(unexpected)} ({_p(unexpected)}), "
            f"shape_mismatch={len(shape_mismatch)} ({_p(shape_mismatch)}). "
            "Use matching checkpoint + config (or model_config_small.json for small checkpoints)."
        )
    model.load_state_dict(state_dict, strict=True)

def _sanitize_tokens(tokenizer, rows):
    # Minimal validation only: drop invalid rows, do not rewrite values.
    cleaned, last_onset = [], -1
    max_onset = max(0, tokenizer.timeshift_vocab_size - 3)
    max_dur = max(0, tokenizer.dur_vocab_size - 3)
    max_oct = max(0, tokenizer.octave_vocab_size - 3)
    max_pitch = max(0, tokenizer.pitch_class_vocab_size - 3)
    max_instr = max(0, tokenizer.instrument_vocab_size - 3)
    max_vel = max(0, tokenizer.velocity_vocab_size - 3)
    for row in rows:
        if len(row) != 6:
            continue
        onset, duration, octave, pitch, instrument, velocity = [int(x) for x in row]
        if onset < 0 or onset > max_onset:
            continue
        if duration <= 0 or duration > max_dur:
            continue
        if onset < last_onset:
            continue
        if not (0 <= octave <= max_oct):
            continue
        if not (0 <= pitch <= max_pitch):
            continue
        if not (0 <= instrument <= max_instr):
            continue
        if not (0 <= velocity <= max_vel):
            continue
        cleaned.append([onset, duration, octave, pitch, instrument, velocity])
        last_onset = onset
    return cleaned

for var_name in ["generator", "model", "checkpoint", "state_dict", "tokenizer"]:
    if var_name in globals():
        del globals()[var_name]
gc.collect()
torch.cuda.empty_cache()
if not torch.cuda.is_available():
    raise RuntimeError("CUDA GPU is required. In Colab set Runtime -> GPU.")

config = LlamaConfig.from_pretrained(str(resolved_model_config_path))
model = LlamaForCausalLM(config)
try:
    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
except TypeError:
    checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = _normalize_checkpoint_state_dict(checkpoint)
_strict_load_or_fail(model, state_dict)
print(f"[info] Strict load successful: {len(state_dict)} tensors")

if FINETUNED_PEFT_WEIGHT_PATH:
    print(f"Applying LoRA adapter from: {FINETUNED_PEFT_WEIGHT_PATH}")
    model = PeftModel.from_pretrained(model, FINETUNED_PEFT_WEIGHT_PATH, is_trainable=False)

model = model.to("cuda")
if torch.cuda.is_bf16_supported():
    model = model.to(torch.bfloat16)
model.eval()

with contextlib.redirect_stdout(io.StringIO()):
    tokenizer = MusicTokenizer(
        timeshift_vocab_size=config.onset_vocab_size,
        dur_vocab_size=config.dur_vocab_size,
        octave_vocab_size=config.octave_vocab_size,
        pitch_class_vocab_size=config.pitch_class_vocab_size,
        instrument_vocab_size=config.instrument_vocab_size,
        velocity_vocab_size=config.velocity_vocab_size,
    )
generator = MusicLlama(model, tokenizer, config)
print("Generator ready. Recommended next step: run continuation cells below.")

if bool(globals().get("ADVANCED_RUN_SCRATCH", False)):
    num_generations = max(1, int(globals().get("NUM_GENERATIONS", 1)))
    base_seed = int(globals().get("BASE_SEED", 42))
    variation_offset = int(globals().get("VARIATION_OFFSET", 0))
    max_gen_len = int(globals().get("MAX_GEN_LEN", 512))
    temperature = float(globals().get("TEMPERATURE", 0.95))
    top_p = float(globals().get("TOP_P", 0.93))
    all_outputs = []
    for i in range(num_generations):
        sample_seed = base_seed + variation_offset + i
        torch.manual_seed(sample_seed)
        torch.cuda.manual_seed_all(sample_seed)
        sos_prompt = [generator.tokenizer.sos_token_compound]
        result = generator.music_completion(
            prompt_tokens=[sos_prompt],
            temperature=temperature,
            top_p=top_p,
            max_gen_len=max_gen_len,
        )[0]
        sanitized_tokens = _sanitize_tokens(generator.tokenizer, result["generation"]["tokens"])
        if not sanitized_tokens:
            raise RuntimeError(f"No valid generated tokens remained after validation for sample {i}.")
        out_path = Path("out.mid") if num_generations == 1 else Path(f"out_{i+1}.mid")
        generator.tokenizer.compound_to_midi(sanitized_tokens).save(str(out_path))
        print(f"Saved MIDI to: {out_path.resolve()} | tokens={len(sanitized_tokens)} | seed={sample_seed}")
        all_outputs.append(out_path)
    print(f"Done. Generated {len(all_outputs)} from-scratch file(s).")
else:
    print("From-scratch generation skipped (ADVANCED_RUN_SCRATCH=False).")


## 6b) Continue from an uploaded MIDI (default path)
Upload a MIDI, then generate continuations using the selected quality preset.


In [None]:
# Continuation controls (default flow)
CONT_NUM_GENERATIONS = 2  #@param {type:"integer"}
CONT_BASE_SEED = int(globals().get("BASE_SEED_DEFAULT", 123))  #@param {type:"integer"}
CONT_VARIATION_OFFSET = 0  #@param {type:"integer"}
CONT_MAX_GEN_LEN = int(globals().get("preset", {}).get("cont_len", 384))  #@param {type:"integer"}
CONT_TEMPERATURE = float(globals().get("preset", {}).get("temperature", 0.95))  #@param {type:"number"}
CONT_TOP_P = float(globals().get("preset", {}).get("top_p", 0.93))  #@param {type:"number"}
CONT_USE_FULL_PROMPT = True  #@param {type:"boolean"}
CONT_PROMPT_MAX_TOKENS = 256  #@param {type:"integer"}
print(f"Continuation -> num={CONT_NUM_GENERATIONS}, base_seed={CONT_BASE_SEED}, max_gen_len={CONT_MAX_GEN_LEN}, temperature={CONT_TEMPERATURE}, top_p={CONT_TOP_P}, use_full_prompt={CONT_USE_FULL_PROMPT}")


In [None]:
%cd /content/Moonbeam-MIDI-Foundation-Model

from pathlib import Path
from google.colab import files
import torch

uploaded = files.upload()
if not uploaded:
    raise RuntimeError("No file uploaded. Upload a .mid file.")

upload_name = next(iter(uploaded.keys()))
input_midi_path = Path(upload_name)
print(f"Uploaded: {input_midi_path}")

prompt_tokens = generator.tokenizer.midi_to_compound(str(input_midi_path))
if not prompt_tokens:
    raise RuntimeError("Uploaded MIDI produced an empty token list.")

use_full_prompt = bool(globals().get("CONT_USE_FULL_PROMPT", True))
prompt_max = max(1, int(globals().get("CONT_PROMPT_MAX_TOKENS", 256)))
if use_full_prompt:
    prompt_tokens_for_gen = prompt_tokens
else:
    prompt_tokens_for_gen = prompt_tokens[-prompt_max:]

num_generations = max(1, int(globals().get("CONT_NUM_GENERATIONS", 1)))
base_seed = int(globals().get("CONT_BASE_SEED", 123))
variation_offset = int(globals().get("CONT_VARIATION_OFFSET", 0))
max_gen_len = int(globals().get("CONT_MAX_GEN_LEN", 384))
temperature = float(globals().get("CONT_TEMPERATURE", 0.9))
top_p = float(globals().get("CONT_TOP_P", 0.95))

continuation_outputs = []
for i in range(num_generations):
    sample_seed = base_seed + variation_offset + i
    torch.manual_seed(sample_seed)
    torch.cuda.manual_seed_all(sample_seed)

    result = generator.music_completion(
        prompt_tokens=[prompt_tokens_for_gen],
        temperature=temperature,
        top_p=top_p,
        max_gen_len=max_gen_len,
    )[0]

    sanitized_tokens = _sanitize_tokens(generator.tokenizer, result["generation"]["tokens"])
    if not sanitized_tokens:
        raise RuntimeError(f"No valid continuation tokens remained after sanitization for sample {i}.")

    out_path = Path(f"cont_{i+1}.mid")
    generator.tokenizer.compound_to_midi(sanitized_tokens).save(str(out_path))
    print(f"Saved continuation: {out_path.resolve()} | tokens={len(sanitized_tokens)} | seed={sample_seed}")
    continuation_outputs.append(out_path)

print(f"Used prompt tokens: {len(prompt_tokens_for_gen)} / original {len(prompt_tokens)}")
print(f"Done. Generated {len(continuation_outputs)} continuation file(s).")


## 7) Verify output and (optional) render to audio preview

In [None]:
from pathlib import Path

outputs = sorted(Path('.').glob('out*.mid')) + sorted(Path('.').glob('cont_*.mid'))
assert outputs, "No MIDI outputs were created"
for out_path in outputs:
    assert out_path.stat().st_size > 0, f"{out_path} is empty"
print("âœ… Generated files:")
for out_path in outputs:
    print(" -", out_path.resolve(), "size:", out_path.stat().st_size, "bytes")


In [None]:
# Optional audio preview if dependencies are available.
# If synthesis backends are unavailable in Colab, this cell may be skipped.

!pip install pretty_midi midi2audio

from pathlib import Path
import pretty_midi
from IPython.display import Audio, display

midi_files = sorted(Path('.').glob('out*.mid')) + sorted(Path('.').glob('cont_*.mid'))
assert midi_files, "No out*.mid/cont_*.mid files found. Run generation cell first."

for midi_path in midi_files:
    print(f"Rendering: {midi_path}")
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    # Attempt software synthesis (requires fluidsynth backend in runtime)
    audio = midi.synthesize(fs=16000)
    display(Audio(audio, rate=16000))


## Notes on checkpoint compatibility
`MusicLlama.build()` in this fork was updated to accept multiple checkpoint layouts:
- `{"model_state_dict": ...}`
- `{"state_dict": ...}`
- `{"model": ...}`
- or a raw state dict

It also strips `module.` prefixes when present.