## Attempt at "vibe-coding" this project

In [None]:
# =========================================
# 0) Setup (install + imports + GPU check)
# =========================================
!pip -q install miditoolkit tqdm

import os, glob, zipfile, random, math
from collections import defaultdict
from dataclasses import dataclass

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import miditoolkit

print("torch:", torch.__version__)
print("cuda available?", torch.cuda.is_available())
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# =========================================
# 1) Upload ZIP of MIDI files -> midi_paths
# =========================================
from google.colab import files
uploaded = files.upload()  # upload your zip of 10 MIDIs (no-metronome)

zip_name = next(iter(uploaded.keys()))
DATA_DIR = "/content/midis"
os.makedirs(DATA_DIR, exist_ok=True)

with zipfile.ZipFile(zip_name, "r") as z:
    z.extractall(DATA_DIR)

midi_paths = sorted(glob.glob(os.path.join(DATA_DIR, "**/*.mid"), recursive=True)) + \
            sorted(glob.glob(os.path.join(DATA_DIR, "**/*.midi"), recursive=True))

print("Found MIDI files:", len(midi_paths))
print("Example paths:", midi_paths[:3])

# =========================================
# 2) Quick scan: tracks/drums/polyphony stats
# =========================================
def max_distinct_pitches_same_start(m):
    inst = m.instruments[0]
    by = defaultdict(set)
    for n in inst.notes:
        by[n.start].add(n.pitch)
    return max((len(v) for v in by.values()), default=0)

def max_simultaneous_notes(notes):
    events = []
    for n in notes:
        events.append((n.start, 1))
        events.append((n.end, -1))
    events.sort()
    cur = 0
    mx = 0
    for _, d in events:
        cur += d
        mx = max(mx, cur)
    return mx

print("\nname | tracks | drums | notes | same_start | max_poly")
for p in midi_paths:
    m = miditoolkit.MidiFile(p)
    inst = m.instruments[0]
    print(f"{os.path.basename(p)[:28]:28} | {len(m.instruments):6d} | {str(any(i.is_drum for i in m.instruments)):5s} "
          f"| {len(inst.notes):5d} | {max_distinct_pitches_same_start(m):10d} | {max_simultaneous_notes(inst.notes):8d}")

# =========================================
# 3) Clean to monophonic + tokenize (BAR + POS + REST + NOTE + DUR)
# =========================================
STEPS_PER_BEAT = 12
BAR_STEPS = 4 * STEPS_PER_BEAT  # assume 4/4
MAX_DUR = 48
MAX_REST = 48

def ticks_to_steps(ticks, tpb, spb=STEPS_PER_BEAT):
    return int(round((ticks / tpb) * spb))

def clean_to_monophonic(notes, tpb, spb=STEPS_PER_BEAT):
    """
    Convert messy/polyphonic notes into a single line by:
    1) quantize to steps
    2) keep 1 note per onset step (longest dur, tie -> highest pitch)
    3) trim overlaps
    Returns list of (start_step, end_step, pitch)
    """
    if not notes:
        return []

    q = []
    for n in notes:
        s = ticks_to_steps(n.start, tpb, spb)
        e = ticks_to_steps(n.end,   tpb, spb)
        if e <= s:
            e = s + 1
        q.append((s, e, int(n.pitch)))

    by_s = defaultdict(list)
    for s, e, p in q:
        by_s[s].append((s, e, p))

    kept = [max(g, key=lambda x: ((x[1]-x[0]), x[2])) for g in by_s.values()]
    kept.sort(key=lambda x: (x[0], x[2]))

    mono = []
    for s, e, p in kept:
        if mono and s < mono[-1][1]:
            ps, pe, pp = mono[-1]
            mono[-1] = (ps, s, pp)
            if mono[-1][1] <= mono[-1][0]:
                mono.pop()
        mono.append((s, e, p))

    mono = [x for x in mono if x[1] > x[0]]
    return mono

def tokens_from_mono_with_pos(mono, max_rest=MAX_REST, max_dur=MAX_DUR, add_bar=True):
    def emit(out, prefix, v, cap):
        while v > cap:
            out.append(f"{prefix}_{cap}")
            v -= cap
        out.append(f"{prefix}_{v}")

    toks = []
    cur = 0
    for s, e, p in mono:
        # mark bar at note start if exactly on boundary
        if add_bar and (s % BAR_STEPS == 0):
            toks.append("BAR")

        # advance time with REST chunks (and insert BAR when crossing boundaries)
        while cur < s:
            if add_bar and (cur % BAR_STEPS == 0):
                toks.append("BAR")
            step = min(s - cur, max_rest)
            toks.append(f"REST_{step}")
            cur += step

        # beat-position anchor
        toks.append(f"POS_{s % STEPS_PER_BEAT}")

        toks.append(f"NOTE_{p}")
        dur = max(1, e - s)
        emit(toks, "DUR", dur, max_dur)
        cur = e

    return toks

