Build dataset


In [101]:
from typing import TypedDict, Type, Any, Callable

import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from spark.config import views
from spark.create_session import create_session

from IPython.display import display

from fitter import Fitter, get_common_distributions

In [102]:
VIEWS = views("v3")
spark = create_session()

for view, file in VIEWS.items():
    df = spark.read.json(file)
    df.createOrReplaceTempView(view)

In [103]:
sessions_full = spark.sql(f"SELECT * FROM sessions").toPandas()
sessions = spark.sql(f"SELECT DISTINCT user_id, track_id FROM sessions WHERE event_type like 'like' order by user_id, track_id").toPandas()
tracks = spark.sql(
    f"SELECT DISTINCT id, id_artist, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence FROM tracks ").toPandas()

In [104]:
display(sessions_full)
display(sessions)
display(tracks)

Unnamed: 0,event_type,session_id,timestamp,track_id,user_id
0,play,124,2023-09-20T07:18:17,4flBNxpPLGVAdufOzRBIen,101
1,like,124,2023-09-20T07:18:30.069000,4flBNxpPLGVAdufOzRBIen,101
2,play,125,2023-04-27T00:20:57.181000,3uMYq07Kj5m564OQwdSCrD,101
3,play,126,2023-02-11T05:34:54.160000,2RChe0r2cMoyOvuKobZy44,101
4,advertisment,126,2023-02-11T05:34:56.685000,,101
...,...,...,...,...,...
580241,like,102015,2023-11-12T09:06:01.022000,4cLdpErILMO8Db8pQVAVcZ,1100
580242,play,102015,2023-11-12T09:09:17.564000,3j8ja2Hq824OaRqIENJPTH,1100
580243,like,102015,2023-11-12T09:12:19.739000,3j8ja2Hq824OaRqIENJPTH,1100
580244,skip,102015,2023-11-12T09:12:58.093000,3j8ja2Hq824OaRqIENJPTH,1100


Unnamed: 0,user_id,track_id
0,101,03LNdMgu3l3Ldc3QMl1bvZ
1,101,0BVCEJJFVsb8nrQGI11Dj2
2,101,0PCpQRd0hMAWjBLOmJdR7X
3,101,0bLOiofyBB62YU2cNnONJG
4,101,0iJfN2CqrX7O8hkzgAMMAf
...,...,...
80366,1100,7FdAQ7CXm9yS4DBtIKopLi
80367,1100,7a9aeLVkn7DIqFjbanKz0k
80368,1100,7an1exwMnfYRcdVQm0yDev
80369,1100,7cioKB5CHVzk09SOtTyn0T


Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,4y8icjzu6fZP503Mg31Tpn,5V0MlUE1Bft0mbLlND7FJz,0.87800,0.184,224227,0.292,0.000258,7,0.3460,-12.246,55,1956,0.0312,75.630,0.1430
1,4hHbeIIKO5Y5uLyIEbY9Gn,1Mxqyy3pSjf8kZZL4QVxS0,0.84500,0.574,199093,0.338,0.000000,6,0.1650,-11.376,69,1958,0.0420,67.008,0.4930
2,0KSHmjK7OFtGocvbo7NZNO,6kACVPfCOnqzgfEF5ryl0x,0.80300,0.754,107280,0.484,0.000009,10,0.0989,-9.196,52,1959,0.0609,82.761,0.8010
3,30V02AmDYMRvrHE4L8cZAo,22bE4uQ6baNwSHPVcDxLCe,0.17800,0.578,133320,0.429,0.000060,2,0.1710,-9.601,57,1965,0.0268,108.738,0.0889
4,2CQRYn5cTD2B9a1ONjhTN2,3oDbviiivRWhXwIE8hxkVV,0.72500,0.396,145800,0.686,0.000001,5,0.1700,-5.607,57,1966,0.0341,124.318,0.7170
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22407,3atrFJhUzDm5xiutqfEPkP,74XFHRwlV6OrjEM0A2NCMF,0.00576,0.489,218333,0.611,0.000008,6,0.0917,-5.675,52,2007,0.0337,140.956,0.1860
22408,3wYzd4dzWknPdgwNaPGQvg,7JthQ6zwNzfxRfIEjp6wUs,0.06960,0.607,191474,0.605,0.000000,3,0.4540,-6.050,65,2021,0.1590,83.913,0.6650
22409,6UIcN1tiiGdd7oMMzNvyaP,7lbSsjYACZHn1MSDXPxNF2,0.06010,0.479,297502,0.478,0.000007,5,0.0813,-8.884,53,2020,0.1100,100.040,0.1170
22410,62i2LI6iDmAHbE9H9wa99z,7k73EtZwoPs516ZxE72KsO,0.38500,0.470,216970,0.738,0.000621,8,0.0656,-4.641,52,2017,0.0398,97.973,0.4880


