In [44]:
import glob
import pretty_midi
import matplotlib.pyplot as plt
import pandas as pd
import collections
import numpy as np
import os

In [49]:
class MidiHandler:
    def __init__(self):
        self.files = []
        self.notes = []

    def load_files(self, path):
        files = glob.glob(path)
        self.files.extend(files)
        print(f"Found {len(files)} files in {path}")

    def get_midi_notes(self, midi_file):
        pm = pretty_midi.PrettyMIDI(midi_file)
        instrument = pm.instruments[0]
        notes = collections.defaultdict(list)

        sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
        prev_start = sorted_notes[0].start

        for note in sorted_notes:
            start = note.start
            end = note.end
            notes["pitch"].append(note.pitch)
            notes["start"].append(start)
            notes["end"].append(end)
            notes["step"].append(start - prev_start)
            notes["duration"].append(end - start)
            prev_start = start

        return pd.DataFrame({name: np.array(value) for name, value in notes.items()})

    def plot_piano_roll(self, sample):
        notes = self.get_midi_notes(sample)

        plt.figure(figsize=(20, 4))
        plot_pitch = np.stack([notes["pitch"], notes["pitch"]], axis=0)
        plot_start_stop = np.stack([notes["start"], notes["end"]], axis=0)
        plt.plot(plot_start_stop, plot_pitch, color="b", marker=".")
        plt.xlabel("Time [s]")
        plt.ylabel("Pitch")
        plt.title(sample)

    def get_notes_from_samples(self, range_from=0, range_to=None):
        all_notes = []
        file_range = self.files[range_from:range_to]
        num_files = len(file_range)
        errors = []
        for index, file in enumerate(file_range):
            try:
                notes = self.get_midi_notes(file)
                all_notes.append(notes)
                print(f"Loaded {index + 1}/{num_files}: {file}")
            except Exception as e:
                errors.append(file)
                print(f"Error loading {file}: {e}")
        
        if errors:
            print(f"Errors: {errors}")

        self.notes = pd.concat(all_notes)
        return self.notes

In [50]:
handler = MidiHandler()
handler.load_files("../data/maestro/**/*.mid*")
handler.load_files("../data/Cymatics/*.mid*")

print(f"Number of Samples: {len(handler.files)}")

Found 1282 files in ../data/maestro/**/*.mid*
Found 559 files in ../data/Cymatics/*.mid*
Number of Smaples: 1841


In [47]:
# handler.plot_piano_roll(handler.files[0])

In [48]:
all_notes = handler.get_notes_from_samples()
key_order = ["pitch", "step", "duration"]
train_notes = np.stack([all_notes[key] for key in key_order], axis=1)

np.save("../data/notes.npy", train_notes)

Loaded 1/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D2-13-20_mid--AUDIO-from_mp3_17_R1_2015_wav--4.midi
Loaded 2/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_12_R1_2015_wav--1.midi
Loaded 3/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-1-8_mid--AUDIO-from_mp3_08_R1_2015_wav--3.midi
Loaded 4/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_10_R1_2015_wav--2.midi
Loaded 5/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-1-8_mid--AUDIO-from_mp3_01_R1_2015_wav--3.midi
Loaded 6/1841: ../data/maestro/2015/MIDI-Unprocessed_R2_D1-2-3-6-7-8-11_mid--AUDIO-from_mp3_07_R2_2015_wav--1.midi
Loaded 7/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-1-8_mid--AUDIO-from_mp3_08_R1_2015_wav--2.midi
Loaded 8/1841: ../data/maestro/2015/MIDI-Unprocessed_R1_D1-1-8_mid--AUDIO-from_mp3_04_R1_2015_wav--3.midi
Loaded 9/1841: ../data/maestro/2015/MIDI-Unprocessed_R2_D2-12-13-15_mid--AUDIO-from_mp3_12_R2_2015_wav--2.midi
Loaded 10/1841: ../data/maes