**Synphony**

Deep Learning Final Project - MSDS Spring Module 2 - 2025

Aditi Puttur & Emma Juan

# 1. Data Preprocessing

In [3]:
import pandas as pd
import numpy as np

import os
import json

from tqdm import tqdm

import re
import unicodedata

import warnings
warnings.filterwarnings("ignore")

from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile
from symusic import Score

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from typing import Optional

## Loading the data

### LMD: Midi Files

In [None]:
# Open and read the JSON file
with open('/home/emmajuansalazar/deep-learning-project-MSDS/data/data/LMD/md5_to_paths.json', 'r') as file:
    md5_to_paths = json.load(file)

In [None]:
md5_to_paths

In [None]:
lmd_catalog = []

for dirpath, dirnames, filenames in os.walk('data/LMD/lmd_matched'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.mid'):
            lmd_catalog.append(full_path)

In [None]:
lmd_catalog.sort()
lmd_catalog

In [None]:
len(lmd_catalog)

In [None]:
lmd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'LMD_name': []}

lmd_catalog_all['path'] = lmd_catalog
lmd_catalog_all['MSD_name'] = [path.split('/')[-2] for path in lmd_catalog]
lmd_catalog_all['LMD_name'] = [path.split('/')[-1].split('.')[-2] for path in lmd_catalog]

lmd_df = pd.DataFrame(lmd_catalog_all)
lmd_df

In [None]:
lmd_df["MSD_name"].nunique()

### LMD-matched metadata (MillionSongDataset): The Metadata

In [None]:
import hdf5_getters

In [None]:
msd_catalog = []
titles = []
artists = []
releases = []
years = []

for dirpath, dirnames, filenames in tqdm(os.walk('data/LMD-matched-MSD')):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.h5'):

            # Append the path to the list
            msd_catalog.append(full_path)

            # Get the metadata
            h5 = hdf5_getters.open_h5_file_read(full_path)
            titles.append(hdf5_getters.get_title(h5))
            artists.append(hdf5_getters.get_artist_name(h5))
            releases.append(hdf5_getters.get_release(h5))
            years.append(hdf5_getters.get_year(h5))
            # danceability = hdf5_getters.get_danceability(h5)
            # get_energy = hdf5_getters.get_energy(h5)


In [None]:
msd_catalog

In [None]:
len(msd_catalog)

In [None]:
len(msd_catalog) == lmd_df["MSD_name"].nunique()

In [None]:
titles[:5]

In [None]:
artists[:5]

In [None]:
years[:5]

In [None]:
titles = [title.decode('utf-8') for title in titles]
artists = [artist.decode('utf-8') for artist in artists]

In [None]:
msd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'title': [],
                   'artist': [],
                   'year': []}

msd_catalog_all['path'] = msd_catalog
msd_catalog_all['title'] = titles
msd_catalog_all['artist'] = artists
msd_catalog_all['year'] = years
msd_catalog_all['MSD_name'] = [path.split('/')[-1].split('.')[-2] for path in msd_catalog]

msd_df = pd.DataFrame(msd_catalog_all)
msd_df

In [None]:
msd_df.info()

### tagtraum: Adding Genre Tags

In [None]:
tagtraum = {'MSD_name': [],
            'genre': []}

with open("data/tagtraum/msd_tagtraum_cd2c.cls", "r") as file:
    lines = file.readlines()
    for line in lines:
        if not line.startswith('#'):
            track, genre = line.strip().split('\t')
            tagtraum['MSD_name'].append(track)
            tagtraum['genre'].append(genre)

In [None]:
tagtraum_df = pd.DataFrame(tagtraum)
tagtraum_df

In [None]:
tagtraum_df["genre"].unique()

## Creating our dataset: MIDI + Metadata + Genres

### Midi + Metadata

**Each track (MSD_name -> track_id) has one metadata file, and different MIDI files (LMD_name -> midi_id) associated with it.**

In [None]:
len(lmd_df), len(msd_df)

In [None]:
lmd_df["MSD_name"].nunique(), len(msd_df)

In [None]:
dataset = lmd_df.merge(msd_df, how="inner", on="MSD_name", suffixes=('_lmd', '_msd'))
dataset = dataset.rename(columns={"path_lmd": "midi_filepath",
                                  "path_msd": "metadata_filepath",
                                  "MSD_name": "track_id",
                                  "LMD_name": "midi_id"})
