In [1]:
from catboost import CatBoostClassifier
from catboost import Pool
import librosa
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm
%matplotlib notebook

In [2]:
data = pd.read_pickle("../Data/Piano/data-piano-xl.pkl")

In [3]:
def create_xy(df, target_column):
    return df.drop(target_column, axis=1), df[target_column]

def split_data(df, target_column):
    X_train, y_train = create_xy(df.sample(round(0.8*df.shape[0])), target_column)
    df = df.drop(X_train.index)
    X_val, y_val = create_xy(df.sample(round(0.5*df.shape[0])), target_column)
    df = df.drop(X_val.index)
    X_test, y_test = create_xy(df, target_column)
    return X_train, y_train, X_val, y_val, X_test, y_test

In [13]:
X_train, y_train, X_val, y_val, X_test, y_test = split_data(data.drop("Piano", axis=1), "Note")

In [14]:
train_dataset = Pool(data=X_train,
                     label=y_train)

In [15]:
eval_dataset = Pool(data=X_val,
                    label=y_val)

In [22]:
model = CatBoostClassifier(iterations=1000,
                           learning_rate=0.05,
                           depth=4,
                           loss_function='MultiClass', 
                           task_type="GPU")

In [23]:
model.fit(train_dataset)

0:	learn: -3.5456634	total: 550ms	remaining: 9m 9s
1:	learn: -3.5364383	total: 1.04s	remaining: 8m 37s
2:	learn: -3.5292859	total: 1.55s	remaining: 8m 35s
3:	learn: -3.5215194	total: 2.06s	remaining: 8m 34s
4:	learn: -3.5137966	total: 2.58s	remaining: 8m 32s
5:	learn: -3.5068514	total: 3.06s	remaining: 8m 27s
6:	learn: -3.4997749	total: 3.54s	remaining: 8m 22s
7:	learn: -3.4933939	total: 4s	remaining: 8m 15s
8:	learn: -3.4845176	total: 4.58s	remaining: 8m 23s
9:	learn: -3.4788070	total: 5.08s	remaining: 8m 23s
10:	learn: -3.4727349	total: 5.57s	remaining: 8m 20s
11:	learn: -3.4681153	total: 6.08s	remaining: 8m 21s
12:	learn: -3.4582023	total: 6.63s	remaining: 8m 23s
13:	learn: -3.4506276	total: 7.19s	remaining: 8m 26s
14:	learn: -3.4449345	total: 7.71s	remaining: 8m 26s
15:	learn: -3.4388580	total: 8.26s	remaining: 8m 27s
16:	learn: -3.4346930	total: 8.74s	remaining: 8m 25s
17:	learn: -3.4304380	total: 9.2s	remaining: 8m 21s
18:	learn: -3.4250342	total: 9.71s	remaining: 8m 21s
19:	lear

154:	learn: -1.8397333	total: 1m 33s	remaining: 8m 30s
155:	learn: -1.8182460	total: 1m 34s	remaining: 8m 30s
156:	learn: -1.7975579	total: 1m 34s	remaining: 8m 30s
157:	learn: -1.7781468	total: 1m 35s	remaining: 8m 29s
158:	learn: -1.7663158	total: 1m 36s	remaining: 8m 29s
159:	learn: -1.7579821	total: 1m 36s	remaining: 8m 28s
160:	learn: -1.7492987	total: 1m 37s	remaining: 8m 28s
161:	learn: -1.7371277	total: 1m 38s	remaining: 8m 28s
162:	learn: -1.7265446	total: 1m 38s	remaining: 8m 28s
163:	learn: -1.7153571	total: 1m 39s	remaining: 8m 27s
164:	learn: -1.6979240	total: 1m 40s	remaining: 8m 27s
165:	learn: -1.6802737	total: 1m 40s	remaining: 8m 27s
166:	learn: -1.6673212	total: 1m 41s	remaining: 8m 26s
167:	learn: -1.6533085	total: 1m 42s	remaining: 8m 26s
168:	learn: -1.6367397	total: 1m 43s	remaining: 8m 26s
169:	learn: -1.6308774	total: 1m 43s	remaining: 8m 26s
170:	learn: -1.6178242	total: 1m 44s	remaining: 8m 25s
171:	learn: -1.6083545	total: 1m 45s	remaining: 8m 25s
172:	learn

