In [1]:
import os

midi_count = 0

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if filename.lower().endswith('.mid'):
            midi_count += 1

print(f"Total number of MIDI files: {midi_count}")


Total number of MIDI files: 3896


In [1]:
# 📦 1. Install dependencies
!pip install music21 torch pretty_midi tqdm --quiet


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m47.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m6.3 MB/s[0m eta [3

In [2]:
# 📚 2. Import libraries
import os
import music21
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pretty_midi
import pickle
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import random
from tqdm import tqdm


In [3]:
# 🔍 3. Recursively preprocess a subset of Nintendo MIDI files
midi_dir = Path("/kaggle/input/nintendo-midis/Nintendo")
chord_vocab, melody_vocab = {}, {}
chord_seqs, melody_seqs = [], []

max_len = 64  # Increased length for longer melodies
max_files = 500  # Limit to avoid timeout
midi_files = list(midi_dir.rglob("*.mid"))
random.shuffle(midi_files)


def note_to_int(note):
    return int(note.pitch.midi)

def chord_to_label(chord):
    root = chord.root().name
    quality = chord.quality
    return f"{root}_{quality}"

for file in tqdm(midi_files[:max_files], desc="Processing MIDI files"):
    try:
        score = music21.converter.parse(file)
        if isinstance(score, music21.stream.Opus):
            if len(score.scores) == 0:
                continue
            score = score.scores[0]

        chords = score.chordify().flatten().getElementsByClass('Chord')
        melody = score.parts[0].flatten().getElementsByClass('Note')

        chord_seq = []
        for c in chords[:max_len]:
            label = chord_to_label(c)
            if label not in chord_vocab:
                chord_vocab[label] = len(chord_vocab)
            chord_seq.append(chord_vocab[label])

        melody_seq = []
        for n in melody[:max_len]:
            midi = note_to_int(n)
            if midi not in melody_vocab:
                melody_vocab[midi] = len(melody_vocab)
            melody_seq.append(melody_vocab[midi])

        if len(chord_seq) == len(melody_seq) == max_len:
            chord_seqs.append(chord_seq)
            melody_seqs.append(melody_seq)
    except Exception as e:
        print(f"Failed on {file}: {e}")

with open("processed.pkl", "wb") as f:
    pickle.dump({
        "chord_seqs": chord_seqs,
        "melody_seqs": melody_seqs,
        "chord_vocab": chord_vocab,
        "melody_vocab": melody_vocab
    }, f)

print(f"✅ Saved {len(chord_seqs)} Nintendo MIDI sequences.")


Processing MIDI files: 100%|██████████| 500/500 [08:03<00:00,  1.03it/s]

✅ Saved 424 Nintendo MIDI sequences.





In [29]:
# 🧠 4. Define dataset and LSTM model
class ChordMelodyDataset(Dataset):
    def __init__(self, X, Y, seq_len):
        self.X = X
        self.Y = Y
        self.seq_len = seq_len

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.X[idx][:self.seq_len]),
            torch.tensor(self.Y[idx][:self.seq_len])
        )

class ChordToMelodyLSTM(nn.Module):
    def __init__(self, chord_vocab, melody_vocab, hidden_dim=128):
        super().__init__()
        self.chord_embed = nn.Embedding(chord_vocab, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim*2, num_heads=2, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim*2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim*2, melody_vocab)
        )

    def forward(self, chords):
        x = self.chord_embed(chords)
        h, _ = self.lstm(x)
        attn_output, _ = self.attn(h, h, h)
        return self.out(attn_output)

In [30]:
# 🚀 5. Train model
with open("processed.pkl", "rb") as f:
    data = pickle.load(f)

chord_seqs = data["chord_seqs"]
melody_seqs = data["melody_seqs"]
chord_vocab_size = len(data["chord_vocab"])
melody_vocab_size = len(data["melody_vocab"])

seq_len = 64  # Match new max_len
# Filter out sequences that don't match the required seq_len
filtered_pairs = [
    (c, m) for c, m in zip(chord_seqs, melody_seqs)
    if len(c) == seq_len and len(m) == seq_len
]
print(len(filtered_pairs))
chord_seqs, melody_seqs = zip(*filtered_pairs) if filtered_pairs else ([], [])