# Build per-file tokens
per_file_tokens = []
per_file_mono = []
for p in midi_paths:
    m = miditoolkit.MidiFile(p)
    inst = m.instruments[0]
    mono = clean_to_monophonic(inst.notes, m.ticks_per_beat, spb=STEPS_PER_BEAT)
    toks = tokens_from_mono_with_pos(mono, add_bar=True)
    per_file_mono.append(mono)
    per_file_tokens.append(toks)

print("\nToken stats:")
print("Files:", len(per_file_tokens))
print("Total tokens:", sum(len(t) for t in per_file_tokens))
print("Example first 60 tokens:", per_file_tokens[0][:60])

# =========================================
# 4) Build vocab + encode
# =========================================
PAD, BOS, EOS = "<PAD>", "<BOS>", "<EOS>"

vocab = [PAD, BOS, EOS, "BAR"]
vocab += [f"POS_{i}" for i in range(STEPS_PER_BEAT)]
vocab += [f"NOTE_{p}" for p in range(128)]
vocab += [f"DUR_{d}" for d in range(1, MAX_DUR+1)]
vocab += [f"REST_{r}" for r in range(1, MAX_REST+1)]

stoi = {t:i for i,t in enumerate(vocab)}
itos = {i:t for t,i in stoi.items()}
vocab_size = len(vocab)
print("\nVocab size:", vocab_size)

def encode(tokens): return [stoi[t] for t in tokens]
def decode(ids): return [itos[i] for i in ids]

encoded_files = [encode(toks) for toks in per_file_tokens]

# =========================================
# 5) Train/Val split by FILE (no leakage)
# =========================================
rng = random.Random(42)
idxs = list(range(len(encoded_files)))
rng.shuffle(idxs)

val_n = max(1, len(idxs)//5)  # ~20% val (2 files if you have 10)
val_ids = sorted(idxs[:val_n])
train_ids = [i for i in idxs if i not in set(val_ids)]

print("\nTrain ids:", train_ids)
print("Val ids:", val_ids)
print("Train files:", [os.path.basename(midi_paths[i]) for i in train_ids])
print("Val files:", [os.path.basename(midi_paths[i]) for i in val_ids])

# =========================================
# 6) Window Dataset + (optional) transpose augmentation
# =========================================
block_size = 256
stride = 64

# Use observed pitch range if you want safer transposition
NOTE_MIN = 39
NOTE_MAX = 100

NOTE0 = stoi["NOTE_0"]
NOTE127 = stoi["NOTE_127"]

def is_note_id(tid): return NOTE0 <= tid <= NOTE127
def note_pitch_from_id(tid): return tid - NOTE0
def note_id_from_pitch(p): return NOTE0 + p

class JazzWindowDataset(Dataset):
    def __init__(self, encoded_files, file_ids, block_size=256, stride=64, augment=False):
        self.encoded_files = encoded_files
        self.file_ids = file_ids
        self.block_size = block_size
        self.augment = augment
        self.windows = []

        for fid in file_ids:
            seq = encoded_files[fid]
            for s in range(0, len(seq) - block_size, stride):
                self.windows.append((fid, s))

        print(f"Built {len(self.windows)} windows from {len(file_ids)} files. augment={augment}")

    def __len__(self): return len(self.windows)

    def __getitem__(self, idx):
        fid, s = self.windows[idx]
        seq = self.encoded_files[fid]
        chunk = seq[s:s+self.block_size]  # length block_size

        # transpose augmentation: only affect NOTE tokens
        if self.augment:
            pitches = [note_pitch_from_id(t) for t in chunk if is_note_id(t)]
            if pitches:
                lo, hi = min(pitches), max(pitches)
                down = NOTE_MIN - lo
                up = NOTE_MAX - hi
                if down <= up:
                    shift = random.randint(down, up)
                    if shift != 0:
                        chunk = [
                            (note_id_from_pitch(note_pitch_from_id(t) + shift) if is_note_id(t) else t)
                            for t in chunk
                        ]

        x = [stoi[BOS]] + chunk          # len = block_size+1
        y = chunk + [stoi[EOS]]          # len = block_size+1
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

train_ds = JazzWindowDataset(encoded_files, train_ids, block_size=block_size, stride=stride, augment=True)
val_ds   = JazzWindowDataset(encoded_files, val_ids,   block_size=block_size, stride=stride, augment=False)

batch_size = 64 if device == "cuda" else 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)

