<a href="https://colab.research.google.com/github/Sergey-Kit/RecoServiceTemplate/blob/hww_5/itmo_recsys_dz_5_Recbole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install recbole ray kmeans_pytorch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.6/62.6 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [18]:
import ast
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle

import torch
from recbole.model.general_recommender.multivae import MultiVAE

import warnings
warnings.filterwarnings('ignore')

from collections import Counter
from random import randint, random
from scipy.sparse import coo_matrix, hstack
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity
import logging
from logging import getLogger
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.model.sequential_recommender import GRU4Rec, Caser
from recbole.trainer import Trainer
from recbole.utils import init_seed, init_logger
from recbole.quick_start import run_recbole

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [52]:
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_KION.zip
!unzip -o data_KION.zip
!rm data_KION.zip

Archive:  data_KION.zip
  inflating: data_original/interactions.csv  
  inflating: __MACOSX/data_original/._interactions.csv  
  inflating: data_original/users.csv  
  inflating: __MACOSX/data_original/._users.csv  
  inflating: data_original/items.csv  
  inflating: __MACOSX/data_original/._items.csv  


In [53]:
items = pd.read_csv('data_original/items.csv')
interactions = pd.read_csv("data_original/interactions.csv",
                           parse_dates=['last_watch_dt'])
users = pd.read_csv('data_original/users.csv')

## User /  Item / Interaction preparation

In [54]:
users.fillna('unknown', inplace=True)
users['age'] = users['age'].astype('category')
users['income'] = users['income'].astype('category')
users['sex'] = users['sex'].astype('category')
users['kids_flg'] = users['kids_flg'].astype('bool')

items['content_type'] = items['content_type'].astype('category')
items['title'] = items['title'].str.lower()
items['title_orig'] = items['title_orig'].fillna('unknown')

items['release_year'] = items['release_year'].fillna(2020)
items.loc[items['release_year'] < 1920, 'release_year_cat'] = 'inf_1920'
items.loc[items['release_year'] >= 2020, 'release_year_cat'] = '2020_inf'
for i in range (1920, 2020, 10):
    items.loc[(items['release_year'] >= i) & (items['release_year'] < i+10), 'release_year_cat'] = f'{i}-{i+10}'
items = items.drop(columns=['release_year'])
items['release_year_cat'] = items['release_year_cat'].astype('category')

items['genres'] = items['genres'].astype('category')

items['countries'] = items['countries'].fillna('Россия')
items['countries'] = items['countries'].str.lower()
items['countries'] = items['countries'].apply(
    lambda x: ', '.join(sorted(list(set(x.split(', '))))))
items['countries'] = items['countries'].astype('category')

items['for_kids'] = items['for_kids'].fillna(0).astype('bool')
items['age_rating'] = items['age_rating'].fillna(0).astype('category')

items['studios'] = items['studios'].fillna('unknown').str.lower()
items['studios'] = items['studios'].apply(
    lambda x: ', '.join(sorted(list(set(x.split(', '))))))
items['studios'] = items['studios'].astype('category')

items['directors'] = items['directors'].fillna('unknown').str.lower().\
  astype('category')

items['actors'] = items['actors'].fillna('unknown').astype('category')

items['keywords'] = items['keywords'].fillna('unknown').\
  apply(lambda x: list(x.lower().replace(',','').split()))

items['description'] = items['description'].fillna('unknown')

interactions['watched_pct'] = interactions['watched_pct'].astype(pd.Int8Dtype())
interactions['watched_pct'] = interactions['watched_pct'].fillna(0)

In [55]:
user_cat_feats = ["age", "income", "sex", "kids_flg"]
users_ohe = users.user_id
for feat in user_cat_feats:
    ohe_feat = pd.get_dummies(users[feat], prefix=feat)
    users_ohe = pd.concat([users_ohe, ohe_feat], axis=1)
users_ohe.head()

Unnamed: 0,user_id,age_age_18_24,age_age_25_34,age_age_35_44,age_age_45_54,age_age_55_64,age_age_65_inf,age_unknown,income_income_0_20,income_income_150_inf,income_income_20_40,income_income_40_60,income_income_60_90,income_income_90_150,income_unknown,sex_unknown,sex_Ж,sex_М,kids_flg_False,kids_flg_True
0,973171,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1
1,962099,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0
2,1047345,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0
3,721985,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0
4,704055,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,1,0


