## Устанавливаем необходимые библиотеки

In [None]:
pip install -q recbole ray kmeans_pytorch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.6/62.6 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import ast
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import time

from collections import Counter
from random import randint, random
from scipy.sparse import coo_matrix, hstack
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity

import logging
from logging import getLogger
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.model.sequential_recommender import GRU4Rec, Caser
from recbole.trainer import Trainer
from recbole.utils import init_seed, init_logger
from recbole.quick_start import run_recbole

import torch
from recbole.model.general_recommender.multivae import MultiVAE

## Загружаем данные

Код с семинара:

In [None]:
interactions_df = pd.read_csv('interactions_processed_kion.csv')
users_df = pd.read_csv('users_processed_kion.csv')
items_df = pd.read_csv('items_processed_kion.csv')

In [None]:
interactions_df['t_dat'] = pd.to_datetime(interactions_df['last_watch_dt'], format="%Y-%m-%d")
interactions_df['timestamp'] = interactions_df.t_dat.values.astype(np.int64) // 10 ** 9

df = interactions_df[['user_id', 'item_id', 'timestamp']].rename(
    columns={'user_id': 'user_id:token', 'item_id': 'item_id:token', 'timestamp': 'timestamp:float'})

In [None]:
!mkdir recbox_data

In [None]:
df.to_csv('recbox_data/recbox_data.inter', index=False, sep='\t')

In [None]:
parameter_dict = {
    'data_path': '',
    'USER_ID_FIELD': 'user_id',
    'ITEM_ID_FIELD': 'item_id',
    'TIME_FIELD': 'timestamp',
    'device': 'GPU',
    'user_inter_num_interval': "[40,inf)",
    'item_inter_num_interval': "[40,inf)",
    'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},
    'neg_sampling': None,
    'epochs': 10,
    'eval_args': {
        'split': {'RS': [9, 0, 1]},
        'group_by': 'user',
        'order': 'TO',
        'mode': 'full'}
}
config = Config(model='MultiVAE', dataset='recbox_data', config_dict=parameter_dict)

# init random seed
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()
# Create handlers
c_handler = logging.StreamHandler()
c_handler.setLevel(logging.INFO)
logger.addHandler(c_handler)

# write config info into log
# logger.info(config)



In [None]:
dataset = create_dataset(config)
logger.info(dataset)

In [None]:
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)

## Обучаем модели

In [None]:
%%time
model_list = ["MultiVAE", "MultiDAE", "MacridVAE", "NeuMF", "RecVAE",
              "ItemKNN", "DMF", "ConvNCF", "LightGCN"]

for model_name in model_list:
    print(f"running {model_name}...")
    start = time.time()
    result = run_recbole(model=model_name,
                         dataset='recbox_data',
                         config_dict=parameter_dict)
    t = time.time() - start
    print(f"It took {t/60:.2f} mins")
    print(result)
    print("=="*20)

running MultiVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00, 15.75it/s, GPU RAM: 0.38 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:00<00:00, 19.91it/s, GPU RAM: 0.38 G/14.75 G]
Train     2: 100%|███████████████████████████| 7/7 [00:00<00:00, 20.03it/s, GPU RAM: 0.38 G/14.75 G]
Train     3: 100%|███████████████████████████| 7/7 [00:00<00:00, 19.64it/s, GPU RAM: 0.38 G/14.75 G]
Train     4: 100%|███████████████████████████| 7/7 [00:00<00:00, 18.83it/s, GPU RAM: 0.38 G/14.75 G]
Train     5: 100%|███████████████████████████| 7/7 [00:00<00:00, 20.05it/s, GPU RAM: 0.38 G/14.75 G]
Train     6: 100%|███████████████████████████| 7/7 [00:00<00:00, 19.57it/s, GPU RAM: 0.38 G/14.75 G]
Train     7: 100%|███████████████████████████| 7/7 [00:00<00:00, 1

It took 3.19 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0834), ('mrr@10', 0.1671), ('ndcg@10', 0.0816), ('hit@10', 0.3466), ('precision@10', 0.0462)])}
running MultiDAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00, 12.48it/s, GPU RAM: 0.38 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:00<00:00,  8.58it/s, GPU RAM: 0.38 G/14.75 G]
Train     2: 100%|███████████████████████████| 7/7 [00:01<00:00,  5.67it/s, GPU RAM: 0.38 G/14.75 G]
Train     3: 100%|███████████████████████████| 7/7 [00:00<00:00, 11.32it/s, GPU RAM: 0.38 G/14.75 G]
Train     4: 100%|███████████████████████████| 7/7 [00:00<00:00, 12.06it/s, GPU RAM: 0.40 G/14.75 G]
Train     5: 100%|███████████████████████████| 7/7 [00:02<00:00,  2.86it/s, GPU RAM: 0.40 G/14.75 G]
Train     6: 100%|███████████████████████████| 7/7 [00:00<00:00, 13.85it/s, GPU RAM: 0.40 G/14.75 G]
Train     7: 100%|███████████████████████████| 7/7 [00:00<00:00,  

It took 4.80 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0837), ('mrr@10', 0.1657), ('ndcg@10', 0.0814), ('hit@10', 0.3466), ('precision@10', 0.0463)])}
running MacridVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  3.60it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  3.79it/s, GPU RAM: 0.95 G/14.75 G]
Train     2: 100%|███████████████████████████| 7/7 [00:01<00:00,  3.53it/s, GPU RAM: 0.95 G/14.75 G]
Train     3: 100%|███████████████████████████| 7/7 [00:02<00:00,  2.70it/s, GPU RAM: 0.95 G/14.75 G]
Train     4: 100%|███████████████████████████| 7/7 [00:02<00:00,  3.35it/s, GPU RAM: 0.95 G/14.75 G]
Train     5: 100%|███████████████████████████| 7/7 [00:01<00:00,  3.71it/s, GPU RAM: 0.95 G/14.75 G]
Train     6: 100%|███████████████████████████| 7/7 [00:01<00:00,  3.67it/s, GPU RAM: 0.95 G/14.75 G]
Train     7: 100%|███████████████████████████| 7/7 [00:01<00:00,  

It took 8.60 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0827), ('mrr@10', 0.1548), ('ndcg@10', 0.0775), ('hit@10', 0.3469), ('precision@10', 0.0455)])}
running NeuMF...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Train     0: 100%|███████████████████████| 755/755 [00:34<00:00, 21.58it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████| 755/755 [00:35<00:00, 21.09it/s, GPU RAM: 0.95 G/14.75 G]
Train     2: 100%|███████████████████████| 755/755 [00:36<00:00, 20.93it/s, GPU RAM: 0.95 G/14.75 G]
Train     3: 100%|███████████████████████| 755/755 [00:35<00:00, 21.34it/s, GPU RAM: 0.95 G/14.75 G]
Train     4: 100%|███████████████████████| 755/755 [00:35<00:00, 21.56it/s, GPU RAM: 0.95 G/14.75 G]
Train     5: 100%|███████████████████████| 755/755 [00:35<00:00, 21.53it/s, GPU RAM: 0.95 G/14.75 G]
Train     6: 100%|███████████████████████| 755/755 [00:34<00:00, 21.67it/s, GPU RAM: 0.95 G/14.75 G]
Train     7: 100%|███████████████████████| 755/755 [00:35<00:00, 21.37it/s, GPU RAM: 0.95 G/14.75 G]
Train     8: 100%|███████████████████████| 755/755 [00:35

It took 10.09 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0687), ('mrr@10', 0.1181), ('ndcg@10', 0.0607), ('hit@10', 0.3008), ('precision@10', 0.038)])}
running RecVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00,  8.75it/s, GPU RAM: 0.95 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00,  8.86it/s, GPU RAM: 0.95 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00,  7.41it/s, GPU RAM: 0.95 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  6.93it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  6.58it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  6.68it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:00<00:00,  7.57it/s, GPU RAM: 0.95 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:00<00:00,  

It took 6.87 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0846), ('mrr@10', 0.1661), ('ndcg@10', 0.0818), ('hit@10', 0.3523), ('precision@10', 0.0469)])}
running ItemKNN...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Train     0: 100%|███████████████████████| 755/755 [00:27<00:00, 27.45it/s, GPU RAM: 0.95 G/14.75 G]
Evaluate   : 100%|███████████████████| 13354/13354 [06:20<00:00, 35.12it/s, GPU RAM: 0.95 G/14.75 G]


