## Data preparation

In [6]:
import os
import torch
import miditoolkit
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset
import numpy as np
import random

from torch import optim
import torch.nn as nn
import math
from torch.utils.data import DataLoader

In [7]:
# --- Transpose chord label ---
def transpose_chord_label(chord, shift):
    import music21
    try:
        root, kind = chord.split(':')
        symbol = music21.harmony.ChordSymbol(root)
        transposed = symbol.transpose(shift)
        return f"{transposed.root().name}:{kind}"
    except:
        return chord

# --- Extract melody and chord with transposition ---
def process_song(folder, sequence_length=32, transpose_range=range(-3, 4)):
    song_id = os.path.basename(folder)
    midi_path = os.path.join(folder, f"{song_id}.mid")
    chord_path = os.path.join(folder, "chord_midi.txt")

    if not os.path.exists(midi_path) or not os.path.exists(chord_path):
        return [], [], []

    midi = miditoolkit.MidiFile(midi_path)
    ticks_per_beat = midi.ticks_per_beat
    melody_track = next((inst for inst in midi.instruments if not inst.is_drum), None)
    if melody_track is None or len(melody_track.notes) == 0:
        return [], [], []

    melody_track.notes.sort(key=lambda x: x.start)
    max_tick = midi.max_tick

    # Build base melody by beat
    melody_seq = []
    for i in range(0, max_tick, ticks_per_beat):
        notes = [n.pitch for n in melody_track.notes if i <= n.start < i + ticks_per_beat]
        melody_seq.append(notes[0] if notes else 0)

    # Load chord sequence
    with open(chord_path) as f:
        chord_seq = [line.strip() for line in f.readlines()]

    min_len = min(len(melody_seq), len(chord_seq))
    melody_seq = melody_seq[:min_len]
    chord_seq = chord_seq[:min_len]

    all_melody, all_chord, chord_vocab = [], [], set()
    for shift in transpose_range:
        melody_shifted = [(p + shift if p > 0 else 0) for p in melody_seq]
        chord_shifted = [transpose_chord_label(c, shift) for c in chord_seq]
        chord_vocab.update(chord_shifted)

        for i in range(0, min_len - sequence_length + 1, sequence_length):
            all_melody.append(torch.tensor(melody_shifted[i:i+sequence_length]))
            all_chord.append(chord_shifted[i:i+sequence_length])

    return all_melody, all_chord, chord_vocab


In [8]:

# --- Main loop ---
DATASET_DIR = "data/POP909"
OUTPUT_PATH = "data/processed_transposed_2.pt"
SEQUENCE_LENGTH = 64

melody_samples, chord_samples = [], []
chord_vocab = set()

for song_id in tqdm(range(1, 910)):
    folder = os.path.join(DATASET_DIR, f"{song_id:03d}")
    melody_seq, chord_seq, vocab = process_song(folder, SEQUENCE_LENGTH)
    melody_samples.extend(melody_seq)
    chord_samples.extend(chord_seq)
    chord_vocab.update(vocab)

# Build vocab
chord_list = sorted(chord_vocab)
chord2idx = {ch: i + 1 for i, ch in enumerate(chord_list)}  # 0 for unknown
idx2chord = {i: ch for ch, i in chord2idx.items()}

# Encode chords
encoded_chords = [
    torch.tensor([chord2idx.get(c, 0) for c in seq], dtype=torch.long)
    for seq in chord_samples
]

# Save
Path("data").mkdir(parents=True, exist_ok=True)
torch.save({
    'chord_sequences': torch.stack(encoded_chords),
    'melody_sequences': torch.stack(melody_samples),
    'chord2idx': chord2idx,
    'idx2chord': idx2chord,
}, OUTPUT_PATH)

print(f"\n Saved {len(encoded_chords)} samples to {OUTPUT_PATH}")


100%|██████████| 909/909 [01:12<00:00, 12.48it/s]



 Saved 10416 samples to data/processed_transposed_2.pt


In [9]:

class ChordMelodyDataset(Dataset):
    def __init__(self, data_path, split='train', val_ratio=0.1, seed=42):
        data = torch.load(data_path)
        self.chords = data['chord_sequences']  # [N, L]
        self.melodies = data['melody_sequences']  # [N, L]

        assert len(self.chords) == len(self.melodies)

        # Set split
        N = len(self.chords)
        indices = list(range(N))
        random.seed(seed)
        random.shuffle(indices)

        val_size = int(N * val_ratio)
        if split == 'train':
            split_ids = indices[val_size:]
        elif split == 'val':
            split_ids = indices[:val_size]
        else:
            raise ValueError("split must be 'train' or 'val'")

        self.chord = self.chords[split_ids]
        self.melody = self.melodies[split_ids]

    def __len__(self):
        return len(self.chord)

    def __getitem__(self, idx):
        return self.chord[idx], self.melody[idx]


