# Model usage example

In [1]:
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np

from spark.config import views
from spark.create_session import create_session

from IPython.display import display


Create Spark session and load the data. Files should be stored in the `data/{version}` directory under project root.

In [2]:
VIEWS = views("v4")
spark = create_session()

for view, file in VIEWS.items():
    df = spark.read.json(file)
    df.createOrReplaceTempView(view)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/01/19 23:54:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/01/19 23:54:17 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

Query for users, tracks and sessions. Sessions query is performed here only to preview the data, a more complex sessions query will be performed later on.

In [3]:
sessions_full = spark.sql(f"SELECT * FROM sessions").toPandas()
sessions = spark.sql(
    f"SELECT DISTINCT user_id, track_id FROM sessions WHERE event_type like 'like' order by user_id, track_id").toPandas()
_tracks = spark.sql(
    f"SELECT DISTINCT id, id_artist, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence FROM tracks ").toPandas()
users = spark.sql(f"SELECT user_id FROM users").toPandas()

                                                                                

In [4]:
display(sessions_full)
display(sessions)
display(_tracks)
display(users)

Unnamed: 0,event_type,session_id,timestamp,track_id,user_id
0,play,124,2023-05-30T23:37:17,0qbV4e18lWrTTxlswVoLbI,101
1,advertisment,124,2023-05-30T23:44:37.941000,,101
2,like,124,2023-05-30T23:44:37.941000,0qbV4e18lWrTTxlswVoLbI,101
3,skip,124,2023-05-30T23:44:46.212000,0qbV4e18lWrTTxlswVoLbI,101
4,play,124,2023-05-30T23:44:51.212000,3pyTksNccLM1jRvzQ4zTke,101
...,...,...,...,...,...
12340426,play,878516,2022-12-27T11:20:54.236000,2nxzEMUmIhSJeGBaXgh30x,5100
12340427,play,878516,2022-12-27T11:26:11.409000,19K3lUMJmOdeuOBTrbLm19,5100
12340428,play,878516,2022-12-27T11:30:27.182000,1PhLYngBKbeDtdmDzCg3Pb,5100
12340429,like,878516,2022-12-27T11:32:22.571000,1PhLYngBKbeDtdmDzCg3Pb,5100


Unnamed: 0,user_id,track_id
0,101,01q4ccXbvPlCwZ1fPiFaeM
1,101,02ePjHjIiszSYqeLykvpTN
2,101,02ppMPbg1OtEdHgoPqoqju
3,101,03EnOL1O9EKi9CFNmPyrCm
4,101,0534jmQ0dYChW5MSzYXNVr
...,...,...
2568073,5100,7xD7MvjAdZkx1YICschIuI
2568074,5100,7xa3dJQMzBVzsrZ81tNcHP
2568075,5100,7yYvvOB7CuzdVldb6zOk1m
2568076,5100,7ygpwy2qP3NbrxVkHvUhXY


Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,4y8icjzu6fZP503Mg31Tpn,5V0MlUE1Bft0mbLlND7FJz,0.87800,0.184,224227,0.292,0.000258,7,0.3460,-12.246,55,1956,0.0312,75.630,0.1430
1,4hHbeIIKO5Y5uLyIEbY9Gn,1Mxqyy3pSjf8kZZL4QVxS0,0.84500,0.574,199093,0.338,0.000000,6,0.1650,-11.376,69,1958,0.0420,67.008,0.4930
2,0KSHmjK7OFtGocvbo7NZNO,6kACVPfCOnqzgfEF5ryl0x,0.80300,0.754,107280,0.484,0.000009,10,0.0989,-9.196,52,1959,0.0609,82.761,0.8010
3,30V02AmDYMRvrHE4L8cZAo,22bE4uQ6baNwSHPVcDxLCe,0.17800,0.578,133320,0.429,0.000060,2,0.1710,-9.601,57,1965,0.0268,108.738,0.0889
4,2CQRYn5cTD2B9a1ONjhTN2,3oDbviiivRWhXwIE8hxkVV,0.72500,0.396,145800,0.686,0.000001,5,0.1700,-5.607,57,1966,0.0341,124.318,0.7170
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22407,3atrFJhUzDm5xiutqfEPkP,74XFHRwlV6OrjEM0A2NCMF,0.00576,0.489,218333,0.611,0.000008,6,0.0917,-5.675,52,2007,0.0337,140.956,0.1860
22408,3wYzd4dzWknPdgwNaPGQvg,7JthQ6zwNzfxRfIEjp6wUs,0.06960,0.607,191474,0.605,0.000000,3,0.4540,-6.050,65,2021,0.1590,83.913,0.6650
22409,6UIcN1tiiGdd7oMMzNvyaP,7lbSsjYACZHn1MSDXPxNF2,0.06010,0.479,297502,0.478,0.000007,5,0.0813,-8.884,53,2020,0.1100,100.040,0.1170
22410,62i2LI6iDmAHbE9H9wa99z,7k73EtZwoPs516ZxE72KsO,0.38500,0.470,216970,0.738,0.000621,8,0.0656,-4.641,52,2017,0.0398,97.973,0.4880