In [105]:
# d = spark.sql(
#     """
#     SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
#     FROM (
#         select user_id, track_id, sum(event_weight) as weight
#         from (
#             SELECT user_id, track_id, CASE WHEN event_type like 'like' THEN 1 ELSE 0.02 END as event_weight
#             FROM sessions
#             WHERE event_type like 'like' or event_type like 'play'
#             ) 
#         group by user_id, track_id
#     ) s
#     inner join tracks t on s.track_id = t.id
#     order by s.user_id, t.id
#     """).toPandas()
d = spark.sql(
    """
    SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
    FROM (
        select user_id, track_id, sum(event_weight) as weight
        from (
            SELECT user_id, track_id, 1 as event_weight
            FROM sessions
            WHERE event_type like 'like'
            ) 
        group by user_id, track_id
    ) s
    inner join tracks t on s.track_id = t.id
    order by s.user_id, t.id
    """).toPandas()
d

Unnamed: 0,user_id,track_id,weight,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,101,03LNdMgu3l3Ldc3QMl1bvZ,1,0.54500,0.674,182549,0.6370,0.000002,0,0.0483,-3.400,52,2014,0.0330,139.867,0.755
1,101,0BVCEJJFVsb8nrQGI11Dj2,1,0.00522,0.656,235253,0.9230,0.000000,2,0.3260,-3.541,56,2011,0.0339,139.930,0.652
2,101,0PCpQRd0hMAWjBLOmJdR7X,1,0.02280,0.481,166040,0.8820,0.000220,2,0.1520,-5.068,51,1999,0.0407,141.721,0.718
3,101,0bLOiofyBB62YU2cNnONJG,1,0.32900,0.457,404933,0.4280,0.000002,1,0.1690,-9.796,55,1975,0.0283,129.980,0.151
4,101,0iJfN2CqrX7O8hkzgAMMAf,1,0.76700,0.222,355634,0.0555,0.923000,0,0.1180,-22.118,53,1994,0.0321,98.941,0.033
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
80366,1100,7FdAQ7CXm9yS4DBtIKopLi,1,0.88100,0.444,250880,0.2660,0.000000,9,0.1460,-9.001,56,2004,0.0366,130.430,0.404
80367,1100,7a9aeLVkn7DIqFjbanKz0k,1,0.31900,0.145,271867,0.4520,0.920000,2,0.0753,-13.560,55,2007,0.0448,82.297,0.146
80368,1100,7an1exwMnfYRcdVQm0yDev,1,0.21700,0.418,239013,0.4820,0.000000,5,0.1230,-5.769,54,2006,0.0266,175.558,0.261
80369,1100,7cioKB5CHVzk09SOtTyn0T,1,0.10300,0.436,369040,0.5340,0.006430,1,0.5070,-9.416,59,2014,0.0773,122.822,0.325


In [106]:
from scipy import stats

users = d['user_id']
items = d['track_id']


# _features = pd.DataFrame(tracks['track_id'], columns=['track_id']).merge(tracks, left_on="track_id", right_on="id", how="inner", validate="1:1" )


features = pd.concat([stats.zscore(tracks.drop(['id', 'id_artist'], axis=1)), tracks['id_artist']], axis=1)
features


# _features

