In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import tensorflow as tf
import pandas as pd
import librosa
from tqdm import tqdm_notebook as tqdm
import pickle
import IPython.display as ipd
import librosa.display
from catboost import CatBoostClassifier
from catboost import Pool
%matplotlib inline

Using TensorFlow backend.


In [2]:
def create_xy(df, target_column):
    return df.drop(target_column, axis=1), df[target_column]

def split_data(df, target_column):
    X_train, y_train = create_xy(df.sample(round(0.8*df.shape[0])), target_column)
    df = df.drop(X_train.index)
    X_val, y_val = create_xy(df.sample(round(0.5*df.shape[0])), target_column)
    df = df.drop(X_val.index)
    X_test, y_test = create_xy(df, target_column)
    return X_train, y_train, X_val, y_val, X_test, y_test

In [3]:
data = pd.read_pickle("../Data/Piano/data-piano-xl.pkl")

In [4]:
X_train, y_train, X_val, y_val, X_test, y_test = split_data(data.drop("Piano", axis=1), "Note")

## Catboost training

In [6]:
train_dataset = Pool(data=X_train,
                     label=y_train)

In [7]:
eval_dataset = Pool(data=X_val,
                    label=y_val)

In [8]:
model = CatBoostClassifier(iterations=1000,
                           learning_rate=0.05,
                           depth=4,
                           loss_function='MultiClass', 
                           task_type="GPU")

In [9]:
model.fit(train_dataset)

0:	learn: -3.5460775	total: 551ms	remaining: 9m 10s
1:	learn: -3.5343896	total: 1.03s	remaining: 8m 35s
2:	learn: -3.5260779	total: 1.51s	remaining: 8m 23s
3:	learn: -3.5148230	total: 2.09s	remaining: 8m 41s
4:	learn: -3.5064013	total: 2.64s	remaining: 8m 44s
5:	learn: -3.4990252	total: 3.13s	remaining: 8m 38s
6:	learn: -3.4920246	total: 3.58s	remaining: 8m 28s
7:	learn: -3.4864250	total: 4.1s	remaining: 8m 28s
8:	learn: -3.4771975	total: 4.57s	remaining: 8m 22s
9:	learn: -3.4648905	total: 5.12s	remaining: 8m 27s
10:	learn: -3.4514533	total: 5.64s	remaining: 8m 27s
11:	learn: -3.4432281	total: 6.19s	remaining: 8m 30s
12:	learn: -3.4356214	total: 6.74s	remaining: 8m 31s
13:	learn: -3.4311276	total: 7.22s	remaining: 8m 28s
14:	learn: -3.4247456	total: 7.74s	remaining: 8m 28s
15:	learn: -3.4191106	total: 8.28s	remaining: 8m 29s
16:	learn: -3.4142436	total: 8.83s	remaining: 8m 30s
17:	learn: -3.4086680	total: 9.39s	remaining: 8m 32s
18:	learn: -3.4012593	total: 9.91s	remaining: 8m 31s
19:	

154:	learn: -1.7804628	total: 1m 32s	remaining: 8m 22s
155:	learn: -1.7764773	total: 1m 32s	remaining: 8m 21s
156:	learn: -1.7603477	total: 1m 33s	remaining: 8m 21s
157:	learn: -1.7505469	total: 1m 34s	remaining: 8m 21s
158:	learn: -1.7327060	total: 1m 34s	remaining: 8m 21s
159:	learn: -1.7219463	total: 1m 35s	remaining: 8m 20s
160:	learn: -1.7066165	total: 1m 36s	remaining: 8m 20s
161:	learn: -1.6875478	total: 1m 36s	remaining: 8m 20s
162:	learn: -1.6691663	total: 1m 37s	remaining: 8m 20s
163:	learn: -1.6521919	total: 1m 38s	remaining: 8m 19s
164:	learn: -1.6335714	total: 1m 38s	remaining: 8m 19s
165:	learn: -1.6174911	total: 1m 39s	remaining: 8m 19s
166:	learn: -1.6123102	total: 1m 40s	remaining: 8m 18s
167:	learn: -1.5954667	total: 1m 40s	remaining: 8m 18s
168:	learn: -1.5873587	total: 1m 41s	remaining: 8m 18s
169:	learn: -1.5711868	total: 1m 42s	remaining: 8m 18s
170:	learn: -1.5685485	total: 1m 42s	remaining: 8m 17s
171:	learn: -1.5545367	total: 1m 43s	remaining: 8m 17s
172:	learn