In [56]:
users_ohe["uid"] = users_ohe["user_id"].astype("category")
users_ohe["uid"] = users_ohe["uid"].cat.codes

uid_to_user_id = users_ohe[["uid", "user_id"]].to_dict()["user_id"]
user_id_to_uid = {v:k for k, v in zip(uid_to_user_id.keys(), uid_to_user_id.values())}

users_ohe.drop(columns=["uid"], inplace=True)

In [57]:
print(f"N users before: {interactions.user_id.nunique()}")
print(f"N items before: {interactions.item_id.nunique()}\n")

# отфильтруем все события взаимодействий, в которых пользователь посмотрел
# фильм менее чем на 10 процентов
interactions = interactions[interactions.watched_pct > 10]

# соберем всех пользователей, которые посмотрели
# больше 10 фильмов
valid_users = []

c = Counter(interactions.user_id)
for user_id, entries in c.most_common():
    if entries > 10:
        valid_users.append(user_id)

# соберем все фильмы, которые посмотрели больше 3 пользователей
valid_items = []

c = Counter(interactions.item_id)
for item_id, entries in c.most_common():
    if entries > 3:
        valid_items.append(item_id)

# отбросим непопулярные фильмы и неактивных юзеров
interactions = interactions[interactions.user_id.isin(valid_users)]
interactions = interactions[interactions.item_id.isin(valid_items)]

print(f"N users after: {interactions.user_id.nunique()}")
print(f"N items after: {interactions.item_id.nunique()}")

N users before: 962179
N items before: 15706

N users after: 79515
N items after: 9387


In [None]:
common_users = set(interactions.user_id.unique()).intersection(set(users_ohe.user_id.unique()))

interactions = interactions[interactions.user_id.isin(common_users)]
users_ohe = users_ohe[users_ohe.user_id.isin(common_users)]

common_items = set(interactions.item_id.unique()).intersection(set(items_ohe.item_id.unique()))

interactions = interactions[interactions.item_id.isin(common_items)]
items_ohe = items_ohe[items_ohe.item_id.isin(common_items)]

print(len(common_users))
print(len(common_items))

In [None]:
max_date = interactions['last_watch_dt'].max()

interactions_train = interactions[(interactions['last_watch_dt'] < max_date - pd.Timedelta(days=7))].copy()
users_ohe_train = users_ohe[users_ohe['user_id'].isin(interactions_train['user_id'].unique())].copy()
items_ohe_train = items_ohe[items_ohe['item_id'].isin(interactions_train['item_id'].unique())].copy()

interactions_test = interactions[(interactions['last_watch_dt'] >= max_date - pd.Timedelta(days=7))].copy()

# оставляем только теплых пользователей в тесте
interactions_hot_test = interactions_test[(interactions_test['user_id'].isin(interactions_train['user_id'].unique())) & (interactions_test['item_id'].isin(interactions_train['item_id'].unique()))].copy()
users_ohe_hot_test = users_ohe[users_ohe['user_id'].isin(interactions_hot_test['user_id'].unique())].copy()

catalog = interactions_train['item_id'].unique()

print(f"train: {interactions_train.shape}")
print(f"test: {interactions_test.shape}")
print(f"hot test: {interactions_hot_test.shape}")

In [None]:
interactions_train["uid"] = interactions_train["user_id"].astype("category")
interactions_train["uid"] = interactions_train["uid"].cat.codes

interactions_hot_test["uid"] = interactions_hot_test["user_id"].astype("category")
interactions_hot_test["uid"] = interactions_hot_test["uid"].cat.codes

interactions_train["iid"] = interactions_train["item_id"].astype("category")
interactions_train["iid"] = interactions_train["iid"].cat.codes

mapping_iid = dict(zip(interactions_train['item_id'], interactions_train['iid']))
interactions_hot_test['iid'] = interactions_hot_test['item_id'].map(mapping_iid)
del mapping_iid

print(sorted(interactions_train.iid.unique())[:5])
print(sorted(interactions_train.uid.unique())[:5])
print(sorted(interactions_hot_test.iid.unique())[:5])
print(sorted(interactions_hot_test.uid.unique())[:5])

