In [1]:
%reload_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import warnings

import pandas as pd
import numpy as np
import pickle

import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import AgglomerativeClustering, KMeans, DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from sklearn import datasets, metrics

import seaborn as sns
import matplotlib.pyplot as plt

<IPython.core.display.Javascript object>

In [3]:
def print_vif(x):
    """Utility for checking multicollinearity assumption
    
    :param x: input features to check using VIF. This is assumed to be a pandas.DataFrame
    :return: nothing is returned the VIFs are printed as a pandas series
    """
    # Silence numpy FutureWarning about .ptp
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        x = sm.add_constant(x)

    vifs = []
    for i in range(x.shape[1]):
        vif = variance_inflation_factor(x.values, i)
        vifs.append(vif)

    print("VIF results\n-------------------------------")
    print(pd.Series(vifs, index=x.columns))
    print("-------------------------------\n")

<IPython.core.display.Javascript object>

In [4]:
df = pd.read_csv("../data/features_3_sec.csv")
df["genre"] = df["filename"].str.split(".").str[0]
df["songname"] = df["filename"].str.split(".").str[:2].str.join(".")

<IPython.core.display.Javascript object>

In [5]:
keep_cols = [
    #     "chroma_stft_mean",
    #     "chroma_stft_var",
    #     "rms_var",
    #     "zero_crossing_rate_mean",
    #     "zero_crossing_rate_var",
    #     "harmony_mean",
    #     "harmony_var",
    #     "perceptr_mean",
    #     "tempo",
    "mfcc1_mean",
    "mfcc2_mean",
    "mfcc2_var",
    "mfcc3_mean",
    "mfcc3_var",
    "mfcc4_mean",
    "mfcc4_var",
    "mfcc5_var",
    "mfcc6_mean",
    "mfcc6_var",
    "mfcc7_mean",
    "mfcc8_mean",
    "mfcc8_var",
    "mfcc9_mean",
    "mfcc9_var",
    "mfcc10_var",
    "mfcc12_mean",
    "mfcc12_var",
    "mfcc13_mean",
    "mfcc15_mean",
    "mfcc15_var",
    "mfcc16_mean",
    "mfcc16_var",
    "mfcc17_mean",
    "mfcc18_mean",
    "mfcc19_mean",
    "mfcc19_var",
    "mfcc20_mean",
    "mfcc20_var",
]

<IPython.core.display.Javascript object>

In [6]:
X = df[keep_cols]
y = df["genre"]

<IPython.core.display.Javascript object>

In [7]:
X_logged = X.copy()
for c in X_logged:
    if c.endswith("_var"):
        X_logged[c] = np.log(X_logged[c])

<IPython.core.display.Javascript object>

In [8]:
scaler = StandardScaler()
std_X = scaler.fit_transform(X)

<IPython.core.display.Javascript object>

In [14]:
clst = GaussianMixture(n_components=10)
# clst = DBSCAN(eps=3.5, min_samples=11, n_jobs=-1)
# clst = AgglomerativeClustering(n_clusters=10, affinity="cosine", linkage="complete")
# clst = KMeans(n_clusters=10, n_jobs=-1,)
clusters = clst.fit_predict(std_X)
df["cluster"] = clusters

<IPython.core.display.Javascript object>

In [15]:
pd.Series(clusters).value_counts()

2    1700
7    1625
9    1295
1    1163
3    1016
5    1014
4     863
8     609
0     471
6     234
dtype: int64

<IPython.core.display.Javascript object>

In [18]:
means_df = pd.DataFrame(scaler.inverse_transform(clst.means_), columns=X.columns).T
# means_df = pd.DataFrame(
#     scaler.inverse_transform(clst.cluster_centers_), columns=X.columns
# ).T
means_df.style.background_gradient()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
mfcc1_mean,-188.600049,-82.024766,-71.366736,-231.721935,-265.770374,-191.42259,-281.561888,-132.710863,-199.636631,-62.372077
mfcc2_mean,116.297142,68.840596,94.7747,125.827638,125.277241,116.031975,102.534101,106.944918,104.129194,75.208804
mfcc2_var,707.850758,773.773328,280.503804,306.152728,443.536615,924.418503,916.822,604.509869,1138.701373,491.268704
mfcc3_mean,-9.267046,8.48671,-27.43126,-13.148157,-16.084087,-3.560494,-15.986558,-16.085996,-4.459446,3.502528
mfcc3_var,308.968287,496.568735,201.517299,176.653671,296.471327,513.258246,380.47026,421.278438,886.905482,304.74109
mfcc4_mean,41.550305,20.54089,54.235434,34.069382,32.365939,36.700295,28.493284,46.26076,38.579519,24.340871
mfcc4_var,210.934678,280.277115,109.343554,70.060626,116.889503,231.28062,185.828461,205.234793,374.886238,162.811604
mfcc5_var,216.91077,192.03811,82.019529,54.018951,97.10024,175.034695,131.761723,181.948669,303.62459,108.984834
mfcc6_mean,15.695822,4.690492,27.793516,10.986423,10.865701,10.789884,8.118496,25.38126,14.087394,8.301044
mfcc6_var,121.584473,176.239795,66.739607,44.462123,72.724738,113.881953,118.299109,128.708401,210.069916,85.588211