It took 10.06 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0909), ('mrr@10', 0.1768), ('ndcg@10', 0.088), ('hit@10', 0.3654), ('precision@10', 0.0504)])}
running DMF...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of item's history interaction records has reached 44.36540621490079% of the total.
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████| 755/755 [00:50<00:00, 15.09it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████| 755/755 [00:49<00:00, 15.12it/s, GPU RAM: 1.59 G/14.75 G]
Train     2: 100%|███████████████████████| 755/755 [00:50<00:00, 15.08it/s, GPU RAM: 1.59 G/14.75 G]
Train     3: 100%|███████████████████████| 755/755 [00:48<00:00, 15.42it/s, GPU RAM: 1.59 G/14.75 G]
Train     4: 100%|███████████████████████| 755/755 [00:49<00:00, 15.17it/s, GPU RAM: 1.59 G/14.75 G]
Train     5: 100%|███████████████████████| 755/755 [00:49<00:00, 15.37it/s, GPU RAM: 1.59 G/14.75 G]
Train     6: 100%|███████████████████████| 755/755 [00:49<00:00, 15.25it/s

It took 14.12 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0827), ('mrr@10', 0.1569), ('ndcg@10', 0.0781), ('hit@10', 0.3467), ('precision@10', 0.0455)])}
running RecVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.50it/s, GPU RAM: 1.59 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.58it/s, GPU RAM: 1.59 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.61it/s, GPU RAM: 1.59 G/14.75 G]
Train     0: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.67it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.05it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.76it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  4.57it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:01<00:00,  

It took 8.92 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0846), ('mrr@10', 0.1661), ('ndcg@10', 0.0818), ('hit@10', 0.3523), ('precision@10', 0.0469)])}
running ConvNCF...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Train     0: 100%|███████████████████████| 378/378 [02:20<00:00,  2.69it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████| 378/378 [00:56<00:00,  6.67it/s, GPU RAM: 1.59 G/14.75 G]
Train     2: 100%|███████████████████████| 378/378 [00:58<00:00,  6.45it/s, GPU RAM: 1.59 G/14.75 G]
Train     3: 100%|███████████████████████| 378/378 [00:57<00:00,  6.53it/s, GPU RAM: 1.59 G/14.75 G]
Train     4: 100%|███████████████████████| 378/378 [00:58<00:00,  6.49it/s, GPU RAM: 1.59 G/14.75 G]
Train     5: 100%|███████████████████████| 378/378 [00:58<00:00,  6.47it/s, GPU RAM: 1.59 G/14.75 G]
Train     6: 100%|███████████████████████| 378/378 [00:58<00:00,  6.48it/s, GPU RAM: 1.59 G/14.75 G]
Train     7: 100%|███████████████████████| 378/378 [00:57<00:00,  6.55it/s, GPU RAM: 1.59 G/14.75 G]
Train     8: 100%|███████████████████████| 378/378 [00:58

It took 22.12 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.061), ('mrr@10', 0.1361), ('ndcg@10', 0.0628), ('hit@10', 0.2648), ('precision@10', 0.0331)])}
running LightGCN...


command line args [-f /root/.local/share/jupyter/runtime/kernel-ffe82819-fe22-4eb8-992f-4401f79ca9dd.json] will not be used in RecBole
Train     0: 100%|███████████████████████| 378/378 [00:51<00:00,  7.28it/s, GPU RAM: 1.59 G/14.75 G]
Train     1: 100%|███████████████████████| 378/378 [00:52<00:00,  7.23it/s, GPU RAM: 1.59 G/14.75 G]
Train     2: 100%|███████████████████████| 378/378 [00:55<00:00,  6.86it/s, GPU RAM: 1.59 G/14.75 G]
Train     3: 100%|███████████████████████| 378/378 [00:52<00:00,  7.17it/s, GPU RAM: 1.59 G/14.75 G]
Train     4: 100%|███████████████████████| 378/378 [00:52<00:00,  7.18it/s, GPU RAM: 1.59 G/14.75 G]
Train     5: 100%|███████████████████████| 378/378 [00:52<00:00,  7.23it/s, GPU RAM: 1.59 G/14.75 G]
Train     6: 100%|███████████████████████| 378/378 [00:52<00:00,  7.27it/s, GPU RAM: 1.59 G/14.75 G]
Train     7: 100%|███████████████████████| 378/378 [00:52<00:00,  7.20it/s, GPU RAM: 1.59 G/14.75 G]
Train     8: 100%|███████████████████████| 378/378 [00:52

It took 15.61 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])}
CPU times: user 1h 31min, sys: 6min 14s, total: 1h 37min 15s
Wall time: 1h 44min 24s


