Build dataset


In [4]:
from typing import TypedDict, Type, Any, Callable

import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from spark.config import views
from spark.create_session import create_session

from IPython.display import display

from fitter import Fitter, get_common_distributions

In [5]:
VIEWS = views("stolen")
spark = create_session()

for view, file in VIEWS.items():
    df = spark.read.json(file)
    df.createOrReplaceTempView(view)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/27 17:52:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/12/27 17:52:30 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

In [6]:
sessions_full = spark.sql(f"SELECT * FROM sessions").toPandas()
sessions = spark.sql(f"SELECT DISTINCT user_id, track_id FROM sessions WHERE event_type like 'LIKE' order by user_id, track_id").toPandas()
tracks = spark.sql(
    f"SELECT DISTINCT id, id_artist, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence FROM tracks ").toPandas()

                                                                                

In [7]:
display(sessions_full)
display(sessions)
display(tracks)

Unnamed: 0,event_type,session_id,timestamp,track_id,user_id
0,PLAY,124,2020-04-17T16:43:09,5EmL6IbswQGhfH9AX7ezWd,101
1,LIKE,124,2020-04-17T16:43:55.237000,5EmL6IbswQGhfH9AX7ezWd,101
2,PLAY,124,2020-04-17T16:45:44.733000,67ov0nL5eR7zdx0JfXDqro,101
3,SKIP,124,2020-04-17T16:48:26.836000,67ov0nL5eR7zdx0JfXDqro,101
4,ADVERTISEMENT,124,2020-04-17T16:48:26.836000,,101
...,...,...,...,...,...
10191757,LIKE,269652,2022-12-19T01:03:08.170000,0mW8oH0PZejrFsaNi1Ud8i,20100
10191758,PLAY,269652,2022-12-19T01:05:49.228000,3tOS5wMzYU5tiVztKKKHih,20100
10191759,SKIP,269652,2022-12-19T01:07:15.125000,3tOS5wMzYU5tiVztKKKHih,20100
10191760,PLAY,269652,2022-12-19T01:07:15.125000,5fWqmzQHxmfBVxf2k6JlX4,20100


Unnamed: 0,user_id,track_id
0,101,08dATSKGXhGASauLBtCoO8
1,101,09jtIFItoNKnC86zlzBZ29
2,101,0FDjpGYB8iVHXZiWY7E4AM
3,101,0KAaslGdPc5I6WxmKe3whe
4,101,0NWPxcsf5vdjdiFUI8NgkP
...,...,...
1402213,20100,7r1i1TZUGZQDxR5QHX4Mmx
1402214,20100,7svwP4tC0UYJbCkiCo6Itz
1402215,20100,7tVQg3ov9G0CnXTzqmZVsZ
1402216,20100,7yC5SaMeZJfvFL6lICCulP


Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,1d3KXNYriNnjSdBcTBeam8,3TVifQ5FPcIzzcYSUuJkp9,0.648000,0.301,408000,0.268,0.000000,2,0.0638,-17.810,13,1977,0.0638,177.468,0.275
1,79pNwy5Q95QfeGUyGNXx11,3BCJyAgxvYyeIjQyoBU8XL,0.237000,0.462,336000,0.343,0.021400,9,0.0935,-15.477,10,1983,0.0287,103.597,0.342
2,6OoGZwmtpHuG6FIVkGjfKW,25uGmqV7NJt81bSYlEMKB0,0.092400,0.456,213467,0.932,0.004540,2,0.3220,-7.088,15,1979,0.0804,185.954,0.603
3,6aLmvz0CPeCNHCXK2H5QIC,3vbKDsSS70ZX9D2OcvbZmS,0.303000,0.611,343360,0.597,0.008200,0,0.0605,-7.864,56,1999,0.0764,147.507,0.385
4,63tInNaRRc44t4vxCQ65JA,5ujwfg1AKpM7CGPZhHxs22,0.798000,0.331,247987,0.293,0.000058,2,0.1030,-14.062,12,1980,0.0288,105.270,0.231
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
129643,7bwPQVVelqsawx9pA5gAwF,1owt6WYWjy94FlqNcj1x4U,0.907000,0.303,224160,0.292,0.841000,7,0.1760,-10.707,18,2004,0.0308,64.724,0.277
129644,3Ad4Vd2MLBBpApduPsydXk,32vWCbZh0xZ4o9gkz4PsEU,0.845000,0.536,165533,0.274,0.000089,0,0.1180,-10.758,40,1974,0.0281,86.915,0.462
129645,4z5eRXrNKYAhjtuJA49REb,5vngPClqofybhPERIqQMYd,0.423000,0.793,164133,0.544,0.000210,6,0.0830,-11.928,49,1992,0.0265,105.537,0.942
129646,0JiY190vktuhSGN6aqJdrt,1KCSPY1glIKqW2TotWuXOR,0.000329,0.534,215160,0.870,0.000000,11,0.2410,-3.078,78,2008,0.0425,126.019,0.462