In [None]:
iid_to_item_id_train = interactions_train[["iid", "item_id"]].drop_duplicates().set_index("iid").to_dict()["item_id"]
item_id_to_iid_train = interactions_train[["iid", "item_id"]].drop_duplicates().set_index("item_id").to_dict()["iid"]

uid_to_user_id_train = interactions_train[["uid", "user_id"]].drop_duplicates().set_index("uid").to_dict()["user_id"]
user_id_to_uid_train = interactions_train[["uid", "user_id"]].drop_duplicates().set_index("user_id").to_dict()["uid"]

uid_to_user_id_hot_test = interactions_hot_test[["uid", "user_id"]].drop_duplicates().set_index("uid").to_dict()["user_id"]
user_id_to_uid_hot_test = interactions_hot_test[["uid", "user_id"]].drop_duplicates().set_index("user_id").to_dict()["uid"]

In [None]:
items_ohe_train["iid"] = items_ohe_train["item_id"].apply(lambda x: item_id_to_iid_train[x])
items_ohe_train = items_ohe_train.set_index("iid")

users_ohe_train["uid"] = users_ohe_train["user_id"].apply(lambda x: user_id_to_uid_train[x])
users_ohe_train = users_ohe_train.set_index("uid")

users_ohe_hot_test["uid"] = users_ohe_hot_test["user_id"].apply(lambda x: user_id_to_uid_hot_test[x])
users_ohe_hot_test = users_ohe_hot_test.set_index("uid")

In [None]:
items_ohe_train.sort_index(inplace=True)
users_ohe_train.sort_index(inplace=True)
users_ohe_hot_test.sort_index(inplace=True)

In [37]:
interactions['t_dat'] = pd.to_datetime(interactions['last_watch_dt'], format="%Y-%m-%d")
interactions['timestamp'] = interactions.t_dat.values.astype(np.int64) // 10 ** 9

In [38]:
df = interactions[['user_id', 'item_id', 'timestamp']].rename(
    columns={'user_id': 'user_id:token', 'item_id': 'item_id:token', 'timestamp': 'timestamp:float'})

In [39]:
!mkdir recbox_data

mkdir: cannot create directory ‘recbox_data’: File exists


In [10]:
df.to_csv('recbox_data/recbox_data.inter', index=False, sep='\t')

In [40]:
parameter_dict = {
    'data_path': '',
    'USER_ID_FIELD': 'user_id',
    'ITEM_ID_FIELD': 'item_id',
    'TIME_FIELD': 'timestamp',
    'device': 'GPU',
    'user_inter_num_interval': "[40,inf)",
    'item_inter_num_interval': "[40,inf)",
    'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},
    'neg_sampling': None,
    'epochs': 10,
    'eval_args': {
        'split': {'RS': [9, 0, 1]},
        'group_by': 'user',
        'order': 'TO',
        'mode': 'full'}
}
config = Config(model='MultiVAE', dataset='recbox_data', config_dict=parameter_dict)

# init random seed
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()
# Create handlers
c_handler = logging.StreamHandler()
c_handler.setLevel(logging.INFO)
logger.addHandler(c_handler)

# write config info into log
# logger.info(config)

command line args [-f /root/.local/share/jupyter/runtime/kernel-2a9840b7-23cb-44a5-b52a-2d8504614f24.json] will not be used in RecBole


In [41]:
dataset = create_dataset(config)
logger.info(dataset)

In [42]:
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)

In [14]:
import time
from recbole.quick_start import run_recbole

model_list = ['MultiVAE', 'MultiDAE', 'MacridVAE',"BPR", "NeuMF","RecVAE", 'RepeatNet', "LightGCN"]  # Added other models

for model_name in model_list:
    print(f"Running {model_name}...")
    start_time = time.time()

    # Run RecBole with the specified model
    result = run_recbole(model=model_name, dataset="recbox_data", config_dict=parameter_dict)

    elapsed_time = time.time() - start_time
    print(f"{model_name} took {elapsed_time / 60:.2f} mins")
    print(result)


Running MultiVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.14it/s]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:07<00:00,  1.03s/it]
Train     2: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.17it/s]
Train     3: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.00it/s]
Train     4: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.07it/s]
Train     5: 100%|████████████████████████████████████████████████████| 7/7 [00:07<00:00,  1.03s/it]
Train     6: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.18it/s]
Train     7: 100%|████████████████████████████████████████████████