305:	learn: -0.6292977	total: 3m 10s	remaining: 7m 11s
306:	learn: -0.6281025	total: 3m 11s	remaining: 7m 11s
307:	learn: -0.6232932	total: 3m 11s	remaining: 7m 10s
308:	learn: -0.6190615	total: 3m 12s	remaining: 7m 10s
309:	learn: -0.6170954	total: 3m 13s	remaining: 7m 9s
310:	learn: -0.6109608	total: 3m 13s	remaining: 7m 9s
311:	learn: -0.6063612	total: 3m 14s	remaining: 7m 8s
312:	learn: -0.6033177	total: 3m 14s	remaining: 7m 7s
313:	learn: -0.6017907	total: 3m 15s	remaining: 7m 7s
314:	learn: -0.6010138	total: 3m 16s	remaining: 7m 6s
315:	learn: -0.5981977	total: 3m 16s	remaining: 7m 5s
316:	learn: -0.5930614	total: 3m 17s	remaining: 7m 5s
317:	learn: -0.5884566	total: 3m 18s	remaining: 7m 4s
318:	learn: -0.5845358	total: 3m 18s	remaining: 7m 4s
319:	learn: -0.5836945	total: 3m 19s	remaining: 7m 3s
320:	learn: -0.5808256	total: 3m 20s	remaining: 7m 3s
321:	learn: -0.5773446	total: 3m 20s	remaining: 7m 2s
322:	learn: -0.5720339	total: 3m 21s	remaining: 7m 2s
323:	learn: -0.5709564	t

456:	learn: -0.3657395	total: 4m 45s	remaining: 5m 38s
457:	learn: -0.3638416	total: 4m 45s	remaining: 5m 38s
458:	learn: -0.3620817	total: 4m 46s	remaining: 5m 37s
459:	learn: -0.3617335	total: 4m 47s	remaining: 5m 37s
460:	learn: -0.3595645	total: 4m 47s	remaining: 5m 36s
461:	learn: -0.3590113	total: 4m 48s	remaining: 5m 35s
462:	learn: -0.3579370	total: 4m 48s	remaining: 5m 35s
463:	learn: -0.3575242	total: 4m 49s	remaining: 5m 34s
464:	learn: -0.3553050	total: 4m 50s	remaining: 5m 34s
465:	learn: -0.3552095	total: 4m 50s	remaining: 5m 33s
466:	learn: -0.3537277	total: 4m 51s	remaining: 5m 32s
467:	learn: -0.3535971	total: 4m 52s	remaining: 5m 32s
468:	learn: -0.3520101	total: 4m 52s	remaining: 5m 31s
469:	learn: -0.3505654	total: 4m 53s	remaining: 5m 30s
470:	learn: -0.3497477	total: 4m 53s	remaining: 5m 30s
471:	learn: -0.3488822	total: 4m 54s	remaining: 5m 29s
472:	learn: -0.3474848	total: 4m 55s	remaining: 5m 28s
473:	learn: -0.3472149	total: 4m 55s	remaining: 5m 28s
474:	learn

607:	learn: -0.2541069	total: 6m 17s	remaining: 4m 3s
608:	learn: -0.2527547	total: 6m 18s	remaining: 4m 3s
609:	learn: -0.2524732	total: 6m 19s	remaining: 4m 2s
610:	learn: -0.2508993	total: 6m 19s	remaining: 4m 1s
611:	learn: -0.2507620	total: 6m 20s	remaining: 4m 1s
612:	learn: -0.2501389	total: 6m 21s	remaining: 4m
613:	learn: -0.2488759	total: 6m 21s	remaining: 3m 59s
614:	learn: -0.2487190	total: 6m 22s	remaining: 3m 59s
615:	learn: -0.2477705	total: 6m 22s	remaining: 3m 58s
616:	learn: -0.2463252	total: 6m 23s	remaining: 3m 58s
617:	learn: -0.2461922	total: 6m 24s	remaining: 3m 57s
618:	learn: -0.2445940	total: 6m 24s	remaining: 3m 56s
619:	learn: -0.2445087	total: 6m 25s	remaining: 3m 56s
620:	learn: -0.2443048	total: 6m 26s	remaining: 3m 55s
621:	learn: -0.2440567	total: 6m 26s	remaining: 3m 55s
622:	learn: -0.2433236	total: 6m 27s	remaining: 3m 54s
623:	learn: -0.2430443	total: 6m 27s	remaining: 3m 53s
624:	learn: -0.2414248	total: 6m 28s	remaining: 3m 53s
625:	learn: -0.2403

