In [1]:
from functools import reduce
import pickle
import os

import pandas as pd
import numpy as np

from sklearn.metrics import ndcg_score, average_precision_score
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer

from clearml import Task

from evaluation import generate_recommendations, evaluate_recommendations

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
task = Task.init(
    project_name = 'MoviesGRS_MFDP', 
    task_name = 'PlotEmbeddingsRecommender',
    tags = ['PlotEmbeddings', 'Evaluation']
)

ClearML Task: created new task id=079be21daaf94523b8145b51bb9bcf40
2023-05-30 00:34:18,288 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/f3cb8157bfe7443abdc531a44bb15332/experiments/079be21daaf94523b8145b51bb9bcf40/output/log


In [3]:
data_path = 'data/'
train_data = pd.read_parquet(data_path + 'ratings_train.pq')
test_data = pd.read_parquet(data_path + 'ratings_test.pq')
movies_data = pd.read_parquet(data_path + 'movies_train.pq')
groups5 = pd.read_parquet(data_path + 'groups5.pq')
groups6 = pd.read_parquet(data_path + 'groups6.pq')
groups7 = pd.read_parquet(data_path + 'groups7.pq')

for i, group in enumerate([groups5, groups6, groups7]):
    test_data = test_data.merge(group, on='userId').rename(columns={'group': f'group{i+5}'})
del groups5, groups6, groups7
test_data

Unnamed: 0,userId,movieId,rating,group5,group6,group7
0,41988,790,4.0,31361,14281,11298
1,41988,524,3.0,31361,14281,11298
2,41988,608,4.0,31361,14281,11298
3,41988,695,3.0,31361,14281,11298
4,41988,566,4.0,31361,14281,11298
...,...,...,...,...,...,...
3596663,7343,164,5.0,17403,2054,4080
3596664,7343,193,4.0,17403,2054,4080
3596665,7343,253,4.0,17403,2054,4080
3596666,7343,483,2.0,17403,2054,4080


In [4]:
movie_ids = train_data.movieId.unique()

In [5]:
unwatched = (
    train_data
    .groupby(by='userId')
    .agg({'movieId': list})
    .reset_index()
)
unwatched['unwatched'] = (
    unwatched.movieId
    .apply(
        lambda x: movie_ids[
            np.isin(movie_ids, x, invert=True)
        ]
    )
)

unwatched = (
    test_data
    .merge(unwatched[['userId', 'unwatched']], on='userId')
    [['userId', *[f'group{i}' for i in range(5, 8)], 'unwatched']]
    .groupby(by=['userId', *[f'group{i}' for i in range(5, 8)]])
    .unwatched
    .first()
    .reset_index()
)
unwatched

Unnamed: 0,userId,group5,group6,group7,unwatched
0,1,4424,23830,7737,"[459, 310, 18, 108, 57, 228, 16, 291, 15, 191,..."
1,2,18784,2598,13437,"[459, 18, 108, 57, 228, 16, 786, 291, 15, 191,..."
2,3,29761,20828,14279,"[459, 310, 519, 18, 108, 57, 228, 786, 291, 15..."
3,4,12577,17824,10523,"[459, 310, 519, 18, 108, 57, 786, 291, 15, 191..."
4,5,22262,607,10989,"[459, 310, 519, 18, 108, 57, 228, 16, 786, 291..."
...,...,...,...,...,...
128586,128660,32028,27049,10083,"[459, 310, 519, 18, 108, 57, 228, 16, 786, 291..."
128587,128661,22934,22798,20087,"[459, 310, 18, 108, 57, 228, 16, 291, 15, 191,..."
128588,128662,23328,26324,7171,"[459, 519, 18, 108, 57, 228, 16, 786, 291, 15,..."
128589,128663,13396,19445,7989,"[459, 310, 18, 108, 57, 228, 16, 786, 291, 15,..."


In [6]:
users_watch_history_test: pd.DataFrame = (
    test_data
    .sort_values(by='rating', ascending=False)
    .groupby(by='userId')
    .agg(
        {
        **{f'group{i}': 'first' for i in range(5, 8)},
        'movieId': list,
        'rating': list
        }
    )
    .reset_index()
)
users_watch_history_test['movieId'] = users_watch_history_test.movieId.apply(np.array)
users_watch_history_test['rating'] = users_watch_history_test.rating.apply(np.array)
users_watch_history_test