## Model Architecture

In [10]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_len, d_model]

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class ChordMelodyTransformerV2(nn.Module):
    def __init__(self, vocab_size=129, chord_vocab_size=128, d_model=256, nhead=8,
                 num_encoder_layers=4, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1):
        super().__init__()

        self.chord_embedding = nn.Embedding(chord_vocab_size, d_model)
        self.melody_embedding = nn.Embedding(vocab_size, d_model)  # +1 for EOS
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_layer = nn.Linear(d_model, vocab_size)

    def generate_mask(self, size, device):
        mask = torch.triu(torch.ones(size, size), diagonal=1).to(device)
        mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
        return mask

    def forward(self, chord_seq, melody_input):
        # chord_seq: [B, L], melody_input: [B, L]
        device = chord_seq.device
        chord_emb = self.pos_encoder(self.chord_embedding(chord_seq))
        melody_emb = self.pos_encoder(self.melody_embedding(melody_input))

        memory = self.encoder(chord_emb)
        tgt_mask = self.generate_mask(melody_input.size(1), device)
        output = self.decoder(melody_emb, memory, tgt_mask=tgt_mask)

        return self.output_layer(output)  # [B, L, vocab_size]

    def generate(self, chord_seq, max_length=32, temperature=1.0, start_token=60, eos_token=128):
        self.eval()
        device = chord_seq.device
        memory = self.encoder(self.pos_encoder(self.chord_embedding(chord_seq)))

        melody = [start_token]
        for _ in range(max_length - 1):
            inp = torch.tensor([melody], dtype=torch.long, device=device)
            emb = self.pos_encoder(self.melody_embedding(inp))
            mask = self.generate_mask(inp.size(1), device)
            out = self.decoder(emb, memory, tgt_mask=mask)
            logits = self.output_layer(out[:, -1, :]) / temperature
            probs = torch.softmax(logits, dim=-1)
            next_note = torch.multinomial(probs, 1).item()
            if next_note == eos_token:
                break
            melody.append(next_note)

        return melody[1:]  # drop start token


## Model Training