757:	learn: -0.1740032	total: 7m 50s	remaining: 2m 30s
758:	learn: -0.1732044	total: 7m 51s	remaining: 2m 29s
759:	learn: -0.1731610	total: 7m 51s	remaining: 2m 29s
760:	learn: -0.1724896	total: 7m 52s	remaining: 2m 28s
761:	learn: -0.1724770	total: 7m 53s	remaining: 2m 27s
762:	learn: -0.1723272	total: 7m 53s	remaining: 2m 27s
763:	learn: -0.1722975	total: 7m 54s	remaining: 2m 26s
764:	learn: -0.1721967	total: 7m 54s	remaining: 2m 25s
765:	learn: -0.1721565	total: 7m 55s	remaining: 2m 25s
766:	learn: -0.1720221	total: 7m 56s	remaining: 2m 24s
767:	learn: -0.1714353	total: 7m 56s	remaining: 2m 23s
768:	learn: -0.1709148	total: 7m 57s	remaining: 2m 23s
769:	learn: -0.1707609	total: 7m 57s	remaining: 2m 22s
770:	learn: -0.1698286	total: 7m 58s	remaining: 2m 22s
771:	learn: -0.1692885	total: 7m 59s	remaining: 2m 21s
772:	learn: -0.1689912	total: 7m 59s	remaining: 2m 20s
773:	learn: -0.1685947	total: 8m	remaining: 2m 20s
774:	learn: -0.1680380	total: 8m 1s	remaining: 2m 19s
775:	learn: -0.

908:	learn: -0.1288188	total: 9m 23s	remaining: 56.4s
909:	learn: -0.1283816	total: 9m 23s	remaining: 55.8s
910:	learn: -0.1283389	total: 9m 24s	remaining: 55.1s
911:	learn: -0.1278716	total: 9m 25s	remaining: 54.5s
912:	learn: -0.1274263	total: 9m 25s	remaining: 53.9s
913:	learn: -0.1273584	total: 9m 26s	remaining: 53.3s
914:	learn: -0.1268039	total: 9m 27s	remaining: 52.7s
915:	learn: -0.1262990	total: 9m 27s	remaining: 52.1s
916:	learn: -0.1261914	total: 9m 28s	remaining: 51.4s
917:	learn: -0.1261762	total: 9m 28s	remaining: 50.8s
918:	learn: -0.1258576	total: 9m 29s	remaining: 50.2s
919:	learn: -0.1255897	total: 9m 30s	remaining: 49.6s
920:	learn: -0.1251079	total: 9m 30s	remaining: 49s
921:	learn: -0.1248392	total: 9m 31s	remaining: 48.3s
922:	learn: -0.1247724	total: 9m 31s	remaining: 47.7s
923:	learn: -0.1247009	total: 9m 32s	remaining: 47.1s
924:	learn: -0.1246865	total: 9m 33s	remaining: 46.5s
925:	learn: -0.1242960	total: 9m 33s	remaining: 45.8s
926:	learn: -0.1242695	total: 

<catboost.core.CatBoostClassifier at 0xdb0efdf940>

In [10]:
preds_class = model.predict(eval_dataset)
preds_class = preds_class.reshape((preds_class.shape[0],))
(preds_class == y_val).mean()

0.8741721854304636

## NN training