Unnamed: 0,userId,group5,group6,group7,movieId,rating
0,1,4424,23830,7737,"[384, 613, 181, 16, 114, 533, 584, 131, 572, 5...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, ..."
1,2,18784,2598,13437,"[801, 609, 647, 540, 748, 548, 689, 353, 779, ...","[5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.5, 4.5, ..."
2,3,29761,20828,14279,"[230, 62, 51, 80, 100, 144, 50, 55, 141, 96, 8...","[5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, ..."
3,4,12577,17824,10523,"[96, 54, 135, 79, 239, 80, 74, 242, 244, 216, ...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, ..."
4,5,22262,607,10989,"[45, 175, 75, 225, 191, 217, 133, 67, 98, 127,...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, ..."
...,...,...,...,...,...,...
128586,128660,32028,27049,10083,"[130, 68, 100, 59, 204, 461, 334, 549, 284, 73...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, ..."
128587,128661,22934,22798,20087,"[714, 653, 42, 791, 15, 638, 513, 354, 90, 516...","[4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, ..."
128588,128662,23328,26324,7171,"[251, 134, 167, 186, 68, 154, 157, 82, 496, 495]","[5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 3.0, ..."
128589,128663,13396,19445,7989,"[656, 524, 695, 664, 715, 674, 488, 398, 630, ...","[5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.0, 4.0, ..."


In [7]:
movies_data = movies_data[['movieId', 'plot']]
movies_data

Unnamed: 0,movieId,plot
0,0,"A group of living toys, who assume lifelessnes..."
1,1,"In 1986, MI6 agents James Bond and Alec Trevel..."
2,2,Popular Democratic President Andrew Shepherd p...
5,3,"In 1973, sports handicapper and Mafia associat..."
6,4,"When Mr. Dashwood dies, his wife and three dau..."
...,...,...
4152,797,"In 1991, the brainwashed super-soldier James ""..."
4180,798,"In 1877, bounty hunter and Union veteran Major..."
4255,799,"In 2035, the crew of the Ares III mission to M..."
4266,800,Within the mind of a young girl named Riley ar...


In [8]:
BigBadModel = SentenceTransformer('paraphrase-distilroberta-base-v1')

movies_data['embedding'] = movies_data['plot'].apply(BigBadModel.encode)
movies_data

Unnamed: 0,movieId,plot,embedding
0,0,"A group of living toys, who assume lifelessnes...","[-0.1895642, 0.39579535, 0.101590574, -0.13271..."
1,1,"In 1986, MI6 agents James Bond and Alec Trevel...","[0.152432, 0.22731969, 0.35638723, 0.57620895,..."
2,2,Popular Democratic President Andrew Shepherd p...,"[-0.4052123, 0.51854575, -0.09895681, -0.24726..."
5,3,"In 1973, sports handicapper and Mafia associat...","[0.006261606, -0.026158923, 0.09536111, 0.2737..."
6,4,"When Mr. Dashwood dies, his wife and three dau...","[-0.30718213, 0.70415014, 0.5588496, 0.6722348..."
...,...,...,...
4152,797,"In 1991, the brainwashed super-soldier James ""...","[-0.0211845, 0.45003378, 0.21097228, -0.355851..."
4180,798,"In 1877, bounty hunter and Union veteran Major...","[-0.02784963, 0.123887956, 0.19217172, 0.37382..."
4255,799,"In 2035, the crew of the Ares III mission to M...","[-0.2835912, -0.13103433, 0.009648345, 0.14046..."
4266,800,Within the mind of a young girl named Riley ar...,"[-0.10373472, 0.5093446, 0.3733033, 0.48482883..."


In [9]:
users_data = train_data.groupby('userId').agg({col: list for col in ['movieId', 'rating']}).reset_index()
users_data['movieId'] = users_data['movieId'].apply(np.array)
users_data['rating'] = users_data['rating'].apply(np.array)
users_data