Unnamed: 0,user_id
0,101
1,102
2,103
3,104
4,105
...,...
4995,5096
4996,5097
4997,5098
4998,5099


Standardize all numeric features in `tracks` ($\mu=0$, $\sigma=1$).

In [5]:
from scipy import stats

# normalize track features
tracks = pd.concat([_tracks[['id', 'id_artist']], _tracks.drop(['id', 'id_artist'], axis=1).apply(stats.zscore)],
                   axis=1)
display(tracks)

Unnamed: 0,id,id_artist,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,4y8icjzu6fZP503Mg31Tpn,5V0MlUE1Bft0mbLlND7FJz,2.210183,-2.622956,-0.081994,-1.698778,-0.261491,0.483649,0.910200,-1.350924,-0.792191,-3.602614,-0.579005,-1.562809,-1.533923
1,4hHbeIIKO5Y5uLyIEbY9Gn,1Mxqyy3pSjf8kZZL4QVxS0,2.090663,-0.158580,-0.430623,-1.479284,-0.263179,0.202650,-0.167568,-1.118180,0.947358,-3.458942,-0.458085,-1.853949,-0.102263
2,0KSHmjK7OFtGocvbo7NZNO,6kACVPfCOnqzgfEF5ryl0x,1.938547,0.978824,-1.704140,-0.782628,-0.263120,1.326645,-0.561162,-0.534983,-1.164951,-3.387106,-0.246475,-1.322015,1.157598
3,30V02AmDYMRvrHE4L8cZAo,22bE4uQ6baNwSHPVcDxLCe,-0.325086,-0.133305,-1.342945,-1.045067,-0.262786,-0.921346,-0.131841,-0.643329,-0.543684,-2.956090,-0.628269,-0.444846,-1.755216
4,2CQRYn5cTD2B9a1ONjhTN2,3oDbviiivRWhXwIE8hxkVV,1.656046,-1.283347,-1.169838,0.181238,-0.263172,-0.078349,-0.137796,0.425153,-0.543684,-2.884254,-0.546536,0.081246,0.814000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22407,3atrFJhUzDm5xiutqfEPkP,74XFHRwlV6OrjEM0A2NCMF,-0.948906,-0.695688,-0.163749,-0.176633,-0.263129,0.202650,-0.604035,0.406961,-1.164951,0.061025,-0.551015,0.643063,-1.358033
22408,3wYzd4dzWknPdgwNaPGQvg,7JthQ6zwNzfxRfIEjp6wUs,-0.717690,0.049944,-0.536304,-0.205263,-0.263179,-0.640347,1.553288,0.306641,0.450344,1.066730,0.851883,-1.283116,0.601296
22409,6UIcN1tiiGdd7oMMzNvyaP,7lbSsjYACZHn1MSDXPxNF2,-0.752097,-0.758877,0.934387,-0.811258,-0.263132,-0.078349,-0.665962,-0.451516,-1.040697,0.994894,0.303264,-0.738553,-1.640275
22410,62i2LI6iDmAHbE9H9wa99z,7k73EtZwoPs516ZxE72KsO,0.424630,-0.815747,-0.182655,0.429362,-0.259116,0.764647,-0.759448,0.683579,-1.164951,0.779385,-0.482717,-0.808350,-0.122715