In [11]:
def train():
    dataset_path = 'data/processed_transposed_2.pt'
    batch_size = 128
    num_epochs = 100
    lr = 2.5e-4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load dataset
    train_data = ChordMelodyDataset(dataset_path, split='train')
    val_data = ChordMelodyDataset(dataset_path, split='val')
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size)

    chord_vocab_size = train_data.chord.max().item() + 1
    model = ChordMelodyTransformerV2(vocab_size=129, chord_vocab_size=chord_vocab_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    best_val = float('inf')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for chords, melody in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}"):
            chords = chords.to(device)
            inp = melody[:, :-1].to(device)
            tgt = melody[:, 1:].to(device)
            out = model(chords, inp)
            loss = criterion(out.view(-1, 129), tgt.reshape(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Train Loss: {total_loss/len(train_loader):.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for chords, melody in val_loader:
                chords = chords.to(device)
                inp = melody[:, :-1].to(device)
                tgt = melody[:, 1:].to(device)
                out = model(chords, inp)
                loss = criterion(out.view(-1, 129), tgt.reshape(-1))
                val_loss += loss.item()

        avg_val = val_loss / len(val_loader)
        print(f"Val Loss: {avg_val:.4f}")

        if avg_val < best_val:
            best_val = avg_val
            torch.save(model.state_dict(), 'models/best_model_2.pt')
            print(" Saved best model")



In [12]:
train()

[Train] Epoch 1: 100%|██████████| 74/74 [00:18<00:00,  4.05it/s]


Train Loss: 2.9903
Val Loss: 2.3171
 Saved best model


[Train] Epoch 2: 100%|██████████| 74/74 [00:17<00:00,  4.13it/s]


Train Loss: 2.2385
Val Loss: 2.1843
 Saved best model


[Train] Epoch 3: 100%|██████████| 74/74 [00:18<00:00,  4.04it/s]


Train Loss: 2.1555
Val Loss: 2.1348
 Saved best model


[Train] Epoch 4: 100%|██████████| 74/74 [00:19<00:00,  3.75it/s]


Train Loss: 2.1252
Val Loss: 2.1284
 Saved best model


[Train] Epoch 5: 100%|██████████| 74/74 [00:21<00:00,  3.52it/s]


Train Loss: 2.0950
Val Loss: 2.0836
 Saved best model


[Train] Epoch 6: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]


Train Loss: 2.0638
Val Loss: 2.0620
 Saved best model


[Train] Epoch 7: 100%|██████████| 74/74 [00:20<00:00,  3.57it/s]


Train Loss: 2.0447
Val Loss: 2.0500
 Saved best model


[Train] Epoch 8: 100%|██████████| 74/74 [00:20<00:00,  3.64it/s]


Train Loss: 2.0252
Val Loss: 2.0362
 Saved best model


[Train] Epoch 9: 100%|██████████| 74/74 [00:19<00:00,  3.73it/s]


Train Loss: 2.0068
Val Loss: 2.0288
 Saved best model


[Train] Epoch 10: 100%|██████████| 74/74 [00:22<00:00,  3.35it/s]


Train Loss: 1.9810
Val Loss: 2.0056
 Saved best model


[Train] Epoch 11: 100%|██████████| 74/74 [00:20<00:00,  3.58it/s]


Train Loss: 1.9537
Val Loss: 1.9840
 Saved best model


[Train] Epoch 12: 100%|██████████| 74/74 [00:21<00:00,  3.42it/s]


Train Loss: 1.9223
Val Loss: 1.9561
 Saved best model


[Train] Epoch 13: 100%|██████████| 74/74 [00:22<00:00,  3.29it/s]


Train Loss: 1.8908
Val Loss: 1.9550
 Saved best model


[Train] Epoch 14: 100%|██████████| 74/74 [00:23<00:00,  3.21it/s]


Train Loss: 1.8680
Val Loss: 1.9177
 Saved best model


[Train] Epoch 15: 100%|██████████| 74/74 [00:22<00:00,  3.25it/s]


Train Loss: 1.8379
Val Loss: 1.9105
 Saved best model


[Train] Epoch 16: 100%|██████████| 74/74 [00:22<00:00,  3.35it/s]


Train Loss: 1.8051
Val Loss: 1.8862
 Saved best model


[Train] Epoch 17: 100%|██████████| 74/74 [00:23<00:00,  3.14it/s]


Train Loss: 1.7724
Val Loss: 1.8467
 Saved best model


[Train] Epoch 18: 100%|██████████| 74/74 [00:23<00:00,  3.16it/s]


Train Loss: 1.7299
Val Loss: 1.8424
 Saved best model


[Train] Epoch 19: 100%|██████████| 74/74 [00:23<00:00,  3.09it/s]


Train Loss: 1.6802
Val Loss: 1.7803
 Saved best model


[Train] Epoch 20: 100%|██████████| 74/74 [00:23<00:00,  3.21it/s]


Train Loss: 1.6133
Val Loss: 1.7381
 Saved best model


[Train] Epoch 21: 100%|██████████| 74/74 [00:24<00:00,  2.97it/s]


Train Loss: 1.5367
Val Loss: 1.6438
 Saved best model


[Train] Epoch 22: 100%|██████████| 74/74 [00:24<00:00,  3.05it/s]


Train Loss: 1.4400
Val Loss: 1.5695
 Saved best model


[Train] Epoch 23: 100%|██████████| 74/74 [00:23<00:00,  3.11it/s]


Train Loss: 1.3480
Val Loss: 1.4573
 Saved best model


[Train] Epoch 24: 100%|██████████| 74/74 [00:23<00:00,  3.12it/s]


Train Loss: 1.2340
Val Loss: 1.3298
 Saved best model


[Train] Epoch 25: 100%|██████████| 74/74 [00:23<00:00,  3.12it/s]


Train Loss: 1.1193
Val Loss: 1.2289
 Saved best model


[Train] Epoch 26: 100%|██████████| 74/74 [00:24<00:00,  2.96it/s]


Train Loss: 1.0015
Val Loss: 1.0920
 Saved best model


[Train] Epoch 27: 100%|██████████| 74/74 [00:23<00:00,  3.18it/s]


Train Loss: 0.8975
Val Loss: 0.9803
 Saved best model


[Train] Epoch 28: 100%|██████████| 74/74 [00:23<00:00,  3.16it/s]


Train Loss: 0.7935
Val Loss: 0.8656
 Saved best model


[Train] Epoch 29: 100%|██████████| 74/74 [00:22<00:00,  3.25it/s]


Train Loss: 0.7025
Val Loss: 0.7672
 Saved best model


[Train] Epoch 30: 100%|██████████| 74/74 [00:23<00:00,  3.21it/s]


Train Loss: 0.6277
Val Loss: 0.6595
 Saved best model


[Train] Epoch 31: 100%|██████████| 74/74 [00:23<00:00,  3.15it/s]


Train Loss: 0.5513
Val Loss: 0.5891
 Saved best model


[Train] Epoch 32: 100%|██████████| 74/74 [00:23<00:00,  3.11it/s]


Train Loss: 0.4956
Val Loss: 0.4977
 Saved best model


[Train] Epoch 33: 100%|██████████| 74/74 [00:23<00:00,  3.22it/s]


Train Loss: 0.4349
Val Loss: 0.4463
 Saved best model


[Train] Epoch 34: 100%|██████████| 74/74 [00:22<00:00,  3.24it/s]


Train Loss: 0.3946
Val Loss: 0.3955
 Saved best model


[Train] Epoch 35: 100%|██████████| 74/74 [00:22<00:00,  3.26it/s]


Train Loss: 0.3516
Val Loss: 0.3574
 Saved best model


[Train] Epoch 36: 100%|██████████| 74/74 [00:23<00:00,  3.19it/s]


Train Loss: 0.3158
Val Loss: 0.3089
 Saved best model


[Train] Epoch 37: 100%|██████████| 74/74 [00:22<00:00,  3.36it/s]


Train Loss: 0.2894
Val Loss: 0.2821
 Saved best model


[Train] Epoch 38: 100%|██████████| 74/74 [00:23<00:00,  3.14it/s]


Train Loss: 0.2629
Val Loss: 0.2498
 Saved best model


[Train] Epoch 39: 100%|██████████| 74/74 [00:23<00:00,  3.21it/s]


Train Loss: 0.2418
Val Loss: 0.2376
 Saved best model


[Train] Epoch 40: 100%|██████████| 74/74 [00:23<00:00,  3.17it/s]


Train Loss: 0.2264
Val Loss: 0.2168
 Saved best model


[Train] Epoch 41: 100%|██████████| 74/74 [00:24<00:00,  3.00it/s]


Train Loss: 0.2084
Val Loss: 0.1921
 Saved best model


[Train] Epoch 42: 100%|██████████| 74/74 [00:24<00:00,  3.06it/s]


Train Loss: 0.1967
Val Loss: 0.1899
 Saved best model


[Train] Epoch 43: 100%|██████████| 74/74 [00:22<00:00,  3.29it/s]


Train Loss: 0.1841
Val Loss: 0.1722
 Saved best model


[Train] Epoch 44: 100%|██████████| 74/74 [00:20<00:00,  3.60it/s]


Train Loss: 0.1731
Val Loss: 0.1590
 Saved best model


[Train] Epoch 45: 100%|██████████| 74/74 [00:21<00:00,  3.50it/s]


Train Loss: 0.1617
Val Loss: 0.1540
 Saved best model


[Train] Epoch 46: 100%|██████████| 74/74 [00:20<00:00,  3.55it/s]


Train Loss: 0.1552
Val Loss: 0.1420
 Saved best model


[Train] Epoch 47: 100%|██████████| 74/74 [00:21<00:00,  3.44it/s]


Train Loss: 0.1463
Val Loss: 0.1374
 Saved best model


[Train] Epoch 48: 100%|██████████| 74/74 [00:22<00:00,  3.24it/s]


Train Loss: 0.1406
Val Loss: 0.1329
 Saved best model


[Train] Epoch 49: 100%|██████████| 74/74 [00:22<00:00,  3.34it/s]


Train Loss: 0.1333
Val Loss: 0.1247
 Saved best model


[Train] Epoch 50: 100%|██████████| 74/74 [00:21<00:00,  3.38it/s]


Train Loss: 0.1290
Val Loss: 0.1240
 Saved best model


[Train] Epoch 51: 100%|██████████| 74/74 [00:21<00:00,  3.42it/s]


Train Loss: 0.1253
Val Loss: 0.1190
 Saved best model


[Train] Epoch 52: 100%|██████████| 74/74 [00:21<00:00,  3.49it/s]


Train Loss: 0.1225
Val Loss: 0.1156
 Saved best model


[Train] Epoch 53: 100%|██████████| 74/74 [00:21<00:00,  3.39it/s]


Train Loss: 0.1168
Val Loss: 0.1108
 Saved best model


[Train] Epoch 54: 100%|██████████| 74/74 [00:21<00:00,  3.40it/s]


Train Loss: 0.1153
Val Loss: 0.1178


[Train] Epoch 55: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]


Train Loss: 0.1096
Val Loss: 0.1096
 Saved best model


[Train] Epoch 56: 100%|██████████| 74/74 [00:21<00:00,  3.38it/s]


Train Loss: 0.1043
Val Loss: 0.1038
 Saved best model


[Train] Epoch 57: 100%|██████████| 74/74 [00:21<00:00,  3.47it/s]


Train Loss: 0.1042
Val Loss: 0.1010
 Saved best model


[Train] Epoch 58: 100%|██████████| 74/74 [00:21<00:00,  3.47it/s]


Train Loss: 0.1049
Val Loss: 0.1022


[Train] Epoch 59: 100%|██████████| 74/74 [00:20<00:00,  3.62it/s]


Train Loss: 0.1030
Val Loss: 0.1023


[Train] Epoch 60: 100%|██████████| 74/74 [00:21<00:00,  3.41it/s]


Train Loss: 0.0984
Val Loss: 0.0978
 Saved best model


[Train] Epoch 61: 100%|██████████| 74/74 [00:22<00:00,  3.36it/s]


Train Loss: 0.0956
Val Loss: 0.0934
 Saved best model


[Train] Epoch 62: 100%|██████████| 74/74 [00:21<00:00,  3.44it/s]


Train Loss: 0.0944
Val Loss: 0.0948


[Train] Epoch 63: 100%|██████████| 74/74 [00:21<00:00,  3.47it/s]


Train Loss: 0.0937
Val Loss: 0.0943


[Train] Epoch 64: 100%|██████████| 74/74 [00:22<00:00,  3.36it/s]


Train Loss: 0.0907
Val Loss: 0.0949


[Train] Epoch 65: 100%|██████████| 74/74 [00:21<00:00,  3.43it/s]


Train Loss: 0.0904
Val Loss: 0.0932
 Saved best model


[Train] Epoch 66: 100%|██████████| 74/74 [00:21<00:00,  3.51it/s]


Train Loss: 0.0899
Val Loss: 0.0918
 Saved best model


[Train] Epoch 67: 100%|██████████| 74/74 [00:20<00:00,  3.54it/s]


Train Loss: 0.0879
Val Loss: 0.0927


[Train] Epoch 68: 100%|██████████| 74/74 [00:20<00:00,  3.53it/s]


Train Loss: 0.0870
Val Loss: 0.0906
 Saved best model


[Train] Epoch 69: 100%|██████████| 74/74 [00:20<00:00,  3.69it/s]


Train Loss: 0.0858
Val Loss: 0.0895
 Saved best model


[Train] Epoch 70: 100%|██████████| 74/74 [00:20<00:00,  3.55it/s]


Train Loss: 0.0828
Val Loss: 0.0887
 Saved best model


[Train] Epoch 71: 100%|██████████| 74/74 [00:22<00:00,  3.36it/s]


Train Loss: 0.0848
Val Loss: 0.0894


[Train] Epoch 72: 100%|██████████| 74/74 [00:21<00:00,  3.44it/s]


Train Loss: 0.0839
Val Loss: 0.0925


[Train] Epoch 73: 100%|██████████| 74/74 [00:21<00:00,  3.52it/s]


Train Loss: 0.0819
Val Loss: 0.0883
 Saved best model


[Train] Epoch 74: 100%|██████████| 74/74 [00:21<00:00,  3.45it/s]


Train Loss: 0.0807
Val Loss: 0.0861
 Saved best model


[Train] Epoch 75: 100%|██████████| 74/74 [00:20<00:00,  3.53it/s]


Train Loss: 0.0800
Val Loss: 0.0867


[Train] Epoch 76: 100%|██████████| 74/74 [00:22<00:00,  3.32it/s]


Train Loss: 0.0785
Val Loss: 0.0880


[Train] Epoch 77: 100%|██████████| 74/74 [00:21<00:00,  3.47it/s]


Train Loss: 0.0787
Val Loss: 0.0857
 Saved best model


[Train] Epoch 78: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]


