In [196]:
import os
import sys
from tqdm.auto import tqdm
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import time

from IPython.display import clear_output

sys.path.append('src')
from train_test import train_test_split

from metrics import compute_metrics

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
DATA_PATH = './data.nosync/'

In [4]:
# Load data files
train = pd.read_csv(DATA_PATH + 'train.csv')
valid = pd.read_csv(DATA_PATH + 'valid.csv')
test = pd.read_csv(DATA_PATH + 'test.csv')
lastfm_360_behav = pd.read_csv(DATA_PATH + 'behav-360k-processed.csv')
lastfm_360_demo = pd.read_csv(DATA_PATH + 'demo-360k-processed.csv')
lastfm_360_demo = lastfm_360_demo.set_index('user_email').astype(bool).reset_index()
test_users = np.load(DATA_PATH + 'test_users.npy')

In [5]:
# Rescale data
train_plays_max = train['log_plays'].max()
train['log_plays'] = train['log_plays'] / train_plays_max
valid['log_plays'] = valid['log_plays'] / train_plays_max
test['log_plays'] = test['log_plays'] / train_plays_max

In [14]:
train

Unnamed: 0,user_email,artist_id,log_plays
0,184649,460,0.000000
1,255115,2495,0.424939
2,113122,250,0.272390
3,249162,2043,0.541921
4,255027,603,0.160624
...,...,...,...
14049672,21192,3149,0.355721
14049673,219208,12864,0.357251
14049674,223801,28,0.297401
14049675,21038,29798,0.312301


In [197]:
def get_users_for_artist(df, artist_id:int):
    return df[df['artist_id'] == artist_id]['user_email'].values

In [198]:
def get_correlation_list(df, user_email: str, selected_users:list, filter_artists=True, verbose=False):
    
    # Filter the behavioral data for the listed users
    selected_behav = df[(df['user_email'].isin(selected_users)) | (df['user_email'] == user_email)]
    
    # Make sure there is no duplicated rows for given user and artist pair
    selected_behav = selected_behav.groupby(['user_email', 'artist_id'], as_index=False).sum()
    
    # Compute the number of artists listened in the selected group
    selected_artists = selected_behav['artist_id'].value_counts()#.index.values
    if verbose: print(f"Total artists: {len(selected_artists.index.values)}")
    
    # Filter the artist to reduce the dimention
    if filter_artists:
        # Remove all artist that have been listened by less than 1% of the user list. 
        artist_threshold = len(selected_users) / 100
        selected_artists = selected_artists[selected_artists > artist_threshold].index.values
        selected_behav = selected_behav[selected_behav['artist_id'].isin(selected_artists)]
        
    # Build the user raking dataset (user_email index and one column per artist)
    user_scores = selected_behav.pivot(index='user_email', columns='artist_id', values='log_plays').fillna(0)
    
    # Compute correlation between users
    correlation_list = user_scores.corrwith(user_scores.loc[user_email], axis=1)
    
    return correlation_list, selected_behav

In [199]:
def compute_prediction(data, user_email: str, artist_id: int, selected_corr: np.array, selected_behav: pd.DataFrame, verbose=False):
    # Get the behav data for the remaining users
    selected_behav = selected_behav[selected_behav['user_email'].isin(list(selected_corr.index.values))]
    #final_behav = selected_behav
    # Compute the normalized average number of play
    avg_log_plays = selected_behav.groupby('user_email')['log_plays'].mean()
    
    # Get the log play for a given artist
    artist_log_plays = selected_behav[(selected_behav['artist_id'] == artist_id)]\
                                    [['user_email', 'log_plays']].set_index('user_email')['log_plays']
    
    # Compute the plays variance from the similar users
    var_pred = (selected_corr * (artist_log_plays - avg_log_plays)).sum() / selected_corr.sum()
    # Compute user total number of play and average number of play
    user_log_plays = data[data['user_email'] == user_email]['log_plays']
    
    if verbose: print(f'user plays: {user_log_plays}, mean: {user_log_plays.mean()}')
    
    # Compute the estimated number of play for new artist (average + normalized var * number of plays)
    predicted_plays = user_log_plays.mean() + var_pred 
    
    return predicted_plays