Unnamed: 0,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence,id_artist
0,2.210183,-2.622956,-0.081994,-1.698778,-0.261491,0.483649,0.910200,-1.350924,-0.792191,-3.602614,-0.579005,-1.562809,-1.533923,5V0MlUE1Bft0mbLlND7FJz
1,2.090663,-0.158580,-0.430623,-1.479284,-0.263179,0.202650,-0.167568,-1.118180,0.947358,-3.458942,-0.458085,-1.853949,-0.102263,1Mxqyy3pSjf8kZZL4QVxS0
2,1.938547,0.978824,-1.704140,-0.782628,-0.263120,1.326645,-0.561162,-0.534983,-1.164951,-3.387106,-0.246475,-1.322015,1.157598,6kACVPfCOnqzgfEF5ryl0x
3,-0.325086,-0.133305,-1.342945,-1.045067,-0.262786,-0.921346,-0.131841,-0.643329,-0.543684,-2.956090,-0.628269,-0.444846,-1.755216,22bE4uQ6baNwSHPVcDxLCe
4,1.656046,-1.283347,-1.169838,0.181238,-0.263172,-0.078349,-0.137796,0.425153,-0.543684,-2.884254,-0.546536,0.081246,0.814000,3oDbviiivRWhXwIE8hxkVV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22407,-0.948906,-0.695688,-0.163749,-0.176633,-0.263129,0.202650,-0.604035,0.406961,-1.164951,0.061025,-0.551015,0.643063,-1.358033,74XFHRwlV6OrjEM0A2NCMF
22408,-0.717690,0.049944,-0.536304,-0.205263,-0.263179,-0.640347,1.553288,0.306641,0.450344,1.066730,0.851883,-1.283116,0.601296,7JthQ6zwNzfxRfIEjp6wUs
22409,-0.752097,-0.758877,0.934387,-0.811258,-0.263132,-0.078349,-0.665962,-0.451516,-1.040697,0.994894,0.303264,-0.738553,-1.640275,7lbSsjYACZHn1MSDXPxNF2
22410,0.424630,-0.815747,-0.182655,0.429362,-0.259116,0.764647,-0.759448,0.683579,-1.164951,0.779385,-0.482717,-0.808350,-0.122715,7k73EtZwoPs516ZxE72KsO


In [107]:
from lightfm.data import Dataset

# assert d['track_id'].unique().shape[0] == features.shape[0]

dataset = Dataset()
dataset.fit(
    users=users,
    items=items
)

dataset.fit_partial(
    items=tracks['id'],
    item_features=tracks.drop('id', axis=1)
)

num_users, num_items = dataset.interactions_shape()
print('Num users: {}, num_items {}.'.format(num_users, num_items))

Num users: 1000, num_items 22412.


In [108]:
(interactions, weights) = dataset.build_interactions(d[['user_id', 'track_id']].apply(tuple, axis=1))

print(repr(interactions))

<1000x22412 sparse matrix of type '<class 'numpy.int32'>'
	with 80371 stored elements in COOrdinate format>


In [109]:
from lightfm.cross_validation import random_train_test_split

(train, test) = random_train_test_split(interactions)


In [110]:
from lightfm import LightFM
from lightfm.evaluation import auc_score, precision_at_k, recall_at_k, reciprocal_rank


In [111]:
feature_names = d.drop(['user_id', 'track_id', 'weight'], axis=1).columns

item_features = dataset.build_item_features(
    ((i, feature_names) for i in items), 
    normalize=False)
print(repr(item_features))

<22412x22426 sparse matrix of type '<class 'numpy.float32'>'
	with 305708 stored elements in Compressed Sparse Row format>


In [112]:
# warp, lr=0.01, no_comp=3-, aplha=1e-6, 5 epochs

model_h = LightFM(
    loss='warp',
    learning_rate=0.003,
    # item_alpha=1e-6,
    # user_alpha=1e-6,
    no_components=30,
)
model_h.fit(
    interactions=train,
    item_features=item_features,
    epochs=5,
    num_threads=12,
    verbose=True)


Epoch: 100%|██████████| 5/5 [00:02<00:00,  2.02it/s]


<lightfm.lightfm.LightFM at 0x7fd2413dd330>