Unnamed: 0,userId,movieId,rating
0,0,"[592, 406, 460, 27, 322, 90, 461, 474, 213, 50...","[5.0, 4.5, 3.0, 4.5, 4.5, 5.0, 3.0, 3.5, 5.0, ..."
1,1,"[531, 626, 801, 729, 281, 589, 600, 701, 587, ...","[3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, ..."
2,2,"[167, 445, 135, 175, 723, 170, 754, 166, 51, 4...","[4.0, 4.0, 4.0, 3.0, 5.0, 3.5, 0.5, 4.5, 3.5, ..."
3,3,"[92, 145, 114, 68, 26, 122, 20, 214, 148, 149,...","[3.0, 3.0, 5.0, 4.0, 5.0, 5.0, 4.0, 5.0, 5.0, ..."
4,4,"[243, 197, 59, 234, 115, 251, 82, 55, 230, 256...","[5.0, 4.0, 1.0, 5.0, 2.0, 5.0, 3.0, 5.0, 5.0, ..."
...,...,...,...
128660,128660,"[248, 497, 494, 443, 265, 527, 407, 394, 122, ...","[4.0, 4.0, 5.0, 5.0, 5.0, 3.0, 4.0, 5.0, 4.0, ..."
128661,128661,"[550, 684, 367, 597, 721, 575, 11, 313, 593, 7...","[4.5, 3.0, 2.5, 4.0, 2.0, 4.5, 4.0, 2.0, 3.5, ..."
128662,128662,"[31, 112, 200, 91, 172, 174, 310, 169, 22, 162...","[5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 5.0, 5.0, 4.0, ..."
128663,128663,"[608, 504, 590, 711, 665, 627, 677, 589, 569, ...","[4.0, 4.0, 4.5, 5.0, 3.0, 4.0, 5.0, 4.5, 4.5, ..."


In [10]:
users_data['embedding'] = users_data.apply(
    lambda row:
    np.mean(movies_data[movies_data.movieId.isin(row['movieId'])].embedding * row['rating'], axis=0),
    axis=1
)
users_data = users_data[['userId', 'embedding']]
users_data

Unnamed: 0,userId,embedding
0,0,"[-0.28446156, 0.93299216, 0.5715234, 0.64569, ..."
1,1,"[-0.30260557, 0.8409, 0.59351224, 0.42042017, ..."
2,2,"[-0.2816124, 0.6678724, 0.59545726, 0.37789539..."
3,3,"[-0.3935022, 0.86840427, 0.4956644, 0.38940465..."
4,4,"[-0.066473536, 0.85598683, 0.58889, 0.50519514..."
...,...,...
128660,128660,"[-0.4618346, 1.0923915, 0.19005956, 0.64650476..."
128661,128661,"[-0.17767367, 1.058262, 0.6296759, 0.50269556,..."
128662,128662,"[-0.34363356, 0.28235617, 0.73458624, 1.049384..."
128663,128663,"[-0.44386974, 0.7305702, 0.53965324, 0.4440984..."


## Average user

In [11]:
def recommend(row):
    unwatched_films = movies_data[
        movies_data.movieId.isin(row['unwatched'])
    ].reset_index(drop=True)
    avg_user = np.mean(
        users_data[
            users_data.userId.isin(row['userId'])
        ].embedding,
        axis=0
    )


    nbrs = NearestNeighbors(n_neighbors=10)
    nbrs.fit(
        np.stack(
            unwatched_films.embedding.values
        )
    )
    _, top_movies_ind = nbrs.kneighbors([avg_user])
    top_movies = unwatched_films.loc[top_movies_ind[0], 'movieId'].values
    return top_movies

In [12]:
recommends = generate_recommendations(recommend, users_watch_history_test, unwatched)
recommends

ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


