In [1]:
from functools import reduce
import pickle
import os

import pandas as pd
import numpy as np

from sklearn.metrics import ndcg_score, average_precision_score
from surprise import Dataset, Reader, SVD

from clearml import Task

from evaluation import generate_recommendations, evaluate_recommendations

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
data_path = 'data/'
train_data = pd.read_parquet(data_path + 'ratings_train_ts.pq')
test_data = pd.read_parquet(data_path + 'ratings_test_ts.pq')
groups2 = pd.read_parquet(data_path + 'groups2.pq')
groups3 = pd.read_parquet(data_path + 'groups3.pq')
groups4 = pd.read_parquet(data_path + 'groups4.pq')
groups5 = pd.read_parquet(data_path + 'groups5.pq')
groups6 = pd.read_parquet(data_path + 'groups6.pq')
groups7 = pd.read_parquet(data_path + 'groups7.pq')

groups_list = [f'group{i}' for i in range(2, 8)]
for i, group in enumerate([groups2, groups3, groups4, groups5, groups6, groups7]):
    test_data = test_data.merge(group, on='userId').rename(columns={'group': f'group{i+2}'})
del groups2, groups3, groups4, groups5, groups6, groups7
test_data

Unnamed: 0,userId,movieId,rating,timestamp,group2,group3,group4,group5,group6,group7
0,1,682,4.0,1453904108,14465,39625,6774,4424,23830,7737
1,1,694,4.0,1453904111,14465,39625,6774,4424,23830,7737
2,1,650,4.0,1453904123,14465,39625,6774,4424,23830,7737
3,1,693,3.0,1453904126,14465,39625,6774,4424,23830,7737
4,1,799,4.0,1453904195,14465,39625,6774,4424,23830,7737
...,...,...,...,...,...,...,...,...,...,...
2346390,128664,100,4.0,1240953576,39073,10745,27588,28627,6527,689
2346391,128664,14,4.5,1240953606,39073,10745,27588,28627,6527,689
2346392,128664,193,4.5,1240953609,39073,10745,27588,28627,6527,689
2346393,128664,198,5.0,1240953613,39073,10745,27588,28627,6527,689


In [3]:
task = Task.init(
    project_name = 'MoviesGRS_MFDP', 
    task_name = 'SVDRecommender',
    tags = ['SVD', 'Evaluation', 'TimeSeriesSplit']
)

ClearML Task: created new task id=f574c4445d294458bb883180eeaa54d8
2023-05-30 12:10:09,478 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/f3cb8157bfe7443abdc531a44bb15332/experiments/f574c4445d294458bb883180eeaa54d8/output/log


In [4]:
if os.path.exists(f"svd_trained_ts.pkl"):
    with open("svd_trained_ts.pkl", "rb") as f:
        svd = pickle.load(f)
else:
    min_rating = 1
    max_rating = 5
    
    reader = Reader(rating_scale=(min_rating, max_rating))
    surprise_train_dataset = Dataset.load_from_df(train_data[['userId', 'movieId', 'rating']], reader)
    trainset = surprise_train_dataset.build_full_trainset()
    
    svd = SVD(n_factors=17, n_epochs=30)
    svd.fit(trainset)
    
    with open("svd_trained_ts.pkl", "wb") as f:
        pickle.dump(svd, f)
    
    del trainset

In [5]:
movie_ids = train_data.movieId.unique()

In [6]:
unwatched = (
    train_data
    .groupby(by='userId')
    .agg({'movieId': list})
    .reset_index()
)
unwatched['unwatched'] = (
    unwatched.movieId
    .apply(
        lambda x: movie_ids[
            np.isin(movie_ids, x, invert=True)
        ]
    )
)

unwatched = (
    test_data
    .merge(unwatched[['userId', 'unwatched']], on='userId')
    [['userId', *groups_list, 'unwatched']]
    .groupby(by=['userId', *groups_list])
    .unwatched
    .first()
    .reset_index()
)
unwatched