## Лучшими моделями оказались `ItemKNN`, `RecVAE`, `MultiVAE`

In [None]:
result = run_recbole(model="MultiVAE",
                     dataset="recbox_data",
                     config_dict=parameter_dict)

command line args [-f /root/.local/share/jupyter/runtime/kernel-38adbd9a-5a34-47f0-bcee-cd772e2c89d3.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|███████████████████████████| 7/7 [00:00<00:00, 15.03it/s, GPU RAM: 0.67 G/14.75 G]
Train     1: 100%|███████████████████████████| 7/7 [00:00<00:00, 16.96it/s, GPU RAM: 0.67 G/14.75 G]
Train     2: 100%|███████████████████████████| 7/7 [00:00<00:00, 16.31it/s, GPU RAM: 0.67 G/14.75 G]
Train     3: 100%|███████████████████████████| 7/7 [00:00<00:00, 16.58it/s, GPU RAM: 0.67 G/14.75 G]
Train     4: 100%|███████████████████████████| 7/7 [00:00<00:00, 16.93it/s, GPU RAM: 0.67 G/14.75 G]
Train     5: 100%|███████████████████████████| 7/7 [00:00<00:00, 15.74it/s, GPU RAM: 0.67 G/14.75 G]
Train     6: 100%|███████████████████████████| 7/7 [00:00<00:00, 13.57it/s, GPU RAM: 0.67 G/14.75 G]
Train     7: 100%|███████████████████████████| 7/7 [00:00<00:00, 1

In [None]:
model = MultiVAE(config, dataset)
checkpoint = torch.load("/content/saved/MultiVAE-Dec-13-2023_09-19-53.pth")
model.load_state_dict(checkpoint["state_dict"])

Max value of user's history interaction records has reached 23.254401942926535% of the total.


<All keys matched successfully>

Напишем функцию, чтобы сделать предсказания для каждого из пользователей

In [None]:
model = model.to(config["device"])

In [None]:
def recommend_to_user(user_id,
                      dataset,
                      model):
    if user_id in dataset.field2token_id[dataset.uid_field] and user_id != "[PAD]":
        model.eval()
        with torch.no_grad():
            uid_series = dataset.token2id(dataset.uid_field, [user_id])
            index = np.isin(dataset[dataset.uid_field].numpy(), uid_series)
            new_inter = dataset[index]
            new_inter = new_inter.to(config["device"])
            new_scores = model.full_sort_predict(new_inter)
            new_scores = new_scores.view(-1, test_data.dataset.item_num)
            new_scores[:, 0] = -np.inf
            recommended_item_indices = torch.topk(new_scores, 10).indices[0].tolist()
            recos = dataset.id2token(dataset.iid_field, [recommended_item_indices]).tolist()
        return recos
    return []

In [None]:
recommendations = {}
users = dataset.field2token_id[dataset.uid_field]
for user_id in users:
    user_recs = recommend_to_user(user_id, dataset, model)
    if user_recs:
        recommendations |= {user_id: user_recs[0]}

In [None]:
recs = {int(user_id): list(map(lambda x: int(x), recommendations[user_id])) for user_id in recommendations}

In [None]:
with open("/content/MultiVAE_recs.pkl", "wb") as f:
    pickle.dump(recs, f)

## Добавив предсказания модели в сервис и протестировав их в боте, получил метрику `MAP@10 = 0.0881563` (>0.075)