Query for sessions:
- Group user-track interactions (this can be used to create event weights later on)
- Exclude tracks from `sessions.jsonl` not present in `tracks.jsonl`
- Filter event type

In [6]:
d = spark.sql(
    """
    SELECT s.user_id, s.track_id, s.weight, acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, popularity, EXTRACT(year from `release_date`) as release_year, speechiness, tempo, valence
    FROM (
        select user_id, track_id, sum(event_weight) as weight
        from (
            SELECT user_id, track_id, 1 as event_weight
            FROM sessions
            WHERE event_type like 'like'
            ) 
        group by user_id, track_id
    ) s
    inner join tracks t on s.track_id = t.id
    order by s.user_id, t.id
    """).toPandas()
d

                                                                                

Unnamed: 0,user_id,track_id,weight,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,popularity,release_year,speechiness,tempo,valence
0,101,01q4ccXbvPlCwZ1fPiFaeM,1,0.126000,0.680,122760,0.626,0.000001,7,0.0995,-8.519,57,1987,0.0255,104.333,0.960
1,101,02ePjHjIiszSYqeLykvpTN,2,0.159000,0.487,220920,0.909,0.000001,9,0.3030,-3.484,54,1995,0.0463,125.418,0.510
2,101,02ppMPbg1OtEdHgoPqoqju,1,0.511000,0.523,290213,0.656,0.160000,4,0.0679,-7.441,63,2007,0.0262,104.271,0.214
3,101,03EnOL1O9EKi9CFNmPyrCm,1,0.497000,0.555,355640,0.460,0.000000,9,0.1030,-7.032,69,2005,0.0277,67.027,0.146
4,101,0534jmQ0dYChW5MSzYXNVr,1,0.019400,0.590,265960,0.711,0.000000,0,0.1220,-7.589,52,2011,0.0247,143.937,0.847
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2568073,5100,7xD7MvjAdZkx1YICschIuI,1,0.663000,0.544,224667,0.884,0.000007,8,0.3810,-8.166,53,1976,0.1440,127.215,0.467
2568074,5100,7xa3dJQMzBVzsrZ81tNcHP,2,0.000094,0.319,203653,0.965,0.002210,7,0.1410,-4.354,52,1992,0.0646,127.721,0.624
2568075,5100,7yYvvOB7CuzdVldb6zOk1m,1,0.001240,0.659,288973,0.846,0.000000,7,0.0415,-7.579,52,1998,0.1110,92.803,0.595
2568076,5100,7ygpwy2qP3NbrxVkHvUhXY,1,0.000509,0.375,258773,0.893,0.000000,2,0.1870,-4.097,76,1995,0.0406,174.323,0.434


Create a LightFm `Dataset`.

In [7]:
from lightfm.data import Dataset

dataset = Dataset()
# all user and track ids
dataset.fit(
    users=users['user_id'],
    items=tracks['id']
)

dataset.fit_partial(
    items=tracks['id'],
    # must follow the same order
    item_features=tracks.drop('id', axis=1)
)

num_users, num_items = dataset.interactions_shape()
print('Num users: {}, num_items {}.'.format(num_users, num_items))

Num users: 5000, num_items 22412.


Build the interaction matrix from user-track tuples. Weights can also be supplied.

In [8]:
(interactions, weights) = dataset.build_interactions(d[['user_id', 'track_id']].apply(tuple, axis=1))

print(repr(interactions))

<5000x22412 sparse matrix of type '<class 'numpy.int32'>'
	with 2568078 stored elements in COOrdinate format>


Split the dataset 80:20 for training and testing.

In [9]:
from lightfm.cross_validation import random_train_test_split

