# Chord → Melody Mini‑Pipeline

**Assumptions**
* You have pre‑tokenised chord / melody pairs stored as `.npz` files under `cached_tokens/`, each containing:
  * `arr_0` – chord token sequence (1‑D uint16)
  * `arr_1` – melody token sequence (1‑D uint16)

If `cached_tokens/` is empty the notebook will **fabricate a tiny dummy dataset** so you can still run end‑to‑end, but training quality will be useless.  Replace with real data for meaningful results.

## 0  Install required libraries

In [26]:
!pip install --quiet torch tqdm numpy pretty_midi
!pip install pyfluidsynth



## 1  Imports, constants, helpers

In [2]:
from pathlib import Path
import numpy as np, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pretty_midi as pm

# ------------ constants ------------
PAD     = 0                 # also used as <bos>/<eos> for simplicity
BATCH   = 64
EPOCHS  = 5
D_MODEL = 256

CACHE_DIR = Path('cached_tokens')
CACHE_DIR.mkdir(exist_ok=True)   # ensure folder exists
files = sorted(CACHE_DIR.glob('*.npz'))

# ------------ fabricate dummy data if none present ------------
if not files:
    print('cached_tokens is empty – creating 512 fake examples for demo…')
    V = 60       # token vocab size (C majors etc.)
    for i in range(512):
        L = np.random.randint(16, 64)
        chords  = np.random.randint(1, V//2,  size=L,  dtype=np.uint16)
        melody  = np.random.randint(V//2, V,  size=L,  dtype=np.uint16)
        np.savez_compressed(CACHE_DIR/f'dummy_{i}.npz', chords, melody)
    files = sorted(CACHE_DIR.glob('*.npz'))

# derive vocab size from first file
with np.load(files[0]) as z:
    VOCAB = int(max(z['arr_0'].max(), z['arr_1'].max())) + 1
print(f'{len(files)} npz pairs found   vocab={VOCAB}')

2301 npz pairs found   vocab=196


## 2  Dataset & DataLoader

In [21]:
class Tokset(Dataset):
    def __init__(self, file_list): self.fs = file_list
    def __len__(self): return len(self.fs)
    def __getitem__(self, idx):
        with np.load(self.fs[idx]) as d:
            x = torch.from_numpy(d['arr_0'].astype(np.int32))   # ← cast
            y = torch.from_numpy(d['arr_1'].astype(np.int32))
        return x, y


# keep only first 512 frames – good for a baseline
MAXLEN = 256
BATCH = 16

def collate(batch):
    xs, ys = zip(*batch)
    xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=PAD)[:, :MAXLEN]
    ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=PAD)[:, :MAXLEN]
    return xs, ys[:, :-1], ys[:, 1:]


loader = DataLoader(train_ds, batch_size=32, num_workers=4, collate_fn=collate)


train_ds = Tokset(files)
loader   = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                      collate_fn=collate, num_workers=0)
print(f'Dataloader ready  batches per epoch = {len(loader)}')

Dataloader ready  batches per epoch = 144


## 3  Tiny Transformer model

In [22]:
class ChordMelody(nn.Module):
    def __init__(self, vocab, d=D_MODEL):
        super().__init__()
        self.enc = nn.Embedding(vocab, d, padding_idx=PAD)
        self.dec = nn.Embedding(vocab, d, padding_idx=PAD)
        self.tf  = nn.Transformer(d_model=d, nhead=4,
                                  num_encoder_layers=3, num_decoder_layers=3,
                                  batch_first=True)
        self.out = nn.Linear(d, vocab)
    def forward(self, src, tgt):
        src_mask = src.eq(PAD); tgt_mask = tgt.eq(PAD)
        y = self.tf(self.enc(src), self.dec(tgt),
                    src_key_padding_mask=src_mask,
                    tgt_key_padding_mask=tgt_mask)
        return self.out(y)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = ChordMelody(VOCAB).to(device)
optim  = torch.optim.AdamW(model.parameters(), lr=3e-4)
lossF  = nn.CrossEntropyLoss(ignore_index=PAD)

next(iter(loader))             # time this — should be <1 s
print(src.shape)               # from training loop


torch.Size([64, 256])


## 4  Train

In [23]:
for ep in range(EPOCHS):
    model.train()
    pbar = tqdm(loader, desc=f'Epoch {ep+1}/{EPOCHS}')
    for src, tgt_in, tgt_out in tqdm(loader):
        src, tgt_in, tgt_out = [t.to(device).long()   # ← cast to int64
                                for t in (src, tgt_in, tgt_out)]

        optim.zero_grad()
        logits = model(src, tgt_in)
        loss   = lossF(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        loss.backward(); optim.step()
        pbar.set_postfix(loss=f"{loss.item():.2f}")


Epoch 1/5:   0%|          | 0/144 [00:26<?, ?it/s]
100%|██████████| 144/144 [05:33<00:00,  2.32s/it], loss=1.62]
Epoch 1/5:   0%|          | 0/144 [05:33<?, ?it/s, loss=1.62]
100%|██████████| 144/144 [05:52<00:00,  2.45s/it]
Epoch 2/5:   0%|          | 0/144 [05:52<?, ?it/s, loss=1.61]
100%|██████████| 144/144 [05:43<00:00,  2.39s/it], loss=1.85]
Epoch 3/5:   0%|          | 0/144 [05:43<?, ?it/s, loss=1.85]
100%|██████████| 144/144 [05:39<00:00,  2.36s/it]
Epoch 4/5:   0%|          | 0/144 [05:39<?, ?it/s, loss=1.44]
100%|██████████| 144/144 [05:47<00:00,  2.41s/it], loss=2.18]


## 5  Generate a melody for one chord sequence

In [28]:
!pip install --quiet pyfluidsynth

In [32]:

import pretty_midi as pm, torch, numpy as np
from IPython.display import FileLink

# --- 1. Generate melody tokens (greedy) -----------------------
model.eval()
src, _ = train_ds[0]                        # chord sequence only
src = src.unsqueeze(0).to(device)
tgt = torch.tensor([[PAD]], device=device)  # <bos>

with torch.no_grad():
    for _ in range(src.size(1)):            # limit to chord length
        logits = model(src, tgt)
        next_tok = logits[:, -1].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, next_tok], 1)
        if next_tok.item() == PAD: break

mel_tokens = tgt.squeeze(0).cpu().tolist()[1:-1]

# --- 2. Decode tokens → PrettyMIDI ----------------------------
hop = 0.5
midi = pm.PrettyMIDI()
inst = pm.Instrument(program=0, name="melody")
time = 0.0
for tok in mel_tokens:
    if tok != PAD:
        inst.notes.append(pm.Note(velocity=90, pitch=int(tok),
                                  start=time, end=time + hop * 0.9))
    time += hop
midi.instruments.append(inst)
midi_path = "generated_melody.mid"
midi.write(midi_path)

# --- 3. Provide a download link -------------------------------
print("✓ wrote", midi_path)
FileLink(midi_path)


✓ wrote generated_melody.mid