In [113]:
train_auc_h = auc_score(model_h, train, 
                        item_features=item_features,
                        num_threads=12).mean()
test_auc_h = auc_score(model_h, test, 
                       train_interactions=train, 
                       item_features=item_features, 
                       num_threads=12).mean()

train_precision_h = precision_at_k(model_h, train, k=10,
                                   item_features=item_features, 
                                   num_threads=12).mean()
test_precision_h = precision_at_k(model_h, test, k=10, 
                                  train_interactions=train,
                                  item_features=item_features, 
                                  num_threads=12).mean()

train_recall_h = recall_at_k(model_h, train, k=10, 
                             item_features=item_features,
                             num_threads=12).mean()
test_recall_h = recall_at_k(model_h, test, k=10, 
                            train_interactions=train,
                            item_features=item_features, 
                            num_threads=12).mean()

train_reciprocal_rank_h = reciprocal_rank(model_h, train,
                                          item_features=item_features,
                                          num_threads=12).mean()
test_reciprocal_rank_h = reciprocal_rank(model_h, test, 
                                         train_interactions=train,
                                         item_features=item_features,
                                         num_threads=12).mean()

print('AUC: train %.6f, test %.6f.' % (train_auc_h, test_auc_h))
print('Precision: train %.6f, test %.6f.' % (train_precision_h, test_precision_h))
print('Recall: train %.6f, test %.6f.' % (train_recall_h, test_recall_h))
print('Reciprocal rank: train %.6f, test %.6f.' % (train_reciprocal_rank_h, test_reciprocal_rank_h))


AUC: train 0.653648, test 0.623724.
Precision: train 0.009900, test 0.002249.
Recall: train 0.001453, test 0.001139.
Reciprocal rank: train 0.045469, test 0.012844.