In [200]:
def get_artist_prediction(data, user_email:str, artist_id:int, corr_threshold=0.2, 
                          filter_demo=False, filter_artist=True, verbose=False):
    start = time.time()
    
    # Select users that have already listened to the given artist 
    selected_users = get_users_for_artist(data, artist_id)
    if verbose: print(f"Selected users: {len(selected_users)}")
    
    # Compute correlation with selected users
    start_corr = time.time()
    correlation_list, selected_behav = get_correlation_list(data, user_email, selected_users, filter_artist)
    end_corr = time.time()
    if verbose: print(f"Correlation list computation time: {end_corr - start_corr}")
    
    # Compute prediction
    selected_corr = correlation_list[correlation_list > corr_threshold]
    selected_corr = selected_corr.drop(user_email)
    
    
    if verbose: print(f"Selected users (after correlation): {len(selected_corr)}")
    
    prediction = compute_prediction(data, user_email, artist_id, selected_corr, selected_behav)
    end = time.time()
    
    if verbose: print(f"Total time for prediction: {end - start}")
    return prediction
    

In [201]:
seen_artists = train.append(valid).drop(columns='log_plays')
seen_artists = seen_artists[seen_artists['user_email'].isin(test_users)]
seen_artists = seen_artists.groupby('artist_id').first().index.values

In [281]:
def pred_func(data, user_id, artists: list):
    preds = []
    for a in artists:
        pred = get_artist_prediction(data, user_id, a)
        if np.isnan(pred):
            pred = 0
        preds.append(pred)
    return preds

In [256]:
def top_100_artists(data):
    total_log_plays = data.groupby('artist_id').sum()
    ranked_artist_scores = total_log_plays.sort_values(['log_plays'], ascending=False)['log_plays']
    return ranked_artist_scores[:100].index.values

In [257]:
top_100 = top_100_artists(train)

In [258]:
def get_ratings(train, full_train, artists):
    ratings = {}
    top_100 = set(top_100_artists(train))
    set_artists = set(artists)
    for user, user_df in tqdm(train.groupby('user_email')):
        new_artists = np.array(list(top_100.intersection(set_artists - set(user_df['artist_id']))))
        ratings[user] = [new_artists, pred_func(full_train, user, new_artists)]
        print(ratings[user])
    return ratings

In [259]:
train_subset = train[train['user_email'].isin(test_users[:100])]

In [260]:
pred_ratings = get_ratings(train_subset, train,seen_artists)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=88.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:

# Evaluate the model 
from sklearn.metrics import mean_squared_error

def rmse(y_true, y_pred):
    return mean_squared_error(y_actual, y_predicted, squared=False)

In [282]:
test_small = test[test['user_email'].isin(test_users[:100])]

In [None]:
preds = {}
true = {}
for (user, user_df) in tqdm(test_small.groupby('user_email')):
    artists = user_df['artist_id'].values
    preds[user] = [artists, pred_func(train, user, artists)]
    print(preds[user])
    true[user] = [artists, user_df['log_plays'].values]

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[array([ 2020, 13632,  2702, 10711,   245]), [0.33557867316222717, 0.3526058388926968, 0.3539373601870344, 0.3402115569069563, 0.35614043450511984]]


  del sys.path[0]


