In [276]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer, make_column_transformer, make_column_selector
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import jaccard_score, hamming_loss, multilabel_confusion_matrix, ConfusionMatrixDisplay
from sklearn.impute import SimpleImputer, KNNImputer
import numpy as np

In [219]:
data = pd.read_parquet("data/features.csv").set_index("track_id")

In [220]:
data.head()

Unnamed: 0_level_0,danceability,duration,energy,key,loudness,mode,tempo,year,genres
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
TRBGGLA128F149C2EB,0.0,262.3473,0.0,11,-7.673,0,89.912,0,"[uk garage, hip hop, gangster rap, rap, electr..."
TRBGGRU12903CAAA2D,0.0,159.9473,0.0,0,-32.535,1,72.357,0,"[country rock, western swing, gypsy jazz, sing..."
TRBGGOS128F9307FC5,0.0,250.77506,0.0,10,-6.188,0,177.963,2002,"[neo soul, hip hop, vocal jazz, progressive ho..."
TRBGGTE128F424ECBC,0.0,159.26812,0.0,0,-10.599,0,80.945,0,"[ccm, christian rock, gospel, blues, pop, soul..."
TRBGGOT128F932DC65,0.0,171.98975,0.0,2,-8.002,1,181.002,0,"[progressive house, electronic, latin]"


## Preprocessing

Remove outliers and incorrect data, filter by genres.

In [242]:
def get_top_genres(df: pd.DataFrame, k=50) -> tuple[list[str], list[int]]:
    genre_counts = df["genres"].explode().value_counts()
    genre_names = list(genre_counts.index)
    genre_counts = list(genre_counts)
    return genre_names[:k], genre_counts[:k]

In [243]:
top_genre_names, top_genre_counts = get_top_genres(data)

In [244]:
data["genres_filtered"] = data["genres"].apply(lambda genres: [genre for genre in genres if genre in top_genre_names])

In [245]:
data_preprocessed = data[data["genres_filtered"].map(len) > 0]
data_preprocessed.head()

Unnamed: 0_level_0,danceability,duration,energy,key,loudness,mode,tempo,year,genres,genres_filtered
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
TRBGGLA128F149C2EB,0.0,262.3473,0.0,11,-7.673,0,89.912,0,"[uk garage, hip hop, gangster rap, rap, electr...","[hip hop, rap, electronic, house]"
TRBGGRU12903CAAA2D,0.0,159.9473,0.0,0,-32.535,1,72.357,0,"[country rock, western swing, gypsy jazz, sing...","[country rock, singer songwriter, country, pun..."
TRBGGOS128F9307FC5,0.0,250.77506,0.0,10,-6.188,0,177.963,2002,"[neo soul, hip hop, vocal jazz, progressive ho...","[hip hop, chill out, blues rock, country rock,..."
TRBGGTE128F424ECBC,0.0,159.26812,0.0,0,-10.599,0,80.945,0,"[ccm, christian rock, gospel, blues, pop, soul...","[blues, pop, soul, rock, folk]"
TRBGGOT128F932DC65,0.0,171.98975,0.0,2,-8.002,1,181.002,0,"[progressive house, electronic, latin]","[electronic, latin]"


In [246]:
data_preprocessed = data_preprocessed[["duration", "key", "loudness", "mode", "tempo", "year", "genres_filtered"]].rename(columns={"genres_filtered": "genres"})

In [247]:
data_preprocessed.head()

Unnamed: 0_level_0,duration,key,loudness,mode,tempo,year,genres
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
TRBGGLA128F149C2EB,262.3473,11,-7.673,0,89.912,0,"[hip hop, rap, electronic, house]"
TRBGGRU12903CAAA2D,159.9473,0,-32.535,1,72.357,0,"[country rock, singer songwriter, country, pun..."
TRBGGOS128F9307FC5,250.77506,10,-6.188,0,177.963,2002,"[hip hop, chill out, blues rock, country rock,..."
TRBGGTE128F424ECBC,159.26812,0,-10.599,0,80.945,0,"[blues, pop, soul, rock, folk]"
TRBGGOT128F932DC65,171.98975,2,-8.002,1,181.002,0,"[electronic, latin]"


In [248]:
data_preprocessed.describe()

Unnamed: 0,duration,key,loudness,mode,tempo,year
count,9593.0,9593.0,9593.0,9593.0,9593.0,9593.0
mean,238.916854,5.284061,-10.426201,0.692171,123.134846,966.01647
std,113.060176,3.555396,5.381839,0.461619,35.064185,998.152209
min,1.04444,0.0,-51.643,0.0,0.0,0.0
25%,176.97914,2.0,-13.078,0.0,97.039,0.0
50%,223.4771,5.0,-9.325,1.0,120.253,0.0
75%,276.4273,8.0,-6.5,1.0,144.076,2000.0
max,1819.76771,11.0,0.566,1.0,262.828,2010.0


### Fix outliers

In [252]:
data_preprocessed["year"] = data_preprocessed["year"].replace(0, np.nan)

In [272]:
valid_tempo_min = 70
valid_tempo_max = 180
def fix_tempo(tempo_val: float) -> float:
    if tempo_val == 0:
        return np.nan
    elif tempo_val > valid_tempo_max:
        return tempo_val / 2
    elif tempo_val < valid_tempo_min:
        return tempo_val * 2
    return tempo_val