(train, test) = random_train_test_split(interactions)

In [10]:
from lightfm import LightFM
from lightfm.evaluation import auc_score, precision_at_k, recall_at_k, reciprocal_rank

Build an item feature matrix. This is a LightFm specific step. As we already passed feature values, we only need to supply item-feature name list tuples. LightFM uses a normalization mechanism for feature values which scales them to ensure all weights in every row sum up to 1 - we probably do not want this to happen. 

In [11]:
feature_names = tracks.drop(['id'], axis=1).columns

item_features = dataset.build_item_features(
    # track id + list of feature name pairs
    ((i, feature_names) for i in tracks['id']),
    normalize=False)
print(repr(item_features))

<22412x22426 sparse matrix of type '<class 'numpy.float32'>'
	with 336180 stored elements in Compressed Sparse Row format>


Train the model.

In [12]:
model_h = LightFM(
    loss='warp',
    learning_rate=0.034,
    item_alpha=1e-3,
    user_alpha=1e-3,
    no_components=60,
)
model_h.fit(
    interactions=train,
    item_features=item_features,
    epochs=10,
    num_threads=12,
    verbose=True)

Epoch: 100%|██████████| 10/10 [04:05<00:00, 24.51s/it]


<lightfm.lightfm.LightFM at 0x7f1f1ddf9240>

Evaluate metrics for single-user predictions. Train set is also included for sanity-checking (plus, to compare how well the model is able to generalize).

In [13]:
print("0/8 done...", end="")
train_auc_h = auc_score(model_h, train,
                        item_features=item_features,
                        num_threads=12).mean()
print("\r1/8 done...", end="")
test_auc_h = auc_score(model_h, test,
                       train_interactions=train,
                       item_features=item_features,
                       num_threads=12).mean()
print("\r2/8 done...", end="")

train_precision_h = precision_at_k(model_h, train, k=10,
                                   item_features=item_features,
                                   num_threads=12).mean()
print("\r3/8 done...", end="")
test_precision_h = precision_at_k(model_h, test, k=10,
                                  train_interactions=train,
                                  item_features=item_features,
                                  num_threads=12).mean()
print("\r4/8 done...", end="")

train_recall_h = recall_at_k(model_h, train, k=10,
                             item_features=item_features,
                             num_threads=12).mean()
print("\r5/8 done...", end="")
test_recall_h = recall_at_k(model_h, test, k=10,
                            train_interactions=train,
                            item_features=item_features,
                            num_threads=12).mean()
print("\r6/8 done...", end="")

train_reciprocal_rank_h = reciprocal_rank(model_h, train,
                                          item_features=item_features,
                                          num_threads=12).mean()
print("\r7/8 done...", end="")
test_reciprocal_rank_h = reciprocal_rank(model_h, test,
                                         train_interactions=train,
                                         item_features=item_features,
                                         num_threads=12).mean()
print("\r8/8 done...")

print('AUC: train %.6f, test %.6f.' % (train_auc_h, test_auc_h))
print('Precision: train %.6f, test %.6f.' % (train_precision_h, test_precision_h))
print('Recall: train %.6f, test %.6f.' % (train_recall_h, test_recall_h))
print('Reciprocal rank: train %.6f, test %.6f.' % (train_reciprocal_rank_h, test_reciprocal_rank_h))

8/8 done...
AUC: train 0.937602, test 0.934772.
Precision: train 0.350780, test 0.233520.
Recall: train 0.024163, test 0.096696.
Reciprocal rank: train 0.539575, test 0.346861.


Process raw scores returned by the model to get predicted track ids for both single-user and multi-user predictions.


**NOTE: these methods work a little different from recommender protocol**

In [14]:
def predict_single(user: int, number: int) -> pd.Series:
    """
    Recommends `number` of tracks to user of ID `user`.
    
    Args:
        user (int): user ID
        number (int): number of recommendations to make
        
    Returns:
         pd.Series: Predicted tracks. The series is indexed with track indices in item list passed when training model, series values are track ID's. 
    """
    predicted_scores = model_h.predict(user, np.arange(interactions.shape[1]),
                                       item_features=item_features,
                                       num_threads=12)

    return pd.concat([tracks['id'], pd.Series(predicted_scores)], axis=1).nlargest(number, 0).sort_values(by=0,
                                                                                                          ascending=False)[
        'id']