[array([ 1162, 41587,  3702, 17090, 17228,  5695]), [0.41079650615218655, 0, 0.4173157132237809, 0.45793798410229236, 0.409280277834704, 0.4034456657319745]]
[array([  434,  4276,  1150,   677,  4511, 39918]), [0.5887299651827055, 0.5681708310717594, 0.5565459109966965, 0.56261092393758, 0.549132315634738, 0.5396215452303871]]
[array([2700, 1475,  247, 1408, 3211,  264]), [0.20625935284124935, 0.21713938655364914, 0.22595995092604582, 0.2203302808267392, 0.24282669429611822, 0.22210110950643897]]
[array([14578,  5012,  1680,   731,    62]), [0.25515684892290424, 0.28066864944893455, 0.28149018575695417, 0.28201935049612004, 0.3143343348328163]]
[array([ 4297, 19540,   268,  5416,  3524,    63]), [0.4150727586775414, 0.39603835470462184, 0.4052052533275835, 0.40152985421877785, 0.39247197969342434, 0.4232941552859127]]
[array([   62,   478, 10235,   731,  1156,   350]), [0.38766746645149036, 0.4123297082050091, 0.3632991936415807, 0.3672074262490962, 0.3555363220571916, 0.38022766894572

  del sys.path[0]
  del sys.path[0]


[array([ 3279, 32332, 37497,   437,  3542, 22272]), [0.1801519596305229, 0, 0, 0.19619419957691464, 0.186591703778344, 0.15966698226250203]]


  del sys.path[0]


[array([1030,  707, 1489, 4280, 8936]), [0.44746841747115496, 0.4610224878207954, 0.4537138708542892, 0.4647559823894839, 0]]
[array([  294, 29013,   266,   277,  5894,   992]), [0.3509502554412475, 0.33713806094609694, 0.37132602376760765, 0.36388077289784526, 0.318597564068521, 0.3595991992100879]]


  del sys.path[0]


[array([  214,  1697,  3890, 27795,   866,  1146]), [0.3115022518475853, 0.280081808187756, 0.2816646190768906, 0, 0.29111626460461176, 0.27333741607261475]]
[array([ 9131,  4444,  8175,  1724,  3139, 15406]), [0.4305561934328985, 0.4203100925985402, 0.4285673908021932, 0.4063258995380558, 0.42294019925572734, 0.39067132525257187]]
[array([ 1752, 50664,  6208, 13349, 27333]), [0.25746992760626664, 0.2392826119440298, 0.26781419773312376, 0.26986433845080987, 0.24935233827859465]]
[array([ 2868, 10448,   110, 30013,   860]), [0.4824005699824487, 0.4936224637795958, 0.5073319592625137, 0.5279852344357008, 0.5088132149595549]]
[array([13860,  3607,  1365,  1159,  1006]), [0.38258338563571553, 0.3415506185261424, 0.35031934378531926, 0.35101454139763183, 0.3557663535030946]]


  del sys.path[0]


[array([   75,   995,   101, 47252, 11836,    41]), [0.4398451212857618, 0.45492196825551995, 0.4273041878556559, 0, 0.36482032109063844, 0.4398992035750043]]
[array([ 6362,  6139, 16393,   603, 16395,  1856]), [0.31538894807174095, 0.28229494596570043, 0.3004328891090694, 0.30173308548308525, 0.309893174495249, 0.29058668623617523]]


  del sys.path[0]


[array([ 1216,   232, 11166, 81686,  4167,  4580]), [0.37865788184397786, 0.3345712539105695, 0.3592372114816173, 0, 0.3529098159927095, 0.3447528584443116]]
[array([ 1557, 17521, 33329,  3449,  7679,   281]), [0.35175671498400807, 0.3607855229282484, 0.31448314995164434, 0.3312475295774174, 0.3508358992575746, 0.3878722118434714]]
[array([ 3699, 10872,  3770,   615,   438,  4280]), [0.2552585541738054, 0.20229598835710116, 0.19331549692876124, 0.20790995042663707, 0.19687160208029375, 0.22830041556037883]]
[array([3406, 1471, 1039,  216, 1568,  935, 1020]), [0.2607659251143226, 0.3037663921467966, 0.3145907314675277, 0.3489305930644253, 0.2947329459473239, 0.3157820626556623, 0.3274044602025563]]


  del sys.path[0]


[array([21147,    37,   744, 10479, 33401,  1388]), [0, 0.2858112795537742, 0.2767906788070121, 0.2657300337860527, 0.2576302143996166, 0.2857159319947901]]
[array([ 973, 1079,  992, 1505, 6609]), [0.3078182621519149, 0.3052015949849172, 0.28865129724047695, 0.3370826208914147, 0.283426827398461]]
[array([10177,   615,    35,  1030, 35701]), [0.4669734230803209, 0.4773428045641841, 0.4937813642284559, 0.4777013646681013, 0.4586656271186784]]


  del sys.path[0]
  del sys.path[0]


[array([35101,  3824,  4031, 35088, 11337]), [0, 0.24261296233348206, 0.33685892261901673, 0, 0.32262302697494]]
[array([ 216,   41, 1258,  766, 5100]), [0.4836227224489852, 0.45525834252703834, 0.4421409963565404, 0.4779288282420952, 0.4319651069417784]]
[array([3006, 5702,  918,  351, 2082, 1194]), [0.1443378342861864, 0.13623053368502816, 0.14634588040187255, 0.12191278974954275, 0.1343380209776645, 0.12493411911523394]]
[array([ 685,  548,  672, 7614,  202, 1691]), [0.5011682326644756, 0.4929420443781066, 0.5139047667429122, 0.5099569439338443, 0.4967643675426448, 0.4986103913202966]]
[array([ 460,  801,  743, 4718, 6657]), [0.37780019851268076, 0.3391301601019103, 0.39744335839890904, 0.38592139194052744, 0.38082135558266705]]


  del sys.path[0]


[array([ 776, 2352, 2464, 6604,  238, 1107]), [0.35912756886532227, 0, 0.35395516088388546, 0.3328037709819609, 0.3442149097994035, 0.32279539510709765]]
[array([ 7092,  6014, 46897, 31381,   397]), [0.4890641312785021, 0.5021307625172824, 0.46712062958255546, 0.4915085737307861, 0.490052421133929]]
[array([ 1017,   515,  2241, 10062,  2496]), [0.3201909691142809, 0.35653532743335803, 0.33915201589048893, 0.30344810058393346, 0.3358836193320043]]
[array([ 296, 2263,  750, 4356,  284]), [0.2632055449899626, 0.2667166567359136, 0.26027755276020187, 0.2744518287472383, 0.28327088459895305]]
[array([13965,  2725, 13588,  3032,    77]), [0.14968160206218448, 0.14701240386477976, 0.17345663147058582, 0.14078552464234897, 0.1319161687892766]]
[array([  806, 62279,   182,  7450,   854]), [0.4616687514668313, 0.45721961352890866, 0.47582242019386745, 0.4833447935355394, 0.4944373976811003]]
[array([ 155, 1194, 2586,  181, 1939]), [0.4175908830725699, 0.3935433294778104, 0.37463547865898633, 0.4

  del sys.path[0]
  del sys.path[0]


[array([ 1600,    79, 17431, 48445,  1617]), [0.5398081056001186, 0.4131733907468504, 0, 0, 0.4623403946703123]]
[array([ 5750, 41930,  2821, 32282,  5539]), [0.20548253888545365, 0.17285366570761523, 0.22688366215037992, 0.13999150464178822, 0.19255568047157343]]
[array([15563,   458,  7944,  1251,  2713]), [0.3574770221242297, 0.4272838047848167, 0.4266852204596914, 0.3953197055993509, 0.41555276218405257]]
[array([ 985, 5097,  767, 1659, 1328]), [0.4388162641683566, 0.44038051569824094, 0.44093154522684375, 0.4389035923933254, 0.4540088878678769]]


  del sys.path[0]


[array([ 2108, 25184, 19010, 30012,  5819,  9156]), [0.31466723388012713, 0, 0.3194998509248527, 0.2956736955777176, 0.33037781575305425, 0.35583315532410753]]


  del sys.path[0]


[array([ 1012,  6942,  1811,   774,  2097, 40108]), [0.3065682909695881, 0.28369274136093675, 0.2923502766405795, 0.3064909440644123, 0.3172574847297061, 0]]
[array([10912,   476,  1145,  2235,  3939]), [0.4869822500866279, 0.5038142213653043, 0.50224395619086, 0.5042651676971521, 0.4822837411866266]]
[array([ 772, 1566, 9855,  527, 5476, 7344]), [0.42119990217090636, 0.37570155898514807, 0.3831445266180567, 0.41806242395444515, 0.3998141850036753, 0.38960503206146485]]


  del sys.path[0]


[array([ 5730, 65380,    62,  2220,   338]), [0.3180673077340827, 0, 0.3052915956961477, 0.2900909074560216, 0.28988770820196386]]


  del sys.path[0]


[array([31249,   485,  5856,   785,    96]), [0, 0.4187587717492163, 0.4387693700181148, 0.42362109052184416, 0.4342219030651387]]
[array([22576,   284,  2305,  7499,   902]), [0.3303715438036328, 0.37505378377788734, 0.3604372829186331, 0.3394205989049121, 0.3461191108613913]]
[array([ 266,  146, 7499, 3969, 1777]), [0.40293756533973857, 0.3855736908898231, 0.3807424569179767, 0.3773060534216601, 0.40766587245035724]]


  del sys.path[0]


[array([  515,  8196,   434, 13985,   147,   767,   991,  4425]), [0.34563073762806956, 0.3255769737223689, 0.35008454498149494, 0, 0.34022273106101186, 0.3406997902257336, 0.35197062355096886, 0.34454459484227395]]
[array([ 2504, 13847, 11607,   668, 13321]), [0.4319718696942552, 0.40461598135013355, 0.40839847958005265, 0.40023663123517617, 0.4128614444547198]]
[array([ 8677, 13687,  1026,    37,  2085]), [0.42837435143484864, 0.4218622269932526, 0.45210661154153425, 0.4473577373564083, 0.4408697437144281]]
[array([2082,  519, 5438,  769, 5819]), [0.3867394640801265, 0.38092639628338165, 0.38092840419893, 0.37922806857182223, 0.41275184814573174]]
[array([  68,  267, 2384, 7253,  515]), [0.40053132592871443, 0.4115629230030689, 0.40646777738391354, 0.44440692033523244, 0.4010000173746704]]


  del sys.path[0]
  del sys.path[0]
  del sys.path[0]


[array([14004, 37569,  9415,  1334, 54240]), [0, 0, 0.2400725199714062, 0.23539797232178225, 0]]
[array([ 276, 6652, 3475, 4300, 8476]), [0.2835619066665021, 0.2770021587505764, 0.28835193662096104, 0.30145863340045265, 0.28899707255812473]]
[array([7178, 2199, 2017, 4175,  254]), [0.3076370131051102, 0.3042793353431105, 0.3106227585665354, 0.3021487114161308, 0.31881772543454606]]


  del sys.path[0]


[array([26982, 57995,   183,  2713,  2828]), [0.4086278784726727, 0, 0.4559841714374938, 0.43819021491794224, 0.4475901311842687]]
[array([32602,  8972,  3013,  8973,  2211]), [0.40417282388369663, 0.3808261641302852, 0.4421924819558013, 0.3788850726394601, 0.3700767214056275]]


  del sys.path[0]


[array([  5389, 100768,    603,   4520,    354]), [0.3557091396407178, 0, 0.3594096423436699, 0.35588051212183314, 0.41853821385240597]]
[array([  493,  1367, 26431,  3006,   236]), [0.4174169974740251, 0.40334059243912485, 0.36068494963674885, 0.40095413569523447, 0.4252497643900991]]


  del sys.path[0]


[array([ 1157,  8828,  1203,   264,  4235, 42197]), [0.4501969456847036, 0.4295431869256831, 0.46020382380660946, 0.4268855740534707, 0.43840713214217664, 0]]
[array([7487,    5,  146,  688, 2433]), [0.35366343756731994, 0.3179596112085471, 0.33981957384286965, 0.3157935806335297, 0.30811320280857]]


  del sys.path[0]
  del sys.path[0]
  del sys.path[0]


[array([ 4589,   562, 37801,  3729,  2141,    62]), [0, 0, 0, 0.37916744534982205, 0.35459036063837035, 0.3515481411234305]]
[array([4810, 2017, 2817,  709, 1041]), [0.39389307658930595, 0.4083326811869151, 0.41881406579894537, 0.4150984078793912, 0.41304167352297555]]
[array([   13,  1689, 20972,   885,   909]), [0.33193376719876017, 0.32740824076523356, 0.3636474872291857, 0.3508392094333046, 0.3698822015247538]]


  del sys.path[0]


[array([2697,   55, 5706, 2119,  552]), [0.35899472189907117, 0.40098454025268965, 0, 0.3999859878215652, 0.4004257010459638]]