Train Loss: 0.0757
Val Loss: 0.0848
 Saved best model


[Train] Epoch 79: 100%|██████████| 74/74 [00:21<00:00,  3.45it/s]


Train Loss: 0.0757
Val Loss: 0.0878


[Train] Epoch 80: 100%|██████████| 74/74 [00:21<00:00,  3.38it/s]


Train Loss: 0.0759
Val Loss: 0.0870


[Train] Epoch 81: 100%|██████████| 74/74 [00:21<00:00,  3.45it/s]


Train Loss: 0.0768
Val Loss: 0.0864


[Train] Epoch 82: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]


Train Loss: 0.0770
Val Loss: 0.0892


[Train] Epoch 83: 100%|██████████| 74/74 [00:21<00:00,  3.41it/s]


Train Loss: 0.0755
Val Loss: 0.0895


[Train] Epoch 84: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]


Train Loss: 0.0729
Val Loss: 0.0887


[Train] Epoch 85: 100%|██████████| 74/74 [00:22<00:00,  3.32it/s]


Train Loss: 0.0722
Val Loss: 0.0856


[Train] Epoch 86: 100%|██████████| 74/74 [00:21<00:00,  3.52it/s]


Train Loss: 0.0732
Val Loss: 0.0874


[Train] Epoch 87: 100%|██████████| 74/74 [00:21<00:00,  3.39it/s]


