In [1]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm import tqdm
import librosa
import numpy as np
import miditoolkit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, average_precision_score, accuracy_score
import random
import pretty_midi
import math
from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile
from glob import glob
# used chatgpt to help me generate some functions

  import pkg_resources


In [2]:
# Processing the midi files
midi_files = glob('nes_midis/*')
print(len(midi_files))

config = TokenizerConfig(num_velocities=1)
tokenizer = REMI(config)
tokenizer.train(vocab_size = 2000, files_paths=midi_files)


2000


In [3]:
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found")
print("Torch CUDA version:", torch.version.cuda)
instruments = {}
bad_files = []

for file in midi_files:
    try:
        midi = pretty_midi.PrettyMIDI(file)
        for instrument in midi.instruments:
            name = pretty_midi.program_to_instrument_name(instrument.program)
            instruments[name] = instruments.get(name, 0) + 1
    except Exception as e:
        bad_files.append(file)

sorted_instruments = sorted(instruments.items(), key=lambda x: x[1], reverse=True)
midi_files = [file for file in midi_files if file not in bad_files]


CUDA available: True
Device name: NVIDIA GeForce GTX 1660 SUPER
Torch CUDA version: 11.8




In [4]:
# Using the top 20 instruments to condense the instrument types
useful_instruments = set(name for name, _ in sorted_instruments[:20]) 

# extracts only the notes where the instruments are useful
def extract_note_sequence(midi_path):
    midi = pretty_midi.PrettyMIDI(midi_path)
    notes = []
    for instrument in midi.instruments:
        if not instrument.is_drum:
            notes += [note.pitch for note in instrument.notes]
    return notes

#extract_notes(midi_files[1])

In [5]:

class MIDIDataset(Dataset):
    def __init__(self, midi_dir, vocab, seq_len=128):
        self.data = []
        self.vocab = vocab
        self.seq_len = seq_len
        self.pitch2idx = {p: i for i, p in enumerate(vocab)}

        for file in midi_dir:
            print(f"Processing file: {file}")
            notes = extract_note_sequence(file)
            encoded = [self.pitch2idx[n] for n in notes if n in self.pitch2idx]

            for i in range(0, len(encoded) - seq_len):
                x = encoded[i:i + seq_len]
                y = encoded[i+seq_len]
                self.data.append((torch.tensor(x), torch.tensor(y)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]



In [7]:
pitch_vocab = sorted(set(p for f in midi_files for p in extract_note_sequence(f)))
pitch2idx = {p: i for i, p in enumerate(pitch_vocab)}
idx2pitch = {i: p for i, p in enumerate(pitch_vocab)}
# vocab = sorted(vocab)

dataset = MIDIDataset(midi_dir=midi_files, vocab=pitch_vocab)
print(f"Dataset size: {len(dataset)}") 
# loader = DataLoader(dataset[:1000], batch_size=32, shuffle=True, num_workers=0, pin_memory=True)


Processing file: nes_midis\001_10-Yard_Fight-Kick_Off.mid
Processing file: nes_midis\002_1943.mid
Processing file: nes_midis\003_1943sab.mid
Processing file: nes_midis\004_1943-lev1.mid
Processing file: nes_midis\005_43pbos1.mid
Processing file: nes_midis\006_43pbos12.mid
Processing file: nes_midis\007_1943-lev3.mid
Processing file: nes_midis\008_1943-Lev3Win.mid
Processing file: nes_midis\009_1943lost.mid
Processing file: nes_midis\010_1943won.mid
Processing file: nes_midis\011_1943boss.mid
Processing file: nes_midis\012_1943boss1.mid
Processing file: nes_midis\013_1943BossWin.mid
Processing file: nes_midis\014_1999.mid
Processing file: nes_midis\015_3D_Worldrunner_Bonus.mid
Processing file: nes_midis\016_3D_Worldrunner_Boss.mid
Processing file: nes_midis\017_3D_Worldrunner_Main.mid
Processing file: nes_midis\018_720.mid
Processing file: nes_midis\019_Egypt.mid
Processing file: nes_midis\020_8_Eyes_-_Enterance.mid
Processing file: nes_midis\021_8_Eyes.mid
Processing file: nes_midis\02

In [8]:
class PitchLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])  # use the last output for prediction
        return out, hidden


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
# model = PitchLSTM(vocab_size=len(vocab))
model = PitchLSTM(vocab_size=len(pitch_vocab))
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(15):
    total_loss = 0
    num_batches = 0
    for batch_x, batch_y in loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out, _ = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    avg_loss = total_loss / num_batches
    perplexity = math.exp(avg_loss)

    print(f"Epoch {epoch + 1} | Loss: {loss.item():.4f} | Perplexity: {perplexity:.2f}")






cuda
Epoch 1 | Loss: 1.4516 | Perplexity: 5.67
Epoch 2 | Loss: 1.3688 | Perplexity: 4.04
Epoch 3 | Loss: 1.1196 | Perplexity: 3.76
Epoch 4 | Loss: 1.1282 | Perplexity: 3.69
Epoch 5 | Loss: 1.2404 | Perplexity: 3.69
Epoch 6 | Loss: 1.3552 | Perplexity: 3.72
Epoch 7 | Loss: 1.1930 | Perplexity: 3.78
Epoch 8 | Loss: 1.1841 | Perplexity: 3.84
Epoch 9 | Loss: 1.1899 | Perplexity: 3.93
Epoch 10 | Loss: 1.4265 | Perplexity: 4.06
Epoch 11 | Loss: 1.2719 | Perplexity: 4.21
Epoch 12 | Loss: 1.2421 | Perplexity: 4.48
Epoch 13 | Loss: 2.2194 | Perplexity: 4.96
Epoch 14 | Loss: 1.9732 | Perplexity: 6.68
Epoch 15 | Loss: 2.5185 | Perplexity: 8.04


In [13]:
def generate_sequence(model, start_seq, length, device):
    model.eval()
    generated = start_seq[:]
    input_seq = torch.tensor(start_seq, dtype=torch.long).unsqueeze(0).to(device)

    hidden = None
    for _ in range(length):
        with torch.no_grad():
            out, hidden = model(input_seq, hidden)
            prob = torch.softmax(out, dim=-1)
            next_token = torch.multinomial(prob, num_samples=1).item()

        generated.append(next_token)
        input_seq = torch.tensor(generated[-len(start_seq):], dtype=torch.long).unsqueeze(0).to(device)

    return generated


In [14]:
def sequence_to_midi(token_sequence, idx2pitch, output_path="generated.mid", duration=0.5):
    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    time = 0

    for token in token_sequence:
        duration = random.choice([0.5, 0.75, 1.0])
        pitch = idx2pitch[token]
        note = pretty_midi.Note(velocity=100, pitch=pitch, start=time, end=time + duration)
        instrument.notes.append(note)
        time += duration

    midi.instruments.append(instrument)
    midi.write(output_path)


In [17]:
start_seq = [random.choice(range(len(pitch_vocab))) for _ in range(32)]
generated_tokens = generate_sequence(model, start_seq, length=200, device=device)
sequence_to_midi(generated_tokens, idx2pitch, output_path="my_song.mid")