305:	learn: -0.6272521	total: 3m 15s	remaining: 7m 22s
306:	learn: -0.6224585	total: 3m 15s	remaining: 7m 21s
307:	learn: -0.6175700	total: 3m 16s	remaining: 7m 21s
308:	learn: -0.6128550	total: 3m 17s	remaining: 7m 20s
309:	learn: -0.6089168	total: 3m 17s	remaining: 7m 20s
310:	learn: -0.6034818	total: 3m 18s	remaining: 7m 19s
311:	learn: -0.6009573	total: 3m 19s	remaining: 7m 19s
312:	learn: -0.5998525	total: 3m 19s	remaining: 7m 18s
313:	learn: -0.5995944	total: 3m 20s	remaining: 7m 17s
314:	learn: -0.5946884	total: 3m 21s	remaining: 7m 17s
315:	learn: -0.5940116	total: 3m 21s	remaining: 7m 16s
316:	learn: -0.5938940	total: 3m 22s	remaining: 7m 15s
317:	learn: -0.5893530	total: 3m 22s	remaining: 7m 15s
318:	learn: -0.5891022	total: 3m 23s	remaining: 7m 14s
319:	learn: -0.5850193	total: 3m 24s	remaining: 7m 13s
320:	learn: -0.5837106	total: 3m 24s	remaining: 7m 13s
321:	learn: -0.5835365	total: 3m 25s	remaining: 7m 12s
322:	learn: -0.5832787	total: 3m 26s	remaining: 7m 12s
323:	learn

456:	learn: -0.3659520	total: 4m 51s	remaining: 5m 46s
457:	learn: -0.3639346	total: 4m 52s	remaining: 5m 45s
458:	learn: -0.3634500	total: 4m 52s	remaining: 5m 44s
459:	learn: -0.3612483	total: 4m 53s	remaining: 5m 44s
460:	learn: -0.3588925	total: 4m 54s	remaining: 5m 43s
461:	learn: -0.3588236	total: 4m 54s	remaining: 5m 43s
462:	learn: -0.3584945	total: 4m 55s	remaining: 5m 42s
463:	learn: -0.3584106	total: 4m 55s	remaining: 5m 41s
464:	learn: -0.3578879	total: 4m 56s	remaining: 5m 40s
465:	learn: -0.3577779	total: 4m 56s	remaining: 5m 40s
466:	learn: -0.3574080	total: 4m 57s	remaining: 5m 39s
467:	learn: -0.3550822	total: 4m 58s	remaining: 5m 38s
468:	learn: -0.3534880	total: 4m 58s	remaining: 5m 38s
469:	learn: -0.3531254	total: 4m 59s	remaining: 5m 37s
470:	learn: -0.3513135	total: 5m	remaining: 5m 37s
471:	learn: -0.3511682	total: 5m	remaining: 5m 36s
472:	learn: -0.3491792	total: 5m 1s	remaining: 5m 35s
473:	learn: -0.3469431	total: 5m 2s	remaining: 5m 35s
474:	learn: -0.34664

607:	learn: -0.2515014	total: 6m 25s	remaining: 4m 8s
608:	learn: -0.2499264	total: 6m 25s	remaining: 4m 7s
609:	learn: -0.2498742	total: 6m 26s	remaining: 4m 7s
610:	learn: -0.2497247	total: 6m 27s	remaining: 4m 6s
611:	learn: -0.2494377	total: 6m 27s	remaining: 4m 5s
612:	learn: -0.2485074	total: 6m 28s	remaining: 4m 5s
613:	learn: -0.2483535	total: 6m 29s	remaining: 4m 4s
614:	learn: -0.2482759	total: 6m 29s	remaining: 4m 3s
615:	learn: -0.2478273	total: 6m 30s	remaining: 4m 3s
616:	learn: -0.2473604	total: 6m 30s	remaining: 4m 2s
617:	learn: -0.2463655	total: 6m 31s	remaining: 4m 2s
618:	learn: -0.2462702	total: 6m 32s	remaining: 4m 1s
619:	learn: -0.2450037	total: 6m 32s	remaining: 4m
620:	learn: -0.2447392	total: 6m 33s	remaining: 4m
621:	learn: -0.2445870	total: 6m 34s	remaining: 3m 59s
622:	learn: -0.2443488	total: 6m 34s	remaining: 3m 58s
623:	learn: -0.2431913	total: 6m 35s	remaining: 3m 58s
624:	learn: -0.2417730	total: 6m 36s	remaining: 3m 57s
625:	learn: -0.2409962	total: 