dataset = dataset[["track_id", "midi_id", "midi_filepath",
                   "title", "artist", "year"]]
dataset

In [None]:
grouped_dataset = dataset.groupby('track_id').first().reset_index()
grouped_dataset = grouped_dataset[['track_id', 'midi_id', 'midi_filepath']]
grouped_dataset = grouped_dataset.merge(
    dataset[
        ['track_id', "title", "artist", "year"]
    ].drop_duplicates(), on='track_id', how='left' )
grouped_dataset = grouped_dataset[["track_id", "midi_id", "midi_filepath",
                                   "title", "artist", "year"]]
grouped_dataset

### Adding the genre tags

In [None]:
dataset = dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
dataset = dataset.drop(columns=["MSD_name"])
dataset

In [None]:
grouped_dataset = grouped_dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
grouped_dataset = grouped_dataset.drop(columns=["MSD_name"])
grouped_dataset

## Sluggifying our parameters

In [None]:
genres = dataset["genre"].unique()
artists = dataset["artist"].unique()
years = dataset["year"].unique()

In [None]:
def slug(text: str) -> str:
    """Return an ALL_CAPS alnum/underscore version of `text`."""
    # 1) strip accents → ascii
    text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
    # 2) replace non‑alnum with underscore
    text = re.sub(r"[^\w]+", "_", text)
    # 3) collapse multiple underscores and upper‑case
    return re.sub(r"_+", "_", text).strip("_").upper()

In [None]:
genres_slugged = np.array([slug(genre) for genre in genres])
artists_slugged = np.array([slug(artist) for artist in artists])
years = np.array([int(year) for year in years if not pd.isna(year)])

In [None]:
genres = pd.DataFrame({
    'genre': genres,
    'slugged_genre': genres_slugged
})

artists = pd.DataFrame({
    'artist': artists,
    'slugged_artist': artists_slugged
})

years = pd.DataFrame({
    'year': years
})

In [None]:
genres = genres.sort_values(by='genre')
artists = artists.sort_values(by='artist')
years = years.sort_values(by='year')

In [None]:
dataset["slugged_genre"] = dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
dataset["slugged_artist"] = dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

grouped_dataset["slugged_genre"] = grouped_dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
grouped_dataset["slugged_artist"] = grouped_dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

## Saving our data

### Saving the metadata datasets

In [None]:
dataset.to_csv("data/metadata.csv", index=False)

In [None]:
grouped_dataset.to_csv("data/grouped_metadata.csv", index=False)

### Saving the different parameters to csvs

In [None]:
genres.to_csv("data/genres.csv", index=False)
artists.to_csv("data/artists.csv", index=False)
years.to_csv("data/years.csv", index=False)

# 2. Model Implementation

In [20]:
dataset = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/metadata.csv")
grouped_dataset = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/grouped_metadata.csv")

genres = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/genres.csv")
titles = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/titles.csv")
artists = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/artists.csv")
years = pd.read_csv("/home/emmajuansalazar/deep-learning-project-MSDS/data/data/years.csv")

In [21]:
genres_slugged = genres["slugged_genre"].values
artists_slugged = artists["slugged_artist"].values
years_vals = years["year"].values

In [22]:
# Config whith which the model was trained
# MAX_TOKENS = 128
# BATCH_SIZE = 1

# D_MODEL    = 128
# N_LAYERS   = 1
# N_HEADS    = 1

# New config to try
MAX_TOKENS = 512
BATCH_SIZE = 2

D_MODEL = 512
N_LAYERS = 6
N_HEADS = 8

## Tokenization

### Defining the tokenizer

In [23]:
# config = TokenizerConfig(
#     pitch_range=(21, 108),           # A0–C8
#     beat_res={(0, 4): 8, (4, 8): 4}, # finer grid in 1st half‑bar
#     num_velocities=32,
#     use_rests=True,
#     rest_range=(2, 8),               # long rests allowed
#     use_tempos=True,
#     use_chords=False,
#     use_time_signatures=False,
#     # you can still add / remove special tokens later with
#     # tokenizer.add_to_vocab([...])
# )
config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)

tokenizer = REMI(config)

### Adding our special tokens

In [24]:
special_toks = \
    [f"<GENRE_{g}>"  for g in genres_slugged] + \
        [f"<ARTIST_{a}>" for a in artists_slugged] + \
            [f"<YEAR_{y}>"   for y in years_vals]  + \
                ["<EOS>", "<PAD>"]

