 # <b><i> Testing LightFM </i> </b>

 # > Import

In [1]:
import os
import pickle
from datetime import datetime

import numpy as np
from lightfm import LightFM
from lightfm.cross_validation import random_train_test_split
from lightfm.evaluation import precision_at_k
from scipy.sparse import identity
from sklearn.model_selection import train_test_split
from tqdm import tqdm

 # > Config

In [2]:
TEST_CODE = "1561003029.019894"
CHOSEN_EPOCH = 600

MODEL_PATH = "../log/{}/models/epoch_{}".format(TEST_CODE, CHOSEN_EPOCH)
LOG_PATH = "../log/{}/log.txt".format(TEST_CODE)

 # > Preparation

 ## >> Load dataset

In [3]:
ratings_pivot_csr_filename = "../data/intersect-20m/ratings.csr"

ratings_pivot = pickle.load(open(ratings_pivot_csr_filename, 'rb'))
train, test = random_train_test_split(ratings_pivot, test_percentage=0.2)

train_csr = train.tocsr()
test_csr = test.tocsr()
test_user, test_item = test.nonzero()

 ## >> Models

In [4]:
model = pickle.load(open(MODEL_PATH, 'rb'))

 ## >> Users & items feature

In [5]:
user_identity = identity(train.shape[0])
item_identity = identity(train.shape[1])

 # > Evaluation

 ## >> Evaluation function

In [6]:
def get_top_suggestion(sample_user, k):

    test_item_idx = [i for i in range(0, test.shape[1])]

    prediction = model.predict(user_ids=sample_user, item_ids=test_item_idx, user_features=user_identity, item_features=item_identity)
    prediction = [(prediction[i], i) for i in range(0, len(prediction))]
    prediction = sorted(prediction, reverse=True)

    return prediction[:k]


def get_top_truth(sample_user, k):

    truth = []

    user_ratings = test_csr[sample_user].todense().tolist()[0]
    user_rated_item = test_csr[sample_user].nonzero()[1]
    for item in user_rated_item:
        truth.append((user_ratings[item], item))

    user_ratings = train_csr[sample_user].todense().tolist()[0]
    user_rated_item = train_csr[sample_user].nonzero()[1]
    for item in user_rated_item:
        truth.append((user_ratings[item], item))

    truth = sorted(truth, reverse=True)

    return truth[:k]

In [7]:
def _get_intersect_pred_truth(prediction, truth, k):
    pred_item_set = {x[1] for x in prediction[:k]}
    truth_item_set = {x[1] for x in truth[:k]}

    return pred_item_set.intersection(truth_item_set)

In [8]:
def check_precision(prediction, truth, k=10):

    intersect = _get_intersect_pred_truth(prediction, truth, k)
    len_intersect = len(intersect)
    len_truth = len(truth) if 0 < len(truth) <= k else k

    return intersect, len_intersect / len_truth

 ## >> Run Evaluation

In [9]:
k_suggestion = 10
n_users = 5000

sample_user = np.random.randint(1, 135000, n_users) # sampling
# sample_user = [i in range(0, 15000)] # uncomment to use non sampling

suggested_items = []
truth_items = []
intersects = []
scores = []

all_intersect = None
all_union = None

for user in tqdm(sample_user):

    try:

        top_suggestions = get_top_suggestion(user, k_suggestion)
        top_suggested_items = set([x[1] for x in top_suggestions])
        top_truth_items = get_top_truth(user, k_suggestion)

        intersect, score = check_precision(top_suggestions, top_truth_items, k=k_suggestion)

        suggested_items.append(top_suggested_items)
        truth_items.append(top_truth_items)
        intersects.append(intersect)
        scores.append(score)

        if all_intersect is None:
            all_intersect = top_suggested_items
        else:
            all_intersect = all_intersect.intersection(top_suggested_items)

        if all_union is None:
            all_union = top_suggested_items
        else:
            all_union = all_union.union(top_suggested_items)

    except Exception as e:
        print("error occur for {} : {}".format(user, e))


100%|██████████| 5000/5000 [02:41<00:00, 30.97it/s]


In [10]:
print("Prec@k score:", np.average(scores))
# print("top_suggested_items:", top_suggested_items)
# print("truth_items:", truth_items)

print("\nintersect")
print(all_intersect, len(all_intersect))
print("\nunion")
print(all_union, len(all_union))
print("\ndistinct rate")
print((len(all_union)) / (n_users * k_suggestion))

Prec@k score: 0.10204222222222223

intersect
set() 0

union
{15106, 4354, 11529, 3594, 5129, 14604, 11533, 8975, 11536, 13591, 2073, 1567, 5668, 550, 1321, 10796, 8750, 1839, 2606, 8497, 7476, 12342, 3135, 65, 13122, 8259, 8001, 8517, 10050, 4681, 3401, 1611, 3663, 3921, 13906, 594, 14932, 14162, 12378, 10330, 13405, 1886, 9056, 13152, 6247, 12394, 7027, 628, 8052, 9334, 10103, 7286, 8823, 8317, 9342, 12926, 7304, 12681, 2187, 5260, 911, 13712, 1168, 4499, 10646, 15258, 2972, 11677, 7074, 10659, 676, 1187, 13224, 6825, 2223, 433, 8115, 14334, 14774, 14522, 13243, 5058, 8389, 14025, 14795, 3788, 11981, 8918, 8662, 4827, 13276, 13277, 11230, 13535, 14560, 9441, 9443, 5348, 11238, 11497, 4586, 1003, 491, 13551, 13310, 9727} 106

distinct rate
0.00212


In [11]:
sample_user = [np.random.randint(1, 138000) for i in range(0, 3)]

for user in sample_user:

    prediction = get_top_suggestion(user, 10)
    truth = get_top_truth(user, 10)

    print(user)
    print((prediction))
    print([x[1] for x in prediction])
    print((truth))
    print(check_precision(prediction, truth, 10))
    print("==================")

119468
[(2.7601897716522217, 5058), (2.7598536014556885, 10659), (2.462766170501709, 7304), (2.1554150581359863, 2073), (2.124174118041992, 13122), (2.118602991104126, 14025), (2.1142380237579346, 11536), (2.1131255626678467, 6825), (2.1125900745391846, 1321), (2.1121749877929688, 911)]
[5058, 10659, 7304, 2073, 13122, 14025, 11536, 6825, 1321, 911]
[(5.0, 14403), (5.0, 14162), (5.0, 12475), (5.0, 10659), (5.0, 5058), (5.0, 3401), (5.0, 3264), (5.0, 2903), (5.0, 2897), (5.0, 1392)]
({5058, 10659}, 0.2)
6491
[(1.918827772140503, 5058), (1.9182579517364502, 10659), (1.6212118864059448, 7304), (1.313458800315857, 2073), (1.283602237701416, 13122), (1.2776063680648804, 14025), (1.273353099822998, 11536), (1.2721185684204102, 6825), (1.2716999053955078, 1321), (1.2710808515548706, 10796)]
[5058, 10659, 7304, 2073, 13122, 14025, 11536, 6825, 1321, 10796]
[(5.0, 10659), (5.0, 8962), (5.0, 5058), (5.0, 2073), (5.0, 1168), (5.0, 340), (4.5, 12536), (4.5, 7304), (4.5, 6930), (4.5, 2633)]
({7304,