In [6]:
#!pip install pretty_midi  # Uncomment to install prett_midi
import pretty_midi
import numpy as np

In [31]:
def get_tones(midi_path: str, fs: float):
    """
    Get a dictionary mapping (time)frame to a (list of) notes played at that time in a song at midi_path.
    """
    
    
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    piano_data = midi_data.instruments[0]  # Select piano instrument.
    
    # Piano roll is a 2d list where 1st dimension is notes (128 possible notes, see MIDI numbers) and
    # other dimension is frame number. Total number of frames is decided by duration of the song and
    # sampling frequency, fs. (piano_roll[i][j] is velocity?)
    piano_roll = piano_data.get_piano_roll(fs = fs)
    
    # If a note is playing, its velocity would be greater than 0.
    indices = np.where(piano_roll > 0)  # List of 2 lists, one for each dimension of piano_roll.
    # If indices[0][k] = i and indices[1][k] = j, then piano_roll[i][j] > 0 i.e. ith note is playing in jth frame.
    
    frames = np.unique(indices[1])  # Frames in which music is playing. 
    
    frame_notes = {}  # Frame -> List of nodes played in that frame.
    
    for frame in frames:
        frame_notes[frame] = indices[0][np.where(indices[1] == frame)].astype(np.uint8)  # List of notes played in frame.
    
    return frame_notes

In [32]:
def get_train_test(frame_notes: dict[int, list[int]], seq_len):
    """
    Get list of training and testing samples from given frame_notes dictionary.
    """
    
    if len(frame_notes.keys()) == 0:
        return [], []
    
    
    train, test = [], []
    start_frame, end_frame = min(frame_notes.keys()), max(frame_notes.keys())
    exit = False
    ## Adding [e e e ... start].....[start....start+seq_len]
    for n_empty in range(seq_len - 1, -1, -1):   
        train_sample = ['e'] * n_empty
        for frame in range(start_frame, start_frame + seq_len - n_empty):
            if frame >= end_frame:
                # We cannot create any more test cases.
                exit = True
                break
            train_sample_notes = frame_notes.get(frame, ['e'])
            train_sample.append(','.join(str(note) for note in train_sample_notes))    
        
        if exit: 
            break
        
        train.append(train_sample)
        
        test_sample_notes = frame_notes.get(start_frame + seq_len - n_empty, ['e'])
        test.append(','.join(str(note) for note in test_sample_notes))
            
    if exit:
        return train, test
    
    ## Adding [start+1....start+seq_len+1]...[end-seq_len-1....end-1]
    
    for begin in range(start_frame + 1, end_frame - seq_len - 1):
        train_sample = []
        for frame in range(begin, begin + seq_len):
            if frame >= end_frame:
                exit = True
                break
            train_sample_notes = frame_notes.get(frame, ['e'])
            train_sample.append(','.join(str(note) for note in train_sample_notes))
            
        if exit:
            break
            
        train.append(train_sample)
        
        test_sample_notes = frame_notes.get(begin + seq_len, ['e'])
        test.append(','.join(str(note) for note in test_sample_notes))
        
    return train, test
        

In [33]:
gt = get_tones('data/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.midi', fs = 30)
tr, ts = get_train_test(gt, 5)
print(len(tr), len(ts))

7990 7990