for tok in special_toks:
    tokenizer.add_to_vocab(tok)

### Tokenizing: Storing each track as a numpy int32 array.

In [25]:
tokenizing = False

In [26]:
# ─── 1. Helpers ──────────────────────────────────────────────────────────
def build_prefix(genre, artist, year, tokenizer):
    """Convert metadata row → list[int] conditioning tokens."""
    genre_tok  = f"<GENRE_{genre}>"
    artist_tok = f"<ARTIST_{artist}>"
    year_tok   = f"<YEAR_{year}>"

    # NOTE: use tokenizer.vocab[...]  (or .token_to_id(...))
    return [
        tokenizer.vocab[genre_tok],
        tokenizer.vocab[artist_tok],
        tokenizer.vocab[year_tok],
    ]

# ─── 3. Output directory -------------------------------------------------
out_dir = "data/tokens/train"

# ─── 4. Iterate files ----------------------------------------------------
if tokenizing:
    rows, _ = grouped_dataset.shape
    for row in tqdm(range(1000)):
        try:
            # 4.0. Get row
            row = grouped_dataset.iloc[row]

            # 4.1. Get MIDI filepath
            midi_path = row["midi_filepath"]

            # 4.2. Get the track ID
            track_id = row["track_id"]

            # 4a. Build CONDITIONING prefix
            genre = row["slugged_genre"]
            artist = row["slugged_artist"]
            year = row["year"]
            prefix_ids = build_prefix(genre, artist, year, tokenizer)          # list[int]

            # 4b. Encode MIDI to tokens
            midi = Score(midi_path)
            midi_tokens = tokenizer(midi)                 # list[int]

            # 4c. Concatenate prefix + midi + <EOS>
            seq_ids = prefix_ids + midi_tokens.ids + [tokenizer.vocab["<EOS>"]]

            # 4d. Save as int32 .npy
            np.save(f"{out_dir}/{track_id}.npy", np.array(seq_ids, dtype=np.int32))
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            continue

## The Model

In [27]:
class RelativePositionalEncoding(nn.Module):
    """
    Sinusoidal *relative‑style* positional encoding.
    The tensor it returns has the same shape as `x`
    so you can just add it:  x + pos(x)

    Args
    ----
    d_model : int            # embedding size
    max_len : int, optional  # maximum sequence length
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Create the (max_len, d_model) sinusoid table once
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * -(math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)          # (L, D)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as a buffer so it moves with .to(device)
        self.register_buffer("pe", pe)              # (L, D)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : Tensor, shape (batch, seq_len, d_model)

        Returns
        -------
        pos : Tensor, same shape as `x`
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}")
        # (1, L, D) – broadcast over batch dimension
        return self.pe[:seq_len].unsqueeze(0)


In [28]:
class TransformerDecoderBlock(nn.Module):
    """
    Decoder block that merges causal + pad masking into a (B×H, L, L) float mask,
    so no hidden bool→float blow-ups occur.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_len: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim   = d_model,
            num_heads   = n_heads,
            dropout     = dropout,
            batch_first = True,
        )
        self.ln1      = nn.LayerNorm(d_model)
        self.ln2      = nn.LayerNorm(d_model)
        self.ff       = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout  = nn.Dropout(dropout)

        # Precompute float causal mask: 0 on/under diag, -inf above
        causal = torch.triu(
            torch.full((max_len, max_len), float("-inf")),
            diagonal=1
        )
        self.register_buffer("causal_mask", causal, persistent=False)

    def forward(
        self,
        x: torch.Tensor,            # (B, L, D)
        pad_mask: torch.Tensor=None  # (B, L), True=keep token, False=pad
    ) -> torch.Tensor:
        B, L, _ = x.shape
        H       = self.self_attn.num_heads
        device  = x.device
        dtype   = x.dtype

        # 1) slice the (L×L) causal mask
        causal = self.causal_mask[:L, :L]              # float32, (L, L)

        # 2) build a (B, L) float pad mask: 0 on tokens, -inf on pads
        if pad_mask is not None:
            pad_float = torch.zeros((B, L), device=device, dtype=dtype)
            pad_float = pad_float.masked_fill(~pad_mask, float("-inf"))
            # 3) expand pad_float to (B, L, L) and add causal
            #    pad_float.unsqueeze(1): (B, 1, L) → broadcast over src_len
            attn_batch = causal.unsqueeze(0) + pad_float.unsqueeze(1)  # (B, L, L)
        else:
            attn_batch = causal                               # (L, L)

        # 4) if we have a batch, repeat per-head to (B×H, L, L)
        if pad_mask is not None:
            # attn_batch: (B, L, L) → repeat each batch H times
            attn_mask = attn_batch.repeat_interleave(H, dim=0)  # (B*H, L, L)
        else:
            attn_mask = attn_batch   # 2D mask

        # 5) self-attention with ONLY attn_mask
        attn_out, _ = self.self_attn(
            x, x, x,
            attn_mask=attn_mask
        )

        # 6) residual + norm + feed-forward + norm
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x