Train Loss: 0.0755
Val Loss: 0.0880


[Train] Epoch 88: 100%|██████████| 74/74 [00:21<00:00,  3.49it/s]


Train Loss: 0.0736
Val Loss: 0.0881


[Train] Epoch 89: 100%|██████████| 74/74 [00:21<00:00,  3.50it/s]


Train Loss: 0.0734
Val Loss: 0.0880


[Train] Epoch 90: 100%|██████████| 74/74 [00:21<00:00,  3.40it/s]


Train Loss: 0.0735
Val Loss: 0.0844
 Saved best model


[Train] Epoch 91: 100%|██████████| 74/74 [00:21<00:00,  3.38it/s]


Train Loss: 0.0714
Val Loss: 0.0860


[Train] Epoch 92: 100%|██████████| 74/74 [00:21<00:00,  3.43it/s]


Train Loss: 0.0713
Val Loss: 0.0850


[Train] Epoch 93: 100%|██████████| 74/74 [00:21<00:00,  3.43it/s]


Train Loss: 0.0708
Val Loss: 0.0827
 Saved best model


[Train] Epoch 94: 100%|██████████| 74/74 [00:20<00:00,  3.59it/s]


Train Loss: 0.0724
Val Loss: 0.0835


[Train] Epoch 95: 100%|██████████| 74/74 [00:21<00:00,  3.39it/s]


