In [1]:
import librosa
import pandas as pd
import numpy as np

In [2]:
import lightgbm as lgb
from sklearn.metrics import mean_absolute_error

# Tempo prediction with LightGBM

Due to problems implementing CatBoost and XGBoost in the final app, I have decided to switch to LightGBM

In [3]:
%run ../Utilities/Utilities.ipynb

Using TensorFlow backend.


In [4]:
data = pd.read_pickle("../Data/Guitar/bpm-data.pkl")

In [5]:
X_train, y_train, X_val, y_val, X_test, y_test = split_data(data, "bpm")

In [6]:
lgb_train = lgb.Dataset(X_train, y_train)

In [7]:
lgb_eval = lgb.Dataset(X_val, y_val, reference=lgb_train)

In [8]:
params = {
    'boosting_type': 'gbdt',
    'objective': 'regression',
    'metric': 'mae',
    'max_depth': 10,
    'learning_rate': 0.01,
    'verbose': 0
}

In [9]:
gbm = lgb.train(params,
                lgb_train,
                num_boost_round=3000,
                valid_sets=lgb_eval,
                early_stopping_rounds=5)

[1]	valid_0's l1: 26.3828
Training until validation scores don't improve for 5 rounds.
[2]	valid_0's l1: 26.2356
[3]	valid_0's l1: 26.0893
[4]	valid_0's l1: 25.9432
[5]	valid_0's l1: 25.7988
[6]	valid_0's l1: 25.6531
[7]	valid_0's l1: 25.5122
[8]	valid_0's l1: 25.3705
[9]	valid_0's l1: 25.2328
[10]	valid_0's l1: 25.0952
[11]	valid_0's l1: 24.9597
[12]	valid_0's l1: 24.8242
[13]	valid_0's l1: 24.6902
[14]	valid_0's l1: 24.5598
[15]	valid_0's l1: 24.4277
[16]	valid_0's l1: 24.2982
[17]	valid_0's l1: 24.1715
[18]	valid_0's l1: 24.0451
[19]	valid_0's l1: 23.9223
[20]	valid_0's l1: 23.7969
[21]	valid_0's l1: 23.6732
[22]	valid_0's l1: 23.5536
[23]	valid_0's l1: 23.4325
[24]	valid_0's l1: 23.3137
[25]	valid_0's l1: 23.1953
[26]	valid_0's l1: 23.0771
[27]	valid_0's l1: 22.9589
[28]	valid_0's l1: 22.8469
[29]	valid_0's l1: 22.7334
[30]	valid_0's l1: 22.6219
[31]	valid_0's l1: 22.5094
[32]	valid_0's l1: 22.3985
[33]	valid_0's l1: 22.2882
[34]	valid_0's l1: 22.1777
[35]	valid_0's l1: 22.0695
[36

[296]	valid_0's l1: 11.4117
[297]	valid_0's l1: 11.4018
[298]	valid_0's l1: 11.3918
[299]	valid_0's l1: 11.3813
[300]	valid_0's l1: 11.3712
[301]	valid_0's l1: 11.3606
[302]	valid_0's l1: 11.3498
[303]	valid_0's l1: 11.3401
[304]	valid_0's l1: 11.3303
[305]	valid_0's l1: 11.3203
[306]	valid_0's l1: 11.3112
[307]	valid_0's l1: 11.3014
[308]	valid_0's l1: 11.2915
[309]	valid_0's l1: 11.2816
[310]	valid_0's l1: 11.2716
[311]	valid_0's l1: 11.2629
[312]	valid_0's l1: 11.2539
[313]	valid_0's l1: 11.244
[314]	valid_0's l1: 11.2343
[315]	valid_0's l1: 11.2244
[316]	valid_0's l1: 11.2155
[317]	valid_0's l1: 11.206
[318]	valid_0's l1: 11.1969
[319]	valid_0's l1: 11.1887
[320]	valid_0's l1: 11.1797
[321]	valid_0's l1: 11.171
[322]	valid_0's l1: 11.1621
[323]	valid_0's l1: 11.1529
[324]	valid_0's l1: 11.144
[325]	valid_0's l1: 11.1345
[326]	valid_0's l1: 11.1258
[327]	valid_0's l1: 11.1177
[328]	valid_0's l1: 11.11
[329]	valid_0's l1: 11.1013
[330]	valid_0's l1: 11.0932
[331]	valid_0's l1: 11.085

[589]	valid_0's l1: 10.1498
[590]	valid_0's l1: 10.1469
[591]	valid_0's l1: 10.1447
[592]	valid_0's l1: 10.1417
[593]	valid_0's l1: 10.14
[594]	valid_0's l1: 10.1375
[595]	valid_0's l1: 10.1357
[596]	valid_0's l1: 10.1341
[597]	valid_0's l1: 10.1314
[598]	valid_0's l1: 10.1281
[599]	valid_0's l1: 10.1264
[600]	valid_0's l1: 10.1241
[601]	valid_0's l1: 10.1216
[602]	valid_0's l1: 10.1189
[603]	valid_0's l1: 10.1172
[604]	valid_0's l1: 10.1142
[605]	valid_0's l1: 10.1119
[606]	valid_0's l1: 10.1091
[607]	valid_0's l1: 10.1056
[608]	valid_0's l1: 10.1032
[609]	valid_0's l1: 10.0999
[610]	valid_0's l1: 10.0977
[611]	valid_0's l1: 10.0962
[612]	valid_0's l1: 10.0935
[613]	valid_0's l1: 10.0909
[614]	valid_0's l1: 10.0878
[615]	valid_0's l1: 10.085
[616]	valid_0's l1: 10.0826
[617]	valid_0's l1: 10.0812
[618]	valid_0's l1: 10.0785
[619]	valid_0's l1: 10.0759
[620]	valid_0's l1: 10.0735
[621]	valid_0's l1: 10.0705
[622]	valid_0's l1: 10.0678
[623]	valid_0's l1: 10.0653
[624]	valid_0's l1: 10.

[882]	valid_0's l1: 9.62272
[883]	valid_0's l1: 9.621
[884]	valid_0's l1: 9.62008
[885]	valid_0's l1: 9.61912
[886]	valid_0's l1: 9.61865
[887]	valid_0's l1: 9.61766
[888]	valid_0's l1: 9.61606
[889]	valid_0's l1: 9.61504
[890]	valid_0's l1: 9.61457
[891]	valid_0's l1: 9.6131
[892]	valid_0's l1: 9.61251
[893]	valid_0's l1: 9.61208
[894]	valid_0's l1: 9.61125
[895]	valid_0's l1: 9.61081
[896]	valid_0's l1: 9.60928
[897]	valid_0's l1: 9.60795
[898]	valid_0's l1: 9.60636
[899]	valid_0's l1: 9.60547
[900]	valid_0's l1: 9.60398
[901]	valid_0's l1: 9.60269
[902]	valid_0's l1: 9.60133
[903]	valid_0's l1: 9.6001
[904]	valid_0's l1: 9.59854
[905]	valid_0's l1: 9.59735
[906]	valid_0's l1: 9.59671
[907]	valid_0's l1: 9.59578
[908]	valid_0's l1: 9.59454
[909]	valid_0's l1: 9.59366
[910]	valid_0's l1: 9.59283
[911]	valid_0's l1: 9.5913
[912]	valid_0's l1: 9.58979
[913]	valid_0's l1: 9.58939
[914]	valid_0's l1: 9.58809
[915]	valid_0's l1: 9.58724
[916]	valid_0's l1: 9.58684
[917]	valid_0's l1: 9.586

[1170]	valid_0's l1: 9.38665
[1171]	valid_0's l1: 9.38552
[1172]	valid_0's l1: 9.38493
[1173]	valid_0's l1: 9.3839
[1174]	valid_0's l1: 9.38334
[1175]	valid_0's l1: 9.38276
[1176]	valid_0's l1: 9.38174
[1177]	valid_0's l1: 9.38104
[1178]	valid_0's l1: 9.38082
[1179]	valid_0's l1: 9.38035
[1180]	valid_0's l1: 9.37936
[1181]	valid_0's l1: 9.37918
[1182]	valid_0's l1: 9.37817
[1183]	valid_0's l1: 9.37782
[1184]	valid_0's l1: 9.37715
[1185]	valid_0's l1: 9.37644
[1186]	valid_0's l1: 9.37596
[1187]	valid_0's l1: 9.37531
[1188]	valid_0's l1: 9.37436
[1189]	valid_0's l1: 9.37385
[1190]	valid_0's l1: 9.37314
[1191]	valid_0's l1: 9.37266
[1192]	valid_0's l1: 9.37166
[1193]	valid_0's l1: 9.37116
[1194]	valid_0's l1: 9.37049
[1195]	valid_0's l1: 9.36994
[1196]	valid_0's l1: 9.36935
[1197]	valid_0's l1: 9.36873
[1198]	valid_0's l1: 9.36783
[1199]	valid_0's l1: 9.36762
[1200]	valid_0's l1: 9.36728
[1201]	valid_0's l1: 9.36666
[1202]	valid_0's l1: 9.36613
[1203]	valid_0's l1: 9.36557
[1204]	valid_0'

[1453]	valid_0's l1: 9.23365
[1454]	valid_0's l1: 9.23341
[1455]	valid_0's l1: 9.23312
[1456]	valid_0's l1: 9.23257
[1457]	valid_0's l1: 9.23244
[1458]	valid_0's l1: 9.23177
[1459]	valid_0's l1: 9.23165
[1460]	valid_0's l1: 9.23136
[1461]	valid_0's l1: 9.23092
[1462]	valid_0's l1: 9.23038
[1463]	valid_0's l1: 9.23002
[1464]	valid_0's l1: 9.22959
[1465]	valid_0's l1: 9.2285
[1466]	valid_0's l1: 9.22833
[1467]	valid_0's l1: 9.22794
[1468]	valid_0's l1: 9.22745
[1469]	valid_0's l1: 9.22734
[1470]	valid_0's l1: 9.22717
[1471]	valid_0's l1: 9.22687
[1472]	valid_0's l1: 9.22659
[1473]	valid_0's l1: 9.22608
[1474]	valid_0's l1: 9.22576
[1475]	valid_0's l1: 9.22545
[1476]	valid_0's l1: 9.22501
[1477]	valid_0's l1: 9.22456
[1478]	valid_0's l1: 9.22418
[1479]	valid_0's l1: 9.2237
[1480]	valid_0's l1: 9.22336
[1481]	valid_0's l1: 9.22283
[1482]	valid_0's l1: 9.22251
[1483]	valid_0's l1: 9.22218
[1484]	valid_0's l1: 9.22178
[1485]	valid_0's l1: 9.22123
[1486]	valid_0's l1: 9.22107
[1487]	valid_0's

[1736]	valid_0's l1: 9.13193
[1737]	valid_0's l1: 9.1316
[1738]	valid_0's l1: 9.13139
[1739]	valid_0's l1: 9.13112
[1740]	valid_0's l1: 9.1309
[1741]	valid_0's l1: 9.13039
[1742]	valid_0's l1: 9.12991
[1743]	valid_0's l1: 9.12956
[1744]	valid_0's l1: 9.1293
[1745]	valid_0's l1: 9.12891
[1746]	valid_0's l1: 9.12869
[1747]	valid_0's l1: 9.12826
[1748]	valid_0's l1: 9.12802
[1749]	valid_0's l1: 9.12749
[1750]	valid_0's l1: 9.12717
[1751]	valid_0's l1: 9.12695
[1752]	valid_0's l1: 9.12649
[1753]	valid_0's l1: 9.12622
[1754]	valid_0's l1: 9.126
[1755]	valid_0's l1: 9.12573
[1756]	valid_0's l1: 9.12555
[1757]	valid_0's l1: 9.12567
[1758]	valid_0's l1: 9.12539
[1759]	valid_0's l1: 9.12496
[1760]	valid_0's l1: 9.12481
[1761]	valid_0's l1: 9.12449
[1762]	valid_0's l1: 9.1242
[1763]	valid_0's l1: 9.12394
[1764]	valid_0's l1: 9.12365
[1765]	valid_0's l1: 9.12365
[1766]	valid_0's l1: 9.12338
[1767]	valid_0's l1: 9.12301
[1768]	valid_0's l1: 9.12257
[1769]	valid_0's l1: 9.12216
[1770]	valid_0's l1:

[2019]	valid_0's l1: 9.06278
[2020]	valid_0's l1: 9.06251
[2021]	valid_0's l1: 9.06227
[2022]	valid_0's l1: 9.06203
[2023]	valid_0's l1: 9.06203
[2024]	valid_0's l1: 9.06165
[2025]	valid_0's l1: 9.06141
[2026]	valid_0's l1: 9.06118
[2027]	valid_0's l1: 9.06097
[2028]	valid_0's l1: 9.06102
[2029]	valid_0's l1: 9.06138
[2030]	valid_0's l1: 9.06119
[2031]	valid_0's l1: 9.06132
[2032]	valid_0's l1: 9.06081
[2033]	valid_0's l1: 9.06041
[2034]	valid_0's l1: 9.06055
[2035]	valid_0's l1: 9.0603
[2036]	valid_0's l1: 9.06013
[2037]	valid_0's l1: 9.05991
[2038]	valid_0's l1: 9.05968
[2039]	valid_0's l1: 9.05946
[2040]	valid_0's l1: 9.059
[2041]	valid_0's l1: 9.05935
[2042]	valid_0's l1: 9.05903
[2043]	valid_0's l1: 9.05855
[2044]	valid_0's l1: 9.05796
[2045]	valid_0's l1: 9.05775
[2046]	valid_0's l1: 9.05807
[2047]	valid_0's l1: 9.05787
[2048]	valid_0's l1: 9.05756
[2049]	valid_0's l1: 9.058
[2050]	valid_0's l1: 9.05752
[2051]	valid_0's l1: 9.05721
[2052]	valid_0's l1: 9.05669
[2053]	valid_0's l1

[2302]	valid_0's l1: 9.01129
[2303]	valid_0's l1: 9.01114
[2304]	valid_0's l1: 9.01118
[2305]	valid_0's l1: 9.01105
[2306]	valid_0's l1: 9.01079
[2307]	valid_0's l1: 9.01069
[2308]	valid_0's l1: 9.01044
[2309]	valid_0's l1: 9.01032
[2310]	valid_0's l1: 9.01009
[2311]	valid_0's l1: 9.00991
[2312]	valid_0's l1: 9.00991
[2313]	valid_0's l1: 9.00948
[2314]	valid_0's l1: 9.00918
[2315]	valid_0's l1: 9.00908
[2316]	valid_0's l1: 9.00921
[2317]	valid_0's l1: 9.00887
[2318]	valid_0's l1: 9.00849
[2319]	valid_0's l1: 9.00837
[2320]	valid_0's l1: 9.00803
[2321]	valid_0's l1: 9.00817
[2322]	valid_0's l1: 9.0083
[2323]	valid_0's l1: 9.00819
[2324]	valid_0's l1: 9.00788
[2325]	valid_0's l1: 9.00769
[2326]	valid_0's l1: 9.00768
[2327]	valid_0's l1: 9.00739
[2328]	valid_0's l1: 9.00728
[2329]	valid_0's l1: 9.00718
[2330]	valid_0's l1: 9.0071
[2331]	valid_0's l1: 9.00678
[2332]	valid_0's l1: 9.00656
[2333]	valid_0's l1: 9.00636
[2334]	valid_0's l1: 9.00615
[2335]	valid_0's l1: 9.006
[2336]	valid_0's l

Checking in case there is overfitting

In [11]:
mean_absolute_error(y_train, gbm.predict(X_train, num_iteration=gbm.best_iteration))

8.284560688510672

In [12]:
mean_absolute_error(y_val, gbm.predict(X_val, num_iteration=gbm.best_iteration))

8.982210734006422

Let's try a real prediction

In [13]:
gbm.predict(pd.DataFrame({
    "d0": [0.475],
    "d1": [0.28125],
    "d2": [0.5125],
    "d3": [0.21875],
    "d4": [0.28125],
    "d5": [0.275],
    "d6": [0.45625],
    "d7": [0.15625]
}), num_iteration=gbm.best_iteration)

array([121.75564984])

Saving and testing the loading process

In [14]:
gbm.save_model('tempo_lgbm.txt')

<lightgbm.basic.Booster at 0x17a6401630>

In [15]:
loaded = lgb.Booster(model_file='tempo_lgbm.txt')

In [16]:
loaded.predict(pd.DataFrame({
    "d0": [0.475],
    "d1": [0.28125],
    "d2": [0.5125],
    "d3": [0.21875],
    "d4": [0.28125],
    "d5": [0.275],
    "d6": [0.45625],
    "d7": [0.15625]
}), num_iteration=gbm.best_iteration)

array([121.75564984])