MultiVAE took 2.85 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.084), ('mrr@10', 0.1695), ('ndcg@10', 0.0825), ('hit@10', 0.3503), ('precision@10', 0.0467)])}
Running MultiDAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.11it/s]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.08it/s]
Train     2: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.19it/s]
Train     3: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.03it/s]
Train     4: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.21it/s]
Train     5: 100%|████████████████████████████████████████████████████| 7/7 [00:07<00:00,  1.00s/it]
Train     6: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.23it/s]
Train     7: 100%|████████████████████████████████████████████████

MultiDAE took 2.77 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0829), ('mrr@10', 0.1655), ('ndcg@10', 0.081), ('hit@10', 0.3438), ('precision@10', 0.0459)])}
Running MacridVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:52<00:00,  7.53s/it]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:50<00:00,  7.21s/it]
Train     2: 100%|████████████████████████████████████████████████████| 7/7 [00:49<00:00,  7.12s/it]
Train     3: 100%|████████████████████████████████████████████████████| 7/7 [00:48<00:00,  6.90s/it]
Train     4: 100%|████████████████████████████████████████████████████| 7/7 [00:52<00:00,  7.45s/it]
Train     5: 100%|████████████████████████████████████████████████████| 7/7 [01:00<00:00,  8.59s/it]
Train     6: 100%|████████████████████████████████████████████████████| 7/7 [00:48<00:00,  6.94s/it]
Train     7: 100%|████████████████████████████████████████████████

MacridVAE took 12.60 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0835), ('mrr@10', 0.1574), ('ndcg@10', 0.0788), ('hit@10', 0.3499), ('precision@10', 0.0461)])}
Running BPR...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Train     0: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 30.09it/s]
Train     1: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 30.21it/s]
Train     2: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 29.94it/s]
Train     3: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 30.01it/s]
Train     4: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 30.04it/s]
Train     5: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 30.02it/s]
Train     6: 100%|████████████████████████████████████████████████| 378/378 [00:12<00:00, 29.98it/s]
Train     7: 100%|████████████████████████████████████████████████| 378/378 [00:14<00:00, 25.87it/s]
Train     8: 100%|███████████████████████████████████████

BPR took 3.27 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0824), ('mrr@10', 0.1716), ('ndcg@10', 0.0819), ('hit@10', 0.3479), ('precision@10', 0.0457)])}
Running NeuMF...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Train     0: 100%|████████████████████████████████████████████████| 755/755 [00:48<00:00, 15.67it/s]
Train     1: 100%|████████████████████████████████████████████████| 755/755 [00:49<00:00, 15.32it/s]
Train     2: 100%|████████████████████████████████████████████████| 755/755 [00:48<00:00, 15.61it/s]
Train     3: 100%|████████████████████████████████████████████████| 755/755 [00:49<00:00, 15.39it/s]
Train     4: 100%|████████████████████████████████████████████████| 755/755 [00:55<00:00, 13.61it/s]
Train     5: 100%|████████████████████████████████████████████████| 755/755 [00:48<00:00, 15.66it/s]
Train     6: 100%|████████████████████████████████████████████████| 755/755 [00:49<00:00, 15.19it/s]
Train     7: 100%|████████████████████████████████████████████████| 755/755 [00:50<00:00, 14.81it/s]
Train     8: 100%|███████████████████████████████████████

NeuMF took 11.08 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.069), ('mrr@10', 0.1173), ('ndcg@10', 0.0605), ('hit@10', 0.3009), ('precision@10', 0.0381)])}
Running RecVAE...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.62s/it]
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.60s/it]
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.60s/it]
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:10<00:00,  1.49s/it]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:10<00:00,  1.51s/it]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.61s/it]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.62s/it]
Train     1: 100%|████████████████████████████████████████████████

RecVAE took 9.29 mins
{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0849), ('mrr@10', 0.1697), ('ndcg@10', 0.0828), ('hit@10', 0.3532), ('precision@10', 0.047)])}
Running RepeatNet...


command line args [-f /root/.local/share/jupyter/runtime/kernel-78e4b8d9-24ce-4fb7-960c-f0cc1344211a.json] will not be used in RecBole
Train     0:  19%|████████▌                                     | 138/743 [14:46<1:04:47,  6.42s/it]


