Build dataset


In [33]:
from typing import TypedDict, Type, Any, Callable

import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from spark.config import views
from spark.create_session import create_session

from IPython.display import display

from fitter import Fitter, get_common_distributions

In [34]:
VIEWS = views("stolen")
spark = create_session()

for view, file in VIEWS.items():
    df = spark.read.json(file)
    df.createOrReplaceTempView(view)

                                                                                

In [35]:
sessions_full = spark.sql(f"SELECT * FROM sessions").toPandas()
sessions = spark.sql(f"SELECT DISTINCT user_id, track_id FROM sessions WHERE event_type like 'LIKE' order by user_id, track_id").toPandas()
tracks = spark.sql(
    f"SELECT DISTINCT id, id_artist, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence FROM tracks ").toPandas()

                                                                                

In [36]:
display(sessions_full)
display(sessions)
display(tracks)

Unnamed: 0,event_type,session_id,timestamp,track_id,user_id
0,PLAY,124,2020-04-17T16:43:09,5EmL6IbswQGhfH9AX7ezWd,101
1,LIKE,124,2020-04-17T16:43:55.237000,5EmL6IbswQGhfH9AX7ezWd,101
2,PLAY,124,2020-04-17T16:45:44.733000,67ov0nL5eR7zdx0JfXDqro,101
3,SKIP,124,2020-04-17T16:48:26.836000,67ov0nL5eR7zdx0JfXDqro,101
4,ADVERTISEMENT,124,2020-04-17T16:48:26.836000,,101
...,...,...,...,...,...
10191757,LIKE,269652,2022-12-19T01:03:08.170000,0mW8oH0PZejrFsaNi1Ud8i,20100
10191758,PLAY,269652,2022-12-19T01:05:49.228000,3tOS5wMzYU5tiVztKKKHih,20100
10191759,SKIP,269652,2022-12-19T01:07:15.125000,3tOS5wMzYU5tiVztKKKHih,20100
10191760,PLAY,269652,2022-12-19T01:07:15.125000,5fWqmzQHxmfBVxf2k6JlX4,20100


Unnamed: 0,user_id,track_id
0,101,08dATSKGXhGASauLBtCoO8
1,101,09jtIFItoNKnC86zlzBZ29
2,101,0FDjpGYB8iVHXZiWY7E4AM
3,101,0KAaslGdPc5I6WxmKe3whe
4,101,0NWPxcsf5vdjdiFUI8NgkP
...,...,...
1402213,20100,7r1i1TZUGZQDxR5QHX4Mmx
1402214,20100,7svwP4tC0UYJbCkiCo6Itz
1402215,20100,7tVQg3ov9G0CnXTzqmZVsZ
1402216,20100,7yC5SaMeZJfvFL6lICCulP


Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,1d3KXNYriNnjSdBcTBeam8,3TVifQ5FPcIzzcYSUuJkp9,0.648000,0.301,408000,0.268,0.000000,2,0.0638,-17.810,13,1977,0.0638,177.468,0.275
1,79pNwy5Q95QfeGUyGNXx11,3BCJyAgxvYyeIjQyoBU8XL,0.237000,0.462,336000,0.343,0.021400,9,0.0935,-15.477,10,1983,0.0287,103.597,0.342
2,6OoGZwmtpHuG6FIVkGjfKW,25uGmqV7NJt81bSYlEMKB0,0.092400,0.456,213467,0.932,0.004540,2,0.3220,-7.088,15,1979,0.0804,185.954,0.603
3,6aLmvz0CPeCNHCXK2H5QIC,3vbKDsSS70ZX9D2OcvbZmS,0.303000,0.611,343360,0.597,0.008200,0,0.0605,-7.864,56,1999,0.0764,147.507,0.385
4,63tInNaRRc44t4vxCQ65JA,5ujwfg1AKpM7CGPZhHxs22,0.798000,0.331,247987,0.293,0.000058,2,0.1030,-14.062,12,1980,0.0288,105.270,0.231
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
129643,7bwPQVVelqsawx9pA5gAwF,1owt6WYWjy94FlqNcj1x4U,0.907000,0.303,224160,0.292,0.841000,7,0.1760,-10.707,18,2004,0.0308,64.724,0.277
129644,3Ad4Vd2MLBBpApduPsydXk,32vWCbZh0xZ4o9gkz4PsEU,0.845000,0.536,165533,0.274,0.000089,0,0.1180,-10.758,40,1974,0.0281,86.915,0.462
129645,4z5eRXrNKYAhjtuJA49REb,5vngPClqofybhPERIqQMYd,0.423000,0.793,164133,0.544,0.000210,6,0.0830,-11.928,49,1992,0.0265,105.537,0.942
129646,0JiY190vktuhSGN6aqJdrt,1KCSPY1glIKqW2TotWuXOR,0.000329,0.534,215160,0.870,0.000000,11,0.2410,-3.078,78,2008,0.0425,126.019,0.462