# =========================================
# 7) GPT-style Transformer (decoder-only via causal mask)
# =========================================
class GPTMini(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, dropout=0.15, max_len=block_size+2):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4*d_model,
            dropout=dropout,
            batch_first=True,
            norm_first=False,  # avoids that nested-tensor warning sometimes
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

        self.max_len = max_len

    def forward(self, idx):
        B, T = idx.shape
        if T > self.max_len:
            raise ValueError(f"T={T} > max_len={self.max_len}")
        pos = torch.arange(T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)

        causal = torch.triu(torch.ones(T, T, device=idx.device), diagonal=1).bool()
        x = self.encoder(x, mask=causal)

        x = self.ln(x)
        return self.head(x)  # (B,T,V)

model = GPTMini(vocab_size).to(device)
print("\nModel on:", next(model.parameters()).device)

# =========================================
# 8) Train loop + eval + checkpoint
# =========================================
def run_eval(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss(ignore_index=stoi[PAD])
    total_loss, total_tokens = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = ce(logits.reshape(-1, vocab_size), y.reshape(-1))
            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()
    return total_loss / max(1, total_tokens)

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
ce = nn.CrossEntropyLoss(ignore_index=stoi[PAD])

best_val = float("inf")
os.makedirs("/content/ckpt", exist_ok=True)

EPOCHS = 50
for epoch in range(1, EPOCHS+1):
    model.train()
    running = []

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits.reshape(-1, vocab_size), y.reshape(-1))

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        running.append(loss.item())

    val_loss = run_eval(model, val_loader)
    train_loss = float(np.mean(running)) if running else float("nan")
    print(f"epoch {epoch:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        torch.save({
            "model": model.state_dict(),
            "opt": opt.state_dict(),
            "epoch": epoch,
            "val_loss": val_loss,
            "stoi": stoi,
            "itos": itos,
            "config": {"STEPS_PER_BEAT": STEPS_PER_BEAT, "block_size": block_size}
        }, "/content/ckpt/best.pt")
        print("  saved best checkpoint")

# =========================================
# 9) Compact generation (grammar masks) + MIDI writer
# =========================================
def build_type_masks(vocab, stoi, device):
    def is_note(tok): return tok.startswith("NOTE_")
    def is_dur(tok):  return tok.startswith("DUR_")
    def is_rest(tok): return tok.startswith("REST_")
    def is_pos(tok):  return tok.startswith("POS_")

    banned = torch.zeros(len(vocab), dtype=torch.bool, device=device)
    banned[stoi[PAD]] = True
    banned[stoi[BOS]] = True

    masks = {}
    masks["NOTE"] = torch.tensor([is_note(t) for t in vocab], device=device) & ~banned
    masks["DUR"]  = torch.tensor([is_dur(t)  for t in vocab], device=device) & ~banned
    masks["REST"] = torch.tensor([is_rest(t) for t in vocab], device=device) & ~banned
    masks["POS"]  = torch.tensor([is_pos(t)  for t in vocab], device=device) & ~banned

    eos_mask = torch.zeros(len(vocab), dtype=torch.bool, device=device)
    eos_mask[stoi[EOS]] = True
    masks["EOS"] = eos_mask & ~banned

    bar_mask = torch.zeros(len(vocab), dtype=torch.bool, device=device)
    bar_mask[stoi["BAR"]] = True
    masks["BAR"] = bar_mask & ~banned

    return masks

masks = build_type_masks(vocab, stoi, device=device)

def topk_sample(logits, k=20, temperature=0.9):
    logits = logits / max(temperature, 1e-6)
    if k is not None and k < logits.numel():
        v, _ = torch.topk(logits, k)
        logits = logits.masked_fill(logits < v[-1], -1e9)
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()

def generate_tokens(model, max_new_tokens=900, temperature=0.9, top_k=20):
    """
    Grammar (based on how we tokenized):
      BOS -> BAR or POS
      BAR -> REST or POS
      REST -> REST or BAR or POS
      POS -> NOTE
      NOTE -> DUR
      DUR -> REST or BAR or POS or EOS
    """
    allowed_by_prev = {
        "BOS":  (masks["BAR"] | masks["POS"]),
        "BAR":  (masks["REST"] | masks["POS"]),
        "REST": (masks["REST"] | masks["BAR"] | masks["POS"]),
        "POS":  masks["NOTE"],
        "NOTE": masks["DUR"],
        "DUR":  (masks["REST"] | masks["BAR"] | masks["POS"] | masks["EOS"]),
    }

    model.eval()
    ids = [stoi[BOS]]
    prev_type = "BOS"

    for _ in range(max_new_tokens):
        x = torch.tensor(ids[-(block_size+1):], device=device).unsqueeze(0)
        logits = model(x)[0, -1]  # (V,)

        allowed = allowed_by_prev.get(prev_type, (masks["REST"] | masks["POS"] | masks["BAR"] | masks["EOS"]))
        logits = logits.masked_fill(~allowed, -1e9)

        nxt = topk_sample(logits, k=top_k, temperature=temperature)
        ids.append(nxt)

        tok = itos[nxt]
        if tok == EOS:
            break
        if tok == "BAR":
            prev_type = "BAR"
        else:
            prev_type = tok.split("_")[0]  # POS / NOTE / DUR / REST

    return [itos[i] for i in ids]

