# Train

In [1]:
from Packages.spotifyclassifier import *
from Packages.run import *
from Packages.wrangling import *
from Packages.constants import LABEL_LIST
from Packages.serialize import save_classifier

### Prepare Training Data

The metadata, feature space, and genres are read in and stored in the "songs" dictionary by `track_id`.

In [2]:
default_song_data = FrozenMap(read_data('provide your data path', True))
print('Congrats! Song data reading completed.')

HBox(children=(IntProgress(value=1, bar_style='info', description='Reading data from .csv...', max=1, style=Pr…


Congrats! Song data reading completed.


In [3]:
default_feature_names = tuple(next(iter(default_song_data.values()))['features'].keys())

In [4]:
# Sanity check
sanity_check(default_song_data)

 likes: 704 
 dislikes: 704


In [5]:
test_cluster_size(default_song_data, 10)

2 649.623252054139
3 560.8935890551633
4 498.55679547133707
5 467.2898260769377
6 442.46535321389223
7 421.4029874411187
8 401.4857572570418
9 386.61502344885093
10 372.6527993692656


Compute clusters from training data for algorithms that use them.

In [6]:
clustered_song_data, songs_by_cluster = {}, {}
clustered_song_data, songs_by_cluster = get_kmeans_clusters(default_song_data, NUM_CLUSTERS)

In [7]:
# sanity check
if set(next(iter(default_song_data.values()))['features'].keys()) == set(next(iter(clustered_song_data.values()))['features'].keys()):
    raise ValueError('Default features messed up.')

In [8]:
default_training_data, default_validation_data, clustered_training_data, clustered_validation_data = get_experiment_split(default_song_data, clustered_song_data)
training_clusters = get_train_clusters(clustered_training_data, songs_by_cluster[NUM_CLUSTERS])

### Train Classifiers

In [9]:
print('----\nUnclustered\n----')
active_unclustered_results = run_active_suite(default_song_data, default_training_data, default_validation_data, SUPPORTED_ALGS, AL_STRATS)

----
Unclustered
----
svc accs w/ random
50.53	65.12	49.46	70.81	49.46	49.46	49.46	70.46	49.46	71.88	49.46	70.81	49.46	71.17	49.46	69.03	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	58.36	69.03	59.78	58.36	55.87	57.29	55.51	51.24	51.24	54.8	53.38	57.29	59.78	63.34	60.14	62.27	68.68	62.63	60.14	68.32	67.61	67.61	64.41	61.92	67.61	67.25	67.25	62.27	67.25	67.61	70.1	68.68	69.75	69.03	66.19	64.05	62.98	61.92	62.98	65.48	62.63	64.41	63.34	63.34	64.41	64.76	64.05	65.12	66.19	66.19	64.41	64.41	64.41	64.76	65.12	63.7	62.98	62.63	63.34	64.05	66.19	66.19	68.68	70.1	67.97	69.75	70.1	67.97	67.97	69.39	69.03	67.61	69.03	
svc accs w/ uncertainty
49.46	58.0	50.53	50.53	50.53	65.48	50.53	70.1	49.46	62.98	50.53	67.97	52.66	64.05	60.49	66.54	66.19	65.83	70.1	64.76	69.39	54.8	71.53	65.48	71.17	65.48	71.88	59.78	72.59	62.27	72.24	62.27	52.31	61.92	72.24	63.34	55.87	67.25	71.88	66.54	71.17	67.25	70.46	70.81	71.53	68.68	66.9	71.17	65.83	71.53	71.17	67.97	61.2	55.16	60.49	69.03	61.



lsvc accs w/ random
49.46	57.29	50.88	55.87	58.71	54.8	63.34	63.34	59.43	61.2	60.49	69.03	66.19	67.25	67.25	69.75	69.75	69.75	69.75	69.39	69.75	70.46	71.53	71.17	69.39	67.97	70.81	71.17	71.17	69.03	73.3	71.53	71.53	71.53	71.53	71.53	71.53	71.17	71.17	71.17	70.46	69.75	69.03	69.03	69.03	70.46	68.68	68.68	70.1	72.59	71.17	71.17	71.88	72.59	72.95	72.95	72.95	72.95	72.24	72.95	71.53	71.17	72.95	72.59	70.46	72.24	71.88	71.17	71.17	72.95	72.24	71.88	72.59	72.59	72.59	72.24	71.53	71.88	71.17	71.88	72.24	72.24	73.66	72.95	73.66	73.66	73.66	74.02	74.02	74.02	75.8	75.44	76.15	75.44	75.44	75.44	75.08	73.66	73.66	73.3	73.66	
lsvc accs w/ uncertainty
49.46	48.75	49.11	49.82	47.33	50.17	56.93	58.0	58.71	63.7	65.48	67.97	69.03	66.54	71.17	71.88	72.59	73.3	73.3	74.37	73.66	73.3	72.95	69.75	64.05	65.12	63.7	67.61	67.61	65.83	67.25	65.12	64.76	68.32	66.9	66.54	67.61	68.68	67.25	68.68	68.32	67.61	69.75	67.61	69.75	70.46	70.81	70.46	71.88	73.3	75.08	74.02	71.88	72.59	71.53	71.17	71.88	71.88	70.81	72.59	71

In [10]:
print('----\nClustered\n----')
active_clustered_results = run_active_suite(default_song_data, clustered_training_data, clustered_validation_data, SUPPORTED_ALGS, AL_STRATS)

----
Clustered
----
svc accs w/ random
50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	66.54	49.46	66.9	50.53	50.53	50.53	68.68	49.46	49.46	49.46	49.46	49.46	70.81	49.46	66.9	56.58	50.88	66.54	53.73	65.48	50.88	50.88	57.65	69.03	65.83	67.97	59.07	56.58	53.38	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.53	50.88	50.88	50.88	50.88	50.88	50.88	50.88	51.24	50.88	50.88	50.53	50.53	50.53	50.53	
svc accs w/ uncertainty
50.53	67.61	49.46	61.92	50.53	64.41	49.46	59.78	50.53	61.92	58.0	62.27	54.09	65.48	55.16	51.6	49.46	50.53	56.93	52.31	58.36	64.05	58.36	62.98	64.76	63.7	66.54	65.48	56.58	66.54	58.36	64.41	59.78	52.31	60.49	65.12	64.41	65.83	63.7	64.76	66.54	65.48	66.9	66.54	69.39	67.25	65.83	64.05	66.19	65.12	68.32	66.19	66.19	63.7	64.76	62.63	63

In [11]:
print('\n----\nClustered w/ Cluster Sampling\n----')
active_cluster_sampled_results = run_clusters_suite(default_song_data, clustered_training_data, clustered_validation_data, training_clusters)


----
Clustered w/ Cluster Sampling
----
sgd accs w/ random
50.53	64.05	51.95	55.51	50.53	58.71	62.63	61.92	71.88	71.88	71.88	71.88	68.68	68.68	68.68	59.43	70.81	70.81	55.51	71.53	50.53	50.53	71.53	73.3	73.3	73.3	
sgd accs w/ uncertainty
66.19	51.24	68.68	60.49	55.87	62.27	65.12	66.54	64.76	65.12	60.85	66.19	64.76	70.46	67.25	69.03	74.02	65.12	67.25	71.17	67.25	67.25	69.03	69.03	72.59	70.81	


In [12]:
best_classifier = get_highest_benchmark(
    default_song_data,
    default_training_data,
    default_validation_data,
    clustered_training_data,
    clustered_validation_data
)
save_classifier(best_classifier, None) # specify filename instead of None if desired

print('Classifier saved successfully.')



Successfully saved classifier: ./classifiers\YOUR PKL FILE
Classifier saved successfully.