Train Loss: 0.0695
Val Loss: 0.0819
 Saved best model


[Train] Epoch 96: 100%|██████████| 74/74 [00:21<00:00,  3.38it/s]


Train Loss: 0.0686
Val Loss: 0.0845


[Train] Epoch 97: 100%|██████████| 74/74 [00:20<00:00,  3.55it/s]


Train Loss: 0.0696
Val Loss: 0.0820


[Train] Epoch 98: 100%|██████████| 74/74 [00:20<00:00,  3.54it/s]


Train Loss: 0.0668
Val Loss: 0.0848


[Train] Epoch 99: 100%|██████████| 74/74 [00:21<00:00,  3.43it/s]


Train Loss: 0.0689
Val Loss: 0.0881


[Train] Epoch 100: 100%|██████████| 74/74 [00:22<00:00,  3.36it/s]


Train Loss: 0.0676
Val Loss: 0.0834


## Generate Melody

In [None]:


def save_melody_only(melody_ids, path, ticks_per_beat=480):
    midi = miditoolkit.MidiFile(ticks_per_beat=ticks_per_beat)
    inst = miditoolkit.Instrument(program=0, is_drum=False, name="Melody")
    for i, pitch in enumerate(melody_ids):
        if pitch > 0 and pitch < 128:
            inst.notes.append(miditoolkit.Note(velocity=80, pitch=pitch,
                                               start=i*ticks_per_beat, end=(i+1)*ticks_per_beat))
    midi.instruments.append(inst)
    midi.dump(path)
    print(f" Saved melody-only MIDI to {path}")

def save_combined_midi(melody_ids, chord_labels, path, ticks_per_beat=480):
    midi = miditoolkit.MidiFile(ticks_per_beat=ticks_per_beat)
    inst = miditoolkit.Instrument(program=0, is_drum=False, name="Melody")
    for i, pitch in enumerate(melody_ids):
        if pitch > 0 and pitch < 128:
            inst.notes.append(miditoolkit.Note(velocity=80, pitch=pitch,
                                               start=i*ticks_per_beat, end=(i+1)*ticks_per_beat))
    midi.instruments.append(inst)

    # Chord as markers
    for i, chord in enumerate(chord_labels):
        midi.markers.append(miditoolkit.Marker(text=chord, time=i*ticks_per_beat))

    midi.dump(path)
    print(f"Saved combined chord+melody MIDI to {path}")

