In [None]:
import numpy as np
from music21 import converter, instrument, note, chord, meter, duration, stream
import music21
import os
from tqdm import tqdm
from music21 import common

In [None]:
def decode_midi(data, header=None):
    strm = stream.Stream()
    
    for element in header[1:]: # TODO: right now I'm skipping the instrument because it adds like 10 empty measures for some reason
        strm.append(element)

    current_note = None
    current_length = 0.25

    for t in range(data.shape[1]):
        nothing = True

        for p, val in enumerate(data[:,t]):
            if val:
                if p < 128:
                    current_note = note.Note()
                    current_note.pitch.midi = p
                    current_length = 0.25

                    nothing = False
                elif p == 128:
                    pass #TODO: something with rests
                else:
                    if nothing:
                        current_length += 0.25
                    current_note.duration = duration.Duration(current_length)
                    #print('{}\t{}'.format(current_note.pitch, current_note.duration))
                    strm.append(current_note)
                    current_note = None
                    nothing = False
        
        if nothing and current_note: # if there was no data in this time step and there is a current note increase it's length by 1 16th
            current_length += 0.25
    
    return strm

In [None]:
from collections import defaultdict
def insert_note(arr, idx, pitch, note_off): 
    arr[idx, pitch] = True

    if note_off > idx:
        arr[idx + 1 : note_off + 1, -2] = True

def beat_to_total_beat(offset_16th, beat, single_note_length):
    return int(offset_16th + (beat - 1) * single_note_length)

def encode_midi(measureMap, t=32, measure_length=16, single_note_length = 4, header=None):
    data = np.zeros((t, 130), dtype=np.bool)
    measure_count = 0
    fill_header = True

    """ max_key = float(t) / single_note_length

    print(measureMap.keys())

    measure_keys = filter(lambda key: key < max_key, measureMap.keys())
    measure_keys = sorted(measure_keys)
    """
    last_beat = -1
    sustain_until = -1
    notes = defaultdict(tuple)
    for key in measureMap.keys():
        i = key * single_note_length
        if i >= t:
            break
        
        for element in measureMap[key][0]:
            if isinstance(element, note.Note):
                fill_header = False
                idx = beat_to_total_beat(i, element.beat, single_note_length)
                if not notes[idx] and idx < t:
                    notes[idx] = (int(element.pitch.midi), idx + int(element.duration.quarterLength * 4) - 1)
            elif fill_header and header != None:
                print(type(element))
                header.append(element)
        

    beats = sorted([key for key, value in notes.items() if value != ()])
    
    if not beats:
        return None

    if beats[0] != 0: # if the first note isn't on the first beat fill the missing space with rests
        data[:beats[0], -1] = True

    for i, idx in enumerate(beats):
        pitch, note_off = notes[idx]
        if i < len(beats) - 2:
            next_beat = beats[i + 1]
            if note_off >= next_beat: # if this note continues past the next one clip it
                note_off = next_beat - 1
            elif next_beat - note_off > 1: # if there is at least one beat between the end of this note and the beginning of the next
                data[note_off + 1 : next_beat, -1] = True # fill the intervening beats with rests
        else:   # fill everything after the last note off with rests
            if note_off < t - 1:
                data[note_off + 1 : ,-1] = True
            elif note_off >= t: # clip the last note to the end of the last measure
                note_off = t - 1

        insert_note(data, idx, pitch, note_off)
    
    return data

def load_single(path, measures=2, desired_instrument='Guitar'):
    midi = converter.parse(path)
    ts = midi.getTimeSignatures()[0]
    num = ts.numerator
    denom = ts.denominator

    measure_length = None # length of measure in 16ths
    single_note_length = None # length of a single note in 16ths

    if denom / 4 == 1:
        measure_length = num * 4 # quarter notes
        single_note_length = 4
    elif denom / 4 == 2:
        measure_length = num * 2 # eigth notes
        single_note_length = 2
    elif denom / 4 == 4:
        measure_length = num # 16ths
        single_note_length = 1
    elif denom / 4 == 0.5:
        measure_length = num * 8
        single_note_length = 8
    else:
        print(denom)


    notes_to_parse = None
    score = instrument.partitionByInstrument(midi)
    if score: # file has instrument parts
        correctPart = None

        for part in score.parts:
            instr = str(part.getInstrument())
            
            if instr == desired_instrument:
                correctPart = part
                break

        if not correctPart:
            correctPart = score.parts[0]

        correctPart.makeMeasures(inPlace=True)
        measureMap = correctPart.measureOffsetMap()

    else: # file has notes in a flat structure
        #notes_to_parse = midi.flat.notes
        measureMap = midi.flat.measureOffsetMap()

    """ print(measureMap)
    return """
    return encode_midi(measureMap, t=measures*16, measure_length=measure_length, single_note_length = single_note_length)

def parse(filename):
    return converter.parse(filename)
    """ try:
        return converter.parse(filename)
    except:
        return None """

def load_data(folder, measures=2, instrument='Guitar'):
    filenames = [[folder + f, measures, instrument] for f in os.listdir(folder)]
    results = common.runParallel(filenames, load_single, updateFunction=True, unpackIterable=True)
    #results = [parse(f) for f in filenames]
    #result = Parallel(n_jobs=4, backend="threading", verbose=1)(delayed(load_single)(folder + f, measures=16) for f in filenames)
    """ data = []
    for midi in tqdm(results):
        if midi:
            data.append(load_single(midi)) """
    
    return np.array(list(filter(lambda result: isinstance(result, np.ndarray), results)))



In [None]:
parent = 'data/train/'
for filename in tqdm(os.listdir(parent)[14:]):
    path = parent + filename
    data = load_single(path)

In [None]:
from timeit import default_timer as timer

path = 'data/train/'
start = timer()
train_x = load_data(path)
end = timer()

total_time = end - start
avg_time = total_time / len(os.listdir(path))
print('total time: {}, avg time per song: {}'.format(total_time, avg_time))

In [None]:
np.save('data/training_data_2bar.npy', train_x)

In [None]:
len(train_x)

In [None]:
header = []
measures = 16
data = encode_midi(measureMap, t=measures*16, header=header)
strm = decode_midi(data, header)
strm.ticksPerQuarterNote = 1024
strm.makeRests(fillGaps=True, inPlace=True)
#strm.show('text')
strm.write('midi', fp='data/Sin_City_decoded.mid')