In [114]:
predicted_scores = model_h.predict(148, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
predicted_scores

array([ 1.4685166 ,  1.5172684 ,  3.035224  , ..., -0.12276211,
       -0.12586237, -0.12728812], dtype=float32)

In [205]:
indices = np.argpartition(predicted_scores, -100)[-100:]
indices

array([ 4171,  3487,  6905,  4204,  2393,  5767,  1339, 10399,  2359,
       12417,   994, 10363, 14684,  3473,  3976,   567, 10634,  9556,
        8811,  2307, 11917,  1105,  1725,  1976,  2811,  8974,  6423,
        7093,  3272,  1629, 11450,  1449,   680,  6966,  1389, 17616,
        7223, 17900,   888,  2486,  2611,  5276,  6778,  3779,  5319,
        3516,  8056, 15591,  8944,  7856,  2209, 12642,  9194,  2309,
       12476,  2605,   496,  3744,  4341,  6695,  1193,  5555,  8868,
          51, 15623,  3563,  9265, 15174, 14531,   131,  2261,  3698,
        4796,  3327,  2260,  2458,  2074, 14994,   336, 16458,  5906,
        7080, 14244,  2247,  1876, 10106,  2700,  3575,  3330,  9129,
        1259, 13118,  2446,  3010,  3332,  1009,  9508,   255,  5172,
        2972])

In [206]:
predicted = pd.DataFrame(tracks).loc[indices]['id']
predicted

4171    2LBqCSwhJGcFQeTHMVGwy3
3487    6Db8IlZ7YY1pfIjJllejyH
6905    6QdwofpqDvvNxX88C9A0iQ
4204    0j46E9gFXDQND4AY4az8V3
2393    2Q99zPXVqCPN5RaZawvJWZ
                 ...          
1009    0EY1Z9UmZnYZyM7zHs6C0j
9508    2nvC4i2aMo4CzRjRflysah
255     6MgGapP3EPFm9kYUvYBTZR
5172    62oUgHSx7RPfgVWv1AYOpY
2972    6PPhp1qpAjLUxQr75vSD4H
Name: id, Length: 100, dtype: object

In [207]:
user_history = d[d.user_id == 249]

In [208]:
set(user_history['track_id'])

{'07BuyVse8pYAWd9DXD7B2D',
 '07qHCBBSRswIsfs0waYdjC',
 '0afhq8XCExXpqazXczTSve',
 '0gDyuX5rdHulQTUyrIdSR1',
 '1A9PKAFEHMWgXNLpgf7k4J',
 '1HKl3RJInVzf5ObVnM644j',
 '1HYnjKqSSHh1tdl2Hi57zH',
 '1HbyxkzMjXKIuyiZ5Xn4nB',
 '1N5p1acMejThPZZjmxQ3i5',
 '1bx6spmieE655BQvWdTYKA',
 '2D1hlMwWWXpkc3CZJ5U351',
 '2TdDRjNiF1HuRvnclprnce',
 '2Ygs64z9ywJGakuKU8tr6o',
 '2eLxEzMVfiQih0cJIWIowm',
 '2lD6AoA8qf2t4Dkf2TcmNK',
 '2mIrfke7vosXAEWfz6ucyo',
 '2sMBMXX43vpVOaqh6INXbV',
 '2vH8JFJKfkAgZs1GFlOzGh',
 '422Avi2VjqZKCVzstKmqog',
 '4O7oKy8YADaBrbehcPl0kE',
 '4sI8uN1G3PsoiNizkOqATO',
 '4sLlPACKA3OTNcJryGgWKR',
 '59nRuCsjHNdHNrs6BMj3fR',
 '5BY0p2EH4EznNZ0MFD9mjt',
 '5kQQ3eAIsg5DGbikSHQ8qG',
 '5sC0youR92ljJRn6VatElb',
 '5tNW2HIWbZLnaQLNsYfhj7',
 '6QpZSfLid1YZ6c01BvB5jH',
 '6VrLYoQKdhu1Jruei06t65',
 '6cvFwzez8ZbEWPTs5A0vAm',
 '6v2eEC9Nr9POe5xPUm8361',
 '74oWejX8MG6SEaJiz1334D',
 '7JiCaZ93B0hdj3XwFqwn4W',
 '7gpuC3rLKkI7PoJcEnSIO6'}

In [209]:
len([i for i in predicted if i in user_history['track_id']])

0

In [210]:
set(predicted).intersection(set(user_history['track_id']))

set()

In [211]:
print(len(set(predicted).intersection(set(user_history['track_id']))) / 1000)

0.0


In [212]:
coords = np.take(model_h.get_item_representations()[1], indices, axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

0.058694463

In [222]:
coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=100), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

  coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=100), axis=0)


0.05512355

In [223]:
def predict_single(user: int, number: int):
    predicted_scores = model_h.predict(user, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
    indices = np.argpartition(predicted_scores, -number)[-number:]
    return tracks.loc[indices]['id']


def predict_multiple(users: list[int], number: int):
    if len(users) == 1:
        return frozenset(predict_single(users[0], number))
    
    predictions = [frozenset(predict_single(u, number)) for u in users]
    return predictions[0].intersection(*predictions[1:])

res= predict_multiple([123, 124], 100)
print(len(res), list(res)[:20])

99 ['17opN752ZQpNuoptelsNQ1', '32BTFbqhSvYKftE0e8a8d4', '1wVuPmvt6AWvTL5W2GJnzZ', '3xKToHlV3enPcQVpns4kUf', '0EY1Z9UmZnYZyM7zHs6C0j', '6PPhp1qpAjLUxQr75vSD4H', '6ovlMmTTp4fyvD9DBe1zo1', '0YkbYk24ODhuewb79zZZzM', '62oUgHSx7RPfgVWv1AYOpY', '3wCx8pwK7J1LVe13dFXjPm', '0faXHILILebCGnJBPU6KJJ', '2KOt2JrCB720UxIbyzweQo', '6LBnaXnvXFwn1PgLmpxTXM', '3VK8copXwqvAueMPhnBe6K', '32MQwnY0WUaNVT7i4W9d9F', '698eQRku24PIYPQPHItKlA', '5W3cjX2J3tjhG8zb6u0qHn', '3HGctlDltHdllOSTogGKhJ', '11Mq1tJwmFOig3BXxjqcQr', '48q0vSHcJdhK3IiXH8C5WJ']