Unnamed: 0,userId,group5,group6,group7,movieId,rating,group5_rec,group6_rec,group7_rec
0,1,4424,23830,7737,"[384, 613, 181, 16, 114, 533, 584, 131, 572, 5...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, ...","[315, 517, 408, 417, 204, 287, 57, 325, 513, 130]","[204, 417, 57, 517, 652, 167, 74, 130, 248, 496]","[204, 517, 57, 417, 652, 337, 95, 361, 283, 611]"
1,128307,19099,24828,7737,"[99, 142, 33, 244, 97, 80, 181, 223, 176, 157,...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, ...","[517, 417, 693, 652, 749, 669, 376, 283, 337, ...","[57, 517, 749, 417, 461, 693, 669, 337, 95, 408]","[204, 517, 57, 417, 652, 337, 95, 361, 283, 611]"
2,5104,30892,8245,7737,"[198, 8, 85, 290, 206, 172, 208, 112, 184, 88,...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, ...","[517, 204, 461, 571, 669, 74, 417, 408, 596, 652]","[417, 408, 693, 337, 462, 361, 325, 596, 163, ...","[204, 517, 57, 417, 652, 337, 95, 361, 283, 611]"
3,47993,16791,1281,7737,"[90, 784, 394, 59, 734, 779, 324, 193, 328, 82...","[5.0, 5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.5, ...","[461, 204, 517, 417, 337, 287, 163, 693, 652, 57]","[693, 204, 517, 57, 74, 417, 95, 337, 287, 167]","[204, 517, 57, 417, 652, 337, 95, 361, 283, 611]"
4,32850,281,7953,7737,"[384, 730, 55, 394, 548, 90, 793]","[5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 3.5]","[204, 517, 693, 417, 315, 461, 669, 325, 361, ...","[517, 693, 204, 417, 57, 669, 408, 741, 462, 361]","[204, 517, 57, 417, 652, 337, 95, 361, 283, 611]"
...,...,...,...,...,...,...,...,...,...
128586,90835,5409,12500,20074,"[59, 572, 561, 544, 529, 301, 326, 354, 531]","[5.0, 4.0, 4.0, 3.5, 3.5, 3.0, 2.5, 2.5, 2.0]","[517, 315, 462, 669, 325, 57, 259, 596, 95, 597]","[315, 417, 517, 204, 596, 669, 57, 434, 163, 95]","[315, 517, 287, 462, 669, 417, 57, 259, 337, 361]"
128587,96812,26072,7682,20074,"[332, 367, 267, 100, 274, 329, 259, 153, 344, ...","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, ...","[517, 204, 315, 417, 669, 693, 130, 715, 408, ...","[517, 417, 337, 130, 571, 408, 283, 315, 462, ...","[315, 517, 287, 462, 669, 417, 57, 259, 337, 361]"
128588,65482,20574,10362,12742,"[90, 135, 175, 51, 177, 188, 220]","[5.0, 4.0, 4.0, 4.0, 3.0, 2.0, 1.0]","[693, 204, 531, 669, 57, 496, 361, 749, 326, 517]","[693, 669, 496, 400, 361, 517, 74, 564, 715, 417]","[400, 693, 496, 277, 167, 236, 669, 221, 631, ..."
128589,37798,10175,5664,12742,"[59, 22, 64, 96, 97, 26, 14, 79]","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0]","[517, 57, 693, 361, 417, 749, 287, 337, 163, 259]","[749, 57, 236, 669, 400, 95, 741, 517, 417, 283]","[400, 693, 496, 277, 167, 236, 669, 221, 631, ..."


In [13]:
avg_user_results = evaluate_recommendations(recommends)
avg_user_results

Unnamed: 0,MAP,NDCG
group5,0.019639,0.165259
group6,0.018148,0.155253
group7,0.016835,0.14664


## Group sum

In [14]:
def recommend(row):
    unwatched_films = movies_data[
        movies_data.movieId.isin(row['unwatched'])
    ].reset_index(drop=True)
    users = np.stack(
        users_data[
            users_data.userId.isin(row['userId'])
        ].embedding.values
    )

    nbrs = NearestNeighbors(n_neighbors=10)
    nbrs.fit(
        np.stack(
            unwatched_films.embedding.values
        )
    )
    probs = {}
    prob, ind = nbrs.kneighbors(users)
    for p, ix in zip(prob.flatten(), ind.flatten()):
        probs[ix] = probs.get(ix, 0) + p
    top_movies_ind = sorted(probs.keys(), key=lambda x: -probs[x])[:10]
    top_movies = unwatched_films.loc[top_movies_ind, 'movieId'].values
    return top_movies