<IPython.core.display.Javascript object>

In [19]:
print(pd.crosstab(clusters, y))

genre  blues  classical  country  disco  hiphop  jazz  metal  pop  reggae  \
row_0                                                                       
0        156         15      148      3       6    13     19   17      46   
1          1          1       94    108     190    31      0  494     200   
2        235          2       74    209     108     6    763    0      16   
3         46        399       44     17      10   377     49   15       3   
4         57        390       28     41      15   234     14   34      22   
5        107         61      145     57     141   143      5   52     220   
6         14         97       12      4       0    51      8   38       5   
7        248          1      242    260     238    30    111    0     274   
8        136         13       62     24     125    39      8   55     119   
9          0         19      148    276     165    76     23  295      95   

genre  rock  
row_0        
0        48  
1        44  
2       287  
3    

<IPython.core.display.Javascript object>

In [13]:
df[df["cluster"] == 6][100:150]

Unnamed: 0,filename,length,chroma_stft_mean,chroma_stft_var,rms_mean,rms_var,spectral_centroid_mean,spectral_centroid_var,spectral_bandwidth_mean,spectral_bandwidth_var,...,mfcc18_mean,mfcc18_var,mfcc19_mean,mfcc19_var,mfcc20_mean,mfcc20_var,label,genre,songname,cluster
2370,country.00037.5.wav,66149,0.330232,0.082878,0.091321,0.000867,2364.242382,927143.0,2503.685955,181798.257431,...,-4.462796,25.490288,4.11252,44.979614,0.206159,38.576035,country,country,country.00037,2
2372,country.00037.7.wav,66149,0.384266,0.083647,0.077156,0.000841,2213.470078,758336.1,2356.729081,261155.543811,...,-4.639517,43.122768,-3.565802,48.842167,-2.060452,40.609848,country,country,country.00037,2
2395,country.00040.0.wav,66149,0.401978,0.088664,0.071637,0.001592,2793.773731,530241.5,2738.594643,154334.410749,...,-3.400161,29.142855,-3.017127,23.504848,-3.261614,39.045944,country,country,country.00040,2
2409,country.00041.4.wav,66149,0.401592,0.084197,0.116052,0.00058,2467.51152,699593.5,2927.146752,211648.885202,...,-8.206468,45.874363,-0.443196,31.986464,-2.576834,20.449451,country,country,country.00041,2
2410,country.00041.5.wav,66149,0.372751,0.082503,0.104552,0.000404,2599.347765,982827.2,2923.235637,289494.654567,...,-4.757769,30.37245,-1.339501,22.866653,-3.17878,18.819704,country,country,country.00041,2
2411,country.00041.6.wav,66149,0.388089,0.086669,0.117804,0.000643,2401.524568,737235.5,2790.481373,265956.976715,...,-5.286162,57.990452,0.054635,26.924347,-5.056355,35.412838,country,country,country.00041,2
2412,country.00041.7.wav,66149,0.365553,0.08429,0.119748,0.000578,2737.491673,637976.5,2858.141108,174252.353387,...,-9.512616,27.974392,-0.052922,34.233505,-3.119357,33.602215,country,country,country.00041,2
2437,country.00044.2.wav,66149,0.414964,0.084031,0.079983,0.000479,1071.823463,450511.8,1818.877496,443946.044007,...,3.12154,46.518963,1.733237,49.823795,-3.950626,28.564705,country,country,country.00044,2
2438,country.00044.3.wav,66149,0.291293,0.088386,0.149195,0.001361,2579.336949,628827.9,2714.346339,110427.695082,...,-5.624926,37.167961,5.237616,29.307474,-6.440866,70.347389,country,country,country.00044,2
2440,country.00044.5.wav,66149,0.349176,0.080785,0.111846,0.000666,1685.386899,136714.9,2360.604816,98695.241095,...,-0.406148,42.409367,9.223216,83.824142,-3.889648,68.344269,country,country,country.00044,2


<IPython.core.display.Javascript object>