In [29]:
class Synphony(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = RelativePositionalEncoding(d_model, max_len=2048)
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embed(x) + self.pos(x)
        for blk in self.blocks:
            x = blk(x, pad_mask)
        x = self.ln(x)
        return self.out(x)

## The Training Loop

In [34]:
from torch.utils.data import Dataset, DataLoader

import random
random.seed(42)  # For reproducibility

In [35]:
tok_paths = []

for dirpath, dirnames, filenames in os.walk('/home/emmajuansalazar/deep-learning-project-MSDS/data/data/tokens/train'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.npy'):
            tok_paths.append(full_path)

In [37]:
split_index = int(len(tok_paths) * 0.8)  # 80% train, 20% test
random.shuffle(tok_paths)

train_paths = tok_paths[:split_index]
test_paths = tok_paths[split_index:]

In [38]:
# ─── 1. Dataset + collate ────────────────────────────────────────────────
class MidiTokenDataset(Dataset):
    def __init__(self, npy_paths):
        self.paths = npy_paths

    def __len__(self):               # number of songs in split
        return len(self.paths)

    def __getitem__(self, idx):      # returns 1‑D np.ndarray[int]
        return np.load(self.paths[idx]).astype(np.int64)

def collate_fn(batch, pad_id):
    B, L = len(batch), MAX_TOKENS
    x = torch.full((B, L), pad_id, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq = torch.from_numpy(seq)
        if seq.numel() > L:
            start = torch.randint(0, seq.numel() - L + 1, (1,)).item()
            seq = seq[start : start + L]
        x[i, : seq.numel()] = seq
    pad_mask = ~x.eq(pad_id)
    return x, pad_mask


# ─── 2. DataLoaders ──────────────────────────────────────────────────────
PAD_ID = tokenizer.vocab['<PAD>']          # or use the ID you chose for <PAD>

train_ds = MidiTokenDataset(train_paths)
val_ds   = MidiTokenDataset(test_paths)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)
val_loader   = DataLoader(
    val_ds,   batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)

# ─── 3. Model, optimiser, scheduler ─────────────────────────────────────
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

model = Synphony(
    vocab_size=len(tokenizer), d_model=D_MODEL,
    n_layers=N_LAYERS, n_heads=N_HEADS).to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.OneCycleLR(
    optim, max_lr=1e-4,
    steps_per_epoch=len(train_loader), epochs=50
)


# ─── 4. Training loop ────────────────────────────────────────────────────
best_val_loss = float("inf")

for epoch in range(1, 51):                         # 50 epochs
    # ---- train ----------------------------------------------------------
    model.train()
    running_loss = 0.0

    for x, pad_mask in train_loader:          # pad_mask: (B, L)
        x, pad_mask = x.to(device), pad_mask.to(device)

        logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])

        loss   = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            x[:, 1:].reshape(-1),
            ignore_index=PAD_ID,
            label_smoothing=0.1
        )

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step(); sched.step()

        running_loss += loss.item()

    train_ppl = math.exp(running_loss / len(train_loader))

    # ---- validation -----------------------------------------------------
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x, pad_mask in val_loader:             # pad_mask is (B, L)
            x, pad_mask = x.to(device), pad_mask.to(device)

            # exactly like in training
            logits  = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
            val_loss += F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                x[:, 1:].reshape(-1),
                ignore_index=PAD_ID
            ).item()

    val_ppl = math.exp(val_loss / len(val_loader))
    print(f"val PPL {val_ppl:6.2f}")
    print(f"Epoch {epoch:02d} ▸ train PPL {train_ppl:6.2f} | val PPL {val_ppl:6.2f}")

    # ---- checkpoint -----------------------------------------------------
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "synphony_best.pt")
        print("  ✓ new best model saved")