In [8]:
# d = spark.sql(
#     """
#     SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
#     FROM (
#         select user_id, track_id, sum(event_weight) as weight
#         from (
#             SELECT user_id, track_id, CASE WHEN event_type like 'like' THEN 1 ELSE 0.02 END as event_weight
#             FROM sessions
#             WHERE event_type like 'like' or event_type like 'play'
#             ) 
#         group by user_id, track_id
#     ) s
#     inner join tracks t on s.track_id = t.id
#     order by s.user_id, t.id
#     """).toPandas()
d = spark.sql(
    """
    SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
    FROM (
        select user_id, track_id, sum(event_weight) as weight
        from (
            SELECT user_id, track_id, 1 as event_weight
            FROM sessions
            WHERE event_type like 'LIKE'
            ) 
        group by user_id, track_id
    ) s
    inner join tracks t on s.track_id = t.id
    order by s.user_id, t.id
    """).toPandas()
d

                                                                                

Unnamed: 0,user_id,track_id,weight,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,101,08dATSKGXhGASauLBtCoO8,1,0.882000,0.799,166693,0.559,0.381000,0,0.0996,-4.033,1,1944,0.0353,106.396,0.696
1,101,09jtIFItoNKnC86zlzBZ29,1,0.000480,0.509,254840,0.796,0.460000,2,0.1610,-11.728,34,1983,0.0359,143.664,0.652
2,101,0FDjpGYB8iVHXZiWY7E4AM,1,0.011000,0.503,214107,0.864,0.003820,5,0.5990,-6.829,34,1978,0.1620,154.979,0.551
3,101,0KAaslGdPc5I6WxmKe3whe,1,0.000101,0.169,314560,0.949,0.618000,7,0.2970,-10.596,41,1990,0.1140,159.678,0.147
4,101,0NWPxcsf5vdjdiFUI8NgkP,1,0.006030,0.346,210160,0.768,0.380000,9,0.0244,-5.695,72,1967,0.0377,169.492,0.532
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1402213,20100,7r1i1TZUGZQDxR5QHX4Mmx,1,0.002940,0.570,329979,0.777,0.181000,7,0.0477,-6.537,47,1987,0.0270,107.982,0.487
1402214,20100,7svwP4tC0UYJbCkiCo6Itz,1,0.062300,0.457,213000,0.915,0.000003,3,0.0778,-5.876,19,1976,0.0808,126.695,0.446
1402215,20100,7tVQg3ov9G0CnXTzqmZVsZ,1,0.242000,0.247,671760,0.742,0.370000,2,0.9680,-12.678,24,1978,0.0571,143.533,0.466
1402216,20100,7yC5SaMeZJfvFL6lICCulP,1,0.150000,0.417,176000,0.667,0.000000,4,0.2170,-6.404,21,1978,0.0367,131.796,0.556


In [197]:
from scipy import stats

users = d['user_id']
items = d['track_id']


# _features = pd.DataFrame(tracks['track_id'], columns=['track_id']).merge(tracks, left_on="track_id", right_on="id", how="inner", validate="1:1" )



# _features

pandas.core.series.Series

In [285]:
from lightfm.data import Dataset

# assert d['track_id'].unique().shape[0] == features.shape[0]

dataset = Dataset()
dataset.fit(
    users=users,
    items=tracks['id']
)

