In [2]:
from utils import *

In [86]:
import pickle
import pandas as pd
from scipy.sparse import csr_matrix
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from sklearn import metrics
import random
import timeit

In [4]:
with open( "train_test_data.pkl", "rb") as f:
    X_train, y_train, X_test, y_test = pickle.load(f)

In [5]:
X_train = pd.DataFrame(X_train.values(),index = X_train.keys()) # convert from dict to pd DataFrame
X_train['tracks'] = X_train.values.tolist() #stack all columns into one
X_train = X_train[['tracks']]

In [7]:
# create one hot encoding for all the train X_all_onehot.shapetracks
X_train_onehot = pd.get_dummies(pd.DataFrame(X_train['tracks'].values.tolist()), prefix='', prefix_sep='').groupby(axis=1, level=0).max()
X_train_onehot = X_train_onehot.set_index(X_train.index)
X_train_onehot.shape #7473 101217

(7473, 101217)

In [9]:
tracks = list(X_train_onehot.columns) #all the tracks appear in the X_train

In [10]:
track_playlist_data = csr_matrix(X_train_onehot) #prepare the data matrix for the implicit package

In [12]:
X_train.head()

Unnamed: 0,tracks
323_381,"[spotify:track:6WhzFzROw3aq3rPWjgYlxr, spotify..."
323_507,"[spotify:track:7DD1ojeTUwnW65g5QuZw7X, spotify..."
323_543,"[spotify:track:61LtVmmkGr8P9I2tSPvdpf, spotify..."
323_570,"[spotify:track:0r9knJtQ6VVpX324mtkLcX, spotify..."
323_682,"[spotify:track:5vIu19A3EEdHgFM4Cba6F4, spotify..."


In [66]:
def overlap_score(tracks, pred_tracks, test_size = 10):
    ''' Computes the overlap score for tracks and pred_tracks. 
        returns #overlap'''
    return sum([a in tracks for a in pred_tracks])

def avg_overlap(true_dict, pred):
    ''' Returns the accuracy score given true_label and pred'''
    avg_overlap = np.mean([overlap_score(a, b) for a,b in zip(true_dict.values(), pred)])    
    return avg_overlap

def dcg(predictions, labels, test_size = 100):
    ''' Calculates the discounted cumulative gain for prediction and labels. 
        Inputs:
            Prediction: list of predictions 
            labels: list of actual labels
            test_size: size of each of the two sets'''
    zero_one_label = [predictions[i] in labels for i in range(len(predictions))]
    zero_one_label = [zero_one_label[i]/np.log2(i+2) for i in range(len(zero_one_label))]
    return np.sum(zero_one_label)

In [17]:
#create cv index
cv = 2
l = X_train.shape[0]
def chunks(l, n):
    # For item i in a range that is a length of l,
    for i in range(0, l, l//n):
        # Create an index range for l of n items:
        linspace = np.linspace(i,i+(l//n),(l//n))
        index = [int(j) for j in list(linspace)]
        yield index     
train_cv_ind = list(chunks(l,cv)) # create a list of 2 indices for the cv

In [97]:
import implicit
# choosing factor size and iteration size
iters = range(1,20,6)
facts = range(50,250,50)
overlap_val_iter = {}
overlap_train_iter = {}

for fac in facts:
    for it in iters:
        # the following two lists will store the overlap score for each fold
        overlap_val_cv = []
        overlap_train_cv = []
        for cv_ind in range(cv):
            ind = [j for j in range(cv) if j != cv_ind]
            # save the training indices to total_ind
            total_ind = np.concatenate([train_cv_ind[k] for k in ind] )
            X = track_playlist_data[total_ind]
            # save the validation indices to val_ind
            val_ind = train_cv_ind[cv_ind]
            X_val = track_playlist_data[val_ind]
            
            # fit the training data
            model_mf = implicit.als.AlternatingLeastSquares(factors = fac,iterations = it)
            model_mf.fit(X.T)
            
            prec = []
            # loop for each playlist in the training data and calculate the overlap score between the true tracks and predictions
            for pl in range(len(total_ind)):
                playlistid = pl
                # use the model_mf to recommend 10 tracks for playlistid 
                recommendations = model_mf.recommend(playlistid, X,filter_already_liked_items = True)
                playlist_pred = [tracks[i] for i in list(zip(*recommendations))[0]]
                # get the dictionary key for the playlist with id total_ind[pl] from y_train
                key = list(y_train.keys())[total_ind[pl]]
                # get the true tracks from y_train
                playlist_true = y_train[key]
                prec.append(overlap_score(playlist_true,playlist_pred)) 
            overlap_train_cv.append(np.mean(prec))
            
            
            
            model_val = implicit.als.AlternatingLeastSquares(factors = fac,iterations = it)
            model_val.fit(X_val.T)
            # rank the tracks for the validation playlists by multiplying the user_factors from the validation data
            # and the item_factors from the training data
            recommendations_val = model_val.user_factors@model_mf.item_factors.T
            prec_val = []
            
            # loop for each playlist in the validation data and calculate the overlap score between the true tracks and predictions
            for pl in range(len(val_ind)):
                playlistid = pl
                # recommend the top 10 tracks that have the highest score for the playlist with id pl
                recommendations = recommendations_val[pl].argsort()[-10:][::-1]
                playlist_pred = [tracks[i] for i in recommendations]
                # get the dictionary key for the playlist with id val_ind[pl] from y_train
                key = list(y_train.keys())[val_ind[pl]]
                # get the true tracks from y_val (y_train[key])
                playlist_true = y_train[key]
                prec_val.append(overlap_score(playlist_true,playlist_pred))
            overlap_val_cv.append(np.mean(prec_val))

        overlap_train_iter[(fac,it)] = np.mean(overlap_train_cv)
        overlap_val_iter[(fac,it)] = np.mean(overlap_val_cv)
        print(overlap_train_iter)
        print(overlap_val_iter)
        


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


{(50, 1): 0.14079229122055675}
{(50, 1): 0.04630620985010707}


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466}


HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006}


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984}


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439}


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906}


HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741}


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394}


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232}


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974}


HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006}


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205, (150, 19): 0.6429336188436832}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006, (150, 19): 0.05085653104925054}


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205, (150, 19): 0.6429336188436832, (200, 1): 0.24745717344753748}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006, (150, 19): 0.05085653104925054, (200, 1): 0.03894539614561028}


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, max=7), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205, (150, 19): 0.6429336188436832, (200, 1): 0.24745717344753748, (200, 7): 0.6498929336188437}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006, (150, 19): 0.05085653104925054, (200, 1): 0.03894539614561028, (200, 7): 0.04349571734475374}


HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205, (150, 19): 0.6429336188436832, (200, 1): 0.24745717344753748, (200, 7): 0.6498929336188437, (200, 13): 0.6383832976445396}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006, (150, 19): 0.05085653104925054, (200, 1): 0.03894539614561028, (200, 7): 0.04349571734475374, (200, 13): 0.051659528907922914}


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


{(50, 1): 0.14079229122055675, (50, 7): 0.6179068522483939, (50, 13): 0.6270074946466809, (50, 19): 0.6208511777301927, (100, 1): 0.1876338329764454, (100, 7): 0.646948608137045, (100, 13): 0.644271948608137, (100, 19): 0.6399892933618844, (150, 1): 0.21761241970021414, (150, 7): 0.6535064239828694, (150, 13): 0.6403907922912205, (150, 19): 0.6429336188436832, (200, 1): 0.24745717344753748, (200, 7): 0.6498929336188437, (200, 13): 0.6383832976445396, (200, 19): 0.6402569593147751}
{(50, 1): 0.04630620985010707, (50, 7): 0.046440042826552466, (50, 13): 0.055406852248394006, (50, 19): 0.048982869379014984, (100, 1): 0.04202355460385439, (100, 7): 0.06049250535331906, (100, 13): 0.04135438972162741, (100, 19): 0.0555406852248394, (150, 1): 0.04470021413276232, (150, 7): 0.05058886509635974, (150, 13): 0.048313704496788006, (150, 19): 0.05085653104925054, (200, 1): 0.03894539614561028, (200, 7): 0.04349571734475374, (200, 13): 0.051659528907922914, (200, 19): 0.04590471092077088}


In [103]:
# import time to calculate the run time for this algorithm
start = timeit.default_timer()

model = implicit.als.AlternatingLeastSquares(factors=50, iterations = 13)

# train the model on a sparse matrix of item/user/confidence weights
model.fit(track_playlist_data.T)


r_pred = []
overlap = []
dcg_score = []

for ind in range(len(X_train)):
    playlistid = ind
    playlist_track = track_playlist_data.T
    # recommend 10 tracks for the playlist with id playlistid
    recommendations = model.recommend(playlistid, playlist_track,filter_already_liked_items = True)
    # match the index of the track with the track names
    playlist_pred = [tracks[i] for i in list(zip(*recommendations))[0]]
    # get the dictionary key for the playlistid from y_train
    key = list(y_train.keys())[ind]
    # get the true tracks from y_train
    playlist_true = y_train[key]

    overlap.append(overlap_score(playlist_true,playlist_pred))
    dcg_score.append(dcg(playlist_pred,playlist_true))

print('Overall overlap score of ALS: ', np.mean(overlap))
print('Overall dcg of ALS: ', np.mean(dcg_score))

stop = timeit.default_timer()

print('Time(s): ', stop - start) 

HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


Overall overlap score of ALS:  0.4902984076006958
Overall dcg of ALS:  0.23334265807457652
Time(s):  18.53966950299946