Unnamed: 0,userId,group2,group3,group4,group5,group6,group7,unwatched
0,1,14465,39625,6774,4424,23830,7737,"[381, 92, 107, 85, 220, 234, 358, 27, 319, 347..."
1,2,10418,7556,11348,18784,2598,13437,"[381, 548, 92, 107, 85, 220, 257, 234, 358, 53..."
2,3,53801,35849,12281,29761,20828,14279,"[381, 548, 107, 85, 220, 257, 234, 358, 535, 2..."
3,4,3990,24019,11784,12577,17824,10523,"[381, 548, 107, 85, 220, 257, 234, 358, 535, 2..."
4,5,39404,9579,25927,22262,607,10989,"[381, 548, 92, 107, 85, 257, 234, 358, 535, 27..."
...,...,...,...,...,...,...,...,...
128583,128660,22670,42046,12206,32028,27049,10083,"[381, 548, 92, 107, 85, 220, 257, 234, 358, 53..."
128584,128661,30264,32218,16183,22934,22798,20087,"[381, 92, 107, 85, 220, 257, 234, 358, 535, 27..."
128585,128662,24301,22982,420,23328,26324,7171,"[381, 548, 92, 107, 85, 220, 257, 234, 358, 53..."
128586,128663,55355,18586,11297,13396,19445,7989,"[381, 548, 92, 107, 85, 220, 257, 234, 358, 53..."


In [7]:
users_watch_history_test: pd.DataFrame = (
    test_data
    .sort_values(by='rating', ascending=False)
    .groupby(by='userId')
    .agg(
        {
        **{g: 'first' for g in groups_list},
        'movieId': list,
        'rating': list
        }
    )
    .reset_index()
)
users_watch_history_test['movieId'] = users_watch_history_test.movieId.apply(np.array)
users_watch_history_test['rating'] = users_watch_history_test.rating.apply(np.array)

## Average user

In [8]:
def recommend(row):
    movie_pseudorating = svd.bi[row["unwatched"]] + (
        svd.qi[row["unwatched"]] @ np.mean(svd.pu[row["userId"]], axis=0)
    )
    top_movies = row["unwatched"][np.argsort(-movie_pseudorating)][:10]
    return top_movies

In [9]:
recommends = generate_recommendations(recommend, users_watch_history_test, unwatched, groups_list)
recommends

