# Music genre classification based on MIDI



## Brief introduction

The aim of the project is to classify pieces of music to one of the genres. The representation of music we are working with is MIDI, a format which encodes music as a series of chronological events (such as playing a certain note), separated into tracks and channels which correspond to instruments. From these MIDIs we extract features relevant to distinguishing the genre, and we apply classificators which work with that gathered information.

## A bit of details

### Dataset

The dataset we prepared for this project is based on: the [Lakh](https://colinraffel.com/projects/lmd/) dataset which contains MIDIs of songs, the [Million Song Dataset](http://millionsongdataset.com/) which establishes quasiuniversal (meaning respected by several other datasets) IDs of pieces of music (along with providing metadata about them), the [MSD Tagtraum](https://www.tagtraum.com/msd_genre_datasets.html) and the [Acousticbrainz Genre](https://mtg.github.io/acousticbrainz-genre-dataset/) datasets which provide genres for IDs from the Million Song Dataset. These sets were combined so that MIDIs are matched to the ID and the genre of the song each of them encodes.

The actual working (filtered) dataset consists of ~1200 MIDIs, ~170 per each of the 7 genres: Electronic, Metal, RnB, Country, Jazz, Rock, Pop. The proportions train:test:valid is 60:20:20.

### Features

In order to predict the genre of a song it's necessary to extract some features from the MIDI. Some music traits we decided to capture are:
- pitch
- velocity of notes (their "intensity")
- duration of and space between notes
- polyphony (how many notes are being played simultaneously)
- used instruments
- melodic intervals (musical distance between consecutive notes)
- shape of melodies
- time signature
- tempo
- balance of registers (low, mid, high)

After extracting features for all the used tracks, we analyzed entropies of columns, which gave us a rough idea about the importance of each feature. However, it was not as straightforward to just use the columns with the highest amount of contained information and leave out the rest. To guide the choice of feature subset to be used in the classification, we ran a genetic algorithm whose target function was the classifier output and the evaluated entities were bit encodings of subsets of all features.

As an implementation note, some feature extractors we implemented ourselves, however the more non-trivial ones were extracted using the [music21](https://github.com/cuthbertLab/music21) library.

### Used classifiers

We tried K-nearest neighbors, logistic regression, support vector machine, and neural networks. The numbers will be discussed in the *Conclusion* section.

## Conclusion

### Results

The above-mentioned genetic algorithm helped us choose the following features to use in classification:
- pitch_avg
- pitch_std
- pitch_min
- pitch_max
- velocity_avg
- velocity_std
- polyphony_avg
- polyphony_std
- polyphony_max
- Brass_Fraction
- Electric_Instrument_Fraction
- Violin_Fraction
- Melodic_Interval_Histogram_2
- Melodic_Interval_Histogram_3
- Melodic_Interval_Histogram_5
- Melodic_Interval_Histogram_7
- Melodic_Interval_Histogram_9
- Melodic_Interval_Histogram_11
- Melodic_Interval_Histogram_12
- Melodic_Interval_Histogram_14
- Melodic_Interval_Histogram_15
- Melodic_Interval_Histogram_16
- Melodic_Interval_Histogram_21
- Melodic_Interval_Histogram_22
- Melodic_Interval_Histogram_23
- Melodic_Interval_Histogram_24
- Melodic_Interval_Histogram_25
- Melodic_Interval_Histogram_26
- Melodic_Interval_Histogram_27
- Melodic_Interval_Histogram_28
- Melodic_Interval_Histogram_29
- Melodic_Interval_Histogram_32
- Melodic_Interval_Histogram_35
- Melodic_Interval_Histogram_37
- Melodic_Interval_Histogram_40
- Melodic_Interval_Histogram_44
- Melodic_Interval_Histogram_45
- Melodic_Interval_Histogram_47
- Melodic_Interval_Histogram_48
- Pitch_Class_Distribution_0
- Pitch_Class_Distribution_2
- Pitch_Class_Distribution_3
- Pitch_Class_Distribution_8
- Pitch_Class_Distribution_11
- Direction_of_Motion
- Repeated_Notes
- Importance_of_Bass_Register
- Importance_of_Middle_Register
- Duration_of_Melodic_Arcs
- tempo
- resolution
- ts_numerator

based on the classifiers' performance on the test part of the dataset.

![Progression of GA](https://i.imgur.com/NMX2mJW.png)

After choosing the features, we ran classifiers on the validation dataset using those features, and obtained the following results:

|     |        |
|-----|--------|
| SVM | 34.87% |
| KNN | 39.92% |
| LR  | 33.61% |
| NNs | 31.51% |

where the SVM classifier was the scikit's SVC, the KNN used 26 nearest neighbors, the logistic regression classifier was scikit's default one, the neural network used was `MLPClassifier(alpha=1e-05, hidden_layer_sizes=(10, 10), random_state=1)`. Below are the confusion matrices for each classifier. Rows' labels are the true labes, and columns' labels are the predicted ones.

| SVM ||||||||
|------------|------------|-----|------|-----|------|-------|---------|
|            | electronic | pop | rock | rnb | jazz | metal | country |
| electronic | 15 |  1 |  7 |  3 |  4 |  3 |   1 |
| pop        | 17 |  4 |  5 |  5 |  0 |  1 |   2 |
| rock       |  8 |  5 |  4 |  5 |  2 |  4 |   6 |
| rnb        |  8 |  6 |  4 | 12 |  2 |  0 |   2 |
| jazz       |  9 |  5 |  0 |  6 | 12 |  0 |   2 |
| metal      |  8 |  1 |  2 |  0 |  1 | 22 |   0 |
| country    |  5 |  5 |  3 |  6 |  1 |  0 |  14 |

<br />

| KNN ||||||||
|------------|------------|-----|------|-----|------|-------|---------|
|            | electronic | pop | rock | rnb | jazz | metal | country |
| electronic | 9 | 2 | 3 | 1 | 7 | 8 | 4|
| pop        | 5 | 3 | 5 | 4 | 4 | 2 |11|
| rock       | 0 | 4 | 6 | 5 | 3 | 7 | 9|
| rnb        | 1 | 6 | 4 | 8 | 5 | 1 | 9|
| jazz       | 0 | 0 | 1 | 6 |21 | 2 | 4|
| metal      | 1 | 0 | 1 | 2 | 3 |26 | 1|
| country    | 0 | 3 | 4 | 3 | 2 | 0 |22|  

<br />

| LR ||||||||
|------------|------------|-----|------|-----|------|-------|---------|
|            | electronic | pop | rock | rnb | jazz | metal | country |
| electronic | 19 | 0 | 4 | 0 |  5 |  4 |  2 |
| pop        | 22 | 2 | 3 | 5 |  0 |  1 |  1 |
| rock       | 10 | 5 | 4 | 4 |  3 |  5 |  3 |
| rnb        | 14 | 4 | 4 | 9 |  2 |  0 |  1 |
| jazz       |  9 | 3 | 1 | 6 | 14 |  0 |  1 |
| metal      |  8 | 0 | 2 | 0 |  2 | 22 |  0 |
| country    | 11 | 5 | 5 | 1 |  2 |  0 | 10 |  

<br />

| NNs ||||||||
|------------|------------|-----|------|-----|------|-------|---------|
|            | electronic | pop | rock | rnb | jazz | metal | country |
| electronic |  9 |  1 |   1|   3 |   8 |    9 |  3 |
| pop        | 14 | 1  |  2 |  3  |  5  |   4  |  5 |
| rock       |  6 | 2  |  2 |  4  |  5  |   7  |  8 |
| rnb        | 12 | 2  |  4 |  3  |  7  |   1  |  5 |
| jazz       |  1 | 1  |  1 |  7  | 20  |   3  |  1 |
| metal      |  3 | 0  |  1 |  0  |  2  |  28  |  0 |
| country    |  6 | 4  |  3 |  2  |  6  |   1  | 12 |

### Lessons learned

- It seems that tembre (what the music actually sounds like), among other audio features, is a significant factor that makes genres distinguishable. MIDI doesn't really carry that information, besides maybe the instruments used.
- Probably, the features we extracted don't possess enough power to attain outstanding accuracy with the dataset being only MIDIs. There are many possible features which leverage some (possibly large) degree of music theory, however they are very tricky to extract.
- There might be patterns characteristic for certain genres, which are not detectable in a dataset that small.
- It's quite a challenge to prepare the data.


# Setup

In [None]:
# Download the database
!gdown https://drive.google.com/uc?id=1xPAla-T739rMoIANtaEWDJLfG03EL28J

Downloading...
From: https://drive.google.com/uc?id=1xPAla-T739rMoIANtaEWDJLfG03EL28J
To: /content/db.sqlite
100% 15.2M/15.2M [00:00<00:00, 24.3MB/s]


In [None]:
# Load the sql IPython extension
%load_ext sql

In [None]:
# Load the DB
%%sql
sqlite:///db.sqlite

'Connected: @db.sqlite'

In [None]:
# Download externals
!pip install pygad

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pygad
  Downloading pygad-2.18.3-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.4/56.4 KB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pygad
Successfully installed pygad-2.18.3


In [None]:
# Imports
import functools
from math import ceil
import numpy as np
import pandas as pd
import pygad
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
import sqlite3

# Preparing the data

In [None]:
# Some records of the database (with featuers already computed)
%%sql
SELECT * FROM used_tracks LIMIT 10;

 * sqlite:///db.sqlite
Done.


track_id,song_id,mb_track_id,artist,title,genre,pitch_avg,pitch_std,pitch_min,pitch_max,duration_avg,duration_std,duration_min,duration_max,total_duration,velocity_avg,velocity_std,velocity_min,velocity_max,polyphony_avg,polyphony_std,polyphony_max,tempo,resolution,ts_numerator,ts_denominator,Acoustic_Guitar_Fraction,Brass_Fraction,Electric_Guitar_Fraction,Electric_Instrument_Fraction,Orchestral_Strings_Fraction,Saxophone_Fraction,String_Ensemble_Fraction,String_Keyboard_Fraction,Violin_Fraction,Woodwinds_Fraction,Amount_of_Arpeggiation,Chromatic_Motion,Average_Melodic_Interval,Direction_of_Motion,Changes_of_Meter,Importance_of_High_Register,Duration_of_Melodic_Arcs,Importance_of_Middle_Register,Size_of_Melodic_Arcs,Importance_of_Bass_Register,Melodic_Interval_Histogram_0,Melodic_Interval_Histogram_1,Melodic_Interval_Histogram_2,Melodic_Interval_Histogram_3,Melodic_Interval_Histogram_4,Melodic_Interval_Histogram_5,Melodic_Interval_Histogram_6,Melodic_Interval_Histogram_7,Melodic_Interval_Histogram_8,Melodic_Interval_Histogram_9,Melodic_Interval_Histogram_10,Melodic_Interval_Histogram_11,Melodic_Interval_Histogram_12,Melodic_Interval_Histogram_13,Melodic_Interval_Histogram_14,Melodic_Interval_Histogram_15,Melodic_Interval_Histogram_16,Melodic_Interval_Histogram_17,Melodic_Interval_Histogram_18,Melodic_Interval_Histogram_19,Melodic_Interval_Histogram_20,Melodic_Interval_Histogram_21,Melodic_Interval_Histogram_22,Melodic_Interval_Histogram_23,Melodic_Interval_Histogram_24,Melodic_Interval_Histogram_25,Melodic_Interval_Histogram_26,Melodic_Interval_Histogram_27,Melodic_Interval_Histogram_28,Melodic_Interval_Histogram_29,Melodic_Interval_Histogram_30,Melodic_Interval_Histogram_31,Melodic_Interval_Histogram_32,Melodic_Interval_Histogram_33,Melodic_Interval_Histogram_34,Melodic_Interval_Histogram_35,Melodic_Interval_Histogram_36,Melodic_Interval_Histogram_37,Melodic_Interval_Histogram_38,Melodic_Interval_Histogram_39,Melodic_Interval_Histogram_40,Melodic_Interval_Histogram_41,Melodic_Interval_Histogram_42,Melodic_Interval_Histogram_43,Melodic_Interval_Histogram_44,Melodic_Interval_Histogram_45,Melodic_Interval_Histogram_46,Melodic_Interval_Histogram_47,Melodic_Interval_Histogram_48,Note_Density,Pitch_Class_Distribution_0,Pitch_Class_Distribution_1,Pitch_Class_Distribution_2,Pitch_Class_Distribution_3,Pitch_Class_Distribution_4,Pitch_Class_Distribution_5,Pitch_Class_Distribution_6,Pitch_Class_Distribution_7,Pitch_Class_Distribution_8,Pitch_Class_Distribution_9,Pitch_Class_Distribution_10,Pitch_Class_Distribution_11,Repeated_Notes,Stepwise_Motion,Staccato_Incidence,Variability_of_Time_Between_Attacks,Average_Time_Between_Attacks
TRADQXE128F1466AB9,SOBKKRK12AF729F24D,ae34f5aa-6f47-46bd-9f60-ad2b88229288,Queensryche,Lady Jane,metal,51.69824561403509,13.334778315052844,28.0,96.0,0.5822110142436557,0.5990183802908198,0.0073529416666531,7.058824000000428,240.6540968333321,103.18853801169593,17.32573408561033,9.0,127.0,10.342446352016218,3.92413898003798,20.0,135.1848045106509,120.0,4.0,4.0,0.8571428571428571,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5555555555555556,0.1111111111111111,5.555555555555555,0.4984326018808777,0.0,0.0844155844155844,1.380952380952381,0.4545454545454545,16.917748917748916,0.461038961038961,0.3333333333333333,0.1111111111111111,0.074074074074074,0.037037037037037,0.037037037037037,0.037037037037037,0.037037037037037,0.0,0.074074074074074,0.0,0.0,0.074074074074074,0.0,0.0,0.037037037037037,0.074074074074074,0.0,0.0,0.0,0.037037037037037,0.0,0.0,0.0,0.0,0.037037037037037,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.635655261971022,0.2513914656771799,0.1484230055658627,0.0037105751391465,0.0287569573283859,0.0296846011131725,0.2189239332096475,0.0,0.0955473098330241,0.1131725417439703,0.0,0.1103896103896103,0.0,0.3333333333333333,0.1851851851851851,0.0,0.3658416526557677,0.3775458477109903
TRAFMVN128F147CBCE,SOQDMXT12A6D4F8255,,Metallica,Fade To Black,metal,48.84875515909999,12.562241184975733,28.0,86.0,0.3043759812461242,0.3410797027211689,0.0,6.533338233333268,389.59076838333,122.54200505924643,17.57611401791315,50.0,127.0,5.867981030141541,2.148225101551445,11.0,226.4940293560418,120.0,4.0,4.0,0.2278400676962132,0.0357520626189972,0.2523799449968267,0.4480643114025809,0.0,0.0,0.0,0.3659826528453564,0.0,0.0,0.5436936936936937,0.0945945945945946,3.76981981981982,0.4700776610324349,0.0,0.0660038079119949,1.6458646616541353,0.414216204781045,8.778947368421052,0.5197799873069601,0.2594594594594595,0.0945945945945946,0.1936936936936936,0.0675675675675675,0.0594594594594594,0.0824324324324324,0.0067567567567567,0.0788288288288288,0.0229729729729729,0.0306306306306306,0.018018018018018,0.0063063063063063,0.045045045045045,0.0045045045045045,0.0031531531531531,0.0072072072072072,0.0018018018018018,0.0049549549549549,0.0,0.0063063063063063,0.0,0.0013513513513513,0.0013513513513513,0.0,0.0022522522522522,0.0,0.0,0.0009009009009009,0.0,0.0004504504504504,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,13.075907590759076,0.2143008250475989,0.0133276919822297,0.0871588745504548,0.1461815104717579,0.0050772159932303,0.2117622170509837,0.0004231013327691,0.1224878358366828,0.0539454199280727,0.0376560186164586,0.1076792891897609,0.0,0.2594594594594595,0.2882882882882883,0.0096191597958382,0.0863488410924513,0.2056780428873452
TRAHRRN12903CF8AAE,SOVHXTB12A67ADC863,0d7a1164-6d7c-4c3a-a353-acf59b6c458a,Lacuna Coil,Heaven's A Lie,metal,51.74526212158504,15.97494767012445,24.0,86.0,0.3684021650246166,0.5454794249458541,0.1071427499999941,13.714272000001074,284.5711439999972,93.05463942899335,5.759984643894395,76.0,95.0,5.260542168674702,2.8438070012344365,9.0,255.6869003305996,480.0,4.0,4.0,0.0,0.0,0.4211087420042644,0.5835110163468372,0.0,0.0,0.1428571428571428,0.3415067519545131,0.0,0.0,0.5786802030456852,0.1586294416243654,2.142766497461929,0.4430740037950664,1.0,0.1321961620469083,1.4638888888888888,0.4680170575692964,5.654166666666667,0.3997867803837953,0.5063451776649746,0.1586294416243654,0.1510152284263959,0.0348984771573604,0.0177664974619289,0.0450507614213197,0.0,0.0107868020304568,0.0025380710659898,0.0171319796954314,0.008248730964467,0.0006345177664974,0.0,0.0,0.0012690355329949,0.0,0.0,0.0,0.0,0.0203045685279187,0.0114213197969543,0.0,0.0114213197969543,0.0025380710659898,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.559738955823294,0.2746979388770433,0.0,0.1311300639658848,0.2125088841506752,0.0,0.1474769012082445,0.0,0.0991471215351812,0.083866382373845,0.0071073205401563,0.0440653873489694,0.0,0.5063451776649747,0.3096446700507614,0.0,0.0695169528486194,0.2311984191561083
TRAKAPI128F428CB43,SOGVQKE12A8C141409,,Regurgitate,Skull of Shit and Sludge,metal,45.23943661971831,13.830426539848723,35.0,78.0,0.1351928246003565,0.1301023272976221,0.0,0.7727265000000116,52.31812950000031,111.40845070422536,13.336994471894515,100.0,127.0,1.480451781059948,1.1481523574243624,4.0,220.0002200002196,384.0,4.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0625,0.0,0.0,0.4893617021276595,0.0,20.93617021276596,0.5,0.0,0.3333333333333333,1.0434782608695652,0.0,41.0,0.6666666666666666,0.4893617021276595,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5106382978723404,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.2723948811700185,0.6666666666666666,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.0,0.0,0.4893617021276595,0.0,0.0,2.547582846812804,0.752754820936639
TRAKOJW128F92CC883,SOXEBRP12A8C144763,,SONATA ARCTICA,UnOpened,metal,51.387076923076926,17.15016086249535,25.0,90.0,0.248698207932667,0.4066322671313571,0.0295275000000003,7.559040000000095,207.7542596875005,109.56261538461538,13.983977180156906,50.0,127.0,7.772220710542193,2.6623110887037,17.0,250.6414394846589,384.0,4.0,4.0,0.0,0.0,0.1645907473309608,0.4375,0.0,0.0,0.0097864768683274,0.2237544483985765,0.0,0.0,0.6535849056603774,0.1347169811320754,1.421132075471698,0.4642137096774194,0.0,0.2646797153024911,1.7207285342584562,0.2693505338078292,6.883781439722463,0.4659697508896797,0.5452830188679245,0.1347169811320754,0.1569811320754716,0.0460377358490566,0.0347169811320754,0.0373584905660377,0.000377358490566,0.0196226415094339,0.0033962264150943,0.0056603773584905,0.0037735849056603,0.000754716981132,0.0026415094339622,0.0,0.0011320754716981,0.000377358490566,0.000377358490566,0.000377358490566,0.000377358490566,0.0,0.0015094339622641,0.0,0.000754716981132,0.000377358490566,0.000377358490566,0.0,0.0,0.000754716981132,0.0,0.000377358490566,0.0,0.0,0.0,0.000377358490566,0.0,0.0,0.000377358490566,0.0,0.0,0.000377358490566,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000377358490566,0.000377358490566,23.535370565507552,0.2012900355871886,0.0011120996441281,0.1487989323843416,0.0,0.1419039145907473,0.1154359430604982,0.0026690391459074,0.1072064056939501,0.0124555160142348,0.172153024911032,0.0237989323843416,0.0731761565836298,0.5452830188679245,0.2916981132075472,0.0,0.3231591195005412,0.1429393609110633
TRAKQIH128EF343180,SOOGSWQ12A67AE0DF2,a1d28b74-ace6-4c08-be5f-8c4b373bb98b,Apocalyptica,Somewhere Around Nothing,metal,48.00936855911561,8.964522454601605,36.0,87.0,0.1770835817199579,0.2974747071616422,0.0197368124999997,5.052624000000236,237.47332799999828,104.29960652051714,13.406501944678652,63.0,127.0,3.9821309840425534,1.595871245863869,8.0,190.0002850004268,480.0,4.0,4.0,0.0,0.0,0.0,0.0,0.8108108108108109,0.0,0.0,0.5334620334620335,0.0035392535392535,0.0,0.5723076923076923,0.1846153846153846,1.3871794871794871,0.5198863636363636,0.0,0.0186615186615186,2.3259911894273126,0.3468468468468468,6.187224669603524,0.6344916344916345,0.4887179487179487,0.1846153846153846,0.1635897435897436,0.0384615384615384,0.0235897435897435,0.0548717948717948,0.0174358974358974,0.0082051282051282,0.001025641025641,0.003076923076923,0.0082051282051282,0.0041025641025641,0.0005128205128205,0.0015384615384615,0.0005128205128205,0.0,0.0005128205128205,0.0,0.0005128205128205,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0005128205128205,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,16.098503236245953,0.1721364221364221,0.055019305019305,0.1357786357786357,0.0186615186615186,0.016087516087516,0.1074646074646074,0.0057915057915057,0.1605534105534105,0.0424710424710424,0.1618404118404118,0.1222651222651222,0.0019305019305019,0.4887179487179487,0.3482051282051282,0.4867510145619479,0.2232881889870385,0.1532922132815679
TRARIBY128F4273287,SOTKFLY12AF72A73C3,923264d0-da62-4d17-b730-f93e7eb2c6ed,System of a Down,ATWA,metal,47.586771224002725,12.049465707667832,27.0,82.0,0.3440039611054418,0.2153346523016468,0.1724137499999756,1.5517237500000365,176.5516800000007,95.763723150358,7.970309294927095,63.0,127.0,5.718750000000001,2.411925125807184,11.0,170.44902220409227,480.0,4.0,4.0,0.2153772683858643,0.0,0.5152817574021012,0.6986628462273161,0.0,0.0,0.0,0.0,0.0,0.0,0.7235401459854015,0.1131386861313868,2.572992700729927,0.4235880398671096,0.0,0.0305635148042024,1.6358695652173914,0.332378223495702,7.975543478260869,0.6370582617000955,0.5036496350364963,0.1131386861313868,0.1450729927007299,0.0757299270072992,0.0310218978102189,0.0091240875912408,0.0,0.0036496350364963,0.0063868613138686,0.0018248175182481,0.0018248175182481,0.0,0.0045620437956204,0.0,0.0,0.0647810218978102,0.0383211678832116,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000912408759124,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.841015625,0.4703915950334288,0.0152817574021012,0.0888252148997134,0.1208213944603629,0.0,0.0668576886341929,0.0,0.1361031518624642,0.0343839541547277,0.0076408787010506,0.0596943648519579,0.0,0.5036496350364964,0.2582116788321167,0.0,0.1302751870826168,0.2710875331564987
TRAULTS128F931290B,SOLDLXG12AB0185707,,Nightwish,Beauty And The Beast,metal,50.387487126673534,13.873154880279518,28.0,86.0,0.3595005983649361,0.6276168135008527,0.0,18.66664800000004,374.8755742500154,95.0,0.0,95.0,95.0,7.448638342432625,2.529376353985134,12.0,200.9371738087447,480.0,4.0,4.0,0.0,0.0,0.255980007140307,0.3786147804355587,0.0,0.0,0.1960014280614066,0.2702606212067119,0.0,0.0,0.5593578369243768,0.1309674693705112,2.724123362906633,0.4664607237422771,1.0,0.0451624419850053,1.3329411764705883,0.4573366654766155,5.705294117647059,0.4975008925383791,0.3738910012674271,0.1309674693705111,0.1550485847063794,0.0215462610899873,0.0485847063793831,0.0608365019011406,0.0215462610899873,0.0882974228981833,0.0261934938741022,0.0426700464723278,0.0168990283058724,0.0008449514152936,0.0092944655682298,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0012674271229404,0.0,0.0008449514152936,0.0,0.0004224757076468,0.0,0.0,0.0004224757076468,0.0,0.0,0.0,0.0,0.0004224757076468,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11.39350532540651,0.2043912888254195,0.0071403070332024,0.1033559443056051,0.0403427347375937,0.1212067118886112,0.0994287754373438,0.0053552302749018,0.1099607283113174,0.026597643698679,0.1579792931096037,0.0940735451624419,0.0301677972152802,0.3738910012674271,0.2860160540768905,0.0185928586520177,0.1817571500375245,0.2101890153689405
TRAVGNH128F428CBA9,SOBARGP12A58A7B1D9,,Amorphis,Into Hiding,metal,41.99720223820943,8.861691063869118,25.0,74.0,0.3522640439338595,0.56956367300427,0.0,14.769216000000526,224.1622169999981,94.39708233413268,4.154538641937754,57.0,95.0,7.873822141656904,2.5469811389940857,15.0,235.53378579557813,480.0,4.0,4.0,0.0,0.0,0.4717980679339358,0.7154876908694298,0.0,0.0,0.0202555313181676,0.3443440324088501,0.0,0.0,0.6906636670416197,0.1293588301462317,1.7463442069741282,0.5246753246753246,0.0,0.0056092240573387,1.7824074074074074,0.0956684325334995,7.796296296296297,0.8987223434091617,0.5433070866141733,0.1293588301462317,0.125421822272216,0.062429696287964,0.0236220472440944,0.0185601799775028,0.0101237345331833,0.0224971878515185,0.0061867266591676,0.0050618672665916,0.0106861642294713,0.0011248593925759,0.0258717660292463,0.0056242969628796,0.0005624296962879,0.0011248593925759,0.0,0.0067491563554555,0.0,0.0,0.0011248593925759,0.0,0.0,0.0,0.0005624296962879,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,14.293858267716551,0.3948270489248987,0.0445621688999688,0.0751012776565908,0.0607665939545029,0.0327204736678092,0.0526643814272358,0.0320972265503272,0.1860392645684013,0.0358367092552196,0.015269554378311,0.0635712059831723,0.0065440947335618,0.5433070866141733,0.2547806524184477,0.0,0.1897364769308918,0.1760992316948993
TRAWHYS128F42BA12C,SOJWQFO12A8C13E769,,Metallica,To Live Is To Die,metal,54.28195782348347,9.333136143474428,40.0,88.0,0.3355808308871575,0.3281758688408543,0.0,6.545447999999908,437.5632899999984,93.45118458734704,5.371661811251354,57.0,95.0,2.945781789504096,1.4725309054339222,8.0,201.24610288544557,480.0,4.0,4.0,0.1012230028623471,0.0,0.3994275305750716,0.3994275305750716,0.0041634139994795,0.0,0.0,0.0,0.0,0.0,0.5320261437908497,0.1254901960784313,4.882352941176471,0.4934782608695652,1.0,0.0439760603695029,1.5248618784530388,0.4275305750715586,7.466298342541436,0.5284933645589384,0.1790849673202614,0.1254901960784313,0.1620915032679738,0.134640522875817,0.049673202614379,0.0692810457516339,0.0026143790849673,0.0444444444444444,0.0261437908496732,0.0196078431372549,0.0196078431372549,0.0039215686274509,0.0457516339869281,0.0,0.0196078431372549,0.0444444444444444,0.0104575163398692,0.0156862745098039,0.0026143790849673,0.0065359477124183,0.0,0.0013071895424836,0.0091503267973856,0.0,0.0065359477124183,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0013071895424836,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.4336408210753735,0.1665365599791829,0.0039032006245121,0.1295862607338017,0.0078064012490242,0.1530054644808743,0.0304449648711943,0.0749414519906323,0.1298464741087691,0.0041634139994795,0.1397345823575331,0.0104085349986989,0.1496226906062971,0.1790849673202614,0.2875816993464052,0.0,0.1798306809217316,0.2958457526080478


The below is the code which served as the core of the database preparation. Please, note that the database's state after just running the code might differ from the state actually used, because we executed a couple of statements manually in order to clean up the data.

## Database preparation code

```python
# init_db.py

from collections import defaultdict as dd, namedtuple
from pathlib import Path
import sqlite3
import os
import json

Entry = namedtuple('Entry', ['track_id', 'song_id', 'mb_track_id', 'artist', 'title', 'genre'], defaults=(None, None, None))

genres = dd(lambda: None)
genres_mbid = dd(lambda: [None])
labels = {}
entries = []

def find_mb_track_id(data):
    for song in data['response']['songs']:  # usually there's only one song
        for attr in song['tracks']:
            if attr['catalog'] == 'musicbrainz':
                return attr['foreign_id'].removeprefix('musicbrainz:track:')

def song_to_mbid(song):
    # mappings of echonest id to musicbrainz track ids
    # https://drive.google.com/file/d/1AZctGV7WysvsAaDCPWM1GVBvgaFz2Dys/view
    json_path = os.path.join('millionsongdataset_echonest/', song[2:4], song + '.json')
    if os.path.isfile(json_path):
        with open(json_path, 'r') as json_file:
            data = json.load(json_file)
            return find_mb_track_id(data)

# https://www.tagtraum.com/genres/msd_tagtraum_cd2c.cls.zip
with open('msd_tagtraum_cd2c.cls', 'r') as dataset:
    for line in dataset:
        if '#' not in line:
            track, genre = line.strip().split('\t')
            genres[track] = genre

# http://hog.ee.columbia.edu/craffel/lmd/lmd_matched.tar.gz
for root, dirs, files in os.walk('./lmd_matched/'):
    if len(files):
        track = os.path.basename(os.path.normpath(root))
        labels[track] = genres[track]

# https://mtg.github.io/acousticbrainz-genre-dataset/
for f in Path('.').glob('acousticbrainz-mediaeval-*.tsv'):
    with open(f, 'r') as tsv_file:
        tsv_file.readline() # skip header
        for line in tsv_file:
            track_id, _, *genres = filter(lambda x: x != '' and '---' not in x, (el.strip() for el in line.split('\t')))
            genres_mbid[track_id] = genres

# http://millionsongdataset.com/sites/default/files/AdditionalFiles/unique_tracks.txt
with open('unique_tracks.txt', 'r') as metadata:
    for line in metadata:
        track, song, artist, title = (el.strip() for el in line.split('<SEP>'))
        if track in labels.keys():
            mbid = song_to_mbid(song)
            genre = labels[track]
            if genre is None:
                 genre = genres_mbid[mbid][0]
            entries.append(Entry(track, song, mbid, artist, title, genre.lower() if genre is not None else genre))

con = sqlite3.connect('db.sqlite')
cur = con.cursor()
cur.execute("DROP TABLE IF EXISTS tracks")
cur.execute("CREATE TABLE tracks(track_id TEXT PRIMARY KEY, song_id TEXT, mb_track_id TEXT, artist TEXT, title TEXT, genre TEXT)")
cur.executemany("INSERT INTO tracks VALUES(?, ?, ?, ?, ?, ?)", entries)
con.commit()

labeled = cur.execute("SELECT COUNT(*) FROM tracks WHERE genre IS NOT NULL")
n_labeled, = labeled.fetchone()
unlabeled = cur.execute("SELECT COUNT(*) FROM tracks WHERE genre IS NULL")
n_unlabeled, = unlabeled.fetchone()
n_tracks = n_labeled + n_unlabeled

print(f'no. of tracks:\t\t{n_tracks}')
print(f'no. of labeled:\t\t{n_labeled}')
print(f'no. of unlabeled:\t{n_unlabeled}')
print(f'amt of labeled:\t\t{round(100 * n_labeled / n_tracks, 2)}%')
```


```python
# feature_extractor.py

from collections import defaultdict as dd
from functools import cache
import mido
import pretty_midi
import music21
import heapq

DEFAULT_TEMPO = 500000

class MidiLibParser:
    @staticmethod
    @cache
    def parse(filename):
        return None

class MidoParser(MidiLibParser):
    @staticmethod
    @cache
    def parse(filename):
        return mido.MidiFile(filename)

class PrettyMidiParser(MidiLibParser):
    @staticmethod
    @cache
    def parse(filename):
        return pretty_midi.PrettyMIDI(filename)

class Music21Parser(MidiLibParser):
    @staticmethod
    # @cache
    def parse(filename):
        return music21.converter.parseFile(filename, format='midi')

class FeatureExtractor:
    def __init__(self):
        pass

    def features(self):
        return ()

    def extract(self, midi):
        return ()

    def _compute_stats(self, data):
        ssum = 0
        sqrs_ssum = 0
        mmax = 0
        mmin = 10 ** 9
        n = 0
        for d in data:
            ssum += d
            sqrs_ssum += d ** 2
            mmax = max(mmax, d)
            mmin = min(mmin, d)
            n += 1
        return (ssum / n, self.__compute_std(ssum, sqrs_ssum, n), mmin, mmax)
        
    def __compute_std(self, sum, sqrs_sum, n):
        return ((sqrs_sum - sum ** 2 / n) / n) ** 0.5
    
    def _to_abstime_inplace(self, messages):
        """Convert messages to absolute time."""
        now = 0
        for msg in messages:
            now += msg.time
            # sublime method of monkey-patching
            vars(msg)['orig_time'] = msg.time
            msg.time = now

    def _merge_tracks(self, tracks):
        for track in tracks:
            self._to_abstime_inplace(track)
            # don't worry about sorting - messages in each track will be sorted
            # with respect to msg.time

        now = 0
        accum = 0
        for msg in heapq.merge(*tracks, key=lambda msg: msg.time):
            delta = msg.time - now
            now = msg.time
            msg.time = delta

            if msg.type == 'end_of_track':
                accum += msg.time
            else:
                if accum:
                    delta2 = accum + msg.time
                    msg.time = delta2
                    yield msg
                    accum = 0
                else:
                    yield msg

        # keep the magic invariant of having orig_time field
        eot = mido.MetaMessage('end_of_track', time=accum)
        vars(eot)['orig_time'] = accum
        yield eot

    def fast_iterator(self, midi):
        # The tracks of type 2 files are not in sync, so they can
        # not be played back like this.
        if midi.type == 2:
            raise TypeError("can't merge tracks in type 2 (asynchronous) file")

        tempo = DEFAULT_TEMPO
        for msg in self._merge_tracks(midi.tracks):
            # Convert message time from absolute time
            # in ticks to relative time in seconds.
            if msg.time > 0:
                delta = mido.tick2second(msg.time, midi.ticks_per_beat, tempo)
            else:
                delta = 0

            msg.time = delta
            yield msg
            msg.time = msg.orig_time

            if msg.type == 'set_tempo':
                tempo = msg.tempo


class PitchStatsExtractor(FeatureExtractor, MidoParser):
    def features(self):
        return ('pitch_avg', 'pitch_std', 'pitch_min', 'pitch_max')

    def extract(self, midi):
        return self._compute_stats((msg.note for msg in self.fast_iterator(midi) if msg.type == 'note_on' and msg.velocity > 0))


class DurationStatsExtractor(FeatureExtractor, MidoParser):
    def features(self):
        return ('duration_avg', 'duration_std', 'duration_min', 'duration_max', 'total_duration')
    
    def extract(self, midi):
        total_duration = 0
        start_times = dd(lambda: None)
        durations = []
        for msg in self.fast_iterator(midi):
            total_duration += msg.time
            if msg.type not in ['note_on', 'note_off']:
                continue
            note_id = (msg.channel, msg.note)
            if msg.type == 'note_on' and msg.velocity > 0:
                if start_times[note_id]:  # note is already being played, but gets repeated just now
                    durations.append(total_duration - start_times[note_id])
                start_times[note_id] = total_duration
            elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                if start_times[note_id]:
                    durations.append(total_duration - start_times[note_id])
                    start_times[note_id] = None

        return self._compute_stats(durations) + (total_duration,)


class DynamicsStatsExtractor(FeatureExtractor, MidoParser):
    def features(self):
        return ('velocity_avg', 'velocity_std', 'velocity_min', 'velocity_max')

    def extract(self, midi):
        return self._compute_stats((msg.velocity for msg in self.fast_iterator(midi) if msg.type == 'note_on' and msg.velocity > 0))


class PolyphonyStatsExtractor(FeatureExtractor, MidoParser):
    # min wouldn't be useful at all
    def features(self):
        return ('polyphony_avg', 'polyphony_std', 'polyphony_max')

    def extract(self, midi):
        active_notes = set()
        timers_per_amount_of_notes = dd(lambda: 0)
        for msg in self.fast_iterator(midi):
            timers_per_amount_of_notes[len(active_notes)] += msg.time
            if msg.type not in ['note_on', 'note_off']:
                continue
            note_id = (msg.channel, msg.note)
            if msg.type == 'note_on' and msg.velocity > 0:
                if note_id not in active_notes:
                    active_notes.add(note_id)
            elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                if note_id in active_notes:
                    active_notes.remove(note_id)

        time_sum = 0
        mean = 0
        mmax = 0
        for n, t in timers_per_amount_of_notes.items():
            time_sum += t
            mean += n * t
            mmax = max(mmax, n)
        mean /= time_sum

        std = 0
        for n, t in timers_per_amount_of_notes.items():
            std += (n - mean) ** 2 * t
        std /= time_sum
        std **= 0.5

        return (mean, std, mmax)

class BasicStatsExtractor(FeatureExtractor, PrettyMidiParser):
    def features(self):
        return ('tempo', 'resolution', 'ts_numerator', 'ts_denominator')

    def extract(self, midi):
        tempo = midi.estimate_tempo()
        resolution = midi.resolution
        ts_changes = midi.time_signature_changes
        # default is 4/4
        ts_num = 4
        ts_den = 4
        if len(ts_changes) > 0:
            ts_num = ts_changes[0].numerator
            ts_den = ts_changes[0].denominator

        return tempo, resolution, ts_num, ts_den

# Python will come to fear my programming powers!
def m21(cls):
    """
        Dynamically create FeatureExtractor from m21 feature extractor
    """
    def __init__(self):
        self.extractor = cls()

    def features(self):
        return tuple(self.extractor.getAttributeLabels())

    def extract(self, midi):
        self.extractor.setData(midi)
        return tuple(self.extractor.extract().vector)

    return type(cls.__name__, (FeatureExtractor, Music21Parser), {
        "__init__": __init__,
        "features": features,
        "extract": extract
    })

```

```python
# extract_features.py

import mido
import sqlite3
import os
from tqdm import tqdm
from feature_extractor import *
from argparse import ArgumentParser
from music21 import features

basepath = 'lmd_matched/'

feature_extractors = [
    PitchStatsExtractor,
    DurationStatsExtractor,
    DynamicsStatsExtractor,
    PolyphonyStatsExtractor,
    BasicStatsExtractor,
]

feature_extractors += [m21(cls) for cls in music21.features.jSymbolic.featureExtractors]
feature_extractors += [m21(cls) for cls in music21.features.native.featureExtractors]

# TRRSDBS12903CEC331 - unpleasant example

def name2path(name):
    a, b, c = name[2:5]  # subpaths
    return os.path.join(basepath, a, b, c, name)

def run_extractors(extractors):
    con = sqlite3.connect('db.sqlite')
    cur = con.cursor()

    extractor_objs = list(map(lambda ex: ex(), extractors))

    for extractor in extractor_objs:
        for feature in extractor.features():
            try:
                cur.execute(f'ALTER TABLE tracks ADD COLUMN {feature} REAL')
            except sqlite3.OperationalError:
                pass  # ignore if adding column failed - already present

    DISTINCT_GENRES = 7
    genres_counts = sorted(cur.execute("SELECT DISTINCT genre, count(*) FROM tracks WHERE genre IS NOT NULL GROUP BY genre").fetchall(), key=lambda p: p[1])
    no_of_records_per_genre = genres_counts[-DISTINCT_GENRES][1]
    genres = list(map(lambda p: p[0], genres_counts[len(genres_counts)-DISTINCT_GENRES:]))

    all_tracks = []
    for genre in genres:
        tracks = cur.execute(f'SELECT track_id FROM tracks WHERE genre = "{genre}" ORDER BY track_id LIMIT {no_of_records_per_genre}')
        all_tracks.extend(tracks.fetchall())

    for (track,) in tqdm(all_tracks):
        track_path = name2path(track)
        for file in os.listdir(track_path):
            filepath = os.path.join(track_path, file)
            try:
                for extractor in extractor_objs:
                    midi = extractor.parse(filepath) # this is cached so don't worry about performance
                    extracted = extractor.extract(midi)
                    for feature, val in zip(extractor.features(), extracted):
                        cur.execute(f'UPDATE tracks SET {feature} = {val} WHERE track_id = "{track}"')
            except Exception as e:
                # if parsing file/extraction failed for some reason, try another midi file or skip this song
                print(f'Warning:\tprocessing {filepath} failed')
                print(f'Exception:\t{str(e)}')
                continue

            break  # process only one midi file per track (?)
        con.commit()

def main():
    extractor_dict = {type(ex()).__name__: ex for ex in feature_extractors}

    parser = ArgumentParser(description='Run MIDI feature extractors')
    parser.add_argument('--extractors', '-e', nargs='+', default=extractor_dict.keys(),
        help='Space-separated list of feature extractor class names to be ran on the dataset')
    args = parser.parse_args()

    try:
        run_extractors(list(map(lambda name: extractor_dict[name], args.extractors)))
    except KeyError as e:
        print(f'Error: No extractor named {e}')
        exit()

if __name__ == "__main__":
    main()

```

# Classifying MIDIs

In [None]:
# Util functions and vars

DISTINCT_GENRES = 7  # how many biggest-counted genres to use

def split_train_test_valid(data):
    tr = ceil(0.6 * len(data))
    tst = ceil(0.8 * len(data))
    return data[:tr], data[tr:tst], data[tst:]

def split_X_y(data):
    # genre is first column
    return np.array(list(map(lambda r: r[1:], data))), np.array(list(map(lambda r: r[0], data)))

def vstack(arrays):
    return np.concatenate(arrays, axis=0)

def classify_on_data(train_X, train_y, valid_X, valid_y, classifier):
    classifier.fit(train_X / train_X.max(axis=0), train_y)
    preds = classifier.predict(valid_X / valid_X.max(axis=0))
    perf = (preds == valid_y).mean()
    return perf


In [None]:
# Genetic algorithm described at the top
# Helps in finding a good subset of features

def solution_to_features(solution, features):
    return [features[i] for v, i in zip(solution, range(len(solution))) if v]

def callback_gen(ga_instance):
    print("Generation : ", ga_instance.generations_completed)
    solution, solution_fitness, solution_idx = ga_instance.best_solution()
    print("Fitness of the best solution :", solution_fitness)
    print("Parameters of the best solution : {solution}".format(solution=solution))

def run_ga(features, train_X, train_y, test_X, test_y, classifier):
    cache = {}
    def fitness_func(solution, solution_idx):
        tsolution = tuple(solution)
        if tsolution in cache:
            return cache[tsolution]
        else:
            mask = []
            for v, i in zip(solution, range(len(solution))):
                if v:
                    mask.append(i)
            if not mask:
                cache[tsolution] = 0
                return 0

            train_cols = train_X[:, mask]
            test_cols = test_X[:, mask]
            perf = classify_on_data(train_cols, train_y, test_cols, test_y, classifier)

            cache[tsolution] = perf
            return perf

    num_generations = 1000000
    num_parents_mating = 4

    sol_per_pop = 12
    num_genes = len(features)

    init_range_low = 0
    init_range_high = 1
    gene_space = [0, 1]
    gene_type = int

    parent_selection_type = "sss"
    keep_parents = 1

    crossover_type = "two_points"

    mutation_type = "random"
    mutation_percent_genes = 15

    stop_criteria = "reach_0.50"
    keep_elitism = 2

    ga_instance = pygad.GA(num_generations=num_generations,
                        num_parents_mating=num_parents_mating,
                        fitness_func=fitness_func,
                        sol_per_pop=sol_per_pop,
                        num_genes=num_genes,
                        init_range_low=init_range_low,
                        init_range_high=init_range_high,
                        gene_space=gene_space,
                        gene_type=gene_type,
                        parent_selection_type=parent_selection_type,
                        keep_parents=keep_parents,
                        crossover_type=crossover_type,
                        mutation_type=mutation_type,
                        mutation_percent_genes=mutation_percent_genes,
                        on_generation=callback_gen,
                        stop_criteria=stop_criteria,
                        keep_elitism=2)

    ga_instance.run()
    solution, solution_fitness, solution_idx = ga_instance.best_solution()
    return solution_to_features(solution, features)


In [None]:
# Wrappers for running classifiers

def do_knn(classify_func):
    best = 0
    for k in range(1, 101):
        perf = classify_func(KNeighborsClassifier(n_neighbors=k))
        print(f"{k}\t{100 * perf:.2f}%")
        best = max(best, perf)
    return best

def do_svm(classify_func):
    perf = classify_func(SVC())
    print(f"svm: \t{100 * perf:.2f}%")
    return perf

def do_nns(classify_func):
    nns = [
        MLPClassifier(solver='adam', alpha=1e-5, hidden_layer_sizes=(10, 10), random_state=1),
        MLPClassifier(solver='adam', alpha=1e-5, hidden_layer_sizes=(100, 100), random_state=1),
    ]

    best = 0
    for cl in nns:
        perf = classify_func(cl)
        print(f"{cl}\t{100 * perf:.2f}%")
        best = max(best, perf)
    return best

def do_lr(classify_func):
    perf = classify_func(LogisticRegression(max_iter=1000))
    print(f"logit:\t{100 * perf:.2f}%")
    return perf


In [None]:
# Datasets setup helper

def get_dataset(cursor, used_features):
    train_Xs, train_ys = [], []
    test_Xs,  test_ys  = [], []
    valid_Xs, valid_ys = [], []
    genres_counts = sorted(cursor.execute("SELECT DISTINCT genre, count(*) FROM used_tracks WHERE genre IS NOT NULL AND pitch_avg IS NOT NULL GROUP BY genre").fetchall(), key=lambda p: p[1])
    no_of_records_per_genre = genres_counts[-DISTINCT_GENRES][1]

    genres = list(map(lambda p: p[0], genres_counts[len(genres_counts)-DISTINCT_GENRES:]))
    features = ", ".join(used_features)
    for genre in genres:
        records = cursor.execute(f"SELECT genre, {features} FROM used_tracks WHERE genre = '{genre}' AND pitch_avg IS NOT NULL ORDER BY track_id LIMIT {no_of_records_per_genre}").fetchall()
        train, test, valid = split_train_test_valid(records)
        train_part_X, train_part_y = split_X_y(train)
        test_part_X, test_part_y = split_X_y(test)
        valid_part_X, valid_part_y = split_X_y(valid)
        train_Xs.append(train_part_X)
        train_ys.append(train_part_y)
        test_Xs.append(test_part_X)
        test_ys.append(test_part_y)
        valid_Xs.append(valid_part_X)
        valid_ys.append(valid_part_y)

    train_X, train_y = vstack(train_Xs), vstack(train_ys)
    test_X,  test_y  = vstack(test_Xs),  vstack(test_ys)
    valid_X, valid_y = vstack(valid_Xs), vstack(valid_ys)

    return train_X, train_y, test_X, test_y, valid_X, valid_y


In [None]:
all_features = np.asarray(['pitch_avg', 'pitch_std', 'pitch_min', 'pitch_max', 'duration_avg', 'duration_std', 'duration_min', 'duration_max', 'velocity_avg', 'velocity_std', 'velocity_min', 'velocity_max', 'polyphony_avg', 'polyphony_std', 'polyphony_max', 'Brass_Fraction', 'Electric_Guitar_Fraction', 'Electric_Instrument_Fraction', 'Saxophone_Fraction', 'Violin_Fraction', 'Woodwinds_Fraction', 'Melodic_Interval_Histogram_0', 'Melodic_Interval_Histogram_1', 'Melodic_Interval_Histogram_2', 'Melodic_Interval_Histogram_3', 'Melodic_Interval_Histogram_4', 'Melodic_Interval_Histogram_5', 'Melodic_Interval_Histogram_6', 'Melodic_Interval_Histogram_7', 'Melodic_Interval_Histogram_8', 'Melodic_Interval_Histogram_9', 'Melodic_Interval_Histogram_10', 'Melodic_Interval_Histogram_11', 'Melodic_Interval_Histogram_12', 'Melodic_Interval_Histogram_13', 'Melodic_Interval_Histogram_14', 'Melodic_Interval_Histogram_15', 'Melodic_Interval_Histogram_16', 'Melodic_Interval_Histogram_17', 'Melodic_Interval_Histogram_18', 'Melodic_Interval_Histogram_19', 'Melodic_Interval_Histogram_20', 'Melodic_Interval_Histogram_21', 'Melodic_Interval_Histogram_22', 'Melodic_Interval_Histogram_23', 'Melodic_Interval_Histogram_24', 'Melodic_Interval_Histogram_25', 'Melodic_Interval_Histogram_26', 'Melodic_Interval_Histogram_27', 'Melodic_Interval_Histogram_28', 'Melodic_Interval_Histogram_29', 'Melodic_Interval_Histogram_30', 'Melodic_Interval_Histogram_31', 'Melodic_Interval_Histogram_32', 'Melodic_Interval_Histogram_33', 'Melodic_Interval_Histogram_34', 'Melodic_Interval_Histogram_35', 'Melodic_Interval_Histogram_36', 'Melodic_Interval_Histogram_37', 'Melodic_Interval_Histogram_38', 'Melodic_Interval_Histogram_39', 'Melodic_Interval_Histogram_40', 'Melodic_Interval_Histogram_41', 'Melodic_Interval_Histogram_42', 'Melodic_Interval_Histogram_43', 'Melodic_Interval_Histogram_44', 'Melodic_Interval_Histogram_45', 'Melodic_Interval_Histogram_46', 'Melodic_Interval_Histogram_47', 'Melodic_Interval_Histogram_48', 'Pitch_Class_Distribution_0', 'Pitch_Class_Distribution_1', 'Pitch_Class_Distribution_2', 'Pitch_Class_Distribution_3', 'Pitch_Class_Distribution_4', 'Pitch_Class_Distribution_5', 'Pitch_Class_Distribution_6', 'Pitch_Class_Distribution_7', 'Pitch_Class_Distribution_8', 'Pitch_Class_Distribution_9', 'Pitch_Class_Distribution_10', 'Pitch_Class_Distribution_11', 'Average_Time_Between_Attacks', 'Variability_of_Time_Between_Attacks', 'Staccato_Incidence', 'Stepwise_Motion', 'Chromatic_Motion', 'Direction_of_Motion', 'Repeated_Notes', 'Note_Density', 'Importance_of_Bass_Register', 'Importance_of_Middle_Register', 'Importance_of_High_Register', 'Size_of_Melodic_Arcs', 'Duration_of_Melodic_Arcs', 'total_duration', 'tempo', 'resolution', 'ts_numerator', 'ts_denominator', 'Average_Melodic_Interval', 'Amount_of_Arpeggiation'])

In [None]:
# Dataset setup

RUN_GA = False # change here to control whether or not to run GA
DB = "db.sqlite"
CON = sqlite3.connect(DB)
CUR = CON.cursor()

train_X, train_y, test_X, test_y, valid_X, valid_y = get_dataset(CUR, all_features)


In [None]:
def entropy(series):
    probs = series.value_counts() / series.count()
    log_probs = np.log2(probs)
    return -np.sum(probs * log_probs)

df = pd.DataFrame(vstack([train_X, test_X, valid_X]), columns=all_features)
df.apply(entropy).sort_values(ascending=False).head(10)

polyphony_avg                          10.180496
duration_avg                           10.180496
duration_std                           10.180496
total_duration                         10.180496
Average_Time_Between_Attacks           10.180496
Variability_of_Time_Between_Attacks    10.180496
Size_of_Melodic_Arcs                   10.180496
Note_Density                           10.180496
polyphony_std                          10.180496
Average_Melodic_Interval               10.178854
dtype: float64

In [None]:
# Classification

if RUN_GA:
    best_features = run_ga(all_features, train_X, train_y, test_X, test_y, SVC())
else:
    best_features = solution_to_features([1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1,
        1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1,
        0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0], all_features)
train_X, train_y, test_X, test_y, valid_X, valid_y = get_dataset(CUR, best_features)

classify = functools.partial(classify_on_data, train_X, train_y, valid_X, valid_y)

best_knn = do_knn(classify)
best_svm = do_svm(classify)
best_nn  = do_nns(classify)
best_lr  = do_lr(classify)
best = max(best_knn, best_svm, best_nn, best_lr)

print(f"expected random's performance:\t\t{100 / DISTINCT_GENRES:.2f}%")
print(f"best performance on validation set:\t{100 * best:.2f}%")
print(f"diff:\t\t\t\t\t{100 * (best - 1 / DISTINCT_GENRES):.2f}pp")


1	29.41%
2	31.09%
3	29.41%
4	31.51%
5	33.61%
6	34.03%
7	32.77%
8	33.19%
9	34.45%
10	35.29%
11	34.45%
12	34.03%
13	33.61%
14	34.03%
15	34.87%
16	36.13%
17	37.82%
18	34.87%
19	33.19%
20	34.87%
21	36.97%
22	37.39%
23	38.24%
24	38.24%
25	38.24%
26	39.92%
27	37.82%
28	37.39%
29	37.39%
30	36.55%
31	35.71%
32	36.55%
33	34.45%
34	34.03%
35	34.03%
36	35.29%
37	34.87%
38	35.71%
39	33.61%
40	33.61%
41	33.19%
42	33.61%
43	35.71%
44	35.29%
45	34.87%
46	34.87%
47	34.03%
48	36.55%
49	35.71%
50	34.87%
51	34.87%
52	35.29%
53	36.13%
54	36.13%
55	36.13%
56	35.71%
57	35.29%
58	35.71%
59	35.29%
60	34.87%
61	35.71%
62	35.29%
63	35.29%
64	35.29%
65	35.71%
66	35.29%
67	35.71%
68	36.13%
69	36.55%
70	36.55%
71	36.13%
72	36.55%
73	35.71%
74	36.97%
75	36.55%
76	35.71%
77	36.13%
78	35.29%
79	36.55%
80	36.55%
81	36.13%
82	35.71%
83	35.29%
84	34.87%
85	34.87%
86	34.87%
87	36.55%
88	36.13%
89	35.29%
90	36.55%
91	35.71%
92	35.71%
93	36.13%
94	35.29%
95	35.71%
96	34.87%
97	36.13%
98	34.45%
99	35.29%
100	34.87%
svm: 	34



MLPClassifier(alpha=1e-05, hidden_layer_sizes=(10, 10), random_state=1)	31.51%
MLPClassifier(alpha=1e-05, hidden_layer_sizes=(100, 100), random_state=1)	27.31%
logit:	33.61%
expected random's performance:		14.29%
best performance on validation set:	39.92%
diff:					25.63pp




In [None]:
# Compute confusion matrices

def compute_confusion_matrix(train_X, train_y, valid_X, valid_y, cats, classifier):
    classifier.fit(train_X / train_X.max(axis=0), train_y)
    preds = classifier.predict(valid_X / valid_X.max(axis=0))
    return confusion_matrix(preds, valid_y, labels=cats).T

cats = ['electronic', 'pop', 'rock', 'rnb', 'jazz', 'metal', 'country']
print(f"SVM:\n{compute_confusion_matrix(train_X, train_y, valid_X, valid_y, cats, SVC())}", end="\n\n")
print(f"KNN:\n{compute_confusion_matrix(train_X, train_y, valid_X, valid_y, cats, KNeighborsClassifier(n_neighbors=26))}", end="\n\n")
print(f"LR:\n{compute_confusion_matrix(train_X, train_y, valid_X, valid_y, cats, LogisticRegression(max_iter=1000))}", end="\n\n")
print(f"NN:\n{compute_confusion_matrix(train_X, train_y, valid_X, valid_y, cats, MLPClassifier(solver='adam', alpha=1e-5, hidden_layer_sizes=(10, 10), random_state=1))}", end="\n\n")

SVM:
[[15  1  7  3  4  3  1]
 [17  4  5  5  0  1  2]
 [ 8  5  4  5  2  4  6]
 [ 8  6  4 12  2  0  2]
 [ 9  5  0  6 12  0  2]
 [ 8  1  2  0  1 22  0]
 [ 5  5  3  6  1  0 14]]

KNN:
[[ 9  2  3  1  7  8  4]
 [ 5  3  5  4  4  2 11]
 [ 0  4  6  5  3  7  9]
 [ 1  6  4  8  5  1  9]
 [ 0  0  1  6 21  2  4]
 [ 1  0  1  2  3 26  1]
 [ 0  3  4  3  2  0 22]]

LR:
[[19  0  4  0  5  4  2]
 [22  2  3  5  0  1  1]
 [10  5  4  4  3  5  3]
 [14  4  4  9  2  0  1]
 [ 9  3  1  6 14  0  1]
 [ 8  0  2  0  2 22  0]
 [11  5  5  1  2  0 10]]

NN:
[[ 9  1  1  3  8  9  3]
 [14  1  2  3  5  4  5]
 [ 6  2  2  4  5  7  8]
 [12  2  4  3  7  1  5]
 [ 1  1  1  7 20  3  1]
 [ 3  0  1  0  2 28  0]
 [ 6  4  3  2  6  1 12]]



