In [15]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np


In [16]:
merged_data_path = '../data/merged_data.jsonl'
data = pd.read_json(merged_data_path, lines=True)

data = data.drop(
    columns=["release_date", "key", "loudness", "explicit",
             "danceability", "energy", "speechiness",
             "acousticness", "instrumentalness", "liveness", "valence", "tempo"])

# data = data.drop(
#     columns=["release_date", "genres", "favourite_genres", "name"])

data.head(5)

Unnamed: 0,popularity,duration_ms,favourite_genres,name,genres,skipped,number_of_matching_genres
0,34,247707,"[permanent wave, mandopop, funk]",T. Rex,"[album rock, art rock, classic rock, folk rock...",False,0
1,34,247707,"[filmi, regional mexican, folk]",T. Rex,"[album rock, art rock, classic rock, folk rock...",False,0
2,34,247707,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, folk rock...",False,1
3,35,140067,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, folk rock...",False,1
4,35,140067,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, folk rock...",False,1


In [17]:
unique_favourite_genres = set()
for genres in data['favourite_genres']:
  for genre in genres:
    unique_favourite_genres.add(genre)

unique_genres = set()
for genres in data['genres']:
  for genre in genres:
    unique_genres.add(genre)

# show all common genres between all favourite_genres and genres
print("unique_favourite_genres", len(unique_favourite_genres))
print("unique_genres", len(unique_genres))
common_genres = unique_favourite_genres.intersection(unique_genres)
print("common genres", common_genres)

# remove genres that are not in favourite_genres
data['genres'] = data['genres'].apply(
    lambda x: [genre for genre in x if genre in common_genres])

data.head()

# TODO remove ??? moze niepotrzebne

unique_favourite_genres 46
unique_genres 1766
common genres {'ranchera', 'soul', 'rock', 'vocal jazz', 'new wave pop', 'lounge', 'psychedelic rock', 'permanent wave', 'latin rock', 'country rock', 'new romantic', 'tropical', 'singer-songwriter', 'regional mexican', 'funk', 'folk', 'rock en espanol', 'album rock', 'mandopop', 'pop rock', 'mellow gold', 'alternative metal', 'metal', 'brill building pop', 'adult standards', 'classic rock', 'blues rock', 'europop', 'j-pop', 'turkish pop', 'alternative rock', 'soft rock', 'latin', 'hoerspiel', 'pop', 'c-pop', 'art rock', 'filmi', 'dance pop', 'mpb', 'latin alternative', 'new wave', 'quiet storm', 'hard rock', 'latin pop', 'motown'}


Unnamed: 0,popularity,duration_ms,favourite_genres,name,genres,skipped,number_of_matching_genres
0,34,247707,"[permanent wave, mandopop, funk]",T. Rex,"[album rock, art rock, classic rock, psychedel...",False,0
1,34,247707,"[filmi, regional mexican, folk]",T. Rex,"[album rock, art rock, classic rock, psychedel...",False,0
2,34,247707,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, psychedel...",False,1
3,35,140067,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, psychedel...",False,1
4,35,140067,"[psychedelic rock, country rock, rock en espanol]",T. Rex,"[album rock, art rock, classic rock, psychedel...",False,1


In [18]:
# Combine genres and favourite_genres
all_genres = list(data['favourite_genres'] + data['genres'])

# One-hot encode the genres
mlb = MultiLabelBinarizer()
mlb.fit(all_genres)

encoded_all_genres = mlb.fit_transform(all_genres)

# Split encoded_genres into genres and favourite_genres
# encoded_favourite_genres = encoded_all_genres[:, :len(data['favourite_genres'][0])]
# encoded_genres = encoded_all_genres[:, len(data['favourite_genres'][0]):]

encoded_favourite_genres = mlb.transform(data['favourite_genres'])
encoded_genres = mlb.transform(data['genres'])

popularity_normalized = data['popularity'].values.reshape(-1, 1)
popularity_normalized = (popularity_normalized - np.min(popularity_normalized)) / (
      np.max(popularity_normalized) - np.min(popularity_normalized))

duration_ms_normalized = data['duration_ms'].values.reshape(-1, 1)
duration_ms_normalized = (duration_ms_normalized - np.min(duration_ms_normalized)) / (
      np.max(duration_ms_normalized) - np.min(duration_ms_normalized))

number_of_matching_genres_normalized = data['number_of_matching_genres'].values.reshape(-1, 1)
number_of_matching_genres_normalized = (number_of_matching_genres_normalized - np.min(
  number_of_matching_genres_normalized)) / (np.max(number_of_matching_genres_normalized) - np.min(
  number_of_matching_genres_normalized))

# create data frame from data genres, data favourite_genres, encoded genres, encoded favourite_genres
df = pd.DataFrame(
    data={'popularity': data['popularity'],
          'popularity_normalized': popularity_normalized.reshape(-1),
          'duration_ms': data['duration_ms'],
          'duration_ms_normalized': duration_ms_normalized.reshape(-1),
          'number_of_matching_genres': data['number_of_matching_genres'],
          'number_of_matching_genres_normalized': number_of_matching_genres_normalized.reshape(-1),
          'skipped': data['skipped']})

df.head(5)

Unnamed: 0,popularity,popularity_normalized,duration_ms,duration_ms_normalized,number_of_matching_genres,number_of_matching_genres_normalized,skipped
0,34,0.354167,247707,0.092836,0,0.0,False
1,34,0.354167,247707,0.092836,0,0.0,False
2,34,0.354167,247707,0.092836,1,0.333333,False
3,35,0.364583,140067,0.046724,1,0.333333,False
4,35,0.364583,140067,0.046724,1,0.333333,False


In [19]:
# Concatenate the one-hot encoded columns
X = np.concatenate([encoded_favourite_genres, encoded_genres], axis=1)
# X = np.concatenate([popularity_normalized, duration_ms_normalized, number_of_matching_genres_normalized], axis=1)
# X = data.drop(columns=['skipped']).values

# Extract the labels
y = data['skipped'].astype(int).values

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [20]:
rf_model = LogisticRegression()
rf_model.fit(X_train, y_train)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [21]:
print("TEST")
y_pred = rf_model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Confusion matrix:\n", confusion_matrix(y_test, y_pred))
print("Classification report:\n", classification_report(y_test, y_pred))

print("TRAIN")
y_pred = rf_model.predict(X_train)
print("Accuracy:", accuracy_score(y_train, y_pred))
print("Confusion matrix:\n", confusion_matrix(y_train, y_pred))
print("Classification report:\n", classification_report(y_train, y_pred))


TEST
Accuracy: 0.6868993424380374
Confusion matrix:
 [[1080  181]
 [ 438  278]]
Classification report:
               precision    recall  f1-score   support

           0       0.71      0.86      0.78      1261
           1       0.61      0.39      0.47       716

    accuracy                           0.69      1977
   macro avg       0.66      0.62      0.63      1977
weighted avg       0.67      0.69      0.67      1977

TRAIN
Accuracy: 0.701417004048583
Confusion matrix:
 [[4342  657]
 [1703 1202]]
Classification report:
               precision    recall  f1-score   support

           0       0.72      0.87      0.79      4999
           1       0.65      0.41      0.50      2905

    accuracy                           0.70      7904
   macro avg       0.68      0.64      0.65      7904
weighted avg       0.69      0.70      0.68      7904