dataset.fit_partial(
    items=tracks['id'],
    item_features=tracks.drop('id', axis=1)
)

num_users, num_items = dataset.interactions_shape()
print('Num users: {}, num_items {}.'.format(num_users, num_items))

Num users: 20000, num_items 129648.


In [312]:
(interactions, weights) = dataset.build_interactions(d[['user_id', 'track_id']].apply(tuple, axis=1))

print(repr(interactions))

<20000x129648 sparse matrix of type '<class 'numpy.int32'>'
	with 1402218 stored elements in COOrdinate format>


In [12]:
from lightfm.cross_validation import random_train_test_split

(train, test) = random_train_test_split(interactions)


In [13]:
from lightfm import LightFM
from lightfm.evaluation import auc_score, precision_at_k, recall_at_k, reciprocal_rank


In [313]:
feature_names = d.drop(['user_id', 'track_id', 'weight'], axis=1).columns

item_features = dataset.build_item_features(
    ((i, feature_names) for i in items), 
    normalize=False)
print(repr(item_features))

<129648x129662 sparse matrix of type '<class 'numpy.float32'>'
	with 268852 stored elements in Compressed Sparse Row format>


In [16]:
# warp, lr=0.01, no_comp=3-, aplha=1e-6, 5 epochs

model_h = LightFM(
    loss='warp',
    learning_rate=0.003,
    # item_alpha=1e-6,
    # user_alpha=1e-6,
    no_components=30,
)
model_h.fit(
    interactions=train,
    item_features=item_features,
    epochs=5,
    num_threads=12,
    verbose=True)


Epoch: 100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


<lightfm.lightfm.LightFM at 0x7fbf9e1784f0>

In [14]:
train_auc_h = auc_score(model_h, train, 
                        item_features=item_features,
                        num_threads=12).mean()
test_auc_h = auc_score(model_h, test, 
                       train_interactions=train, 
                       item_features=item_features, 
                       num_threads=12).mean()

train_precision_h = precision_at_k(model_h, train, k=10,
                                   item_features=item_features, 
                                   num_threads=12).mean()
test_precision_h = precision_at_k(model_h, test, k=10, 
                                  train_interactions=train,
                                  item_features=item_features, 
                                  num_threads=12).mean()

train_recall_h = recall_at_k(model_h, train, k=10, 
                             item_features=item_features,
                             num_threads=12).mean()
test_recall_h = recall_at_k(model_h, test, k=10, 
                            train_interactions=train,
                            item_features=item_features, 
                            num_threads=12).mean()

train_reciprocal_rank_h = reciprocal_rank(model_h, train,
                                          item_features=item_features,
                                          num_threads=12).mean()
test_reciprocal_rank_h = reciprocal_rank(model_h, test, 
                                         train_interactions=train,
                                         item_features=item_features,
                                         num_threads=12).mean()

print('AUC: train %.6f, test %.6f.' % (train_auc_h, test_auc_h))
print('Precision: train %.6f, test %.6f.' % (train_precision_h, test_precision_h))
print('Recall: train %.6f, test %.6f.' % (train_recall_h, test_recall_h))
print('Reciprocal rank: train %.6f, test %.6f.' % (train_reciprocal_rank_h, test_reciprocal_rank_h))


AUC: train 0.899586, test 0.899762.
Precision: train 0.159950, test 0.069711.
Recall: train 0.042400, test 0.070126.
Reciprocal rank: train 0.252052, test 0.147722.