In [15]:
recommends = generate_recommendations(recommend, users_watch_history_test, unwatched)
recommends

Unnamed: 0,userId,group5,group6,group7,movieId,rating,group5_rec,group6_rec,group7_rec
0,1,4424,23830,7737,"[384, 613, 181, 16, 114, 533, 584, 131, 572, 5...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, ...","[417, 517, 315, 325, 287, 408, 204, 274, 252, ...","[204, 315, 130, 517, 496, 57, 510, 611, 385, 274]","[417, 517, 204, 57, 361, 337, 652, 95, 163, 611]"
1,128307,19099,24828,7737,"[99, 142, 33, 244, 97, 80, 181, 223, 176, 157,...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, ...","[693, 517, 652, 417, 669, 376, 283, 749, 531, ...","[461, 749, 669, 693, 517, 57, 417, 283, 337, 596]","[417, 517, 204, 57, 361, 337, 652, 95, 163, 611]"
2,5104,30892,8245,7737,"[198, 8, 85, 290, 206, 172, 208, 112, 184, 88,...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, ...","[517, 571, 417, 204, 596, 693, 669, 408, 461, 74]","[337, 417, 715, 361, 693, 325, 408, 596, 259, ...","[417, 517, 204, 57, 361, 337, 652, 95, 163, 611]"
3,47993,16791,1281,7737,"[90, 784, 394, 59, 734, 779, 324, 193, 328, 82...","[5.0, 5.0, 5.0, 5.0, 5.0, 4.5, 4.5, 4.5, 4.5, ...","[417, 517, 461, 337, 204, 287, 163, 167, 57, 693]","[57, 204, 517, 693, 167, 74, 417, 95, 283, 287]","[417, 517, 204, 57, 361, 337, 652, 95, 163, 611]"
4,32850,281,7953,7737,"[384, 730, 55, 394, 548, 90, 793]","[5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 3.5]","[693, 204, 517, 741, 417, 57, 315, 669, 325, 461]","[517, 417, 693, 204, 259, 315, 408, 669, 462, ...","[417, 517, 204, 57, 361, 337, 652, 95, 163, 611]"
...,...,...,...,...,...,...,...,...,...
128586,90835,5409,12500,20074,"[59, 572, 561, 544, 529, 301, 326, 354, 531]","[5.0, 4.0, 4.0, 3.5, 3.5, 3.0, 2.5, 2.5, 2.0]","[517, 57, 462, 669, 259, 325, 741, 95, 315, 693]","[315, 417, 57, 204, 517, 130, 669, 95, 513, 259]","[417, 517, 57, 462, 669, 259, 315, 287, 434, 66]"
128587,96812,26072,7682,20074,"[332, 367, 267, 100, 274, 329, 259, 153, 344, ...","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, ...","[315, 517, 513, 567, 715, 531, 204, 287, 720, ...","[517, 571, 130, 408, 337, 417, 315, 259, 564, ...","[417, 517, 57, 462, 669, 259, 315, 287, 434, 66]"
128588,65482,20574,10362,12742,"[90, 135, 175, 51, 177, 188, 220]","[5.0, 4.0, 4.0, 4.0, 3.0, 2.0, 1.0]","[693, 417, 204, 172, 326, 749, 46, 313, 652, 517]","[693, 74, 496, 517, 669, 361, 597, 652, 461, 417]","[693, 95, 236, 57, 669, 167, 517, 283, 749, 204]"
128589,37798,10175,5664,12742,"[59, 22, 64, 96, 97, 26, 14, 79]","[5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0]","[361, 693, 517, 259, 749, 57, 95, 287, 150, 513]","[749, 236, 95, 741, 150, 283, 57, 517, 669, 573]","[693, 95, 236, 57, 669, 167, 517, 283, 749, 204]"


In [16]:
group_sum_results = evaluate_recommendations(recommends)
group_sum_results

Unnamed: 0,MAP,NDCG
group5,0.018896,0.160253
group6,0.017626,0.151317
group7,0.016484,0.143722


In [17]:
task.upload_artifact('avg_user_metrics', avg_user_results)
task.upload_artifact('group_sum_metrics', group_sum_results)

True

In [18]:
task.close()