def predict_multiple(users: list[int], number: int) -> pd.Series:
    """
    Recommend `number` of tracks to users of ID's `users`.
        
    Args:
        users (list[int]): user ID's (at least two)
        number (int): number of recommendations to make
        
    Returns:
         pd.Series: Predicted tracks. The series is indexed with track ID's (NOTE - different from `predict_single`), series values are mean ranks of items across single-user prediction for all users. The series is ordered by mean ranks.
         
    Raises:
        ValueError: If 0 or 1 user ID supplied.
    """
    if len(users) <= 1:
        raise ValueError("To make single-user predictions, use `predict_single`")

    predictions = [predict_single(u, tracks.size) for u in users]
    common = frozenset(predictions[0]).intersection(*(frozenset(p) for p in predictions[1:]))

    indexed_by_rank = (s.reset_index(drop=True).loc[s.reset_index().id.isin(common)].sort_values() for s in predictions)

    return pd.concat([pd.Series(s.index.values, index=s) for s in indexed_by_rank], axis=1).mean('columns').nsmallest(
        number)


predict_multiple([123, 124, 125, 126, 127, 128], 20)


id
2iUmqdfGZcHIhS3b9E9EWq    1792.833333
0K1KOCeJBj3lpDYxEX9qP2    1827.666667
3UH4JIDuP83866Y43bbo4k    1919.333333
562oeuAN5GH85iE7FQVKSB    2014.666667
5aV4HUW9RFOB0aXq0Ud9s0    2065.000000
7iPlcFvOMOzt6v0QvcAueZ    2094.333333
6HguG9HRb1Ke1bhihfE4m8    2115.000000
1fBl642IhJOE5U319Gy2Go    2122.000000
6C88rHxXBlpcgtBY3HAF0E    2177.166667
2TgxCUZdHFkPEVmFge1OSd    2222.666667
4rmIfFUZhhi9sS5IYtpkXw    2236.000000
5WSdMcWTKRdN1QYVJHJWxz    2271.666667
06cCziAHtDg6pcsidZHu03    2326.833333
6WmIyn2fx1PKQ0XDpYj4VR    2329.333333
3TgMcrV32NUKjEG2ujn9eh    2343.500000
5It7Lw1E8VCSyS6AQv6nPm    2353.833333
4w0joPRhKJsrzq1euxkpZ2    2361.833333
7h6lpVuSGPW6RNjDXKpYDh    2373.000000
0l4DTppOxy7NUaEcwXuOb6    2397.500000
67ispalOAo3jnZWYSTTfZk    2427.500000
dtype: float64

Make some test predictions for single users. They should not overlap too much.

In [15]:
n = 2000
r1 = predict_single(123, n)
r2 = predict_single(124, n)
print(r1)
print(r2)
overlap = set(r1).intersection(set(r2))
print(overlap)
print(f"{len(overlap)}/{n} ({len(overlap) * 100 / n:.2f}%) items overlaping")

11333    2mt1IqcFyY1zmYZT8Q3xw9
13775    2xLMifQCjDGFmkHkpNLD9h
12090    3aQem4jVGdhtg116TmJnHz
10938    7qEHsqek33rTcFNT9PFqLf
12380    1XXimziG1uhM0eDNCZCrUl
                  ...          
11315    2lnzGkdtDj5mtlcOW2yRtG
21921    0Dv4DdNy1xnhNcCd9YVoiH
11179    1QFh8OH1e78dGd3VyJZCAC
3578     2APxH5XbOLDTmW9X5yi6Hw
3234     6GcuA4J9ruyClBizBd4m5E
Name: id, Length: 2000, dtype: object
2371     30RKGpKHqr7ytcTljw436k
6994     7As8h8LJTMIritB8QwSmqr
9958     1brpdmqkx3kSxyqzqXfW7J
8252     6F5RvdxWcLOJW7LAGuc24r
10704    3vhKrSxe3fRuS5Ogis76VO
                  ...          