def generate(chords, 
             model_path='models/best_model_2.pt',
             data_path='data/processed_transposed_2.pt', 
             output='output/melody.mid',
             max_length=64,
             temperature=1.2,
             start_pitch=60,
             combine=False):

    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Load vocab
    data = torch.load(data_path)
    chord2idx = data['chord2idx']
    chord_vocab_size = len(chord2idx) + 1

    # Map chords
    chord_ids = [chord2idx.get(c, 0) for c in chords]
    chord_tensor = torch.tensor([chord_ids], dtype=torch.long, device=device)

    # Load model
    model = ChordMelodyTransformerV2(vocab_size=129, chord_vocab_size=chord_vocab_size)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Generate melody
    melody = [start_pitch]
    with torch.no_grad():
        for _ in range(max_length - 1):
            mel_tensor = torch.tensor([melody], dtype=torch.long, device=device)
            memory = model.encoder(model.pos_encoder(model.chord_embedding(chord_tensor)))
            tgt_emb = model.pos_encoder(model.melody_embedding(mel_tensor))
            tgt_mask = model.generate_mask(mel_tensor.size(1), device=device)
            out = model.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
            logits = model.output_layer(out[:, -1, :]) / temperature
            probs = torch.softmax(logits, dim=-1)
            next_note = torch.multinomial(probs, 1).item()
            if next_note == 128:
                break
            melody.append(next_note)
    melody = melody[1:]  # remove start token

    # Save
    Path(output).parent.mkdir(exist_ok=True, parents=True)
    if combine:
        save_combined_midi(melody, chords, output)
    else:
        save_melody_only(melody, output)
    
    return melody  

In [None]:
# basic usage
melody = generate(['C:maj', 'F:maj', 'G:maj', 'C:maj'])

# with custom parameters
melody = generate(
    chords=['Am:min', 'F:maj', 'C:maj', 'G:maj'],
    max_length=64,
    temperature=1.5,
    start_pitch=65,
    combine=True
)


 Saved melody-only MIDI to output/melody.mid
Saved combined chord+melody MIDI to output/melody.mid


In [17]:

from IPython.display import Audio, display

# Classic Pop Progression - Based on C major with complex harmony
pop_progression = [
    # Intro (8 bars)
    'C:maj', 'Am:min', 'F:maj', 'G:maj',
    'C:maj7', 'Am:min7', 'Dm:min7', 'G:7',
    
    # Verse A - with modal interchange (8 bars)  
    'C:maj', 'Am:min', 'F:maj', 'G:maj',
    'Em:min', 'Am:min', 'Dm:min7', 'G:7',
    
    # Verse B - adding secondary dominants (8 bars)
    'F:maj', 'C:maj/3', 'Am:min', 'Em:min',
    'F:maj', 'G:maj', 'Em:min7', 'Am:min',
    
    # Outro - return to tonic (8 bars)
    'F:maj', 'G:maj', 'C:maj', 'Am:min',
    'F:maj', 'G:maj', 'C:maj', 'C:maj'
]

# Jazz Style - Complex ii-V-I progressions
jazz_progression = [
    # Classic ii-V-I in different keys (8 bars)
    'Dm:min7', 'G:7', 'C:maj7', 'C:maj7',
    'Am:min7', 'D:7', 'G:maj7', 'G:maj7',
    
    # Adding substitute chords and modal interchange (8 bars)
    'Em:min7', 'A:7', 'Dm:min7', 'G:7',
    'C:maj7', 'A:7', 'Dm:min7', 'G:7',
    
    # Turnaround chord progression (8 bars)
    'C:maj7', 'Am:min7', 'Dm:min7', 'G:7',
    'Em:min7', 'Eb:dim7', 'Dm:min7', 'G:7',
    
    # Ending (8 bars)
    'C:maj7', 'F:maj7', 'Em:min7', 'Am:min7',
    'Dm:min7', 'G:7', 'C:maj7', 'C:maj7'
]

# Rock Style - Power chords and open chords
rock_progression = [
    # Classic rock progression (8 bars)
    'C:5', 'F:5', 'G:5', 'C:5',
    'Am:min', 'F:maj', 'C:maj', 'G:maj',
    
    # Minor tonality (8 bars)
    'Am:min', 'F:maj', 'C:maj', 'G:maj',
    'Am:min', 'Dm:min', 'G:maj', 'C:maj',
    
    # Power section (8 bars)
    'F:maj', 'G:maj', 'Am:min', 'Am:min',
    'F:maj', 'G:maj', 'C:maj', 'C:maj',
    
    # Strong ending (8 bars)
    'C:5', 'G:5', 'Am:min', 'F:maj',
    'C:5', 'G:5', 'C:5', 'C:5'
]