In [37]:
# d = spark.sql(
#     """
#     SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
#     FROM (
#         select user_id, track_id, sum(event_weight) as weight
#         from (
#             SELECT user_id, track_id, CASE WHEN event_type like 'like' THEN 1 ELSE 0.02 END as event_weight
#             FROM sessions
#             WHERE event_type like 'like' or event_type like 'play'
#             ) 
#         group by user_id, track_id
#     ) s
#     inner join tracks t on s.track_id = t.id
#     order by s.user_id, t.id
#     """).toPandas()
d = spark.sql(
    """
    SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
    FROM (
        select user_id, track_id, sum(event_weight) as weight
        from (
            SELECT user_id, track_id, 1 as event_weight
            FROM sessions
            WHERE event_type like 'LIKE'
            ) 
        group by user_id, track_id
    ) s
    inner join tracks t on s.track_id = t.id
    order by s.user_id, t.id
    """).toPandas()
d

                                                                                

Unnamed: 0,user_id,track_id,weight,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,101,08dATSKGXhGASauLBtCoO8,1,0.882000,0.799,166693,0.559,0.381000,0,0.0996,-4.033,1,1944,0.0353,106.396,0.696
1,101,09jtIFItoNKnC86zlzBZ29,1,0.000480,0.509,254840,0.796,0.460000,2,0.1610,-11.728,34,1983,0.0359,143.664,0.652
2,101,0FDjpGYB8iVHXZiWY7E4AM,1,0.011000,0.503,214107,0.864,0.003820,5,0.5990,-6.829,34,1978,0.1620,154.979,0.551
3,101,0KAaslGdPc5I6WxmKe3whe,1,0.000101,0.169,314560,0.949,0.618000,7,0.2970,-10.596,41,1990,0.1140,159.678,0.147
4,101,0NWPxcsf5vdjdiFUI8NgkP,1,0.006030,0.346,210160,0.768,0.380000,9,0.0244,-5.695,72,1967,0.0377,169.492,0.532
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1402213,20100,7r1i1TZUGZQDxR5QHX4Mmx,1,0.002940,0.570,329979,0.777,0.181000,7,0.0477,-6.537,47,1987,0.0270,107.982,0.487
1402214,20100,7svwP4tC0UYJbCkiCo6Itz,1,0.062300,0.457,213000,0.915,0.000003,3,0.0778,-5.876,19,1976,0.0808,126.695,0.446
1402215,20100,7tVQg3ov9G0CnXTzqmZVsZ,1,0.242000,0.247,671760,0.742,0.370000,2,0.9680,-12.678,24,1978,0.0571,143.533,0.466
1402216,20100,7yC5SaMeZJfvFL6lICCulP,1,0.150000,0.417,176000,0.667,0.000000,4,0.2170,-6.404,21,1978,0.0367,131.796,0.556


In [38]:
from scipy import stats

users = d['user_id']
items = d['track_id']


# _features = pd.DataFrame(tracks['track_id'], columns=['track_id']).merge(tracks, left_on="track_id", right_on="id", how="inner", validate="1:1" )


features = pd.concat([stats.zscore(tracks.drop(['id', 'id_artist'], axis=1)), tracks['id_artist']], axis=1)
features


# _features

