In [1]:
import json
import os
import re
from utils import *

In [2]:
track_uri_to_id, artist_uri_to_id, all_track_uris = get_training_set_info()

In [3]:
def process_playlist_for_test(playlist, test_seeds_num, seed_playlists):
    tracks = []
    artists = []
    for track in playlist['tracks']:
        track_uri  = track['track_uri'].split(':')[2]
        artist_uri = track['artist_uri'].split(':')[2]
        # not consider tracks that did not appear in the training set.
        if track_uri not in all_track_uris:
            continue
            
        track_id = track_uri_to_id.get(track_uri, -1)
        tracks.append(track_id)
        artist_id = artist_uri_to_id.get(artist_uri, -1)
        artists.append(artist_id)

    if len(tracks) <= test_seeds_num:
        return
    n_answers = len(tracks) - test_seeds_num

    #if there are less than 30 or more than 100 tracks that are in the set to be predicted 
    #then ignore playlist
    if n_answers < 30 or n_answers > 100:
        return
    
    seeds_tracks            = []
    seeds_artists           = []
    tracks_to_predict       = []

    for track, artist in zip(tracks[:test_seeds_num], artists[:test_seeds_num]):
        if track != -1:
            seeds_tracks.append(track)
        if artist != -1:
            seeds_artists.append(artist)

    for track in tracks[test_seeds_num:]:
        if (track not in seeds_tracks) and (track == -1 or track not in tracks_to_predict):
            tracks_to_predict.append(track)
            

    #get vector representation of playlist name
    name               = normalize_name(playlist['name'])
    title_char_indices = title_to_indices(name, MAX_TITLE_LEN)

    seed_playlists.append([seeds_tracks, seeds_artists, title_char_indices, tracks_to_predict])

In [4]:
test_seeds_nums = [1, 5, 10, 25, 100]
for test_seeds_num in test_seeds_nums:
    print('Creating test set for %d starting tracks' %test_seeds_num)
    playlists = list()
    filenames = os.listdir(test_fullpaths)
    for filename in sorted(filenames):
         if filename.startswith("mpd.slice.") and filename.endswith(".json"):
            fullpath = os.sep.join((test_fullpaths, filename))
            f = open(fullpath)
            js = f.read()
            f.close()
            mpd_slice = json.loads(js)
            #add playlists to test set
            for playlist in mpd_slice["playlists"]:
                process_playlist_for_test(playlist, test_seeds_num, playlists)

    file_data = {}
    file_data['playlists'] = playlists
    name = 'test-'+str(test_seeds_num)
    if not os.path.isdir(test_data_dir):
        os.mkdir(test_data_dir)
    with open(test_data_dir+'/'+name, 'w') as make_file:
        json.dump(file_data, make_file, indent="\t")
    print("Number of playlists: %d" % len(playlists))

Creating test set for 1 starting tracks
Number of playlists: 23568
Creating test set for 5 starting tracks
Number of playlists: 21850
Creating test set for 10 starting tracks
Number of playlists: 19959
Creating test set for 25 starting tracks
Number of playlists: 15236
Creating test set for 100 starting tracks
Number of playlists: 5058
