In [28]:
#!pip install pretty_midi  # Uncomment to install prett_midi
import pretty_midi
import numpy as np
import os
import torch
import random
from torch.utils.data._utils.collate import default_convert

In [2]:
def get_tones(midi_path: str, fs: float):
    """
    Get a dictionary mapping (time)frame to a (list of) notes played at that time in a song at midi_path.
    """
    
    
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    piano_data = midi_data.instruments[0]  # Select piano instrument.
    
    # Piano roll is a 2d list where 1st dimension is notes (128 possible notes, see MIDI numbers) and
    # other dimension is frame number. Total number of frames is decided by duration of the song and
    # sampling frequency, fs. (piano_roll[i][j] is velocity?)
    piano_roll = piano_data.get_piano_roll(fs = fs)
    
    # If a note is playing, its velocity would be greater than 0.
    indices = np.where(piano_roll > 0)  # List of 2 lists, one for each dimension of piano_roll.
    # If indices[0][k] = i and indices[1][k] = j, then piano_roll[i][j] > 0 i.e. ith note is playing in jth frame.
    frames = np.unique(indices[1])  # Frames in which music is playing. 
    
    frame_notes = {}  # Frame -> List of notes played in that frame. 
    
    for frame in frames:
        frame_notes[frame] = indices[0][np.where(indices[1] == frame)].astype(np.uint8)  # List of notes played in frame.
    
    return frame_notes

In [3]:
def get_sequences(frame_notes: dict[int, list[int]], seq_len):
    """
    Get list of samples from given frame_notes dictionary as a tuple of (input, output).
    """
    
    if len(frame_notes.keys()) == 0:
        return [], []
    
    seq_len += 1  # We will use last note in the sequences as the output.
    
    sequences = []
    start_frame, end_frame = min(frame_notes.keys()), max(frame_notes.keys())
    exit = False
    
    ## Adding [e e e ... start].....[start....start+seq_len]
    for n_empty in range(seq_len - 1, -1, -1):   
        sample = ['e'] * n_empty
        for frame in range(start_frame, start_frame + seq_len - n_empty):
            if frame >= end_frame:  # We cannot create any more test cases.
                exit = True
                break
            sample_notes = frame_notes.get(frame, ['e'])
            sample.append(','.join(str(note) for note in sample_notes))    
        
        if exit: 
            break
        
        sequences.append(sample)
            
    if exit:
        return sequences
    
    ## Adding [start+1....start+seq_len+1]...[end-seq_len-1....end-1]
    
    for begin in range(start_frame + 1, end_frame - seq_len - 1):
        sample = []
        for frame in range(begin, begin + seq_len):
            if frame >= end_frame:
                exit = True
                break
            sample_notes = frame_notes.get(frame, ['e'])
            sample.append(','.join(str(note) for note in sample_notes))
            
        if exit:
            break
            
        sequences.append(sample)
        
    return sequences        

In [4]:
def get_train_test(path_train: str, path_test: str, fs: float, seq_len: int):
    """
    Load training and testing samples using all .midi files stored in
    directories given by path_train and path_test respectively.  
    """
    
    train, test = [], []
    
    for filename in os.listdir(path_train):
        if filename.endswith(".midi"):
            print("Processing {}".format(filename))
            tones = get_tones(os.path.join(path_train, filename), fs)
            samples = get_sequences(tones, seq_len)
            train += samples
            
    for filename in os.listdir(path_test):
        if filename.endswith(".midi"):
            print("Processing {}".format(filename))
            tones = get_tones(os.path.join(path_test, filename), fs)
            samples = get_sequences(tones, seq_len)
            test += samples
            
    return train, test

In [5]:
def tokenizer(samples):
    tones2tokens = {}
    tokens2tones = {}
    cnt = 0

    for sample in samples:
        for j in sample:
            if j not in tones2tokens.keys():
                tones2tokens[j]=cnt
                tokens2tones[cnt]=j
                cnt+=1
                
    return tones2tokens, tokens2tones

In [6]:
def piano_roll_to_pretty_midi(piano_roll, fs=30, program=0):
    '''Convert a Piano Roll array into a PrettyMidi object
     with a single instrument.
    Parameters
    ----------
    piano_roll : np.ndarray, shape=(128,frames), dtype=int
        Piano roll of one instrument
    fs : int
        Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    program : int
        The program number of the instrument.
    Returns
    -------
    midi_object : pretty_midi.PrettyMIDI
        A pretty_midi.PrettyMIDI class instance describing
        the piano roll.
    '''
    notes, frames = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')

    # use changes in velocities to find note on / note off events
    velocity_changes = np.nonzero(np.diff(piano_roll).T)

    # keep track on velocities and note on times
    prev_velocities = np.zeros(notes, dtype=int)
    note_on_time = np.zeros(notes)

    for time, note in zip(*velocity_changes):
        # use time + 1 because of padding above
        velocity = piano_roll[note, time + 1]
        time = time / fs
        if velocity > 0:
            if prev_velocities[note] == 0:
                note_on_time[note] = time
                prev_velocities[note] = velocity
        else:
            pm_note = pretty_midi.Note(
                velocity=prev_velocities[note],
                pitch=note,
                start=note_on_time[note],
                end=time)
            instrument.notes.append(pm_note)
            prev_velocities[note] = 0
    pm.instruments.append(instrument)
    return pm