Unnamed: 0,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence,id_artist
0,0.686576,-1.658531,1.577080,-1.218303,-0.373482,-0.921565,-0.803961,-1.786179,-0.973348,-0.647710,-0.188174,1.953557,-1.141989,3TVifQ5FPcIzzcYSUuJkp9
1,-0.537912,-0.646673,0.944397,-0.908329,-0.281353,1.067707,-0.645053,-1.270175,-1.148502,-0.348766,-0.399408,-0.537601,-0.876726,3BCJyAgxvYyeIjQyoBU8XL
2,-0.968717,-0.684382,-0.132333,1.525999,-0.353937,-0.921565,0.577525,0.585271,-0.856578,-0.548062,-0.088274,2.239731,0.156612,25uGmqV7NJt81bSYlEMKB0
3,-0.341279,0.289767,1.009071,0.141449,-0.338180,-1.489929,-0.821618,0.413639,1.537200,0.448419,-0.112346,0.943179,-0.706483,3vbKDsSS70ZX9D2OcvbZmS
4,1.133469,-1.469986,0.171003,-1.114978,-0.373232,-0.921565,-0.594223,-0.957211,-1.031733,-0.498238,-0.398806,-0.481183,-1.316192,5ujwfg1AKpM7CGPZhHxs22
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
129643,1.458212,-1.645962,-0.038371,-1.119111,3.247084,0.499343,-0.203641,-0.215165,-0.681423,0.697539,-0.386770,-1.848519,-1.134071,1owt6WYWjy94FlqNcj1x4U
129644,1.273496,-0.181596,-0.553542,-1.193505,-0.373098,-1.489929,-0.513967,-0.226445,0.603043,-0.797183,-0.403019,-1.100170,-0.401628,32vWCbZh0xZ4o9gkz4PsEU
129645,0.016236,1.433607,-0.565844,-0.077599,-0.372578,0.215162,-0.701232,-0.485221,1.128506,0.099650,-0.412648,-0.472178,1.498763,5vngPClqofybhPERIqQMYd
129646,-1.243023,-0.194165,-0.117456,1.269754,-0.373482,1.636070,0.144139,1.472188,2.821667,0.896835,-0.316359,0.218538,-0.401628,1KCSPY1glIKqW2TotWuXOR


In [61]:
from lightfm.data import Dataset

# assert d['track_id'].unique().shape[0] == features.shape[0]

dataset = Dataset()
dataset.fit(
    users=users,
    items=items
)

dataset.fit_partial(
    items=tracks['id'],
    item_features=tracks.drop('id', axis=1)
)

num_users, num_items = dataset.interactions_shape()
print('Num users: {}, num_items {}.'.format(num_users, num_items))

Num users: 20000, num_items 129648.


In [62]:
(interactions, weights) = dataset.build_interactions(d[['user_id', 'track_id']].apply(tuple, axis=1))

print(repr(interactions))

<20000x129648 sparse matrix of type '<class 'numpy.int32'>'
	with 1402218 stored elements in COOrdinate format>


In [63]:
from lightfm.cross_validation import random_train_test_split

(train, test) = random_train_test_split(interactions)


In [64]:
from lightfm import LightFM
from lightfm.evaluation import auc_score, precision_at_k, recall_at_k, reciprocal_rank


In [65]:
feature_names = d.drop(['user_id', 'track_id', 'weight'], axis=1).columns

item_features = dataset.build_item_features(
    ((i, feature_names) for i in items), 
    normalize=False)
print(repr(item_features))

<129648x129662 sparse matrix of type '<class 'numpy.float32'>'
	with 268852 stored elements in Compressed Sparse Row format>


In [66]:
# warp, lr=0.01, no_comp=3-, aplha=1e-6, 5 epochs

model_h = LightFM(
    loss='warp',
    learning_rate=0.003,
    # item_alpha=1e-6,
    # user_alpha=1e-6,
    no_components=30,
)
model_h.fit(
    interactions=train,
    item_features=item_features,
    epochs=5,
    num_threads=12,
    verbose=True)


Epoch: 100%|██████████| 5/5 [00:13<00:00,  2.77s/it]


<lightfm.lightfm.LightFM at 0x7fe3f51b4130>