dataset = ChordMelodyDataset(chord_seqs, melody_seqs, seq_len)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = ChordToMelodyLSTM(chord_vocab_size, melody_vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(200):
    total_loss = 0
    for chords, melody in loader:
        logits = model(chords)
        loss = criterion(logits.view(-1, melody_vocab_size), melody.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss:.4f}")

torch.save(model.state_dict(), "lstm_model.pt")
print("✅ Model saved!")

424
Epoch 0: Loss = 59.1663
Epoch 1: Loss = 55.8764
Epoch 2: Loss = 54.6162
Epoch 3: Loss = 53.4292
Epoch 4: Loss = 52.7463
Epoch 5: Loss = 52.1071
Epoch 6: Loss = 51.8705
Epoch 7: Loss = 51.2955
Epoch 8: Loss = 50.8954
Epoch 9: Loss = 50.4553
Epoch 10: Loss = 50.0312
Epoch 11: Loss = 49.8846
Epoch 12: Loss = 49.3610
Epoch 13: Loss = 48.6861
Epoch 14: Loss = 48.4739
Epoch 15: Loss = 47.8643
Epoch 16: Loss = 47.1622
Epoch 17: Loss = 46.9429
Epoch 18: Loss = 46.0541
Epoch 19: Loss = 46.0170
Epoch 20: Loss = 44.7660
Epoch 21: Loss = 44.4209
Epoch 22: Loss = 44.0563
Epoch 23: Loss = 43.5063
Epoch 24: Loss = 42.7152
Epoch 25: Loss = 42.1174
Epoch 26: Loss = 41.8610
Epoch 27: Loss = 41.6130
Epoch 28: Loss = 41.0867
Epoch 29: Loss = 40.5823
Epoch 30: Loss = 39.9292
Epoch 31: Loss = 39.4418
Epoch 32: Loss = 39.1415
Epoch 33: Loss = 39.1981
Epoch 34: Loss = 38.5473
Epoch 35: Loss = 37.7939
Epoch 36: Loss = 37.9304
Epoch 37: Loss = 37.7158
Epoch 38: Loss = 37.1043
Epoch 39: Loss = 37.2727
Epoch 

In [31]:
# 🎵 6. Generate a MIDI file from a melody

def generate_midi(melody_ids, melody_vocab, filename="generated.mid"):
    inv_vocab = {v: k for k, v in melody_vocab.items()}
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    for i, idx in enumerate(melody_ids):
        pitch = inv_vocab.get(idx, 60)
        note = pretty_midi.Note(
            velocity=100, pitch=pitch, start=i*0.5, end=(i+1)*0.5
        )
        inst.notes.append(note)
    pm.instruments.append(inst)
    pm.write(filename)


In [32]:
# 🎼 7. Generate new melody from custom chord input

def sample_logits(logits, temperature=1.0):
    probs = torch.softmax(logits / temperature, dim=-1)
    return torch.multinomial(probs, num_samples=1)

def generate_from_model(model, chord_sequence, melody_vocab, device="cpu", temperature=1.0):
    model.eval()
    with torch.no_grad():
        chords = torch.tensor(chord_sequence).unsqueeze(0).to(device)
        logits = model(chords).squeeze(0)
        sampled = [sample_logits(logits[i], temperature).item() for i in range(logits.size(0))]
        return sampled

# ✅ Load model for inference
model.load_state_dict(torch.load("lstm_model.pt"))
model.eval()

# # 🆕 Example: generate a longer melody from a longer chord progression
# nintendo_intro = [data["chord_vocab"].get(label) for label in [
#     "C_major", "A_minor", "F_major", "G_major",
#     "E_minor", "D_minor", "C_major", "A_minor",
#     "F_major", "G_major", "C_major", "A_minor",
#     "F_major", "G_major", "E_minor", "D_minor",
# ]]

nintendo_intro = [data["chord_vocab"].get(label) for label in [
    "C_major", "F_major", "G_major", "C_major"
]]



nintendo_intro = [c if c is not None else 0 for c in nintendo_intro] + [0] * (64 - 16)

generated_ids = generate_from_model(model, nintendo_intro, data["melody_vocab"], temperature=0.8)
generate_midi(generated_ids, data["melody_vocab"], "generate_4.mid")
