In [4]:
!ls -lh maestro-v3.0.0.zip

-rw-r--r-- 1 vkomma 100018700 7.0M Jun  2 03:24 maestro-v3.0.0.zip


In [9]:
!pip install pretty_midi

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import os
import pretty_midi
import torch
from tqdm import tqdm

SEQ_LEN = 256
VOCAB_SIZE = 128

def extract_melody_harmony_from_midi(midi_file, seq_len=SEQ_LEN):
    try:
        midi = pretty_midi.PrettyMIDI(midi_file)
        notes = []
        for inst in midi.instruments:
            if not inst.is_drum:
                notes.extend(inst.notes)
        if len(notes) < seq_len:
            return None

        notes.sort(key=lambda n: n.start)
        start_time = notes[0].start
        end_time = notes[-1].end
        time_step = (end_time - start_time) / seq_len
        if time_step == 0:
            return None

        melody = [0] * seq_len
        harmony = [[0] * VOCAB_SIZE for _ in range(seq_len)]

        for note in notes:
            idx = int((note.start - start_time) / time_step)
            if 0 <= idx < seq_len and 0 <= note.pitch < VOCAB_SIZE:
                harmony[idx][note.pitch] = 1

        for i in range(seq_len):
            active = [p for p, v in enumerate(harmony[i]) if v]
            melody[i] = max(active) if active else 0

        return melody, harmony
    except:
        return None

def extract_from_folder(folder):
    pairs = []
    for root, _, files in os.walk(folder):
        for file in tqdm(files):
            if file.endswith(".mid") or file.endswith(".midi"):
                result = extract_melody_harmony_from_midi(os.path.join(root, file))
                if result:
                    pairs.append(result)
    print(f"✅ Extracted {len(pairs)} melody-harmony pairs")
    return pairs


In [2]:
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pretty_midi

# ----- CONFIG -----
VOCAB_SIZE = 128
HIDDEN_SIZE = 256
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.001
SEQ_LEN = 256

# ----- DATASET -----
class MaestroDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        melody, harmony = self.pairs[idx]
        return torch.tensor(melody, dtype=torch.long), torch.tensor(harmony, dtype=torch.float32)

# ----- MODEL -----
class ChordLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        out, _ = self.lstm(x)
        logits = self.fc(out)
        return torch.sigmoid(logits)

# ----- TRAINING -----
def train(model, dataloader, epochs=EPOCHS):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.BCELoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        print(f"\n🎵 Epoch {epoch}/{epochs}")
        for melody, harmony in dataloader:
            melody, harmony = melody.to(device), harmony.to(device)
            pred = model(melody)
            loss = loss_fn(pred, harmony)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"✅ Avg Loss: {total_loss / len(dataloader):.4f}")

# ----- GENERATION -----
def generate_chords(model, melody_seq):
    device = next(model.parameters()).device
    model.eval()
    with torch.no_grad():
        inp = melody_seq.unsqueeze(0).to(device)
        pred = model(inp)
        chords = (pred.squeeze(0) > 0.5).int().tolist()
        return chords

# ----- MIDI OUTPUT -----
def save_midi_with_chords(melody_seq, harmony_matrix, filename="final_output.mid"):
    midi = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)
    time = 0.0

    for i, melody_note in enumerate(melody_seq):
        # alternate tempo: fast then slow (0.125 or 0.375)
        dur = 0.125 if i % 4 < 2 else 0.375
        end = time + dur

        if melody_note > 0:
            piano.notes.append(pretty_midi.Note(velocity=100, pitch=melody_note, start=time, end=end))

        active = [(p, v) for p, v in enumerate(harmony_matrix[i]) if v]
        if active:
            top_pitches = sorted([p for p, _ in active])[:2]
            for j, pitch in enumerate(top_pitches):
                piano.notes.append(pretty_midi.Note(velocity=80, pitch=pitch, start=time + j*0.05, end=end))

        time = end

    midi.instruments.append(piano)
    midi.write(filename)
    print(f"🎼 Saved: {filename}")


In [3]:
pairs = extract_from_folder("maestro/maestro-v3.0.0")
dataset = MaestroDataset(pairs)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = ChordLSTM(VOCAB_SIZE, HIDDEN_SIZE)
train(model, loader)

for i in range(6):
    melody_sample, _ = random.choice(dataset)
    harmony = generate_chords(model, melody_sample)
    save_midi_with_chords(melody_sample.tolist(), harmony, f"final_song_{i+1}.mid")


100%|██████████| 5/5 [00:00<00:00, 44620.26it/s]
100%|██████████| 115/115 [00:30<00:00,  3.79it/s]
100%|██████████| 129/129 [00:16<00:00,  7.75it/s]
100%|██████████| 93/93 [00:35<00:00,  2.61it/s]
100%|██████████| 1/1 [00:00<00:00, 26214.40it/s]
100%|██████████| 163/163 [00:19<00:00,  8.35it/s]
100%|██████████| 147/147 [00:19<00:00,  7.50it/s]
100%|██████████| 127/127 [00:16<00:00,  7.76it/s]
100%|██████████| 105/105 [00:28<00:00,  3.63it/s]
100%|██████████| 125/125 [00:31<00:00,  3.98it/s]
100%|██████████| 140/140 [00:20<00:00,  6.73it/s]
100%|██████████| 132/132 [00:28<00:00,  4.66it/s]


✅ Extracted 1275 melody-harmony pairs

🎵 Epoch 1/10
✅ Avg Loss: 0.2429

🎵 Epoch 2/10
✅ Avg Loss: 0.1717

🎵 Epoch 3/10
✅ Avg Loss: 0.1637

🎵 Epoch 4/10
✅ Avg Loss: 0.1597

🎵 Epoch 5/10
✅ Avg Loss: 0.1571

🎵 Epoch 6/10
✅ Avg Loss: 0.1551

🎵 Epoch 7/10
✅ Avg Loss: 0.1536

🎵 Epoch 8/10
✅ Avg Loss: 0.1520

🎵 Epoch 9/10
✅ Avg Loss: 0.1510

🎵 Epoch 10/10
✅ Avg Loss: 0.1496
🎼 Saved: final_song_1.mid
🎼 Saved: final_song_2.mid
🎼 Saved: final_song_3.mid
🎼 Saved: final_song_4.mid
🎼 Saved: final_song_5.mid
🎼 Saved: final_song_6.mid