In [67]:
# train_auc_h = auc_score(model_h, train, 
#                         item_features=item_features,
#                         num_threads=12).mean()
# test_auc_h = auc_score(model_h, test, 
#                        train_interactions=train, 
#                        item_features=item_features, 
#                        num_threads=12).mean()
# 
# train_precision_h = precision_at_k(model_h, train, k=10,
#                                    item_features=item_features, 
#                                    num_threads=12).mean()
# test_precision_h = precision_at_k(model_h, test, k=10, 
#                                   train_interactions=train,
#                                   item_features=item_features, 
#                                   num_threads=12).mean()
# 
# train_recall_h = recall_at_k(model_h, train, k=10, 
#                              item_features=item_features,
#                              num_threads=12).mean()
# test_recall_h = recall_at_k(model_h, test, k=10, 
#                             train_interactions=train,
#                             item_features=item_features, 
#                             num_threads=12).mean()
# 
# train_reciprocal_rank_h = reciprocal_rank(model_h, train,
#                                           item_features=item_features,
#                                           num_threads=12).mean()
# test_reciprocal_rank_h = reciprocal_rank(model_h, test, 
#                                          train_interactions=train,
#                                          item_features=item_features,
#                                          num_threads=12).mean()
# 
# print('AUC: train %.6f, test %.6f.' % (train_auc_h, test_auc_h))
# print('Precision: train %.6f, test %.6f.' % (train_precision_h, test_precision_h))
# print('Recall: train %.6f, test %.6f.' % (train_recall_h, test_recall_h))
# print('Reciprocal rank: train %.6f, test %.6f.' % (train_reciprocal_rank_h, test_reciprocal_rank_h))