KeyboardInterrupt: ignored

# Kyboard stoped cause it was taking too much time

Choosed MultiVAE

In [43]:
result = run_recbole(model='MultiVAE', dataset = 'recbox_data',config_dict = parameter_dict )

command line args [-f /root/.local/share/jupyter/runtime/kernel-2a9840b7-23cb-44a5-b52a-2d8504614f24.json] will not be used in RecBole
command line args [-f /root/.local/share/jupyter/runtime/kernel-2a9840b7-23cb-44a5-b52a-2d8504614f24.json] will not be used in RecBole
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Max value of user's history interaction records has reached 20.9471766848816% of the total.
Train     0: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.23it/s]
Train     1: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.01it/s]
Train     2: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.23it/s]
Train     3: 100%|████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.01it/s]
Train     4: 100%|████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.21it/s]
Train     5: 100%|███████████████████████

In [44]:
result

{'best_valid_score': -inf,
 'valid_score_bigger': True,
 'best_valid_result': None,
 'test_result': OrderedDict([('recall@10', 0.084),
              ('mrr@10', 0.1695),
              ('ndcg@10', 0.0825),
              ('hit@10', 0.3503),
              ('precision@10', 0.0467)])}

In [46]:
# After trained model 'model' and configuration 'config'
# Save the model state dictionary and other necessary information

model = MultiVAE(config, dataset)
torch.save({
    'state_dict': model.state_dict(),
    'config': config
}, 'recbole.pth')

Max value of user's history interaction records has reached 23.254401942926535% of the total.
Max value of user's history interaction records has reached 23.254401942926535% of the total.


In [47]:
# Load model and optimizer state dictionaries
# model = MultiVAE(config, dataset)
checkpoint = torch.load('recbole.pth')
model.load_state_dict(checkpoint['state_dict'])


<All keys matched successfully>

In [50]:
import numpy as np

def recommend_items_to_user(external_user_id, dataset, model):
    if external_user_id in dataset.field2token_id[dataset.uid_field] and external_user_id != "[PAD]":
        # Map external user ID to internal user index
        internal_user_index = dataset.field2token_id[dataset.uid_field][external_user_id]

        # Extract user metadata features and interaction vector
        user_metadata_features = dataset.users_ohe.drop(["user_id"], axis=1).iloc[internal_user_index]
        user_interaction_vector = dataset.interactions_vec[internal_user_index]

        # Predict user vector using the trained user-to-vector model
        user_vector = dataset.u2v.predict(
            [np.array(user_metadata_features).reshape(1, -1), np.array(user_interaction_vector).reshape(1, -1)],
            verbose=False,
        )

        # Instead of calculating distance for all items, just select a random subset
        sampled_item_indices = np.random.choice(dataset.item_embeddings.shape[0], size=100, replace=False)
        sampled_item_embeddings = dataset.item_embeddings[sampled_item_indices, :]

        # Calculate distances between the user vector and sampled item embeddings
        distances = np.linalg.norm(user_vector - sampled_item_embeddings, axis=1)

        # Get the indices of the top 10 items from the sampled set
        topn_item_indices_sampled = np.argsort(distances)[:10]

        # Map internal item indices to item_ids
        topn_item_ids = [dataset.iid_to_item_id[iid] for iid in topn_item_indices_sampled]

        return topn_item_ids

    return []

In [None]:
recos = {}
users = dataset.field2token_id[dataset.uid_field]
for user_id in users:
    recommended_items = recommend_items_to_user(user_id, dataset, model)
    if recommended_items:
        recos[user_id] = recommended_items


In [None]:
# with open("recbole_offline.pkl", "wb") as f:
#     pickle.dump(recos, f)

In [24]:
dataset

[1;35mrecbox_data[0m
[1;34mThe number of users[0m: 13355
[1;34mAverage actions of users[0m: 63.815710648494836
[1;34mThe number of items[0m: 3294
[1;34mAverage actions of items[0m: 258.78985727300335
[1;34mThe number of inters[0m: 852195
[1;34mThe sparsity of the dataset[0m: 98.06281322904924%
[1;34mRemain Fields[0m: ['user_id', 'item_id', 'timestamp']