In [1]:
import sys

import numpy as np
import pandas as pd

sys.path.append("../../")

from helpers.split import tag_label_feature_split, label_strings
from sklearn.utils.class_weight import compute_class_weight, compute_sample_weight

In [2]:
df = pd.read_pickle("../../datasets/baseline_dataset.pickle")
_, one_hot_labels, _ = tag_label_feature_split(df)

In [3]:
one_hot_labels

Unnamed: 0,genre_blues,genre_classical,genre_country,genre_disco,genre_hiphop,genre_jazz,genre_metal,genre_pop,genre_reggae,genre_rock
2,0,0,0,0,0,0,0,1,0,0
6,0,0,0,0,0,0,0,0,0,1
7,0,0,0,0,0,0,0,1,0,0
10,0,0,0,0,0,0,0,0,0,1
12,0,1,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...
55202,0,0,0,0,0,0,0,1,0,0
55204,0,0,0,0,0,1,0,0,0,0
55205,0,0,0,0,0,1,0,0,0,0
55208,0,0,0,0,0,0,0,1,0,0


Get the DataFrame with Label Strings

In [4]:
labels = label_strings(one_hot_labels)
labels

Unnamed: 0,label
2,genre_pop
6,genre_rock
7,genre_pop
10,genre_rock
12,genre_classical
...,...
55202,genre_pop
55204,genre_jazz
55205,genre_jazz
55208,genre_pop


Cast the labels as an array

In [5]:
labels = labels.label.ravel()
labels

array(['genre_pop', 'genre_rock', 'genre_pop', ..., 'genre_jazz',
       'genre_pop', 'genre_jazz'], dtype=object)

Confirm the unique values look right

In [6]:
np.unique(labels)

array(['genre_blues', 'genre_classical', 'genre_country', 'genre_disco',
       'genre_hiphop', 'genre_jazz', 'genre_metal', 'genre_pop',
       'genre_reggae', 'genre_rock'], dtype=object)

Estimate the class weights for our dataset

In [7]:
weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
weights

array([4.56217949, 0.40468158, 5.2202934 , 5.91440443, 1.00428034,
       1.37570876, 2.56931408, 0.39960696, 2.16761421, 0.533775  ])

Make a dictionary with the weights

In [8]:
dict(zip(np.unique(labels), weights))

{'genre_blues': 4.562179487179487,
 'genre_classical': 0.40468157695223655,
 'genre_country': 5.220293398533007,
 'genre_disco': 5.914404432132964,
 'genre_hiphop': 1.0042803386641581,
 'genre_jazz': 1.375708762886598,
 'genre_metal': 2.5693140794223828,
 'genre_pop': 0.399606962380685,
 'genre_reggae': 2.1676142131979694,
 'genre_rock': 0.533775}