In [68]:
predicted_scores = model_h.predict(46, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
predicted_scores

array([34548.434  , 50649.91   , 22206.781  , ...,  -547.18317,
        -283.6529 ,  -275.74194], dtype=float32)

In [69]:
indices = np.argpartition(predicted_scores, -1000)[-1000:]
indices

array([  394,  2771,  3252,  3689,  2384,   716,  4920,  2517,  4890,
        6832,  6664,  6366,  2750,  2606,  2568,  1217,  6472,  4892,
        5972,  2332,  1084,  4170,  5419,   741,  1869,  4225,  4338,
        3817,  1343,  3815,  6716,  8577,  4340,   377,   376,  6717,
        3808,  6727,   631,  5163,  1214,   938,  3801,  3798,  1351,
        6741,  3794,  3820,  3791,  1213,  1355,   983,  3778,  1360,
        9063,  1364,   984,  1368,  1369,  3755,  5169,  3743,  4345,
        3741,   366,  3739,   633,  1888,  1376,  4329,  4358,  5199,
        3831,  1380,  5200,   641,   472,  3714,  5057,  1205,  3711,
        3707,  3701,  6784,  4366,  1387,  1388,   358,  3691,  3851,
        3682,  3678,  3665,  3853,  5152,  5201,  8721,   643,  4379,
        1395,   356,   385,  1398,   355,  3628,  5204,  1399,  4380,
        3622,  6146,  3620,  8509,  8733,  3616,   926,   351,  6060,
        3610,  8735,  1404,  3604,  1194,  3600,   347,  3594,  5217,
        6824,  6825,

In [70]:
import pandas as pd

In [71]:
predicted = tracks.loc[indices]
predicted

Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
394,56vivmQy42zMxqd4P8g3CR,5RK6c1tyaKpwcDpbgCGNgj,0.046100,0.389,149133,0.9730,0.000001,10,0.2870,-1.734,39,2002,0.0606,174.080,0.891
2771,3MIOwa113ZHD5qNbgF7JRZ,3RYYha3CC7js2PHbcBHewt,0.077400,0.819,226933,0.7040,0.000000,1,0.0502,-5.164,36,2015,0.1490,97.981,0.917
3252,6xZ6yd2vcUaQAhPt7V1whk,2aaLAng2L2aWD2FClzwiep,0.000556,0.367,725400,0.9440,0.005140,5,0.0404,-3.411,35,1997,0.0803,121.973,0.171
3689,7mW2krCDyLwIXn8sgde9wX,4F7Q5NV6h5TSwCainz8S5A,0.973000,0.705,193171,0.0886,0.192000,0,0.1120,-15.026,6,2014,0.0705,86.795,0.683
2384,1lI9AnJCdfFZTVUL2nOjdM,1ATDcv6wTF2U42HPB4qEFz,0.388000,0.677,220009,0.8340,0.000000,11,0.2180,-7.633,39,2010,0.0373,89.985,0.839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5083,0p6jS1waDG2L49Bi3yGeEz,4zef2ByBl6wZGFPvYWve6o,0.986000,0.537,192492,0.4560,0.005550,0,0.0457,-4.739,9,1944,0.0366,88.927,0.622
5758,2YplrdHMBoRdnHgMeHEwHm,70cRZdQywnSFp9pnc2WTCE,0.170000,0.440,185253,0.4660,0.000001,6,0.1180,-9.712,55,1966,0.0284,107.744,0.543
3,6aLmvz0CPeCNHCXK2H5QIC,3vbKDsSS70ZX9D2OcvbZmS,0.303000,0.611,343360,0.5970,0.008200,0,0.0605,-7.864,56,1999,0.0764,147.507,0.385
1892,3JyvSSU0VnlMUsQckyEVfX,4ZgQDCtRqZlhLswVS6MHN4,0.000568,0.497,229531,0.7010,0.000287,7,0.3520,-6.834,67,2019,0.1190,76.918,0.180


In [72]:
user_history = d[d.user_id == 651]

In [73]:
set(user_history['track_id'])

{'0AIT0YZoeDXmi4jHygJhwb',
 '0BCIAYgnjkZ8Qu6jLNEALY',
 '0CoPoNohPHCjvggWHlxgk5',
 '0L9v0qr6ufK4Q1sVEchWYa',
 '0LrwgdLsFaWh9VXIjBRe8t',
 '0S90LE5Z8FOdbui3tLak6t',
 '0bHD1nLe7Nhw55ZGJ92332',
 '0hB7p4rUwkpVyNifHxTXXT',
 '0htsqdzg5LTve6NOS8fNjT',
 '0iA2xKtyWBiefuWFK5EcFP',
 '0j4srnVsqW8qXpZ5zlwzoI',
 '0mYC5tSzybVIm0VJNSdVUe',
 '0mxvBDxJNhrMEsfzAaegBJ',
 '0phT9WYvak5mt0lLe9nuKP',
 '0px1HGSfngr3nt5eOaZH4O',
 '0xcxK6blLdkzsRiw5kjqwx',
 '0y1kmUIVEt959sqpkrN6FP',
 '14H2fE4WEdkEem0kjVeODT',
 '172xKQCZJiwzaYUzs0idwP',
 '1DCeBMEXMOZr3OEGIbViiV',
 '1EGa3uFmvkGChrZeIKxGwB',
 '1PBHPRCdKWoALYc1lO6nDB',
 '1VVVFRDykk4pLQMGssCJWp',
 '1YrhoxYtegw7IDv7JHJTHl',
 '1cTv334fPuMrOq8HUmbvlg',
 '1e5JKmlFG3uSBYXDJLEc7H',
 '1fmym2sCaeWsWcEP3inQSQ',
 '1guIEbEw6v69ubNOuH9Aug',
 '1j59mXbJYgG28eLmE53jMG',
 '1k5Jr3Ir3tXJLrI5ZsJ0fo',
 '1lDIEArY9svbiOAfTx7iAs',
 '1oYeGnAbyODhUHO4jOMW9v',
 '20CvyqKfRsnFVChsAKYA9f',
 '20HhDn0i0STNmkjgFqmvO5',
 '2365Q0PQlSAzy9C7HwKZht',
 '23zzKOubAzAu8xlQdkVvZ2',
 '2A3sgyp07d7ksgjyBHYSUt',
 

In [74]:
len([i for i in predicted if i in user_history['track_id']])

0

In [75]:
set(predicted)

{'acousticness',
 'danceability',
 'duration_ms',
 'energy',
 'id',
 'id_artist',
 'instrumentalness',
 'key',
 'liveness',
 'loudness',
 'popularity',
 'release_year',
 'speechiness',
 'tempo',
 'valence'}

In [76]:
set(predicted).intersection(set(user_history['track_id']))

set()

In [77]:
print(len(set(predicted).intersection(set(user_history['track_id']))) / 1000)

0.0


In [78]:
coords = np.take(model_h.get_item_representations()[1], indices, axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

51.141342

In [79]:
coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=1000), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

  coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=1000), axis=0)


9.556648

In [80]:
def predict_single(user: int, number: int):
    predicted_scores = model_h.predict(user, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
    indices = np.argpartition(predicted_scores, -number)[-number:]
    return tracks.loc[indices]['id']


def predict_multiple(users: list[int], number: int):
    if len(users) == 1:
        return frozenset(predict_single(users[0], number))
    
    predictions = [frozenset(predict_single(u, number)) for u in users]
    return predictions[0].intersection(*predictions[1:])

res= predict_multiple([123, 124], 100)
print(len(res), list(res)[:20])

set(predict_single(123, 100)).difference(set(predict_single(124, 100)))

100 ['7gA8MI78dKel6xn4Hf1nyA', '4LI4YCdKH2fLjyhtbtHXrY', '0xxqV4H2Yjb88HMiB18ZFA', '7HuQOzc3ZEUfSYoNbmu7Bp', '25tNFGmNYOWWo07kKMHhGS', '4xjqY6K2leiRFxpDZu9otC', '5VlUbdBoRs0HmISYpEhwNY', '69812nIHYl6n7etd1KTRt0', '4D4AIVc9iUP3JHUCqIqJpO', '4rS66iKCLms3yumMYACKwJ', '0lVkcHdJQ29VEoz0Rs4Owr', '1F1kllba7qMoii2hswILYw', '3TNp5a0TwJsKckiH9QBIVl', '1Oy571ZnyuXnnMlKjzLD1D', '5nJM1brsQmshrruVvINRKi', '0752HvJ08FSISAhCEn8iWC', '0ICwVvXOLKKzSvs3QKJtYx', '5LfQF0Pe6FKIPGHrDHijMN', '245kBlrNZlkuKvwF9O8RCa', '7lbpqAD1ELBS3ifiEBO34O']


set()

In [88]:
print(predict_single(123, 10))
print(predict_single(124, 10))

303    5LGfqAUio6YGSv04xLijU1
342    7HuQOzc3ZEUfSYoNbmu7Bp
157    55Xrxpi5WADlaefkGB6M96
326    7gA8MI78dKel6xn4Hf1nyA
345    030r88NUQtL1YYiHMnU35O
332    5nJM1brsQmshrruVvINRKi
108    3FMLd0xrILFklTmWpY6TsY
111    19aFpTwTMCv2RWVZWcamoI
312    0FMxHUwc6l3ew3MC2kDKLJ
344    2lgOmaRRwZCCL4l7h9UD3D
Name: id, dtype: object
134    6Fe1dz0yOdi7hVALU0T13t
310    5cWSpmTfItDuppACMrE3I9
342    7HuQOzc3ZEUfSYoNbmu7Bp
303    5LGfqAUio6YGSv04xLijU1
157    55Xrxpi5WADlaefkGB6M96
345    030r88NUQtL1YYiHMnU35O
108    3FMLd0xrILFklTmWpY6TsY
312    0FMxHUwc6l3ew3MC2kDKLJ
344    2lgOmaRRwZCCL4l7h9UD3D
111    19aFpTwTMCv2RWVZWcamoI
Name: id, dtype: object


In [117]:
model_1 = LightFM(
    loss='warp',
    learning_rate=0.1,
    # item_alpha=1e-6,
    # user_alpha=1e-6,
    no_components=30,
)
model_1.fit(
    interactions=interactions,
    item_features=item_features,
    epochs=20,
    num_threads=12,
    verbose=True)

Epoch: 100%|██████████| 20/20 [01:32<00:00,  4.62s/it]


<lightfm.lightfm.LightFM at 0x7fe3f7ffb400>

In [153]:
def predict_single1(user: int, number: int):
    predicted_scores = model_1.predict(user, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
    indices = np.argpartition(predicted_scores, -number)[-number:]
    return tracks.loc[indices]['id']


def predict_multiple1(users: list[int], number: int):
    def predict_single_indices(user: int, number: int):
        predicted_scores = model_1.predict(user, np.arange(interactions.shape[1]), 
                                       item_features=item_features, 
                                       num_threads=12)
        return np.argpartition(predicted_scores, -number)[-number:]
    
    if len(users) == 1:
        return frozenset(predict_single_indices(users[0], number))
    
    predictions = [frozenset(predict_single_indices(u, number)) for u in users]
    return predictions[0].intersection(*predictions[1:])

res= predict_multiple1([123, 124], 100000)

print(res)
print(len(res))

frozenset({0, 3, 5, 7, 10, 14, 17, 18, 21, 24, 26, 28, 33, 34, 35, 38, 39, 41, 42, 43, 44, 47, 49, 50, 51, 58, 59, 60, 63, 65, 66, 69, 70, 73, 75, 76, 78, 79, 80, 83, 84, 88, 90, 91, 92, 93, 98, 103, 105, 147, 160, 162, 165, 166, 169, 171, 176, 181, 182, 183, 184, 187, 188, 191, 194, 197, 198, 199, 201, 203, 206, 208, 209, 211, 213, 215, 218, 220, 222, 223, 225, 226, 228, 231, 234, 236, 237, 239, 240, 242, 243, 245, 246, 249, 251, 252, 253, 255, 256, 257, 258, 259, 261, 262, 264, 268, 269, 272, 273, 276, 277, 279, 280, 281, 283, 285, 286, 287, 289, 290, 292, 294, 296, 299, 301, 348, 350, 351, 352, 354, 359, 360, 362, 365, 368, 370, 371, 372, 375, 380, 382, 383, 385, 386, 387, 388, 389, 391, 392, 393, 394, 395, 397, 398, 402, 404, 405, 406, 409, 410, 411, 413, 421, 425, 426, 427, 429, 430, 431, 434, 437, 440, 445, 448, 449, 450, 452, 454, 455, 457, 459, 460, 461, 466, 467, 469, 470, 471, 473, 474, 475, 478, 480, 481, 482, 484, 491, 494, 495, 496, 497, 498, 501, 502, 503, 504, 505, 506, 

In [154]:
print(predict_single1(123, 10))
print(predict_single1(124, 10))

123408    6vkrqAU2dJEnIl1ao3o5MZ
58663     0aPgN3PJhimxMirIDfkizv
121433    2GJ2BXD803LJEBTgpbf3VJ
82909     0C9kbOQoKzJBhG2GGCl6Yw
42923     2gduBnzi43leGE8EzdaLfn
43543     2fItqNfO1O5iBI4Jg2EXyR
105576    0nCPl0QueSCyeEcWD99qpy
55340     6RRC0SYvyaDc4HYr1rHnjS
63035     2lGUCMcXm7SoOHPMqRZ2pC
46817     7L5oDq4OUFTw63xEquduUG
Name: id, dtype: object
11930     6Gyk7ZHfFWo3d8U7poUEPs
125905    6SNHVUVv1OcB19HbDDyvSj
49619     7EZ6Wb86O8XHCHyGhVZmTo
12058     5i2vqEeJiR5sJpH0BDhFWa
99931     0ySDQy0j7sT2Iu0YUOFjuu
41111     6qViB4rjq6SihbQZIbITnz
66878     4lPml9N28tsDhK9yf8Khjk
63442     795SWRJqgJRQd5MzkfDCPe
17640     7xTeKXVH4bW3Ioq0gXCnjG
112123    7ouKT7R4BgwvAtoyzR8gCc
Name: id, dtype: object


In [159]:
coords = np.take(model_1.get_item_representations()[1], list(res), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

1766116.9877091162

In [156]:
coords = np.take(model_1.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_1.get_item_representations()[1]), size=100), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

  coords = np.take(model_1.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_1.get_item_representations()[1]), size=100), axis=0)


68386904.0

In [157]:
filtered = pd.concat(
    (d.loc[d['user_id'] == 123], d.loc[d['user_id'] == 124], d.loc[d['user_id'] == 125])).drop_duplicates(subset=['track_id'])

tracks_extended = pd.merge(tracks, filtered, how='inner', left_on='id', right_on='track_id', validate="1:1")
random_indices = tracks[tracks.id.isin(tracks_extended['track_id'])].index
random_indices

Index([   192,    676,    853,   2235,   3982,   4716,   5015,   6324,   6402,
         7244,
       ...
       121547, 121760, 121986, 122985, 123020, 123891, 126264, 126548, 126648,
       129342],
      dtype='int64', length=127)

In [158]:
coords = np.take(model_1.get_item_representations()[1], list(random_indices), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

79403090.0