In [11]:
model = keras.Sequential([
    keras.layers.Dense(256, input_shape=(16000,)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(35, activation=tf.nn.softmax)
])

Instructions for updating:
Colocations handled automatically by placer.


In [12]:
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [20]:
model.fit(X_train, librosa.note_to_midi(y_train)-49, epochs=15, verbose=2)

Epoch 1/15
 - 6s - loss: 2.6982 - acc: 0.4711
Epoch 2/15
 - 5s - loss: 1.7154 - acc: 0.7467
Epoch 3/15
 - 5s - loss: 1.3113 - acc: 0.8272
Epoch 4/15
 - 5s - loss: 1.0812 - acc: 0.8776
Epoch 5/15
 - 5s - loss: 0.9250 - acc: 0.9064
Epoch 6/15
 - 5s - loss: 0.8126 - acc: 0.9236
Epoch 7/15
 - 5s - loss: 0.7292 - acc: 0.9381
Epoch 8/15
 - 5s - loss: 0.6229 - acc: 0.9507
Epoch 9/15
 - 5s - loss: 0.5848 - acc: 0.9553
Epoch 10/15
 - 5s - loss: 0.5774 - acc: 0.9546
Epoch 11/15
 - 5s - loss: 0.5045 - acc: 0.9550
Epoch 12/15
 - 5s - loss: 0.5178 - acc: 0.9574
Epoch 13/15
 - 5s - loss: 0.4769 - acc: 0.9631
Epoch 14/15
 - 5s - loss: 0.5233 - acc: 0.9599
Epoch 15/15
 - 5s - loss: 0.5420 - acc: 0.9564


<keras.callbacks.History at 0xdb285e6630>

In [23]:
val_loss, val_acc = model.evaluate(X_val, librosa.note_to_midi(y_val)-49)



In [24]:
print(val_acc)

0.8410596014648084


In [25]:
NN_predictions = np.argmax(model.predict(X_val), axis=1)

In [None]:
NN_predictions = librosa.midi_to_note(NN_predictions+49)

## CQT

In [34]:
def cqt_pred(data, sr):
    preds = []
    for row in tqdm(data[["x{}".format(i) for i in range(16000)]].values):
        cqt = librosa.core.cqt(row, sr, n_bins=80)
        preds.append(librosa.amplitude_to_db(cqt, ref=np.max).mean(axis=1).argmax()+24)
    return preds

In [27]:
cqt_predictions = cqt_pred(X_val, 8000)

HBox(children=(IntProgress(value=0, max=302), HTML(value='')))






In [None]:
cqt_predictions = librosa.midi_to_note(cqt_predictions)

In [None]:
(cqt_predictions == y_val).mean()

## Ensemble

In [57]:
predictions = pd.DataFrame({"CQT": cqt_predictions, "NN": NN_predictions, "Catboost": preds_class})

In [102]:
predictions.sample()

Unnamed: 0,CQT,NN,Catboost
128,A#5,A#5,A#5


In [148]:
final = []
for index, row in predictions.drop("y_val", axis=1).iterrows():
    votes = {x: 0 for x in row.unique()}
    chromas = np.array([x[:-1] for x in row])
    for option in row.unique():
        if option in row["CQT"]:
            votes[option] += 0.7
        if option in row["NN"]:
            votes[option] += 1
        if option in row["Catboost"]:
            votes[option] += 1
        if (option == chromas).sum() == 1:
            votes[option] -= 1
    final.append(max(votes, key=lambda key: votes[key]))

In [149]:
("A" == np.array(["A", "B", "A"])).sum()

2

In [150]:
final = np.array(final)

In [151]:
(final == y_val).mean()

0.8543046357615894

In [137]:
index_misclass = final != y_val
index_octshift = abs(librosa.note_to_midi(final) - librosa.note_to_midi(y_val)) == 12
print(round(len(final[index_octshift])/len(final[index_misclass])*100), 
          "% of the misclassified examples are actually octave misclassifications")

80 % of the misclassified examples are actually octave misclassifications


In [120]:
predictions["y_val"] = y_val.reset_index(drop=True)

In [121]:
predictions

Unnamed: 0,CQT,NN,Catboost,y_val
0,C4,C4,C4,C4
1,G#3,G#4,G#4,G#4
2,A#4,A#4,A#4,A#4
3,A#4,A#4,A#4,A#4
4,A#1,G#5,G#5,G#5
5,D#5,D#5,D#5,D#5
6,C5,C5,C5,C5
7,D2,A5,D3,D3
8,A#1,G#4,G#5,G#4
9,D#4,D#4,D#4,D#4
