In [61]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm import tqdm
import librosa
import numpy as np
import miditoolkit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, average_precision_score, accuracy_score
import random
import pretty_midi

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile
from glob import glob
# used chatgpt to help me generate some functions

In [30]:
# Processing the midi files
midi_files = glob('nes_midis/*')
print(len(midi_files))

config = TokenizerConfig(num_velocities=1)
tokenizer = REMI(config)
tokenizer.train(vocab_size = 2000, files_paths=midi_files)


2000





In [45]:
instruments = {}
bad_files = []

for file in midi_files:
    try:
        midi = pretty_midi.PrettyMIDI(file)
        for instrument in midi.instruments:
            name = pretty_midi.program_to_instrument_name(instrument.program)
            instruments[name] = instruments.get(name, 0) + 1
    except Exception as e:
        bad_files.append(file)

sorted_instruments = sorted(instruments.items(), key=lambda x: x[1], reverse=True)
midi_files = [file for file in midi_files if file not in bad_files]




In [43]:
# Using the top 20 instruments to condense the instrument types
useful_instruments = set(name for name, _ in sorted_instruments[:20]) 

# extracts only the notes where the instruments are useful
def extract_notes(midi_file):
    notes = []
    midi = pretty_midi.PrettyMIDI(midi_file)
    for instrument in midi.instruments:
        instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
        if instrument_name in useful_instruments:
            for note in instrument.notes:
                notes.append(note.pitch)
    return notes

#extract_notes(midi_files[1])

In [60]:
#token mapping
all_pitches = []
for file in midi_files:
    for notes in extract_notes(file):
        all_pitches.append(notes)
vocab = sorted(set(all_pitches))

pitch2idx = {p: i for i, p in enumerate(vocab)}
idx2pitch = {i: p for i, p in enumerate(vocab)}

# indexes all the pitches to an index
encoded = [pitch2idx[p] for p in all_pitches]

# training examples
sequence_length = 100
X = []
Y = []
for i in range(len(encoded) - sequence_length):
    X.append(encoded[i:i+sequence_length])
    Y.append(encoded[i+sequence_length])



In [62]:
x = torch.tensor(X, dtype=torch.long)
y = torch.tensor(Y, dtype=torch.long)

dataset = TensorDataset(x,y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)


In [63]:
class PitchLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])  # use the last output for prediction
        return out, hidden


In [64]:
model = PitchLSTM(vocab_size=len(vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):  # increase as needed
    for batch_x, batch_y in loader:
        optimizer.zero_grad()
        out, _ = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} | Loss: {loss.item():.4f}")




KeyboardInterrupt: 