In [1]:
from miditoolkit.midi import parser as mid_parser  
from miditoolkit.midi import containers as ct
from numpy import array, linspace
from sklearn.neighbors import KernelDensity
from matplotlib.pyplot import plot
from scipy.signal import argrelextrema
from scipy.ndimage import gaussian_filter1d
import numpy as np
from miditoolkit.pianoroll import parser as pr_parser
from miditoolkit.pianoroll import utils
import matplotlib.pyplot as plt

In [2]:
# helper functions
def read_midi(path):
    mido_obj = mid_parser.MidiFile(path)
    tick_per_beat = mido_obj.ticks_per_beat

    notes = [] 
    for instrument in mido_obj.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            notes.append(note)

    # sort by start time
    notes.sort(key=lambda note:note.start)
    return notes,tick_per_beat
    
def write_midi(notes,path='out.mid',tick_per_beat=480):
    out = mid_parser.MidiFile()
    out.ticks_per_beat = tick_per_beat
    out.instruments = [ct.Instrument(program=0,is_drum=False,name='kept notes')]
    for note in notes:
        assert(note.velocity)
        out.instruments[0].notes.append(ct.Note(start=note.start,end=note.end,pitch=note.pitch,velocity=note.velocity))
    out.dump(path)

In [3]:
def mergeIntervals(arr):
        # Sorting based on the increasing order 
        # of the start intervals
        arr.sort(key = lambda x: x[0]) 
        # array to hold the merged intervals
        m = []
        s = -10000
        max = -100000
        for i in range(len(arr)):
            a = arr[i]
            if a[0] > max:
                if i != 0:
                    m.append([s,max])
                max = a[1]
                s = a[0]
            else:
                if a[1] >= max:
                    max = a[1]        
        #'max' value gives the last point of 
        # that particular interval
        # 's' gives the starting point of that interval
        # 'm' array contains the list of all merged intervals
        if max != -100000 and [s, max] not in m:
            m.append([s, max])
        return m

In [4]:
def interval_histogram(notep):
    hist_p = dict()
    for note in notep:
        if note.pitch in hist_p:
            hist_p[note.pitch].append([note.start,note.end])
            hist_p[note.pitch] = mergeIntervals(hist_p[note.pitch])
        else:
            hist_p[note.pitch] = [[note.start,note.end]]    
    return hist_p

def find_intersect(note,intervals):
    def is_overlapping(x1,x2,y1,y2):
        return max(x1,y1) < min(x2,y2)
    for inte in intervals:
        if is_overlapping(note.start,note.end,inte[0],inte[1]):
            return True
    return False
    

In [6]:
import os
paired_data=[]
data_path='./'
paired=['LOP']
extracted_o_count=0
root_path = os.walk(data_path+paired[0])
for root, directories, files in root_path:
    for d in directories:
        print(d)
        noteo,tpbo = read_midi(root+'/'+d+"/orchestra.mid")
        notep,tpbp = read_midi(root+'/'+d+"/piano.mid")
        try:
            assert tpbo==tpbp
        except:
            print("GG", d)
            write_midi(noteo,root+'/'+d+"/dontuse.mid",tpbo)
            continue
        histp = interval_histogram(notep)
        kept_notes = []
        # print("Total notes:", len(noteo))
        for note in noteo:
            if note.pitch not in histp:
                continue
            if find_intersect(note,histp[note.pitch]):
                kept_notes.append(note)
        print("Kept notes:" ,len(kept_notes)/len(noteo))
        if len(kept_notes)/len(noteo) < 0.5:
            # write_midi(noteo,root+'/'+d+"/dontuse.mid",tpbo)
            print(d, "DONTUSE")
        kept_notes = sorted(kept_notes,key=lambda x: x.start)
        # write_midi(kept_notes,root+'/'+d+"/labels.mid",tpbo)
        

bouliane-0
Kept notes: 0.9914965986394558
bouliane-1
Kept notes: 0.7278828553996339
bouliane-10
Kept notes: 0.6909090909090909
bouliane-11
Kept notes: 0.974025974025974
bouliane-12
Kept notes: 0.9799599198396793
bouliane-13
Kept notes: 0.8341968911917098
bouliane-14
Kept notes: 0.9975609756097561
bouliane-15
Kept notes: 0.7927461139896373
bouliane-16
Kept notes: 0.845861084681256
bouliane-17
Kept notes: 0.8616684266103485
bouliane-18
Kept notes: 0.8928571428571429
bouliane-19
Kept notes: 0.675032175032175
bouliane-2
Kept notes: 0.8666666666666667
bouliane-20
Kept notes: 0.9939789262418465
bouliane-21
Kept notes: 0.8392857142857143
bouliane-22
Kept notes: 1.0
bouliane-23
Kept notes: 0.9846153846153847
bouliane-24
Kept notes: 0.9346016646848989
bouliane-25
Kept notes: 0.7194656488549618
bouliane-26
Kept notes: 0.8148148148148148
bouliane-27
Kept notes: 0.9938271604938271
bouliane-28
Kept notes: 0.9911616161616161
bouliane-29
Kept notes: 0.9333333333333333
bouliane-3
Kept notes: 0.9016152

In [9]:
noteo, tpbo = read_midi("./LOP/hand_picked_Spotify-33/orchestra.mid")
notep, tpbp = read_midi("./LOP/hand_picked_Spotify-33/piano.mid")
assert tpbo==tpbp
histp = interval_histogram(notep)
kept_notes = []
# print("Total notes:", len(noteo))
for note in noteo:
    if note.pitch not in histp:
        continue
    # print(note.start,note.end,note.pitch)
    if find_intersect(note,histp[note.pitch]):
        kept_notes.append(note)
print("Kept notes:" ,len(kept_notes)/len(noteo))
kept_notes = sorted(kept_notes,key=lambda x: x.start)
write_midi(kept_notes,"labels.mid",tpbo)

Kept notes: 0.5470363444182453