758:	learn: -0.1787209	total: 7m 58s	remaining: 2m 32s
759:	learn: -0.1779175	total: 7m 59s	remaining: 2m 31s
760:	learn: -0.1778289	total: 8m	remaining: 2m 30s
761:	learn: -0.1777835	total: 8m	remaining: 2m 30s
762:	learn: -0.1776400	total: 8m 1s	remaining: 2m 29s
763:	learn: -0.1776261	total: 8m 1s	remaining: 2m 28s
764:	learn: -0.1774962	total: 8m 2s	remaining: 2m 28s
765:	learn: -0.1767036	total: 8m 3s	remaining: 2m 27s
766:	learn: -0.1759856	total: 8m 3s	remaining: 2m 27s
767:	learn: -0.1758803	total: 8m 4s	remaining: 2m 26s
768:	learn: -0.1758611	total: 8m 5s	remaining: 2m 25s
769:	learn: -0.1757948	total: 8m 5s	remaining: 2m 25s
770:	learn: -0.1750797	total: 8m 6s	remaining: 2m 24s
771:	learn: -0.1749439	total: 8m 6s	remaining: 2m 23s
772:	learn: -0.1749130	total: 8m 7s	remaining: 2m 23s
773:	learn: -0.1743987	total: 8m 7s	remaining: 2m 22s
774:	learn: -0.1736110	total: 8m 8s	remaining: 2m 21s
775:	learn: -0.1727393	total: 8m 9s	remaining: 2m 21s
776:	learn: -0.1720787	total: 8m

909:	learn: -0.1288353	total: 9m 32s	remaining: 56.7s
910:	learn: -0.1285164	total: 9m 33s	remaining: 56s
911:	learn: -0.1285007	total: 9m 34s	remaining: 55.4s
912:	learn: -0.1280261	total: 9m 34s	remaining: 54.8s
913:	learn: -0.1276331	total: 9m 35s	remaining: 54.1s
914:	learn: -0.1276109	total: 9m 36s	remaining: 53.5s
915:	learn: -0.1272980	total: 9m 36s	remaining: 52.9s
916:	learn: -0.1272557	total: 9m 37s	remaining: 52.2s
917:	learn: -0.1267244	total: 9m 37s	remaining: 51.6s
918:	learn: -0.1266318	total: 9m 38s	remaining: 51s
919:	learn: -0.1262717	total: 9m 39s	remaining: 50.4s
920:	learn: -0.1258664	total: 9m 39s	remaining: 49.7s
921:	learn: -0.1257844	total: 9m 40s	remaining: 49.1s
922:	learn: -0.1256353	total: 9m 41s	remaining: 48.5s
923:	learn: -0.1251455	total: 9m 41s	remaining: 47.8s
924:	learn: -0.1249721	total: 9m 42s	remaining: 47.2s
925:	learn: -0.1248718	total: 9m 42s	remaining: 46.6s
926:	learn: -0.1247528	total: 9m 43s	remaining: 46s
927:	learn: -0.1246866	total: 9m 4

<catboost.core.CatBoostClassifier at 0x5c90071240>

In [24]:
preds_class = model.predict(eval_dataset)

In [25]:
preds_class = preds_class.reshape((preds_class.shape[0],))

In [26]:
for i,j in zip(preds_class, y_val):
    print(i,j)

