In [2]:
import pretty_midi
import numpy as np
import joblib
import glob
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('white')
sns.set_context('notebook', font_scale=1.5)
import matplotlib.gridspec
import collections
import os
import math

from sklearn import datasets 
from sklearn.metrics import confusion_matrix 
from sklearn.model_selection import train_test_split 


In [3]:
def plot_data(x, y): 
    fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
    plt.scatter(x, y, marker='_')
    plt.xlabel('Time (ms)')
    plt.ylabel('MIDI Pitch')
#     plt.ylim(top=128)
    plt.show()

In [4]:
def compute_statistics(midi_file):
    """
    Given a path to a MIDI file, compute a dictionary of statistics about it
    
    Parameters
    ----------
    midi_file : str
        Path to a MIDI file.
    
    Returns
    -------
    statistics : dict
        Dictionary reporting the values for different events in the file.
    """
    # Some MIDI files will raise Exceptions on loading, if they are invalid.
    # We just skip those.
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
        # Extract informative events from the MIDI file
        return {'n_instruments': len(pm.instruments),
                'program_numbers': [i.program for i in pm.instruments if not i.is_drum],
                'key_numbers': [k.key_number for k in pm.key_signature_changes],
                'tempos': list(pm.get_tempo_changes()[1]),
                'time_signature_changes': pm.time_signature_changes,
                'end_time': pm.get_end_time(),
                'lyrics': [l.text for l in pm.lyrics],
                'path': midi_file}
    # Silently ignore exceptions for a clean presentation (sorry Python!)
    except Exception as e:
        pass

In [5]:
# Compute statistics about every file in our collection in parallel using joblib
# We do things in parallel because there are tons so it would otherwise take too long!
statistics = joblib.Parallel(n_jobs=10, verbose=0)(
    joblib.delayed(compute_statistics)(midi_file)
    for midi_file in glob.glob(os.path.join('..', 'lmd_aligned', '*', '*', '*', '*', '*.mid')))
# When an error occurred, None will be returned; filter those out.
statistics = [s for s in statistics if s is not None]

### Find midi files that have at least instrument with program numbers 0, 25, 33, 48 and store them in pm_array
 0 (“Acoustic Grand Piano”), 48 (“String Ensemble 1”), 33 (“Electric Bass (finger)”), and 25 (“Acoustic Guitar (steel)”)
(Reference: https://nbviewer.jupyter.org/github/craffel/midi-ground-truth/blob/master/Statistics.ipynb)

In [1]:
mixture_midis = []
for s in statistics: 
    if (0 in s['program_numbers']) and (25 in s['program_numbers']) and (33 in s['program_numbers']) and (48 in s['program_numbers']): 
         mixture_midis.append(s['path'])



NameError: name 'statistics' is not defined

In [6]:
num_midis = len(mixture_midis)
pm_array = []

for i in range(num_midis): 
    pm_array = np.append(pm_array, pretty_midi.PrettyMIDI(mixture_midis[i]))


In [7]:
T = 0.010   # Timestep (s)
t_end = 60 # Duration (s)
num_notes = 128
num_instruments = 4
num_timeslices = int(t_end/T)
data = np.zeros((num_midis, num_timeslices, num_notes * num_instruments)) # rows: timesteps, cols: notes, instruments (check wiki)

our_instruments = [0, 25, 48, 33]  # Program numbers -- predict last instrument 


for midi_index in range(num_midis): 
    for instrument in pm_array[midi_index].instruments: 
        if (instrument.program in our_instruments): 
            index = our_instruments.index(instrument.program)
            for note in instrument.notes:
                data[midi_index, math.floor(note.start/T):math.floor(note.end/T), num_notes * index + note.pitch] = 1   #data[math.floor(note.start/T):math.floor(note.end/T)+1, note.pitch] = 1

# import sys
# np.set_printoptions(threshold=sys.maxsize)


# # Print piano roll scatter plots
# output_data = np.argwhere(data[0,:,0:num_notes]>0)
# plot_data(output_data[:,0], output_data[:,1])

# output_data = np.argwhere(data[1,:,:]>0)
# plot_data(output_data[:,0], output_data[:,1])

# output_data = np.argwhere(data[2,:,:]>0)
# plot_data(output_data[:,0], output_data[:,1])


### Reshape data 

In [None]:
all_data = data.reshape(num_midis*num_timeslices, num_notes * num_instruments)

# output_data = np.argwhere(all_data>0)
# plot_data(output_data[:,0], output_data[:,1])


In [9]:
X = all_data[:,0:num_notes*3]
y = all_data[:,num_notes*3:num_notes*4] # predict last instrument

test_size = 0.30
n = X.shape[0]
mid_index = math.floor(n*(1-test_size))


X_train = X[0:mid_index,:]
y_train = y[0:mid_index,:]
X_test = X[mid_index:-1,:]
y_test = y[mid_index:-1,:]

# dividing X, y into train and test data 
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.30) 

y_output = np.argwhere(y_train>0)
y_train_mod = np.zeros(y_train.shape[0])
y_train_mod.fill(-1) # -1 means no note played

for element in y_output: 
    y_train_mod[element[0]] = element[1] # for a given timestamp, the last note listed in y_output will be kept
    

print(X.shape)
print(y.shape)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(4590000, 384)
(4590000, 128)
(3213000, 384)
(3213000, 128)
(1376999, 384)
(1376999, 128)


In [None]:
# training a linear SVM classifier 
from sklearn.svm import SVC
svclassifier = SVC(kernel='rbf')
svclassifier.fit(X_train, y_train_mod)

In [None]:
y_pred = svclassifier.predict(X_test)

# output_X_test = np.argwhere(X_test>0)
# plot_data(output_X_test[:,0], output_X_test[:,1])

# output_y_test = np.argwhere(y_test>0)
# plot_data(output_y_test[:,0], output_y_test[:,1])

# # output_y_all = np.argwhere(y>0)
# # plot_data(output_y_all[:,0], output_y_all[:,1])

# # output_y_pred = np.argwhere(y_pred>0)
# # plot_data(output_y_pred[:,0], output_y_pred[:,1])

# fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
# plt.plot(y_pred, '_', color='red')



# # fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
# plt.plot(y_pred)
# # plt.xlabel('Time (ms)')
# # plt.ylabel('MIDI Pitch')
# # plt.ylim(top=128)
# plt.show()

In [248]:
y_pred

array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 40., 40., 40.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., 40., 40., 40., 40., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., 43., 43., 43., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 40., 40., 40.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., 43., 43., 43