def tokens_to_midi(tokens, out_path="generated.mid", tempo=140, steps_per_beat=STEPS_PER_BEAT):
    tpb = 480
    ticks_per_step = tpb // steps_per_beat

    midi = miditoolkit.MidiFile(ticks_per_beat=tpb)
    midi.tempo_changes = [miditoolkit.TempoChange(tempo, 0)]
    inst = miditoolkit.Instrument(program=56, is_drum=False, name="Trumpet")

    t = 0
    pending_pitch = None

    for tok in tokens:
        if tok in (PAD, BOS):
            continue
        if tok == EOS:
            break
        if tok == "BAR" or tok.startswith("POS_"):
            continue

        typ, val = tok.split("_")
        val = int(val)

        if typ == "REST":
            t += val
            pending_pitch = None
        elif typ == "NOTE":
            pending_pitch = val
        elif typ == "DUR" and pending_pitch is not None:
            start = t * ticks_per_step
            end


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25htorch: 2.9.0+cu126
cuda available? True
device: cuda


Saving 10 Jazz Etudes Midi.zip to 10 Jazz Etudes Midi.zip
Found MIDI files: 10
Example paths: ['/content/midis/10 Jazz Etudes Midi/AnyConv.com__Etude #1 - No Met- Slow.midi', '/content/midis/10 Jazz Etudes Midi/AnyConv.com__Etude #10 - No Met- Slow.midi', '/content/midis/10 Jazz Etudes Midi/AnyConv.com__Etude #2 - No Met- Slow.midi']

name | tracks | drums | notes | same_start | max_poly
AnyConv.com__Etude #1 - No M |      1 | False |  1388 |          9 |       12
AnyConv.com__Etude #10 - No  |      1 | False |  1246 |          7 |       11
AnyConv.com__Etude #2 - No M |      1 | False |  1649 |          8 |       13
AnyConv.com__Etude #3- No Me |      1 | False |  1433 |          8 |       14
AnyConv.com__Etude #4- No Me |      1 | False |  1312 |          8 |       12
AnyConv.com__Etude #5- No Me |      1 | False |  1157 |          7 |       13
AnyConv.com__Etude #6- No Me |      1 | False |  1284 |          8 |       12
AnyConv.com__Etude #7- No Me |      1 | False |  1404 |        

#Load Best Checkpoint


In [None]:
ckpt = torch.load("/content/ckpt/best.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
print("Loaded best checkpoint from epoch:", ckpt["epoch"], "val_loss:", ckpt["val_loss"])


Loaded best checkpoint from epoch: 49 val_loss: 2.720655679702759


In [None]:
def tokens_to_midi(tokens, out_path="generated.mid", tempo=140, steps_per_beat=STEPS_PER_BEAT):
    tpb = 480
    ticks_per_step = tpb // steps_per_beat

    midi = miditoolkit.MidiFile(ticks_per_beat=tpb)
    midi.tempo_changes = [miditoolkit.TempoChange(tempo, 0)]
    inst = miditoolkit.Instrument(program=56, is_drum=False, name="Trumpet")

    t = 0
    pending_pitch = None

    for tok in tokens:
        if tok in (PAD, BOS):
            continue
        if tok == EOS:
            break
        if tok == "BAR" or tok.startswith("POS_"):
            continue

        typ, val = tok.split("_")
        val = int(val)

        if typ == "REST":
            t += val
            pending_pitch = None
        elif typ == "NOTE":
            pending_pitch = val
        elif typ == "DUR" and pending_pitch is not None:
            start = t * ticks_per_step
            end = (t + val) * ticks_per_step
            inst.notes.append(miditoolkit.Note(velocity=90, pitch=pending_pitch, start=start, end=end))
            t += val
            pending_pitch = None

    midi.instruments.append(inst)
    midi.dump(out_path)
    return out_path


In [None]:
settings = [
    ("A", 0.75, 10),
    ("B", 0.85, 15),
    ("C", 0.95, 25),
]

for tag, temp, k in settings:
    gen = generate_tokens(model, max_new_tokens=1400, temperature=temp, top_k=k)
    path = f"gen_{tag}_t{temp}_k{k}.mid"
    tokens_to_midi(gen, out_path=path, tempo=140)
    print("wrote", path, "| bars:", sum(1 for t in gen if t=="BAR"))


wrote gen_A_t0.75_k10.mid | bars: 0
wrote gen_B_t0.85_k15.mid | bars: 1
wrote gen_C_t0.95_k25.mid | bars: 1