data_preprocessed["tempo"] = data_preprocessed["tempo"].map(fix_tempo)

In [273]:
data_preprocessed

Unnamed: 0_level_0,duration,key,loudness,mode,tempo,year,genres
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
TRBGGLA128F149C2EB,262.34730,11,-7.673,0,89.912,,"[hip hop, rap, electronic, house]"
TRBGGRU12903CAAA2D,159.94730,0,-32.535,1,72.357,,"[country rock, singer songwriter, country, pun..."
TRBGGOS128F9307FC5,250.77506,10,-6.188,0,177.963,2002.0,"[hip hop, chill out, blues rock, country rock,..."
TRBGGTE128F424ECBC,159.26812,0,-10.599,0,80.945,,"[blues, pop, soul, rock, folk]"
TRBGGOT128F932DC65,171.98975,2,-8.002,1,90.501,,"[electronic, latin]"
...,...,...,...,...,...,...,...
TRAMMAF128F93051B0,161.85424,7,-4.665,1,130.862,2001.0,"[electronica, disco, funk, soundtrack, ska, el..."
TRAMMTK128F4279CD2,56.71138,11,-25.063,0,94.842,,"[classical, folk]"
TRAMMDO128E0781A5D,328.72444,11,-17.428,0,88.928,1984.0,"[pop rock, new wave, disco, electronic, downte..."
TRAMMDM128F425EAA9,287.50322,6,-6.748,1,85.964,2007.0,"[hip hop, reggae, rap, funk, jazz, pop, rock, ..."


In [274]:
data_preprocessed.describe()

Unnamed: 0,duration,key,loudness,mode,tempo,year
count,9593.0,9593.0,9593.0,9593.0,9570.0,4640.0
mean,238.916854,5.284061,-10.426201,0.692171,118.49933,1997.197414
std,113.060176,3.555396,5.381839,0.461619,25.82657,11.700133
min,1.04444,0.0,-51.643,0.0,32.516,1926.0
25%,176.97914,2.0,-13.078,0.0,96.71325,1993.0
50%,223.4771,5.0,-9.325,1.0,116.826,2001.0
75%,276.4273,8.0,-6.5,1.0,136.845,2005.0
max,1819.76771,11.0,0.566,1.0,179.999,2010.0


## Feature extraction

In [279]:
FEATURE_COLS = ["duration", "key", "loudness", "mode", "tempo", "year"]
NUMERICAL_COLS = ["duration", "loudness", "tempo", "year"]
BINARY_COLS = ["mode"]
CATEGORICAL_COLS = ["key"]
LABEL_COL = "genres"

In [278]:
imputer = KNNImputer(n_neighbors=2, weights="distance")

In [280]:
imputer.fit(data_preprocessed[FEATURE_COLS])

In [283]:
imputed_data = imputer.transform(data_preprocessed[FEATURE_COLS])

In [287]:
imputed_data_df = pd.DataFrame(imputed_data, columns=imputer.get_feature_names_out())
imputed_data_df[LABEL_COL] = data_preprocessed[LABEL_COL]

In [288]:
features = data_preprocessed[FEATURE_COLS]
labels = data_preprocessed[LABEL_COL]

In [289]:
SEED = 42

In [290]:
train_data, test_data = train_test_split(data_preprocessed, test_size=0.2, random_state=SEED)


In [291]:
print(len(train_data))
print(len(test_data))

7674
1919


## Label encoding

In [294]:
mlb = MultiLabelBinarizer()
y_train = mlb.fit_transform(train_data[LABEL_COL])
print(y_train.shape)

(7674, 50)


## Baseline

Predict top 5 most popular genres for all samples

In [295]:
top_5_encoded = mlb.transform([top_genre_names[:5]])[0]
print(top_5_encoded)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0
 0 1 0 0 0 0 0 0 0 0 0 0 0]


In [296]:
y_pred = [top_5_encoded] * len(test_data)

In [297]:
y_true = mlb.transform(test_data['genres'])

In [298]:
jaccard_score(y_true, y_pred, average="samples")

np.float64(0.1995286374138877)

In [299]:
hamming_loss(y_true, y_pred)

0.1907764460656592

## Random Forest Classifier

In [300]:
rfc = RandomForestClassifier(random_state=42)

In [314]:
# Transform columns of pandas dataframe to features
ct = make_column_transformer(
    (MinMaxScaler(), NUMERICAL_COLS),
    # (OneHotEncoder(), CATEGORICAL_COLS),
    ('passthrough', BINARY_COLS),
    remainder='drop'
)

In [315]:
X_train = ct.fit_transform(train_data)

In [316]:
y_train = mlb.fit_transform(train_data[LABEL_COL])

In [317]:
rfc.fit(X_train, y_train)

In [318]:
## Evaluate
X_test = ct.fit_transform(test_data)
y_test = mlb.transform(test_data[LABEL_COL])

In [319]:
rfc.score(X_test, y_test)

0.0005211047420531526

In [320]:
y_pred = rfc.predict(X_test)

In [321]:
y_pred

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [322]:
jaccard_score(y_true, y_pred, average="samples")

np.float64(0.16742605693455878)

In [323]:
hamming_loss(y_true, y_pred)

0.18427305888483586

In [None]:
mlb.inverse_transform(y_pred)

In [None]:
mlb.inverse_transform(np.array(y_true))