F#5 F#5
D4 D4
G#4 G#4
D4 D5
B3 B3
G#3 G#3
C#4 C#4
G#3 G#3
F5 F5
D4 D4
C#3 C#3
D#3 D#3
E3 E3
E4 E4
B4 B4
A4 A4
B4 B4
A4 A4
G#4 G#4
F#3 F#4
A5 A5
A3 A3
B4 B4
F#3 F#4
F#4 F#4
B4 B4
C5 C5
F5 F5
B5 B5
G#3 G#3
C4 C4
G5 G5
D3 D3
G#5 G#5
F5 F5
D#5 D#5
A5 A5
G#5 G#5
G4 G4
C#3 C#3
C4 C4
G#4 G#3
E4 E4
F#3 F#3
F5 F5
A#4 A#4
D#4 D#4
E5 E5
B3 B4
D5 D5
A4 A5
E5 E5
C4 C4
A4 A4
D#3 D#3
A3 A3
C5 C5
A4 A4
D#5 D#5
F#5 F#5
F#3 F#3
E3 E3
F5 F5
A#4 A#3
G#4 G#4
B4 B4
C#3 C#3
F#4 F#4
F#5 F#3
A#3 A#3
C4 C4
C4 C4
C#5 C#5
A#5 A#5
A3 A3
G#4 G#3
D3 D3
A5 A5
A#3 A#3
G#3 G#4
G#5 G#5
G4 G3
F4 F4
D#3 D#3
F#5 F#5
D3 D3
A5 A5
F4 F4
A4 A4
F#4 F#4
C5 C5
C4 C4
G4 G4
A#4 A#3
D5 D5
B5 B5
C#4 C#4
A4 A4
C4 C4
A4 A4
F5 F4
A4 A4
G3 G3
F5 F5
F3 F3
G#4 G#4
E5 E5
A4 A4
B3 B3
D#5 D#5
D#4 D#3
F3 F3
G3 G3
C5 C5
F4 G#4
G#5 G#5
B5 B5
A#4 A#4
D#4 D#4
D#4 G#3
D#5 D#5
D4 D4
F5 F5
F#4 F#3
E5 E5
G3 G3
C4 C4
A5 A5
G#3 G#3
E5 E5
G5 G5
A#3 A#3
D#3 D#3
C4 C4
G4 G5
F#3 F#3
F#3 F#3
A3 A3
D#3 D#3
F5 F5
E3 E3
A5 A5
E4 E4
F#4 F#4
F#5 F#5
A4 A4
C#5 C#5

In [27]:
(preds_class == y_val).mean()

0.8878205128205128

In [30]:
index_misclass = preds_class != y_val
index_octshift = abs(librosa.note_to_midi(preds_class) - librosa.note_to_midi(y_val)) == 12
print(round(len(preds_class[index_octshift])/len(preds_class[index_misclass])*100), 
          "% of the misclassified examples are actually octave misclassifications")

83 % of the misclassified examples are actually octave misclassifications


In [33]:
preds_class

array(['F#5', 'D4', 'G#4', 'D4', 'B3', 'G#3', 'C#4', 'G#3', 'F5', 'D4',
       'C#3', 'D#3', 'E3', 'E4', 'B4', 'A4', 'B4', 'A4', 'G#4', 'F#3',
       'A5', 'A3', 'B4', 'F#3', 'F#4', 'B4', 'C5', 'F5', 'B5', 'G#3',
       'C4', 'G5', 'D3', 'G#5', 'F5', 'D#5', 'A5', 'G#5', 'G4', 'C#3',
       'C4', 'G#4', 'E4', 'F#3', 'F5', 'A#4', 'D#4', 'E5', 'B3', 'D5',
       'A4', 'E5', 'C4', 'A4', 'D#3', 'A3', 'C5', 'A4', 'D#5', 'F#5',
       'F#3', 'E3', 'F5', 'A#4', 'G#4', 'B4', 'C#3', 'F#4', 'F#5', 'A#3',
       'C4', 'C4', 'C#5', 'A#5', 'A3', 'G#4', 'D3', 'A5', 'A#3', 'G#3',
       'G#5', 'G4', 'F4', 'D#3', 'F#5', 'D3', 'A5', 'F4', 'A4', 'F#4',
       'C5', 'C4', 'G4', 'A#4', 'D5', 'B5', 'C#4', 'A4', 'C4', 'A4', 'F5',
       'A4', 'G3', 'F5', 'F3', 'G#4', 'E5', 'A4', 'B3', 'D#5', 'D#4',
       'F3', 'G3', 'C5', 'F4', 'G#5', 'B5', 'A#4', 'D#4', 'D#4', 'D#5',
       'D4', 'F5', 'F#4', 'E5', 'G3', 'C4', 'A5', 'G#3', 'E5', 'G5',
       'A#3', 'D#3', 'C4', 'G4', 'F#3', 'F#3', 'A3', 'D#3', 'F5', 'E3',