# Classical Style - Based on functional harmony
classical_progression = [
    # Theme presentation (8 bars)
    'C:maj', 'G:maj/3', 'Am:min', 'Em:min/3',
    'F:maj', 'C:maj/5', 'G:7', 'C:maj',
    
    # Development - modulation to dominant (8 bars)
    'G:maj', 'D:maj/3', 'Em:min', 'Bm:min/3',
    'C:maj', 'G:maj/5', 'D:7', 'G:maj',
    
    # Recapitulation - return to tonic (8 bars)
    'C:maj', 'Am:min', 'Dm:min', 'G:7',
    'Em:min', 'Am:min', 'Dm:min', 'G:7',
    
    # Coda (8 bars)
    'C:maj', 'F:maj', 'C:maj/5', 'G:7',
    'C:maj', 'F:maj', 'G:7', 'C:maj'
]

# Modern Pop - Complex harmony and borrowed chords
modern_pop_progression = [
    # Modern colors (8 bars)
    'C:maj9', 'Am:min11', 'F:maj7', 'G:sus4',
    'Em:min7', 'Am:min7', 'Dm:min9', 'G:7',
    
    # Modal interchange and borrowed chords (8 bars)
    'C:maj', 'Ab:maj', 'F:maj', 'G:maj',
    'Am:min', 'Fm:min', 'C:maj', 'G:maj',
    
    # Emotional climax (8 bars)
    'F:maj', 'Am:min', 'G:maj', 'Em:min',
    'F:maj', 'Am:min', 'G:sus4', 'G:maj',
    
    # Fadeout ending (8 bars)
    'C:maj', 'Am:min', 'F:maj', 'G:maj',
    'C:maj', 'F:maj', 'C:maj', 'C:maj'
]

# Neo-Soul Style - Complex extended chords
neo_soul_progression = [
    # Complex harmony (8 bars)
    'C:maj7', 'Am:min9', 'Dm:min11', 'G:13',
    'Em:min7', 'A:7', 'Dm:min7', 'G:7',
    
    # Color chords (8 bars)
    'C:maj9', 'F:maj7', 'Bb:maj7', 'Am:min7',
    'Dm:min9', 'G:13', 'Em:min7', 'Am:min7',
    
    # Groove section (8 bars)
    'F:maj9', 'Em:min7', 'Am:min9', 'Dm:min7',
    'G:13', 'C:maj7', 'Am:min9', 'Dm:min7',
    
    # Ending (8 bars)
    'G:13', 'C:maj7', 'Am:min9', 'F:maj9',
    'Em:min7', 'Am:min7', 'Dm:min7', 'C:maj7'
]


In [22]:
print("=== Classic Pop Style ===")
melody1 = generate(chords=pop_progression, output="output/pop_melody.mid", max_length=32)

=== Classic Pop Style ===
 Saved melody-only MIDI to output/pop_melody.mid


In [20]:
print("\n=== Jazz Style ===")
melody2 = generate(chords=jazz_progression, output="output/jazz_melody.mid", max_length=32, temperature=1.3)


=== Jazz Style ===
 Saved melody-only MIDI to output/jazz_melody.mid


In [21]:
print("\n=== Rock Style ===")
melody3 = generate(chords=rock_progression, output="output/rock_melody.mid", max_length=32, temperature=1.1)


=== Rock Style ===
 Saved melody-only MIDI to output/rock_melody.mid


In [23]:
print("\n=== Classical Style ===")
melody4 = generate(chords=classical_progression, output="output/classical_melody.mid", max_length=32, temperature=0.9)


=== Classical Style ===
 Saved melody-only MIDI to output/classical_melody.mid


In [24]:
print("\n=== Modern Pop Style ===")
melody5 = generate(chords=modern_pop_progression, output="output/modern_melody.mid", max_length=32, temperature=1.4)


=== Modern Pop Style ===
 Saved melody-only MIDI to output/modern_melody.mid


In [None]:
print("\n=== Neo-Soul Style ===")
melody6 = generate(chords=neo_soul_progression, output="output/neo_soul_melody.mid", max_length=32, temperature=1.2)


=== Neo-Soul Style ===
 Saved melody-only MIDI to output/neo_soul_melody.mid


In [27]:
# Quick test with simple progression
print("\n=== Quick Test ===")
simple_chords = ['C:maj', 'F:maj', 'G:maj', 'Am:min'] * 8
test_melody = generate(chords=simple_chords, output="output/test_melody.mid", max_length=32)
display(Audio("output/test_melody.mid"))


=== Quick Test ===
 Saved melody-only MIDI to output/test_melody.mid


In [28]:

print(f"\nGenerated melodies with lengths: {[len(m) for m in [melody1, melody2, melody3, melody4, melody5, melody6, test_melody]]}") 


Generated melodies with lengths: [31, 31, 31, 31, 31, 31, 31]