Unnamed: 0,userId,group2,group3,group4,group5,group6,group7,movieId,rating,group2_rec,group3_rec,group4_rec,group5_rec,group6_rec,group7_rec
0,1,14465,39625,6774,4424,23830,7737,"[613, 176, 734, 114, 270, 485, 352, 201, 571, ...","[5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.5, 4.5, ...","[245, 240, 83, 110, 107, 109, 303, 92, 238, 216]","[85, 65, 248, 309, 25, 241, 216, 52, 250, 597]","[83, 240, 71, 245, 110, 92, 238, 309, 86, 18]","[240, 243, 62, 92, 245, 83, 126, 308, 341, 246]","[83, 240, 62, 110, 309, 596, 103, 341, 377, 92]","[83, 596, 377, 110, 240, 360, 241, 25, 309, 92]"
1,5104,48420,12472,25594,30892,8245,7737,"[468, 69, 162, 290, 203, 173, 289, 135, 150, 3...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, ...","[237, 244, 596, 240, 575, 83, 476, 110, 216, 245]","[596, 377, 83, 514, 240, 735, 126, 216, 306, 359]","[596, 83, 244, 240, 514, 237, 377, 476, 460, 691]","[83, 237, 596, 240, 110, 476, 377, 516, 623, 216]","[244, 245, 518, 216, 246, 98, 596, 575, 691, 358]","[83, 596, 377, 110, 240, 360, 241, 25, 309, 92]"
2,109704,31819,26528,10792,15512,26177,7737,"[711, 685, 459, 433, 602, 477, 181, 147, 596, ...","[5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, ...","[83, 240, 516, 237, 245, 377, 25, 18, 110, 623]","[240, 245, 83, 341, 237, 103, 110, 62, 244, 92]","[245, 240, 370, 216, 244, 237, 110, 516, 282, ...","[85, 314, 83, 62, 25, 209, 251, 255, 248, 740]","[83, 596, 237, 240, 216, 110, 244, 18, 516, 307]","[83, 596, 377, 110, 240, 360, 241, 25, 309, 92]"
3,128307,463,31895,6380,19099,24828,7737,"[209, 99, 77, 34, 223, 25, 73, 56, 83, 41, 28,...","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, ...","[18, 83, 458, 362, 320, 91, 354, 420, 255, 44]","[18, 65, 362, 83, 418, 667, 420, 615, 360, 461]","[86, 65, 85, 63, 71, 92, 320, 255, 52, 40]","[83, 18, 65, 85, 47, 1, 91, 92, 740, 377]","[83, 18, 65, 362, 47, 740, 62, 85, 571, 40]","[83, 596, 377, 110, 240, 360, 241, 25, 309, 92]"
4,32850,15425,9692,2595,281,7953,7737,"[689, 100, 516, 540]","[5.0, 3.0, 2.5, 2.0]","[83, 596, 71, 237, 18, 370, 110, 244, 623, 239]","[309, 65, 307, 83, 1, 216, 596, 18, 100, 25]","[83, 596, 377, 240, 106, 103, 623, 239, 693, 516]","[83, 237, 25, 773, 85, 73, 358, 18, 253, 123]","[83, 244, 240, 476, 18, 237, 110, 216, 106, 377]","[83, 596, 377, 110, 240, 360, 241, 25, 309, 92]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
128583,98075,55317,35155,28186,15593,19307,16497,"[573, 625, 524, 520, 786, 580, 600]","[4.0, 4.0, 4.0, 4.0, 3.5, 3.5, 3.0]","[71, 83, 241, 110, 238, 596, 240, 735, 237, 245]","[250, 370, 71, 241, 282, 640, 238, 377, 373, 245]","[62, 309, 573, 552, 341, 245, 248, 377, 126, 282]","[83, 65, 86, 18, 62, 358, 38, 85, 92, 47]","[62, 71, 596, 308, 47, 377, 38, 246, 52, 740]","[83, 71, 240, 245, 237, 18, 62, 110, 238, 85]"
128584,1764,6389,1952,29114,1967,6059,16497,"[584, 90, 23, 481, 462, 176, 181, 329, 128]","[5.0, 5.0, 5.0, 4.5, 4.5, 4.0, 4.0, 4.0, 3.5]","[65, 86, 83, 1, 18, 237, 71, 377, 82, 239]","[83, 85, 240, 92, 62, 52, 99, 65, 86, 1]","[83, 18, 596, 237, 71, 358, 47, 476, 239, 110]","[83, 309, 18, 237, 476, 377, 597, 241, 25, 86]","[83, 18, 237, 240, 596, 62, 377, 243, 92, 110]","[83, 71, 240, 245, 237, 18, 62, 110, 238, 85]"
128585,18641,5515,7461,18994,10646,4255,20065,"[394, 516, 622, 294, 128, 70, 28, 531, 513, 63...","[5.0, 5.0, 5.0, 4.5, 4.0, 4.0, 4.0, 3.0, 3.0, ...","[377, 83, 239, 241, 370, 71, 362, 596, 623, 250]","[83, 71, 240, 377, 358, 47, 596, 110, 282, 98]","[241, 26, 377, 71, 370, 83, 248, 282, 239, 251]","[83, 18, 377, 596, 239, 370, 623, 241, 238, 735]","[596, 18, 240, 377, 110, 62, 370, 516, 623, 241]","[83, 377, 240, 18, 62, 623, 85, 309, 596, 86]"
128586,46254,44498,8562,3795,15078,21003,20065,"[736, 762, 544, 334, 682, 607, 562, 475, 750, ...","[4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.0, 4.0, 4.0, ...","[83, 18, 86, 596, 377, 243, 693, 47, 20, 92]","[83, 91, 62, 92, 240, 238, 245, 99, 65, 85]","[83, 18, 99, 596, 237, 243, 92, 240, 244, 126]","[83, 65, 18, 1, 25, 85, 216, 240, 110, 62]","[83, 245, 240, 110, 623, 65, 92, 596, 370, 98]","[83, 377, 240, 18, 62, 623, 85, 309, 596, 86]"