print("Done!")


val PPL 217.33
Epoch 01 ▸ train PPL 712.98 | val PPL 217.33
  ✓ new best model saved
val PPL  59.08
Epoch 02 ▸ train PPL 164.65 | val PPL  59.08
  ✓ new best model saved
val PPL  27.55
Epoch 03 ▸ train PPL  68.55 | val PPL  27.55
  ✓ new best model saved
val PPL  18.54
Epoch 04 ▸ train PPL  43.19 | val PPL  18.54
  ✓ new best model saved
val PPL  15.08
Epoch 05 ▸ train PPL  34.58 | val PPL  15.08
  ✓ new best model saved
val PPL  13.51
Epoch 06 ▸ train PPL  30.45 | val PPL  13.51
  ✓ new best model saved
val PPL  12.04
Epoch 07 ▸ train PPL  27.86 | val PPL  12.04
  ✓ new best model saved
val PPL  11.63
Epoch 08 ▸ train PPL  26.18 | val PPL  11.63
  ✓ new best model saved
val PPL  10.65
Epoch 09 ▸ train PPL  24.46 | val PPL  10.65
  ✓ new best model saved
val PPL  10.41
Epoch 10 ▸ train PPL  23.24 | val PPL  10.41
  ✓ new best model saved
val PPL   9.30
Epoch 11 ▸ train PPL  21.85 | val PPL   9.30
  ✓ new best model saved
val PPL   8.54
Epoch 12 ▸ train PPL  19.81 | val PPL   8.54
  ✓ n

# 3. Model Inference

In [19]:
model.eval()

Synphony(
  (embed): Embedding(3429, 128)
  (pos): RelativePositionalEncoding()
  (blocks): ModuleList(
    (0): TransformerDecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=128, out_features=3429, bias=True)
)

In [None]:
TEMPERATURE = 1.0
TOP_K = 8

# ─── 2. Helper for top-k filtering ───────────────────────────────────────
def top_k_logits(logits, k):
    v, _ = torch.topk(logits, k)
    threshold = v[-1]
    return torch.where(logits < threshold, torch.full_like(logits, -float("Inf")), logits)

# ─── 3. Autoregressive generation ────────────────────────────────────────
@torch.no_grad()
def generate(
        genre:str,
        artist:str,
        year:int,
        max_length:int = MAX_TOKENS
    ) -> list[int]:
    prefix = build_prefix(genre, artist, year, tokenizer)
    input_ids = torch.tensor([prefix], device=device)  # (1, P)
    pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    for _ in tqdm(range(max_length - len(prefix))):
        logits = model(input_ids, pad_mask=pad_mask)
        next_logits = logits[0, -1, :]                  # (V,)
        next_logits = next_logits / TEMPERATURE
        next_logits = top_k_logits(next_logits, TOP_K)
        probs = F.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (1,)
        if next_id.item() == tokenizer.vocab["<EOS>"]:
            break

        # append and extend pad_mask
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)   # (1, L+1)
        pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    return input_ids[0].tolist()

# ─── 4. Decode to MIDI & save ────────────────────────────────────────────
def tokens_to_midi(token_ids: list[int], out_path: str):
    """
    Drop the 3 metadata tokens + optional EOS, then decode the rest.
    """
    # 1) drop the first 3 prefix IDs (genre, artist, year)
    musical_ids = token_ids[3:]
    # 2) drop trailing <EOS> if present
    eos_id = tokenizer.vocab["<EOS>"]
    if len(musical_ids) > 0 and musical_ids[-1] == eos_id:
        musical_ids = musical_ids[:-1]

    # 3) decode only the musical tokens back to a PrettyMIDI
    pm = tokenizer(musical_ids)
    # 4) write out the .mid file
    pm.dump_midi(out_path)

In [6]:
# ─── 5. Run it! ───────────────────────────────────────────────────────────
# Example user inputs
genre_input  = "POP"
artist_input = "RICK_ASTLEY"
year_input   = 1987

gen_ids = generate(genre_input, artist_input, year_input)
out_file = "generated.mid"
tokens_to_midi(gen_ids, out_file)
print(f"🎹 Wrote MIDI to {out_file}")

NameError: name 'generate' is not defined