In [266]:
predicted_scores = model_h.predict(46, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
predicted_scores

array([ 2.1368147e+03,  3.2872285e+03,  1.3767339e+03, ...,
       -1.8519495e+01,  1.1457134e-02, -1.4957775e-02], dtype=float32)

In [267]:
indices = np.argpartition(predicted_scores, -1000)[-1000:]
indices

array([ 5649,  4000,  7698,  2238,  8073,  3496,   849,  4594,  1736,
        1128,  5371,  6404,   920,  3309,  4568,  3280,   361,  3245,
        3222,  7409,  6815,  8598,  4129,  5906,  2123,  3164,  6200,
         537,  2626,   238,  1485,  7949,  7353,  1073,  2651,  7878,
         196,  1862,  1505,  5508,  1518,  3958,  9049,  2704,  2713,
        1623,  1787,    66,  2886,  7753,   797,  2768,  1299,  9443,
        1961,  1408,  2406,  4683,  4682,  7249,  1194,  4687,  6270,
        1869,  7239,  2196,  7229,   796,   795,  3972,  4694,  1205,
        4696,  7223,  6263,  3990,  1190,  4679,  4703,  1213,  4004,
        1214,  2190,  4713,  7303,  2212,  7181,  7307,  2213,  4677,
        2214,  1856,   772,  7314,  7316,  1853,  3948,   767,  6312,
         825,  7333,  6211,  2221,  7334,  4014,  2222,  1226,  7344,
        4728,  7178,  1228,  7173,  7171,  7366,  1842,  7370,  7371,
         748,  3932,  2229,  7375,  1841,  7170,  4663,  7383,  2185,
        1836,   742,

In [268]:
import pandas as pd

In [269]:
predicted = tracks.loc[indices]
predicted

Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
5649,1Ga880sLVmlPm53x2AQEvo,1luOe8HkZQ7zwuaO2wuJqI,0.000758,0.386,357760,0.9300,0.001870,11,0.1360,-3.633,33,2012,0.1020,150.073,0.266
4000,21Zn4iMPBEuGtukbb1KpGz,0nJUwPwC9Ti4vvuJ0q3MfT,0.907000,0.533,357400,0.2000,0.010500,0,0.2330,-17.792,37,1991,0.0710,149.007,0.490
7698,3jrQvfuHv7dTHAYH14wCWP,3pBvCynt1s5QElgsbyFdTg,0.012000,0.666,312427,0.6840,0.000005,8,0.0737,-10.247,26,2005,0.0255,113.913,0.856
2238,75d5rdmbGxE0JRXYc6EJR8,2GaayiIs1kcyNqRXQuzp35,0.914000,0.663,169640,0.4370,0.857000,9,0.1330,-17.023,21,1960,0.0328,124.445,0.703
8073,6vIPZobhXhE19i9EJlRj0w,380DW51qbu5pSP8crFRIII,0.193000,0.432,292813,0.6820,0.000256,6,0.0651,-5.331,29,2002,0.0276,176.237,0.507
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5370,387oNtAhlKTbeswo91MRAy,66aVJ4ocSmKcdrRUCXR8j0,0.123000,0.704,198547,0.6950,0.035900,7,0.5330,-8.313,42,1993,0.0556,97.561,0.598
3,6aLmvz0CPeCNHCXK2H5QIC,3vbKDsSS70ZX9D2OcvbZmS,0.303000,0.611,343360,0.5970,0.008200,0,0.0605,-7.864,56,1999,0.0764,147.507,0.385
2802,5B1W3KrkWlwSJlSufX5AT7,6l40OFJhuTbHQ9V12evc9K,0.965000,0.627,252000,0.1060,0.782000,2,0.7490,-19.705,7,1954,0.0495,101.195,0.394
1575,4XWItTtGvpXeCFBduYg0F6,3XYnsz11kFrkx8F3upvnE5,0.966000,0.561,408267,0.0422,0.899000,3,0.1250,-18.982,18,1956,0.0547,123.394,0.128


In [270]:
user_history = d[d.user_id == 651]

In [271]:
set(user_history['track_id'])

{'0AIT0YZoeDXmi4jHygJhwb',
 '0BCIAYgnjkZ8Qu6jLNEALY',
 '0CoPoNohPHCjvggWHlxgk5',
 '0L9v0qr6ufK4Q1sVEchWYa',
 '0LrwgdLsFaWh9VXIjBRe8t',
 '0S90LE5Z8FOdbui3tLak6t',
 '0bHD1nLe7Nhw55ZGJ92332',
 '0hB7p4rUwkpVyNifHxTXXT',
 '0htsqdzg5LTve6NOS8fNjT',
 '0iA2xKtyWBiefuWFK5EcFP',
 '0j4srnVsqW8qXpZ5zlwzoI',
 '0mYC5tSzybVIm0VJNSdVUe',
 '0mxvBDxJNhrMEsfzAaegBJ',
 '0phT9WYvak5mt0lLe9nuKP',
 '0px1HGSfngr3nt5eOaZH4O',
 '0xcxK6blLdkzsRiw5kjqwx',
 '0y1kmUIVEt959sqpkrN6FP',
 '14H2fE4WEdkEem0kjVeODT',
 '172xKQCZJiwzaYUzs0idwP',
 '1DCeBMEXMOZr3OEGIbViiV',
 '1EGa3uFmvkGChrZeIKxGwB',
 '1PBHPRCdKWoALYc1lO6nDB',
 '1VVVFRDykk4pLQMGssCJWp',
 '1YrhoxYtegw7IDv7JHJTHl',
 '1cTv334fPuMrOq8HUmbvlg',
 '1e5JKmlFG3uSBYXDJLEc7H',
 '1fmym2sCaeWsWcEP3inQSQ',
 '1guIEbEw6v69ubNOuH9Aug',
 '1j59mXbJYgG28eLmE53jMG',
 '1k5Jr3Ir3tXJLrI5ZsJ0fo',
 '1lDIEArY9svbiOAfTx7iAs',
 '1oYeGnAbyODhUHO4jOMW9v',
 '20CvyqKfRsnFVChsAKYA9f',
 '20HhDn0i0STNmkjgFqmvO5',
 '2365Q0PQlSAzy9C7HwKZht',
 '23zzKOubAzAu8xlQdkVvZ2',
 '2A3sgyp07d7ksgjyBHYSUt',
 

In [272]:
len([i for i in predicted if i in user_history['track_id']])

0

In [273]:
set(predicted)

{'acousticness',
 'danceability',
 'duration_ms',
 'energy',
 'id',
 'id_artist',
 'instrumentalness',
 'key',
 'liveness',
 'loudness',
 'popularity',
 'release_year',
 'speechiness',
 'tempo',
 'valence'}

In [274]:
set(predicted).intersection(set(user_history['track_id']))

set()

In [275]:
print(len(set(predicted).intersection(set(user_history['track_id']))) / 1000)

0.0


In [276]:
coords = np.take(model_h.get_item_representations()[1], indices, axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

2.3959072

In [277]:
coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=1000), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

  coords = np.take(model_h.get_item_representations()[1], np.random.random_integers(low=0, high=len(model_h.get_item_representations()[1]), size=1000), axis=0)


4.5423965

In [278]:
def predict_single(user: int, number: int):
    predicted_scores = model_h.predict(user, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
    indices = np.argpartition(predicted_scores, -number)[-number:]
    return tracks.loc[indices]['id']


def predict_multiple(users: list[int], number: int):
    if len(users) == 1:
        return frozenset(predict_single(users[0], number))
    
    predictions = [frozenset(predict_single(u, number)) for u in users]
    return predictions[0].intersection(*predictions[1:])

res= predict_multiple([123, 124], 100)
print(len(res), list(res)[:20])

set(predict_single(123, 100)).difference(set(predict_single(124, 100)))

100 ['1TvQRwofUzq2U4Q9jbbXXi', '55Xrxpi5WADlaefkGB6M96', '4rS66iKCLms3yumMYACKwJ', '0lVkcHdJQ29VEoz0Rs4Owr', '0bzTbaANQ33XXQG3d4C7Ho', '27awWR6uOBSBYlDPLKyt46', '4kYKr9MQu23Y936UZTw3op', '1QUjZcrx5VRZM23y5iu72w', '3FMLd0xrILFklTmWpY6TsY', '2rCacchqI6MrR1E7toRZrK', '6XoaOnQ4PzEGtXQ8eWMkNE', '2o8U5FPafHmHuVGd3suyv7', '2hSdHOOBLXNkh7zmmgoAn3', '1jhc9VMo0Lq0VQ4hsX3KMH', '3aone8W3MeRtNOXY2c6J9P', '38JSatkJEuhIKODWyJhZxq', '3yRNGKhJwZ8u0VqQWwp1Kz', '6SyLK7iMViKdsYWuOWX21L', '1Oy571ZnyuXnnMlKjzLD1D', '2XGCCel9b78Gkkfu5c3j2g']


set()

In [314]:
model_1 = LightFM(
    loss='warp',
    learning_rate=0.003,
    # item_alpha=1e-6,
    # user_alpha=1e-6,
    no_components=30,
)
model_1.fit(
    interactions=interactions,
    item_features=item_features,
    epochs=5,
    num_threads=12,
    verbose=True)

Epoch: 100%|██████████| 5/5 [00:11<00:00,  2.30s/it]


<lightfm.lightfm.LightFM at 0x7fbecc6cf1f0>

In [315]:
def predict_single1(user: int, number: int):
    predicted_scores = model_1.predict(user, np.arange(interactions.shape[1]), 
                                   item_features=item_features, 
                                   num_threads=12)
    indices = np.argpartition(predicted_scores, -number)[-number:]
    return tracks.loc[indices]['id']


def predict_multiple1(users: list[int], number: int):
    def predict_single_indices(user: int, number: int):
        predicted_scores = model_1.predict(user, np.arange(interactions.shape[1]), 
                                       item_features=item_features, 
                                       num_threads=12)
        return np.argpartition(predicted_scores, -number)[-number:]
    
    if len(users) == 1:
        return frozenset(predict_single_indices(users[0], number))
    
    predictions = [frozenset(predict_single_indices(u, number)) for u in users]
    return predictions[0].intersection(*predictions[1:])

res= predict_multiple1([123, 124, 125], 44100)

print(res)
print(len(res))

frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,

In [316]:
predict_single1(123, 10)

134    6Fe1dz0yOdi7hVALU0T13t
310    5cWSpmTfItDuppACMrE3I9
303    5LGfqAUio6YGSv04xLijU1
342    7HuQOzc3ZEUfSYoNbmu7Bp
157    55Xrxpi5WADlaefkGB6M96
108    3FMLd0xrILFklTmWpY6TsY
345    030r88NUQtL1YYiHMnU35O
312    0FMxHUwc6l3ew3MC2kDKLJ
111    19aFpTwTMCv2RWVZWcamoI
344    2lgOmaRRwZCCL4l7h9UD3D
Name: id, dtype: object

In [317]:
predict_single1(124, 10)

134    6Fe1dz0yOdi7hVALU0T13t
310    5cWSpmTfItDuppACMrE3I9
303    5LGfqAUio6YGSv04xLijU1
342    7HuQOzc3ZEUfSYoNbmu7Bp
157    55Xrxpi5WADlaefkGB6M96
345    030r88NUQtL1YYiHMnU35O
312    0FMxHUwc6l3ew3MC2kDKLJ
108    3FMLd0xrILFklTmWpY6TsY
344    2lgOmaRRwZCCL4l7h9UD3D
111    19aFpTwTMCv2RWVZWcamoI
Name: id, dtype: object

In [318]:
coords = np.take(model_1.get_item_representations()[1], list(res), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

1.4206786

In [319]:
coords = np.take(model_1.get_item_representations()[1], np.random.randint(low=0, high=len(model_1.get_item_representations()[1]), size=100), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

12.770634

In [320]:
filtered = pd.concat(
    (d.loc[d['user_id'] == 123], d.loc[d['user_id'] == 124], d.loc[d['user_id'] == 125])).drop_duplicates(subset=['track_id'])

tracks_extended = pd.merge(tracks, filtered, how='inner', left_on='id', right_on='track_id', validate="1:1")
random_indices = tracks[tracks.id.isin(tracks_extended['track_id'])].index
random_indices

Index([   192,    676,    853,   2235,   3982,   4716,   5015,   6324,   6402,
         7244,
       ...
       121547, 121760, 121986, 122985, 123020, 123891, 126264, 126548, 126648,
       129342],
      dtype='int64', length=127)

In [321]:
coords = np.take(model_1.get_item_representations()[1], list(random_indices), axis=0)
center = np.sum(coords, axis=0) / coords.shape[0]
mean_dist = np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))
mean_dist

12.530965