In [10]:
average_user_results = evaluate_recommendations(recommends, groups_list)
average_user_results

ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


Unnamed: 0,MAP,NDCG
group2,0.013027,0.117886
group3,0.012804,0.115516
group4,0.012553,0.11302
group5,0.012593,0.113308
group6,0.012281,0.11105
group7,0.012091,0.109125


## Group Sum

In [11]:
def recommend(row):
    movie_pseudorating = svd.bi[row["unwatched"]] + (
        np.sum(svd.qi[row["unwatched"]] @ svd.pu[row["userId"]].T, axis=1)
    )
    movie_pseudorating.sum()
    top_movies = row["unwatched"][np.argsort(-movie_pseudorating)][:10]
    return top_movies

In [12]:
recommends = generate_recommendations(recommend, users_watch_history_test, unwatched, groups_list)
recommends

Unnamed: 0,userId,group2,group3,group4,group5,group6,group7,movieId,rating,group2_rec,group3_rec,group4_rec,group5_rec,group6_rec,group7_rec
0,1,14465,39625,6774,4424,23830,7737,"[613, 176, 734, 114, 270, 485, 352, 201, 571, ...","[5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.5, 4.5, ...","[245, 107, 109, 303, 240, 110, 216, 439, 238, ...","[85, 248, 250, 52, 65, 309, 216, 25, 597, 241]","[107, 109, 245, 284, 240, 282, 303, 71, 439, 373]","[240, 245, 107, 243, 308, 283, 109, 92, 303, 62]","[240, 107, 341, 109, 284, 62, 439, 309, 110, 499]","[360, 606, 596, 492, 377, 416, 118, 282, 110, ..."
1,5104,48420,12472,25594,30892,8245,7737,"[468, 69, 162, 290, 203, 173, 289, 135, 150, 3...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, ...","[244, 575, 116, 306, 534, 629, 312, 460, 686, ...","[756, 577, 567, 138, 514, 588, 553, 306, 646, ...","[460, 514, 686, 606, 507, 596, 229, 691, 646, ...","[237, 476, 629, 596, 118, 216, 575, 433, 606, ...","[646, 686, 728, 553, 543, 244, 245, 216, 575, ...","[360, 606, 596, 492, 377, 416, 118, 282, 110, ..."
2,109704,31819,26528,10792,15512,26177,7737,"[711, 685, 459, 433, 602, 477, 181, 147, 596, ...","[5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, ...","[83, 516, 245, 240, 563, 25, 237, 623, 85, 110]","[240, 245, 341, 541, 373, 439, 222, 103, 344, ...","[245, 727, 573, 250, 534, 728, 553, 782, 370, ...","[314, 85, 209, 5, 255, 25, 52, 254, 320, 102]","[118, 216, 307, 534, 672, 679, 644, 244, 596, ...","[360, 606, 596, 492, 377, 416, 118, 282, 110, ..."
3,128307,463,31895,6380,19099,24828,7737,"[209, 99, 77, 34, 223, 25, 73, 56, 83, 41, 28,...","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, ...","[458, 354, 643, 362, 320, 255, 18, 44, 357, 420]","[667, 65, 418, 27, 18, 458, 398, 606, 362, 54]","[86, 85, 65, 63, 458, 320, 255, 102, 52, 760]","[85, 458, 255, 38, 42, 160, 65, 52, 47, 1]","[65, 362, 458, 47, 328, 18, 320, 83, 38, 667]","[360, 606, 596, 492, 377, 416, 118, 282, 110, ..."
4,32850,15425,9692,2595,281,7953,7737,"[689, 100, 516, 540]","[5.0, 3.0, 2.5, 2.0]","[646, 596, 71, 370, 244, 492, 237, 216, 83, 18]","[462, 307, 65, 309, 54, 1, 100, 607, 216, 299]","[646, 106, 596, 460, 118, 183, 534, 474, 83, 547]","[326, 207, 318, 773, 224, 199, 73, 587, 790, 160]","[646, 522, 460, 606, 686, 490, 244, 781, 476, ...","[360, 606, 596, 492, 377, 416, 118, 282, 110, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
128583,98075,55317,35155,28186,15593,19307,16497,"[573, 625, 524, 520, 786, 580, 600]","[4.0, 4.0, 4.0, 4.0, 3.5, 3.5, 3.0]","[241, 71, 250, 735, 85, 110, 238, 245, 282, 26]","[250, 370, 640, 282, 665, 637, 642, 373, 241, 71]","[595, 573, 424, 138, 710, 717, 567, 577, 542, ...","[65, 86, 38, 326, 52, 2, 85, 207, 358, 83]","[62, 52, 5, 38, 308, 739, 616, 393, 480, 301]","[245, 240, 71, 85, 238, 83, 110, 62, 439, 237]"
128584,1764,6389,1952,29114,1967,6059,16497,"[584, 90, 23, 481, 462, 176, 181, 329, 128]","[5.0, 5.0, 5.0, 4.5, 4.5, 4.0, 4.0, 4.0, 3.5]","[65, 86, 1, 18, 83, 82, 85, 75, 239, 25]","[52, 85, 485, 92, 240, 314, 99, 62, 446, 616]","[18, 83, 159, 47, 476, 358, 596, 536, 433, 696]","[597, 309, 476, 629, 4, 25, 241, 216, 253, 749]","[83, 82, 18, 85, 47, 52, 237, 243, 641, 62]","[245, 240, 71, 85, 238, 83, 110, 62, 439, 237]"
128585,18641,5515,7461,18994,10646,4255,20065,"[394, 516, 622, 294, 128, 70, 28, 531, 513, 63...","[5.0, 5.0, 5.0, 4.5, 4.0, 4.0, 4.0, 3.0, 3.0, ...","[250, 377, 370, 239, 241, 362, 492, 642, 741, ...","[250, 71, 282, 47, 358, 573, 640, 719, 670, 547]","[250, 26, 760, 598, 241, 670, 282, 251, 666, 248]","[250, 239, 642, 615, 377, 735, 370, 596, 420, ...","[547, 437, 646, 575, 229, 728, 541, 543, 573, ...","[250, 85, 373, 251, 377, 282, 623, 52, 47, 614]"
128586,46254,44498,8562,3795,15078,21003,20065,"[736, 762, 544, 334, 682, 607, 562, 475, 750, ...","[4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.0, 4.0, 4.0, ...","[18, 83, 86, 596, 362, 20, 47, 693, 377, 99]","[107, 109, 91, 458, 64, 303, 245, 251, 92, 85]","[646, 229, 522, 99, 101, 42, 83, 607, 160, 243]","[65, 1, 607, 85, 54, 42, 216, 25, 52, 30]","[245, 458, 288, 282, 65, 373, 183, 623, 110, 83]","[250, 85, 373, 251, 377, 282, 623, 52, 47, 614]"


In [13]:
group_sum_results = evaluate_recommendations(recommends, groups_list)
group_sum_results

Unnamed: 0,MAP,NDCG
group2,0.011667,0.109666
group3,0.010804,0.103025
group4,0.010118,0.097659
group5,0.010304,0.098881
group6,0.009949,0.096097
group7,0.009546,0.093057


In [14]:
task.upload_artifact('avg_user_metrics', average_user_results)
task.upload_artifact('group_sum_metrics', group_sum_results)

True

In [15]:
task.close()