9920     48lQegoLqGAzaRLnMwK0mO
10108    1qCe9qukzqAwCs08AGfzJG
20630    45DElIx0dXqUH4A88yQFdE
12753    5X6HkkTe8mUwkHo3Lccr6E
2882     5e9TFTbltYBg2xThimr0rU
Name: id, Length: 2000, dtype: object
{'5Fli1xRi01bvCjsZvKWro0', '2FZ0yrA5aPClG5ZPBlV7n4', '2wmr2A1ixEx9qViaOMkkkW', '2CbtdkBeW9Znt4vXTOafAl', '30RKGpKHqr7ytcTljw436k', '7MtBapy0MnmBVsU3o0yrFz', '4PDNOmMej8sc1mkiZfvW9w', '0SUClY63fA1awioMFtMYeE', '7tvOve5Ikj

Sanity check - make sure the model does not return the same items for every user.

In [16]:
r1_ = predict_single(123, 10)
r2_ = predict_single(124, 10)
assert len(set(r1_).intersection(set(r2_))) <= 1
print("Sanity check PASS")

Sanity check PASS


Evaluate the model on multi-user predictions. To compare results, measure the mean distance to center of gravity predicted items' latent representations generated by the model.

In [17]:
N = 20
REPS = 100

res_m = predict_multiple([123, 124, 125, 126, 127, 128], N)
print(res_m)

# LightFM maps provided track ID's to its own internal mapping (indices in interaction matrix). To find latent vector representations for predicted items, we first need to remap ID's to internal indices.
(user_id_mapping, _, item_id_mapping, _) = dataset.mapping()


def mean_dist_from_cluster_center(items_ids, model: LightFM):
    item_indices = [item_id_mapping[i] for i in items_ids]
    coords = np.take(model.get_item_representations()[1], item_indices, axis=0)
    center = np.sum(coords, axis=0) / coords.shape[0]
    return np.average(np.apply_along_axis(lambda x: np.linalg.norm(center - x, ord=2), 1, coords))


dist_predicted = mean_dist_from_cluster_center(res_m.index, model_h)

# draw random samples from all tracks
random_items = tracks['id'].sample(N)
dist_random = sum([mean_dist_from_cluster_center(
    tracks['id'].sample(N),
    model_h) for _ in range(REPS)]) / REPS

# draw random samples from combined users' session history
dist_sampled = sum([mean_dist_from_cluster_center(
    pd.concat(
        (d.loc[d['user_id'] == i] for i in (123, 124, 125, 126, 127, 128))
    )['track_id'].sample(N),
    model_h) for _ in range(REPS)]) / REPS

display(pd.DataFrame(
    data={
        "method": ["Predicted by model", "Sampled from history", "Entirely random"],
        "mean distance": [dist_predicted, dist_sampled, dist_random]
    }
))

id
2iUmqdfGZcHIhS3b9E9EWq    1792.833333
0K1KOCeJBj3lpDYxEX9qP2    1827.666667
3UH4JIDuP83866Y43bbo4k    1919.333333
562oeuAN5GH85iE7FQVKSB    2014.666667
5aV4HUW9RFOB0aXq0Ud9s0    2065.000000
7iPlcFvOMOzt6v0QvcAueZ    2094.333333
6HguG9HRb1Ke1bhihfE4m8    2115.000000
1fBl642IhJOE5U319Gy2Go    2122.000000
6C88rHxXBlpcgtBY3HAF0E    2177.166667
2TgxCUZdHFkPEVmFge1OSd    2222.666667
4rmIfFUZhhi9sS5IYtpkXw    2236.000000
5WSdMcWTKRdN1QYVJHJWxz    2271.666667
06cCziAHtDg6pcsidZHu03    2326.833333
6WmIyn2fx1PKQ0XDpYj4VR    2329.333333
3TgMcrV32NUKjEG2ujn9eh    2343.500000
5It7Lw1E8VCSyS6AQv6nPm    2353.833333
4w0joPRhKJsrzq1euxkpZ2    2361.833333
7h6lpVuSGPW6RNjDXKpYDh    2373.000000
0l4DTppOxy7NUaEcwXuOb6    2397.500000
67ispalOAo3jnZWYSTTfZk    2427.500000
dtype: float64


Unnamed: 0,method,mean distance
0,Predicted by model,0.362724
1,Sampled from history,0.737306
2,Entirely random,0.495509


Note that entirely randomly chosen items may have better clustering, which is expected - they are spread relatively evenly throughout the entire space. What's important is whether model predictions are more clustered than random samples from users' session history - if so, the model managed to find a "common ground" between different music tastes.

# The same example using model classes

Both models match `model.model_protocol.Recommender` protocol - it can be used for type checking.

In [19]:
from model.base_model import BaseModel
from model.fm_model import FMModel

base_model = BaseModel(
    users_df=users,
    tracks_df=tracks,
    sessions_df=d
)

fm_model = FMModel(
    users_df=users,
    tracks_df=tracks,
    sessions_df=d,
    lightfm_model=model_h,
    interactions=interactions,
    item_features=item_features,
    user_id_mapping=user_id_mapping
)

In [20]:
base_model.predict_single(123, N)

0     4oY2T9ur7Ll5b2kpBlcWcb
1     0xVrusjXiWhqSQ5dPoiiWX
2     2NlmmAjGYrrjAp0MED5rGx
3     2fZQIJew3nkNe99s2PKzul
4     3PfIrDoz19wz7qK7tYeu62
5     5UgsZiYk1lkEobuPHmRtWm
6     06u5LrUpbosQlQ1QJFhPpG
7     7cNz65PfCatRXoX7QtqM2A
8     2rIqgWl4riTsgSeyYYM2cl
9     0i2xIlZCO0H3mSpxr2HzgK
10    7zMcNqs55Mxer82bvZFkpg
11    7AQjiRtIpr33P8UT98iveh
12    2wSQyp6VzUopSFBinRo1iD
13    47Slg6LuqLaX0VodpSCvPt
14    1x6jHJGczTBitBy06hIWgx
15    3ESSGgWzRf1xvP7G5hHMhB
16    2MfOcbtgz2yTsiznFmVZUN
17    0xMd5bcWTbyXS7wPrBtZA6
18    6RtzmszGXs32TRhv5zTKNM
19    4dvQg9sD8k9y4qiEURuj8v
Name: track_id, dtype: object

In [21]:
fm_model.predict_single(123, N)

0     5Lgcn7u07bHuqbOtXkN62u
1     5NijSs5dAwaIybq1GaRTIe
2     28GUjBGqZVcAV4PHSYzkj2
3     3ugAI8vlmz6y4T9BlKhytz
4     7mYvtEeBdMqRSyj1Qpv6my
5     3x4Zx0rPoW6kErSDK8SSnW
6     1MaI6NwdrqnE3mRzOYTpoo
7     4squZv12LD9M8ooJfoVgZS
8     3jagJCUbdqhDSPuxP8cAqF
9     5XqDJFVCyRTm5J7cIfRmR1
10    0At2qAoaVjIwWNAqrscXli
11    3ojTJaonfkL96iIWa47SU3
12    590ZheHFaC3JsZBwL8otpk
13    2jdw2tc29bqJwToyGvKgJm
14    53ETzLQQKFCzykRbqWb1ph
15    02UJ1sCanP94fS2MdsWafh
16    3BFrR2hemusNSuiuNWXUqO
17    2heln0PA9q54ngSUjsPnvg
18    2771LMNxwf62FTAdpJMQfM
19    1XXimziG1uhM0eDNCZCrUl
Name: id, dtype: object

In [22]:
res_base = base_model.predict_multiple((123, 124, 125, 126, 127, 128), N)
display(res_base)

0     5Eg4TsPcqNbIjd8ADMZosg
1     5FF6Pyuh5jW6ybKHiaiGMy
2     06RdYCp0UxsBtWsonHfSZz
3     0j1Ia2lQWrcXrQZI4AdJlk
4     0URCTFCWXOg3JQximz1u30
5     06UPCXzhIsXnceSXmKLMEY
6     5MbXzXGbqobR8xPVPs8OXA
7     7iAqvWLgZzXvH38lA06QZg
8     4k7P2LIQco8YCVbIZl1vZB
9     7npLlaPu9Mfno8hjk5OagD
10    24qoogZedlX1wdR4AbzTSB
11    1Bx0zEdVjkFlV27iKaePug
12    0yJi7eb2SosK5CsSnnqc5o
13    0Z5ok0QLLttAKsujOZYOXf
14    3nnG7AM9QopHVPEuLX3Khk
15    3fNaQSMgFrYTkTshRX7J1u
16    2ofOe2OaXFpZF5ETbsc7Qu
17    1tZcw7GtIqviL32bzaKdSo
18    7GkM8M2EC8OJblOxQxAR7t
19    1JIzFhI9Lt5FyslawmHCBi
Name: track_id, dtype: object

In [23]:
res_fm = fm_model.predict_multiple((123, 124, 125, 126, 127, 128), N)
display(res_fm)

0     562oeuAN5GH85iE7FQVKSB
1     5pbajJXEPdcoXQPXoAVR1t
2     4gs07VlJST4bdxGbBsXVue
3     6C88rHxXBlpcgtBY3HAF0E
4     3UH4JIDuP83866Y43bbo4k
5     2WwzQJt4hG7YC6x16ZTYFM
6     0tuyEYTaqLxE41yGHSsXjy
7     0RZyUsKfiC7MtiGKatCtGc
8     0HRshWRNAwQBROvxXqG3i9
9     13HVjjWUZFaWilh2QUJKsP
10    2iUmqdfGZcHIhS3b9E9EWq
11    3PSMcb1gU5A8DveqU2K4z2
12    5FPnjikbwlDMULCCCa6ZCJ
13    5WSdMcWTKRdN1QYVJHJWxz
14    5rwdhliMmo0aAQ08vU0AOZ
15    2Xs64pHlU29DTVMjWKyblt
16    3bXhtg6H8lOMWaLZttQF6F
17    77Y57qRJBvkGCUw9qs0qMg
18    1Kvbih7Ebm4bkPinpSottk
19    7MRn6wgG0ReDRNYV5wJeGX
dtype: object

In [24]:
dist_predicted1 = mean_dist_from_cluster_center(res_fm, fm_model.model)
dist_sampled1 = sum([mean_dist_from_cluster_center(
    pd.concat(
        (d.loc[d['user_id'] == i] for i in (123, 124, 125, 126, 127, 128))
    )['track_id'].sample(N),
    fm_model.model) for _ in range(REPS)]) / REPS

display(pd.DataFrame(
    data={
        "method": ["Predicted by model", "Sampled from history"],
        "mean distance": [dist_predicted1, dist_sampled1]
    }
))

Unnamed: 0,method,mean distance
0,Predicted by model,0.539144
1,Sampled from history,0.738088


## Model serialization

In [25]:
import pickle

In [26]:
mean_dist_from_cluster_center(res_fm, fm_model.model)

0.5391442

In [27]:
fm_model_deserialized = pickle.loads(pickle.dumps(fm_model))
res_fm_deserialized = fm_model_deserialized.predict_multiple((123, 124, 125, 126, 127, 128), N)
mean_dist_from_cluster_center(res_fm_deserialized, fm_model_deserialized.model)


0.5391442

In [28]:
print(auc_score(fm_model.model, test,
                train_interactions=train,
                item_features=item_features,
                num_threads=12).mean())
print(auc_score(fm_model_deserialized.model, test,
                train_interactions=train,
                item_features=item_features,
                num_threads=12).mean())

0.9347719
0.9347719
