In [9]:
#!pip install pretty_midi  # Uncomment to install prett_midi
import pretty_midi
import numpy as np
import os

In [2]:
def get_tones(midi_path: str, fs: float):
    """
    Get a dictionary mapping (time)frame to a (list of) notes played at that time in a song at midi_path.
    """
    
    
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    piano_data = midi_data.instruments[0]  # Select piano instrument.
    
    # Piano roll is a 2d list where 1st dimension is notes (128 possible notes, see MIDI numbers) and
    # other dimension is frame number. Total number of frames is decided by duration of the song and
    # sampling frequency, fs. (piano_roll[i][j] is velocity?)
    piano_roll = piano_data.get_piano_roll(fs = fs)
    
    # If a note is playing, its velocity would be greater than 0.
    indices = np.where(piano_roll > 0)  # List of 2 lists, one for each dimension of piano_roll.
    # If indices[0][k] = i and indices[1][k] = j, then piano_roll[i][j] > 0 i.e. ith note is playing in jth frame.
    
    frames = np.unique(indices[1])  # Frames in which music is playing. 
    
    frame_notes = {}  # Frame -> List of nodes played in that frame.
    
    for frame in frames:
        frame_notes[frame] = indices[0][np.where(indices[1] == frame)].astype(np.uint8)  # List of notes played in frame.
    
    return frame_notes

In [30]:
def get_samples(frame_notes: dict[int, list[int]], seq_len):
    """
    Get list of samples from given frame_notes dictionary as a tuple of (input, output).
    """
    
    if len(frame_notes.keys()) == 0:
        return [], []
    
    
    ip, op = [], []
    start_frame, end_frame = min(frame_notes.keys()), max(frame_notes.keys())
    exit = False
    ## Adding [e e e ... start].....[start....start+seq_len]
    for n_empty in range(seq_len - 1, -1, -1):   
        ip_sample = ['e'] * n_empty
        for frame in range(start_frame, start_frame + seq_len - n_empty):
            if frame >= end_frame:
                # We cannot create any more test cases.
                exit = True
                break
            ip_sample_notes = frame_notes.get(frame, ['e'])
            ip_sample.append(','.join(str(note) for note in ip_sample_notes))    
        
        if exit: 
            break
        
        ip.append(ip_sample)
        
        op_sample_notes = frame_notes.get(start_frame + seq_len - n_empty, ['e'])
        op.append(','.join(str(note) for note in op_sample_notes))
            
    if exit:
        return ip, op
    
    ## Adding [start+1....start+seq_len+1]...[end-seq_len-1....end-1]
    
    for begin in range(start_frame + 1, end_frame - seq_len - 1):
        ip_sample = []
        for frame in range(begin, begin + seq_len):
            if frame >= end_frame:
                exit = True
                break
            ip_sample_notes = frame_notes.get(frame, ['e'])
            ip_sample.append(','.join(str(note) for note in ip_sample_notes))
            
        if exit:
            break
            
        ip.append(ip_sample)
        
        op_sample_notes = frame_notes.get(begin + seq_len, ['e'])
        op.append(','.join(str(note) for note in op_sample_notes))
        
    return ip, op
        

In [49]:
def get_train_test(path_train: str, path_test: str, fs: float, seq_len: int):
    """
    Load training and testing samples using all .midi files stored in
    directories given by path_train and path_test respectively.  
    """
    
    X_train, X_test, y_train, y_test = [], [], [], []
    
    for filename in os.listdir(path_train):
        if filename.endswith(".midi"):
            print("Processing {}".format(filename))
            tones = get_tones(os.path.join(path_train, filename), fs)
            samples = get_samples(tones, seq_len)
            X_train += samples[0]
            y_train += samples[1]
            
    for filename in os.listdir(path_test):
        if filename.endswith(".midi"):
            print("Processing {}".format(filename))
            tones = get_tones(os.path.join(path_test, filename), fs)
            samples = get_samples(tones, seq_len)
            X_test += samples[0]
            y_test += samples[1]
            
    return X_train, X_test, y_train, y_test

In [50]:
train_dir = "data/train"
test_dir = "data/test"
f_sample = 30
seq_length = 50

X_train, X_test, y_train, y_test = get_train_test(train_dir, test_dir, f_sample, seq_length)

# Correctness checkpoint
# print(len(X_train), len(X_test), len(y_train), len(y_test))


Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi
Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.midi
Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_08_Track08_wav.midi
Processing MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_02_Track02_wav.midi
Processing MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_03_Track03_wav.midi
Processing MIDI-Unprocessed_SMF_07_R1_2004_01_ORIG_MID--AUDIO_07_R1_2004_02_Track02_wav.midi
Processing MIDI-Unprocessed_SMF_07_R1_2004_01_ORIG_MID--AUDIO_07_R1_2004_04_Track04_wav.midi
Processing MIDI-Unprocessed_SMF_07_R1_2004_01_ORIG_MID--AUDIO_07_R1_2004_06_Track06_wav.midi
Processing MIDI-Unprocessed_SMF_12_01_2004_01-05_ORIG_MID--AUDIO_12_R1_2004_03_Track03_wav--1.midi
Processing MIDI-Unprocessed_SMF_12_01_2004_01-05_ORIG_MID--AUDIO_12_R1_2004_07_Track07_wav.midi
Processing MIDI-Unprocessed_SMF_12_01_2004_01-05_ORI

In [53]:
# Correctness checkpoint.
# print(X_train[0], y_train[0], X_train[1], y_train[1])

['e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', '71'] 71 ['e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', 'e', '71', '71'] 71
