In [48]:
import numpy as np
import pandas as pd
from numpy import array
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split,cross_val_score,cross_validate,GridSearchCV
from sklearn.metrics import mean_squared_error, multilabel_confusion_matrix, classification_report

In [45]:
df = pd.read_csv('./SpotifyFeatures_new.csv')

# Drop columns
dropped_columns = ['genre', 'artist_name', 'track_name','track_id', 'mode', 'duration_ms', 'tempo', 'energy', 'acousticness']
df.drop(columns=dropped_columns, inplace=True)

# Normalize data
X, y = df.drop(columns='popularity'), df['popularity']
X = (X-X.min())/(X.max()-X.min())

y = ((y-1)//20)+1

print(y)
print(X)

0         1
1         1
2         1
3         1
4         1
         ..
169679    5
169680    5
169681    5
169682    5
169683    5
Name: popularity, Length: 169684, dtype: int64
        danceability  instrumentalness  liveness  loudness  speechiness  \
0           0.541895          0.029930  0.095251  0.808082     0.024661   
1           0.503272          0.621622  0.369907  0.745076     0.302498   
2           0.268319          0.283283  0.509254  0.730983     0.014712   
3           0.477524          0.004755  0.113427  0.778349     0.004551   
4           0.423882          0.376376  0.097271  0.701589     0.016617   
...              ...               ...       ...       ...          ...   
169679      0.754318          0.000000  0.061222  0.834202     0.025826   
169680      0.729643          0.000002  0.048196  0.886194     0.039268   
169681      0.717841          0.000000  0.097271  0.839256     0.073561   
169682      0.832636          0.000002  0.092222  0.801730     0.164903

In [49]:
gridSearch = GridSearchCV(MLPClassifier(solver = 'sgd', random_state=1), {
    'activation':["logistic","relu"],
    'hidden_layer_sizes': [(12,3),(16,8),(16,12,8)],
    'learning_rate_init':[0.2],
    'max_iter': [400]
}, cv=10, return_train_score=False)
test_result = gridSearch.fit(X, y)
print(test_result)
print(test_result.best_score_)
print(test_result.best_params_)

GridSearchCV(cv=10, estimator=MLPClassifier(random_state=1, solver='sgd'),
             param_grid={'activation': ['logistic', 'relu'],
                         'hidden_layer_sizes': [(12, 3), (16, 8), (16, 12, 8)],
                         'learning_rate_init': [0.2], 'max_iter': [400]})
0.5123882735356031
{'activation': 'logistic', 'hidden_layer_sizes': (16, 12, 8), 'learning_rate_init': 0.2, 'max_iter': 400}


In [50]:
tuning_df = pd.DataFrame(gridSearch.cv_results_)
tuning_df

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_activation,param_hidden_layer_sizes,param_learning_rate_init,param_max_iter,params,split0_test_score,...,split3_test_score,split4_test_score,split5_test_score,split6_test_score,split7_test_score,split8_test_score,split9_test_score,mean_test_score,std_test_score,rank_test_score
0,36.676519,7.99364,0.008597,0.001751,logistic,"(12, 3)",0.2,400,"{'activation': 'logistic', 'hidden_layer_sizes...",0.41788,...,0.528434,0.52764,0.5376,0.523927,0.520686,0.510844,0.483557,0.502476,0.034875,5
1,55.003835,8.181963,0.011493,0.002865,logistic,"(16, 8)",0.2,400,"{'activation': 'logistic', 'hidden_layer_sizes...",0.422771,...,0.527668,0.534889,0.555929,0.52487,0.512966,0.504243,0.491337,0.507179,0.035022,2
2,65.791935,9.862308,0.01168,0.000903,logistic,"(16, 12, 8)",0.2,400,"{'activation': 'logistic', 'hidden_layer_sizes...",0.432023,...,0.53639,0.539663,0.557874,0.534536,0.527228,0.513496,0.49334,0.512388,0.035935,1
3,23.597936,20.946998,0.007681,0.002012,relu,"(12, 3)",0.2,400,"{'activation': 'relu', 'hidden_layer_sizes': (...",0.416524,...,0.523013,0.528701,0.524399,0.523868,0.513142,0.515618,0.495521,0.498651,0.036489,6
4,19.846614,5.213739,0.008592,0.001568,relu,"(16, 8)",0.2,400,"{'activation': 'relu', 'hidden_layer_sizes': (...",0.423596,...,0.530791,0.527405,0.527935,0.525165,0.52599,0.505717,0.490865,0.503784,0.033529,4
5,26.713309,8.445587,0.01049,0.001854,relu,"(16, 12, 8)",0.2,400,"{'activation': 'relu', 'hidden_layer_sizes': (...",0.415699,...,0.526077,0.527169,0.551332,0.531942,0.518211,0.509017,0.492751,0.505016,0.036449,3
