In [1]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, LinearSVC
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from tqdm.notebook import tqdm
import time
import warnings
warnings.filterwarnings('ignore')

In [2]:
%run pre_proc.ipynb

In [3]:
models = {
    'LR': LogisticRegression(solver='saga'),
    'SVClinear': LinearSVC(),
    'SVCrbf': SVC(kernel='rbf'),
    'kNN': KNeighborsClassifier(n_neighbors=200),
    'MLP': MLPClassifier(random_state=42, max_iter=200),
    'DT': DecisionTreeClassifier(max_depth=5),
    'NB': GaussianNB(),
    'GBC': GradientBoostingClassifier(),
    'LDA': LinearDiscriminantAnalysis(),
    'QDA': QuadraticDiscriminantAnalysis(),
}

In [4]:
def run_models(models, features):
    columns = list(models.keys()).insert(0, 'dim')
    scores = pd.DataFrame(columns=columns, index=feature_sets.keys())
    for fset_name, fset in tqdm(feature_sets.items(), desc='features'):
        y_train, y_val, y_test, X_train, X_val, X_test = pre_process(tracks, features_all, fset, False)
        scores.loc[fset_name, 'dim'] = X_train.shape[1]
        for clf_name, clf in models.items():
            clf.fit(X_train, y_train)
            score = clf.score(X_test, y_test)
            scores.loc[fset_name, clf_name] = score
    return scores

def format_scores(scores):
    def highlight(s):
        is_max = s == max(s[1:])
        return ['background-color: yellow' if v else '' for v in is_max]
    scores = scores.style.apply(highlight, axis=1)
    return scores.format('{:.2%}', subset=pd.IndexSlice[:, scores.columns[1]:])

In [5]:
scores = run_models(models, feature_sets)

ipd.display(format_scores(scores))

features:   0%|          | 0/18 [00:00<?, ?it/s]

Unnamed: 0,dim,LR,SVClinear,SVCrbf,kNN,MLP,DT,NB,GBC,LDA,QDA
chroma_cens,84.0,38.87%,39.14%,42.29%,37.50%,40.81%,35.68%,9.99%,39.56%,38.24%,24.64%
chroma_cqt,84.0,39.37%,40.34%,44.27%,40.03%,43.68%,35.45%,1.55%,42.05%,39.76%,3.58%
chroma_stft,84.0,41.97%,43.06%,48.31%,43.92%,49.01%,39.88%,4.20%,46.21%,43.53%,5.64%
mfcc,140.0,58.61%,57.09%,60.98%,54.99%,54.06%,45.82%,41.86%,57.68%,57.68%,48.39%
rmse,7.0,37.23%,37.35%,38.90%,38.52%,38.87%,38.63%,11.78%,40.30%,36.57%,15.04%
spectral_bandwidth,7.0,40.50%,40.58%,44.46%,45.39%,45.24%,42.91%,36.18%,43.26%,39.84%,34.16%
spectral_centroid,7.0,42.75%,42.17%,45.71%,45.36%,47.96%,42.67%,33.31%,46.75%,43.02%,36.11%
spectral_contrast,49.0,51.61%,48.97%,54.45%,49.55%,52.62%,43.53%,39.41%,52.27%,48.93%,41.78%
spectral_rolloff,7.0,41.70%,41.47%,47.53%,46.25%,47.69%,45.36%,28.49%,46.48%,41.51%,28.53%
tonnetz,42.0,40.11%,39.53%,42.25%,37.31%,41.86%,35.91%,22.31%,41.47%,38.98%,23.05%