In [7]:
def create_midi(op_filename, generated_tokens, token2notes, fs, note_velocity = 100):
    """
    From a given list of generated_tokens, create a playable .midi file.
    """
    generated_notes = generated_tokens
    #generated_notes = [token2notes[token] for token in generated_tokens]
    piano_roll = np.zeros((128, len(generated_tokens)), dtype = np.int16)
    
    for frame, notes in enumerate(generated_notes):
        if notes != 'e':
            notes = notes.split(',')
            for note in notes:
                piano_roll[int(note)][frame] = note_velocity
        
    midi_data = piano_roll_to_pretty_midi(piano_roll, fs=fs)
    
#     for note in midi_data.instruments[0].notes:
#         note.velocity = 100
        
    midi_data.write(op_filename)

In [8]:
def get_batches(train, batch_size):
    random.shuffle(train)
    
    for i in range(0, len(train), batch_size):
        batch = train[i:i+batch_size]
        yield batch


In [9]:
def fit(samples, notes2tokens):
    tokenized_samples = []

    for sequence in train:
        sample = []
        for tone in sequence:
            sample.append(notes2tokens[tone])
    
        tokenized_samples.append(sample)
        
    return tokenized_samples


In [10]:
train_dir = "data/train"
test_dir = "data/test"
f_sample = 30
seq_length = 50

train, test = get_train_test(train_dir, test_dir, f_sample, seq_length)
notes2tokens, tokens2tones = tokenizer(train)
train_tokenized = fit(train, notes2tokens)

Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi
Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.midi
Processing MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_10_Track10_wav.midi


In [25]:
print(len(train_tokenized[0]))
len(train_tokenized)

51


37013

In [35]:
## Example usage for iterating over a complete training set once (i.e. one epoch).
for batch in torch.utils.data.DataLoader(train_tokenized, batch_size = 16, shuffle = True, collate_fn=default_convert):
    print(len(batch), len(batch[0]), batch[0])
    break

16 51 [224, 224, 224, 795, 795, 795, 795, 795, 795, 795, 795, 795, 795, 795, 0, 0, 389, 0, 0, 0, 224, 0, 0, 0, 0, 1, 1, 1, 1, 1, 13, 13, 1, 1, 1, 270, 1, 1, 1, 41, 0, 0, 0, 0, 183, 417, 417, 116, 1515, 41, 417]


In [12]:
temp_seq = []  # This should be assigned to final output sequence.
for batch in get_batches(train, 16):
    for seq in batch:
        ip, op = seq[:-1], seq[-1]
        print(ip, op)
        temp_seq += ip
    break
        
create_midi("open_at_your_own_risk.midi", temp_seq, None, 30, note_velocity = 100)

['70', '70', '58,70', '58,79', 'e', 'e', 'e', '49,76', '49,76', 'e', 'e', 'e', '51,69', '51,69', '69', '69', '69', '54,69', '69', '69', '69', '69', '69', '57,78', 'e', 'e', 'e', '60,75', 'e', 'e', 'e', 'e', '59,68', '68', '68', '68', '68', '52,68', '52,68', '52,68', '52,68', '52,68,77', '52,68,77', '52,68,77', '52,68,77', 'e', '59,75', 'e', 'e', 'e'] e
['47,76', '47,76', '47,76', '40,71,76', '40,76', '40,76', '43,76', '43,76', '43,76', '76', '45,72,76', '76', '76', '47,76', '47,76', '47,76', '76', '43,74,76', '43,74,76', '72,76', '47,76', '47,71,76', '47,76', '45,72,76', '45,76', '76', '47,76', '47,76', '47,76', '76', '43,74,76', '43,74,76', '72,74', '72', '47,71', '47', '72', '45', '45', 'e', '48', '48', '48', '42,81', '42', '42', '42', '42,45', '45', 'e'] 47,71,79
['52,79', '52,79', '52,79', '52,79', '52,79', '42,52,71,79', '42,52,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,71,79', '42,45,69,79', '42,45,69,79', '42,69,79