From 8fc3999a03e3ac4f5ad5e7fa6c79f934a6251b70 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Fri, 25 Oct 2024 13:17:48 +0300 Subject: [PATCH 1/8] Added bert4rec --- examples/bert4rec.ipynb | 743 ++++++++++++++++++++++++++++++++++++ rectools/models/bert4rec.py | 578 ++++++++++++++++++++++++++++ rectools/models/sasrec.py | 36 +- 3 files changed, 1341 insertions(+), 16 deletions(-) create mode 100644 examples/bert4rec.ipynb create mode 100644 rectools/models/bert4rec.py diff --git a/examples/bert4rec.ipynb b/examples/bert4rec.ipynb new file mode 100644 index 00000000..72acebf9 --- /dev/null +++ b/examples/bert4rec.ipynb @@ -0,0 +1,743 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/data/home/maspirina1/tasks/repo/RecTools/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import threadpoolctl\n", + "from pathlib import Path\n", + "from lightning_fabric import seed_everything\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from rectools import Columns\n", + "\n", + "\n", + "from rectools.dataset import Dataset\n", + "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", + "from rectools.models.bert4rec import CatFeaturesItemNet, IdEmbeddingsItemNet, BERT4RecModel" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", + "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", + "threadpoolctl.threadpool_limits(1, \"blas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# %%time\n", + "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", + "# !unzip -o data_original.zip\n", + "# !rm data_original.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_PATH = Path(\"data_original\")\n", + "\n", + "interactions = (\n", + " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", + " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", + "\n", + "# Split to train / test\n", + "max_date = interactions[Columns.Datetime].max()\n", + "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", + "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", + "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", + "\n", + "# drop items with less than 20 interactions in train\n", + "items = train[\"item_id\"].value_counts()\n", + "items = items[items >= 20]\n", + "items = items.index.to_list()\n", + "train = train[train[\"item_id\"].isin(items)]\n", + " \n", + "# drop users with less than 2 interactions in train\n", + "users = train[\"user_id\"].value_counts()\n", + "users = users[users >= 2]\n", + "users = users.index.to_list()\n", + "train = train[(train[\"user_id\"].isin(users))]\n", + "\n", + "users = train[\"user_id\"].drop_duplicates().to_list()\n", + "\n", + "# drop cold users from test\n", + "test_users_sasrec = test[Columns.User].unique()\n", + "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", + "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", + "test_users = test[Columns.User].unique()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "items = pd.read_csv(DATA_PATH / 'items.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Process item features to the form of a flatten dataframe\n", + "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", + "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", + "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", + "genre_feature.columns = [\"id\", \"value\"]\n", + "genre_feature[\"feature\"] = \"genre\"\n", + "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", + "content_feature.columns = [\"id\", \"value\"]\n", + "content_feature[\"feature\"] = \"content_type\"\n", + "item_features = pd.concat((genre_feature, content_feature))\n", + "\n", + "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", + "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", + "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", + "\n", + "catalog=train[Columns.Item].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_no_features = Dataset.construct(\n", + " interactions_df=train,\n", + ")\n", + "\n", + "dataset_item_features = Dataset.construct(\n", + " interactions_df=train,\n", + " item_features_df=item_features,\n", + " cat_item_features=[\"genre\", \"content_type\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "metrics_name = {\n", + " 'MAP': MAP,\n", + " 'MIUF': MeanInvUserFreq,\n", + " 'Serendipity': Serendipity\n", + " \n", + "\n", + "}\n", + "metrics = {}\n", + "for metric_name, metric in metrics_name.items():\n", + " for k in (1, 5, 10):\n", + " metrics[f'{metric_name}@{k}'] = metric(k=k)\n", + "\n", + "# list with metrics results of all models\n", + "features_results = []\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT4Rec" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 32\n" + ] + }, + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "RANDOM_SEED = 32\n", + "torch.use_deterministic_algorithms(True)\n", + "seed_everything(RANDOM_SEED, workers=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### BERT4Rec with item ids embeddings in ItemNetBlock" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "model = BERT4RecModel(\n", + " n_blocks=3,\n", + " n_heads=4,\n", + " dropout_rate=0.2,\n", + " session_max_len=32,\n", + " lr=1e-3,\n", + " epochs=5,\n", + " verbose=1,\n", + " mask_prob=0.5,\n", + " deterministic=True,\n", + " item_net_block_types=(IdEmbeddingsItemNet, ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 1.3 M \n", + "---------------------------------------------------------------\n", + "1.3 M Trainable params\n", + "0 Non-trainable params\n", + "1.3 M Total params\n", + "5.291 Total estimated model params size (MB)\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "221b2c4ffa834b80a47552e1b8ffd21d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00:1\u001b[0m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:69\u001b[0m, in \u001b[0;36mModelBase.fit\u001b[0;34m(self, dataset, *args, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, dataset: Dataset, \u001b[38;5;241m*\u001b[39margs: tp\u001b[38;5;241m.\u001b[39mAny, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: tp\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 57\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;124;03m Fit model.\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;124;03m self\u001b[39;00m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_fitted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:485\u001b[0m, in \u001b[0;36mBERT4RecModel._fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 483\u001b[0m lightning_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module_type(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtorch_model, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss)\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer \u001b[38;5;241m=\u001b[39m deepcopy(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer)\n\u001b[0;32m--> 485\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlightning_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:544\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 544\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 545\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 546\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:580\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 574\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 575\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 576\u001b[0m ckpt_path,\n\u001b[1;32m 577\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 578\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 579\u001b[0m )\n\u001b[0;32m--> 580\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:987\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 982\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 986\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 987\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 989\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 990\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 991\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 992\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1033\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1031\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1032\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1033\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1034\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1035\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:157\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 154\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 157\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 160\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/core/module.py:1303\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1266\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1269\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1270\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1273\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1301\u001b[0m \n\u001b[1;32m 1302\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1303\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py:152\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 152\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:239\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 239\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/optimizer.py:391\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 387\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 388\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 389\u001b[0m )\n\u001b[0;32m--> 391\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 394\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 75\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 78\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/adam.py:148\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 148\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 151\u001b[0m params_with_grad \u001b[38;5;241m=\u001b[39m []\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Optimizer,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:318\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[1;32m 317\u001b[0m \u001b[38;5;66;03m# manually capture logged metrics\u001b[39;00m\n\u001b[0;32m--> 318\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_result_cls\u001b[38;5;241m.\u001b[39mfrom_training_step_output(training_step_output, trainer\u001b[38;5;241m.\u001b[39maccumulate_grad_batches)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:309\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 309\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 312\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:391\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:386\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule.training_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 373\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(x) \u001b[38;5;66;03m# [batch_size, session_max_len, n_items + 2]\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msoftmax\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 375\u001b[0m \u001b[38;5;66;03m# We are using CrossEntropyLoss with a multi-dimensional case\u001b[39;00m\n\u001b[1;32m 376\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;66;03m# Loss output will have a shape of [batch_size, session_max_len]\u001b[39;00m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;66;03m# and will have zeros for every `0` target label\u001b[39;00m\n\u001b[0;32m--> 386\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 387\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnone\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 388\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch_size, session_max_len]\u001b[39;00m\n\u001b[1;32m 389\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss \u001b[38;5;241m*\u001b[39m w\n\u001b[1;32m 390\u001b[0m n \u001b[38;5;241m=\u001b[39m (loss \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(loss\u001b[38;5;241m.\u001b[39mdtype)\n", + "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/nn/functional.py:3086\u001b[0m, in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[1;32m 3084\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3085\u001b[0m reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 3086\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 90.00 MiB. GPU " + ] + } + ], + "source": [ + "%%time\n", + "model.fit(dataset_no_features)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:181: UserWarning: 91202 target users were considered cold because of missing known items\n", + " warnings.warn(explanation)\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:406: UserWarning: \n", + " Model `` doesn't support recommendations for cold users,\n", + " but some of given users are cold: they are not in the `dataset.user_id_map`\n", + " \n", + " warnings.warn(explanation)\n", + "100%|██████████| 740/740 [00:15<00:00, 49.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 17min 5s, sys: 46.6 s, total: 17min 52s\n", + "Wall time: 34.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "recos = model.recommend(\n", + " users=test_users_sasrec, \n", + " dataset=dataset_item_features,\n", + " k=10,\n", + " filter_viewed=True,\n", + " on_unsupported_targets=\"warn\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "del interactions\n", + "del model\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "recos[\"item_id\"] = recos[\"item_id\"].apply(str)\n", + "test[\"item_id\"] = test[\"item_id\"].astype(str)\n", + "metric_values = calc_metrics(metrics, recos[[\"user_id\", \"item_id\", \"rank\"]], test, train, catalog)\n", + "metric_values[\"model\"] = \"bert4rec_ids\"\n", + "features_results.append(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
5755503138650.6415951
5755513152970.2407842
575552344950.0721063
575553378290.0463104
57555437102-0.1499295
...............
22495510975447102-0.3630396
22495610975444151-0.3999357
22495710975447793-0.4406678
22495810975444457-0.6528259
224959109754412995-0.70826310
\n", + "

947050 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "575550 3 13865 0.641595 1\n", + "575551 3 15297 0.240784 2\n", + "575552 3 4495 0.072106 3\n", + "575553 3 7829 0.046310 4\n", + "575554 3 7102 -0.149929 5\n", + "... ... ... ... ...\n", + "224955 1097544 7102 -0.363039 6\n", + "224956 1097544 4151 -0.399935 7\n", + "224957 1097544 7793 -0.440667 8\n", + "224958 1097544 4457 -0.652825 9\n", + "224959 1097544 12995 -0.708263 10\n", + "\n", + "[947050 rows x 4 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# major recommend\n", + "recos.sort_values([\"user_id\", \"rank\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With timeline mask in the end of the block, with attention mask" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'MAP@1': 0.03386095770656615,\n", + " 'MAP@5': 0.059875092311754766,\n", + " 'MAP@10': 0.06626564554123239,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.06777889234992873,\n", + " 'Serendipity@5': 0.04409114066936074,\n", + " 'Serendipity@10': 0.031205145274404236,\n", + " 'model': 'bert4rec_ids'}]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Without timeline mask, with attention mask" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'MAP@1': 0.031715044770102244,\n", + " 'MAP@5': 0.058107653322795036,\n", + " 'MAP@10': 0.06400667270068171,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.0633651866321736,\n", + " 'Serendipity@5': 0.04325255649454838,\n", + " 'Serendipity@10': 0.030283831925392017,\n", + " 'model': 'bert4rec_ids'}]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With timeline mask in the end of the block, whithout attention mask" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'MAP@1': 0.03521807589657321,\n", + " 'MAP@5': 0.0635501105108066,\n", + " 'MAP@10': 0.07042686574268418,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.07181247030251835,\n", + " 'Serendipity@5': 0.048066313492978796,\n", + " 'Serendipity@10': 0.03423476251676267,\n", + " 'model': 'bert4rec_ids'}]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With timeline mask, whithout attention mask, 5 and 7 epochs" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'MAP@1': 0.03521807589657321,\n", + " 'MAP@5': 0.0635501105108066,\n", + " 'MAP@10': 0.07042686574268418,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.07181247030251835,\n", + " 'Serendipity@5': 0.048066313492978796,\n", + " 'Serendipity@10': 0.03423476251676267,\n", + " 'model': 'bert4rec_ids'},\n", + " {'MAP@1': 0.03613885396129421,\n", + " 'MAP@5': 0.0626756506459862,\n", + " 'MAP@10': 0.06914741192133474,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.07439945092656143,\n", + " 'Serendipity@5': 0.04685051880216868,\n", + " 'Serendipity@10': 0.033803948060973046,\n", + " 'model': 'bert4rec_ids'}]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py new file mode 100644 index 00000000..7dd4351a --- /dev/null +++ b/rectools/models/bert4rec.py @@ -0,0 +1,578 @@ +import typing as tp +import warnings +from copy import deepcopy +from typing import List, Tuple + +import numpy as np +import torch +import tqdm +from pytorch_lightning import Trainer +from scipy import sparse +from torch import nn +from torch.utils.data import DataLoader + +from rectools import Columns, ExternalIds +from rectools.dataset import Dataset, Interactions +from rectools.dataset.features import SparseFeatures +from rectools.dataset.identifiers import IdMap +from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase +from rectools.models.rank import Distance, ImplicitRanker +from rectools.models.sasrec import ( + CatFeaturesItemNet, + IdEmbeddingsItemNet, + ItemNetBase, + ItemNetConstructor, + LearnableInversePositionalEncoding, + PositionalEncodingBase, + SequenceDataset, + SessionEncoderDataPreparatorBase, + SessionEncoderLightningModuleBase, + TransformerLayersBase, +) +from rectools.types import InternalIdsArray + +PADDING_VALUE = "PAD" +MASKING_VALUE = "MASK" + + +class BERT4RecDataPreparator(SessionEncoderDataPreparatorBase): + """TODO""" + + def __init__( + self, + session_max_len: int, + batch_size: int, + dataloader_num_workers: int, + train_min_user_interactions: int, + mask_prob: float, + item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE, MASKING_VALUE), + shuffle_train: bool = True, + ) -> None: + super().__init__() + self.session_max_len = session_max_len + self.batch_size = batch_size + self.dataloader_num_workers = dataloader_num_workers + self.train_min_user_interactions = train_min_user_interactions + self.item_extra_tokens = item_extra_tokens + self.mask_prob = mask_prob + self.shuffle_train = shuffle_train + # TODO: add SequenceDatasetType for fit and recommend + + def process_dataset_train(self, dataset: Dataset) -> Dataset: + """TODO""" + interactions = dataset.get_raw_interactions() + + # Filter interactions + user_stats = interactions[Columns.User].value_counts() + users = user_stats[user_stats >= self.train_min_user_interactions].index + interactions = interactions[(interactions[Columns.User].isin(users))] + interactions = interactions.sort_values(Columns.Datetime).groupby(Columns.User).tail(self.session_max_len) + + # Construct dataset + # TODO: user features are dropped for now + user_id_map = IdMap.from_values(interactions[Columns.User].values) + item_id_map = IdMap.from_values(self.item_extra_tokens) + item_id_map = item_id_map.add_ids(interactions[Columns.Item]) + + # get item features + item_features = None + if dataset.item_features is not None: + item_features = dataset.item_features + # TODO: remove assumption on SparseFeatures and add Dense Features support + if not isinstance(item_features, SparseFeatures): + raise ValueError("`item_features` in `dataset` must be `SparseFeatures` instance.") + + internal_ids = dataset.item_id_map.convert_to_internal( + item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :] + ) + sorted_item_features = item_features.take(internal_ids) + + dtype = sorted_item_features.values.dtype + n_features = sorted_item_features.values.shape[1] + extra_token_feature_values = sparse.csr_matrix((self.n_item_extra_tokens, n_features), dtype=dtype) + + full_feature_values: sparse.scr_matrix = sparse.vstack( + [extra_token_feature_values, sorted_item_features.values], format="csr" + ) + + item_features = SparseFeatures.from_iterables(values=full_feature_values, names=item_features.names) + + interactions = Interactions.from_raw(interactions, user_id_map, item_id_map) + + dataset = Dataset(user_id_map, item_id_map, interactions, item_features=item_features) + + self.item_id_map = dataset.item_id_map + return dataset + + def _mask_session(self, ses: List[int]) -> Tuple[List[int], List[int]]: + masked_session = ses.copy() + target = ses.copy() + random_probs = np.random.rand(len(ses)) + for j in range(len(ses)): + if random_probs[j] < self.mask_prob: + random_probs[j] /= self.mask_prob + if random_probs[j] < 0.8: + masked_session[j] = 1 + elif random_probs[j] < 0.9: + masked_session[j] = np.random.randint(low=2, high=self.item_id_map.size, size=1)[0] + else: + target[j] = 0 + return masked_session, target + + def _collate_fn_train( + self, + batch: List[Tuple[List[int], List[float]]], + ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: + """TODO""" + batch_size = len(batch) + x = np.zeros((batch_size, self.session_max_len)) + y = np.zeros((batch_size, self.session_max_len)) + yw = np.zeros((batch_size, self.session_max_len)) + for i, (ses, ses_weights) in enumerate(batch): + masked_session, target = self._mask_session(ses) + x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len] + y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len] + yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len] + + return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) + + def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: + """TODO""" + sequence_dataset = SequenceDataset.from_interactions(processed_dataset.interactions.df) + train_dataloader = DataLoader( + sequence_dataset, + collate_fn=self._collate_fn_train, + batch_size=self.batch_size, + num_workers=self.dataloader_num_workers, + shuffle=self.shuffle_train, + ) + return train_dataloader + + def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: + """ + Filter out interactions and adapt id maps. + Final dataset will consist only of model known items during fit and only of required + (and supported) target users for recommendations. + All users beyond target users for recommendations are dropped. + All target users that do not have at least one known item in interactions are dropped. + Final user_id_map is an enumerated list of supported (filtered) target users + Final item_id_map is model item_id_map constructed during training + """ + # Filter interactions in dataset internal ids + interactions = dataset.interactions.df + users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) + items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) + interactions = interactions[interactions[Columns.User].isin(users_internal)] # todo: fast_isin + interactions = interactions[interactions[Columns.Item].isin(items_internal)] + + # Convert to external ids + interactions[Columns.Item] = dataset.item_id_map.convert_to_external(interactions[Columns.Item]) + interactions[Columns.User] = dataset.user_id_map.convert_to_external(interactions[Columns.User]) + + # Prepare new user id mapping + rec_user_id_map = IdMap.from_values(interactions[Columns.User]) + + # Construct dataset + # TODO: For now features are dropped because model doesn't support them + n_filtered = len(users) - rec_user_id_map.size + if n_filtered > 0: + explanation = f"""{n_filtered} target users were considered cold because of missing known items""" + warnings.warn(explanation) + filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) + filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) + return filtered_dataset + + def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: + """ + Filter out interactions and adapt id maps. + Final dataset will consist only of model known items during fit. + Final user_id_map is the same as dataset original + Final item_id_map is model item_id_map constructed during training + """ + # TODO: optimize by filtering in internal ids + interactions = dataset.get_raw_interactions() + interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] + filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) + filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) + return filtered_dataset + + def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> torch.LongTensor: + """Right truncation, left padding to session_max_len""" + x = np.zeros((len(batch), self.session_max_len)) + for i, (ses, _) in enumerate(batch): + session = ses.copy() + session = session + [1] + x[i, -len(ses) :] = ses[-self.session_max_len :] + return torch.LongTensor(x) + + def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: + """TODO""" + sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) + recommend_dataloader = DataLoader( + sequence_dataset, + batch_size=self.batch_size, + collate_fn=self._collate_fn_recommend, + num_workers=self.dataloader_num_workers, + shuffle=False, + ) + return recommend_dataloader + + +class PointWiseFeedForward(nn.Module): + """TODO""" + + def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> None: + """TODO""" + super().__init__() + self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) + self.ff_gelu = torch.nn.GELU() + self.ff_dropout = torch.nn.Dropout(dropout_rate) + self.ff_linear2 = nn.Linear(n_factors_ff, n_factors) + + def forward(self, seqs: torch.Tensor) -> torch.Tensor: + """TODO""" + output = self.ff_gelu(self.ff_linear1(seqs)) + fin = self.ff_linear2(self.ff_dropout(output)) + return fin + + +class BERT4RecTransformerLayers(TransformerLayersBase): + """TODO""" + + def __init__( + self, + n_blocks: int, + n_factors: int, + n_heads: int, + dropout_rate: float, + ): + super().__init__() + self.n_blocks = n_blocks + self.multi_head_attn = nn.ModuleList( + [nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)] + ) + self.layer_norm1 = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.dropout1 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) + self.layer_norm2 = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.feed_forward = nn.ModuleList( + [PointWiseFeedForward(n_factors, n_factors * 4, dropout_rate) for _ in range(n_blocks)] + ) + self.dropout2 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) + # self.dropout3 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) + + def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """TODO""" + for i in range(self.n_blocks): + mha_input = self.layer_norm1[i](seqs) + # mha_output, _ = + # self.multi_head_attn[i](mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False) + mha_output, _ = self.multi_head_attn[i](mha_input, mha_input, mha_input, need_weights=False) + seqs = seqs + self.dropout1[i](mha_output) + ff_input = self.layer_norm2[i](seqs) + ff_output = self.feed_forward[i](ff_input) + seqs = seqs + self.dropout2[i](ff_output) + seqs = seqs * timeline_mask + # seqs = self.dropout3[i](seqs) + + return seqs + + +# #### -------------- Session Encoder -------------- #### # + + +class TransformerBasedSessionEncoder(torch.nn.Module): + """TODO""" + + def __init__( + self, + n_blocks: int, + n_factors: int, + n_heads: int, + session_max_len: int, + dropout_rate: float, + use_pos_emb: bool = True, + use_causal_attn: bool = True, + transformer_layers_type: tp.Type[TransformerLayersBase] = BERT4RecTransformerLayers, + item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), + pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, + ) -> None: + super().__init__() + + self.item_model: ItemNetConstructor + self.pos_encoding = pos_encoding_type(use_pos_emb, session_max_len, n_factors) + self.emb_dropout = torch.nn.Dropout(dropout_rate) + self.transformer_layers = transformer_layers_type( + n_blocks=n_blocks, + n_factors=n_factors, + n_heads=n_heads, + dropout_rate=dropout_rate, + ) + self.use_causal_attn = use_causal_attn + self.n_factors = n_factors + self.dropout_rate = dropout_rate + self.n_heads = n_heads + + self.item_net_block_types = item_net_block_types + + def construct_item_net(self, dataset: Dataset) -> None: + """TODO""" + self.item_model = ItemNetConstructor.from_dataset( + dataset, self.n_factors, self.dropout_rate, self.item_net_block_types + ) + + def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + """ + Pass user history through item embeddings and transformer blocks. + + Returns + ------- + torch.Tensor. [batch_size, session_max_len, n_factors] + + """ + session_max_len = sessions.shape[1] + attn_mask = None + if self.use_causal_attn: + attn_mask = ~torch.tril( + torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) + ) + timeline_mask = sessions != 0 + attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1) + timeline_mask = timeline_mask.unsqueeze(-1) + seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] + seqs = self.pos_encoding(seqs, timeline_mask) + seqs = self.emb_dropout(seqs) + seqs = self.transformer_layers(seqs, timeline_mask, attn_mask) + return seqs + + def forward( + self, + sessions: torch.Tensor, # [batch_size, session_max_len] + ) -> torch.Tensor: + """TODO""" + item_embs = self.item_model.get_all_embeddings() # [n_items + 2, n_factors] + session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] + logits = session_embs @ item_embs.T # [batch_size, session_max_len, n_items + 2] + return logits + + +class SessionEncoderLightningModule(SessionEncoderLightningModuleBase): + """TODO""" + + def on_train_start(self) -> None: + """TODO""" + self._truncated_normal_init() + + def configure_optimizers(self) -> torch.optim.Adam: + """TODO""" + optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) + return optimizer + + def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + """TODO""" + x, y, w = batch + logits = self.forward(x) # [batch_size, session_max_len, n_items + 2] + if self.loss == "softmax": + # We are using CrossEntropyLoss with a multi-dimensional case + + # Logits must be passed in form of [batch_size, n_items + 2, session_max_len], + # where n_items + 2 is number of classes + + # Target label indexes must be passed in a form of [batch_size, session_max_len] + # (`0` index for "PAD" ix excluded from loss) + + # Loss output will have a shape of [batch_size, session_max_len] + # and will have zeros for every `0` target label + + loss = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), y, ignore_index=0, reduction="none" + ) # [batch_size, session_max_len] + loss = loss * w + n = (loss > 0).to(loss.dtype) + loss = torch.sum(loss) / torch.sum(n) + return loss + raise ValueError(f"loss {loss} is not supported") + + def _truncated_normal_init(self) -> None: + """TODO""" + for _, param in self.torch_model.named_parameters(): + try: + torch.nn.init.trunc_normal_(param.data) + except ValueError: + pass + + +class BERT4RecModel(ModelBase): + """TODO""" + + def __init__( # pylint: disable=too-many-arguments, too-many-locals + self, + n_blocks: int = 1, + n_heads: int = 1, + n_factors: int = 128, + use_pos_emb: bool = True, + dropout_rate: float = 0.2, + epochs: int = 3, + verbose: int = 0, + deterministic: bool = False, + cpu_n_threads: int = 0, + session_max_len: int = 32, + batch_size: int = 128, + loss: str = "softmax", + lr: float = 0.01, + dataloader_num_workers: int = 0, + train_min_user_interaction: int = 2, + mask_prob: float = 0.15, + trainer: tp.Optional[Trainer] = None, + item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), + pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, + transformer_layers_type: tp.Type[TransformerLayersBase] = BERT4RecTransformerLayers, + data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = BERT4RecDataPreparator, + lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, + device: str = "cpu", # TODO: remove + ): + super().__init__(verbose=verbose) + self.data_preparator = data_preparator_type( + session_max_len=session_max_len, + batch_size=batch_size, + dataloader_num_workers=dataloader_num_workers, + train_min_user_interactions=train_min_user_interaction, + mask_prob=mask_prob, + ) + self.torch_model: TransformerBasedSessionEncoder + self._torch_model = TransformerBasedSessionEncoder( + n_blocks=n_blocks, + n_factors=n_factors, + n_heads=n_heads, + session_max_len=session_max_len, + dropout_rate=dropout_rate, + use_pos_emb=use_pos_emb, + use_causal_attn=False, + transformer_layers_type=transformer_layers_type, + item_net_block_types=item_net_block_types, + pos_encoding_type=pos_encoding_type, + ) + self.lightning_module_type = lightning_module_type + self.trainer: Trainer + if trainer is None: + self._trainer = Trainer( + max_epochs=epochs, + min_epochs=epochs, + deterministic=deterministic, + enable_progress_bar=verbose > 0, + enable_model_summary=verbose > 0, + logger=verbose > 0, + ) + else: + self._trainer = trainer + self.lr = lr + self.loss = loss + self.n_threads = cpu_n_threads + self.u2i_dist = Distance.DOT + self.i2i_dist = Distance.COSINE + self.device = torch.device(device) # TODO: remove + + def _fit( + self, + dataset: Dataset, + ) -> None: + processed_dataset = self.data_preparator.process_dataset_train(dataset) + train_dataloader = self.data_preparator.get_dataloader_train(processed_dataset) + self.torch_model = deepcopy(self._torch_model) # TODO: check that it works + self.torch_model.construct_item_net(processed_dataset) + + lightning_model = self.lightning_module_type(self.torch_model, self.lr, self.loss) + self.trainer = deepcopy(self._trainer) + self.trainer.fit(lightning_model, train_dataloader) + + def _custom_transform_dataset_u2i( + self, dataset: Dataset, users: ExternalIds, on_unsupported_targets: ErrorBehaviour + ) -> Dataset: + return self.data_preparator.transform_dataset_u2i(dataset, users) + + def _custom_transform_dataset_i2i( + self, dataset: Dataset, target_items: ExternalIds, on_unsupported_targets: ErrorBehaviour + ) -> Dataset: + return self.data_preparator.transform_dataset_i2i(dataset) + + def _recommend_u2i( + self, + user_ids: InternalIdsArray, + dataset: Dataset, # [n_rec_users x n_items + 2] + k: int, + filter_viewed: bool, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal + ) -> InternalRecoTriplet: + if sorted_item_ids_to_recommend is None: # TODO: move to _get_sorted_item_ids_to_recommend + sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal + + self.torch_model = self.torch_model.eval() + self.torch_model.to(self.device) + + # Dataset has already been filtered and adapted to known item_id_map + recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) + + session_embs = [] + item_embs = self.torch_model.item_model.get_all_embeddings() # [n_items + 2, n_factors] + with torch.no_grad(): + for x_batch in tqdm.tqdm(recommend_dataloader): # TODO: from tqdm.auto import tqdm. Also check `verbose`` + x_batch = x_batch.to(self.device) # [batch_size, session_max_len] + encoded = self.torch_model.encode_sessions(x_batch, item_embs)[:, -1, :] # [batch_size, n_factors] + encoded = encoded.detach().cpu().numpy() + session_embs.append(encoded) + + user_embs = np.concatenate(session_embs, axis=0) + user_embs = user_embs[user_ids] + item_embs_np = item_embs.detach().cpu().numpy() + + ranker = ImplicitRanker( + self.u2i_dist, + user_embs, # [n_rec_users, n_factors] + item_embs_np, # [n_items + 2, n_factors] + ) + if filter_viewed: + user_items = dataset.get_user_item_matrix(include_weights=False) + ui_csr_for_filter = user_items[user_ids] + else: + ui_csr_for_filter = None + + # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster + + user_ids_indices, all_reco_ids, all_scores = ranker.rank( + subject_ids=np.arange(user_embs.shape[0]), # n_rec_users + k=k, + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 2] + sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + num_threads=self.n_threads, + ) + all_target_ids = user_ids[user_ids_indices] + + return all_target_ids, all_reco_ids, all_scores # n_rec_users, model_internal, scores + + def _recommend_i2i( + self, + target_ids: InternalIdsArray, # model internal + dataset: Dataset, + k: int, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + if sorted_item_ids_to_recommend is None: + sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() + + self.torch_model = self.torch_model.eval() + item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 2, n_factors] + + # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU + # Should we use torch dot and topk? Should be faster + + ranker = ImplicitRanker( + self.i2i_dist, + item_embs, # [n_items + 2, n_factors] + item_embs, # [n_items + 2, n_factors] + ) + return ranker.rank( + subject_ids=target_ids, # model internal + k=k, + filter_pairs_csr=None, + sorted_object_whitelist=sorted_item_ids_to_recommend, # model internal + num_threads=0, + ) diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index 9bc2ba28..b15f6eeb 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -471,23 +471,10 @@ def from_interactions( class SessionEncoderDataPreparatorBase: """Base class for data preparator. Used only for type hinting.""" - def __init__( - self, - session_max_len: int, - batch_size: int, - dataloader_num_workers: int, - item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,), - shuffle_train: bool = True, # not shuffling train dataloader hurts performance - train_min_user_interactions: int = 2, - ) -> None: - self.session_max_len = session_max_len - self.batch_size = batch_size - self.dataloader_num_workers = dataloader_num_workers - self.item_extra_tokens = item_extra_tokens - self.shuffle_train = shuffle_train - self.train_min_user_interactions = train_min_user_interactions + def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: + """TODO""" self.item_id_map: IdMap - # TODO: add SequenceDatasetType for fit and recommend + self.item_extra_tokens: tp.Sequence[tp.Hashable] def get_known_items_sorted_internal_ids(self) -> np.ndarray: """TODO""" @@ -526,6 +513,23 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: class SASRecDataPreparator(SessionEncoderDataPreparatorBase): """TODO""" + def __init__( + self, + session_max_len: int, + batch_size: int, + dataloader_num_workers: int, + item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,), + shuffle_train: bool = True, # not shuffling train dataloader hurts performance + train_min_user_interactions: int = 2, + ) -> None: + super().__init__ + self.session_max_len = session_max_len + self.batch_size = batch_size + self.dataloader_num_workers = dataloader_num_workers + self.item_extra_tokens = item_extra_tokens + self.shuffle_train = shuffle_train + self.train_min_user_interactions = train_min_user_interactions + def process_dataset_train(self, dataset: Dataset) -> Dataset: """TODO""" interactions = dataset.get_raw_interactions() From a4438ab8a4397e66e6eeb586114e842bbd1e0b79 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Sat, 2 Nov 2024 14:18:29 +0300 Subject: [PATCH 2/8] refactored classes --- rectools/models/bert4rec.py | 444 +++--------------------------------- rectools/models/sasrec.py | 289 ++++++++++++----------- 2 files changed, 188 insertions(+), 545 deletions(-) diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py index 7dd4351a..f593ed0e 100644 --- a/rectools/models/bert4rec.py +++ b/rectools/models/bert4rec.py @@ -1,35 +1,23 @@ import typing as tp -import warnings -from copy import deepcopy from typing import List, Tuple import numpy as np import torch -import tqdm from pytorch_lightning import Trainer -from scipy import sparse from torch import nn -from torch.utils.data import DataLoader -from rectools import Columns, ExternalIds -from rectools.dataset import Dataset, Interactions -from rectools.dataset.features import SparseFeatures -from rectools.dataset.identifiers import IdMap -from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase -from rectools.models.rank import Distance, ImplicitRanker from rectools.models.sasrec import ( CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase, - ItemNetConstructor, LearnableInversePositionalEncoding, PositionalEncodingBase, - SequenceDataset, SessionEncoderDataPreparatorBase, + SessionEncoderLightningModule, SessionEncoderLightningModuleBase, TransformerLayersBase, + TransformerModelBase, ) -from rectools.types import InternalIdsArray PADDING_VALUE = "PAD" MASKING_VALUE = "MASK" @@ -45,64 +33,18 @@ def __init__( dataloader_num_workers: int, train_min_user_interactions: int, mask_prob: float, - item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE, MASKING_VALUE), + item_extra_tokens: tp.Sequence[tp.Hashable], shuffle_train: bool = True, ) -> None: - super().__init__() - self.session_max_len = session_max_len - self.batch_size = batch_size - self.dataloader_num_workers = dataloader_num_workers - self.train_min_user_interactions = train_min_user_interactions - self.item_extra_tokens = item_extra_tokens + super().__init__( + session_max_len=session_max_len, + batch_size=batch_size, + dataloader_num_workers=dataloader_num_workers, + train_min_user_interactions=train_min_user_interactions, + item_extra_tokens=item_extra_tokens, + shuffle_train=shuffle_train, + ) self.mask_prob = mask_prob - self.shuffle_train = shuffle_train - # TODO: add SequenceDatasetType for fit and recommend - - def process_dataset_train(self, dataset: Dataset) -> Dataset: - """TODO""" - interactions = dataset.get_raw_interactions() - - # Filter interactions - user_stats = interactions[Columns.User].value_counts() - users = user_stats[user_stats >= self.train_min_user_interactions].index - interactions = interactions[(interactions[Columns.User].isin(users))] - interactions = interactions.sort_values(Columns.Datetime).groupby(Columns.User).tail(self.session_max_len) - - # Construct dataset - # TODO: user features are dropped for now - user_id_map = IdMap.from_values(interactions[Columns.User].values) - item_id_map = IdMap.from_values(self.item_extra_tokens) - item_id_map = item_id_map.add_ids(interactions[Columns.Item]) - - # get item features - item_features = None - if dataset.item_features is not None: - item_features = dataset.item_features - # TODO: remove assumption on SparseFeatures and add Dense Features support - if not isinstance(item_features, SparseFeatures): - raise ValueError("`item_features` in `dataset` must be `SparseFeatures` instance.") - - internal_ids = dataset.item_id_map.convert_to_internal( - item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :] - ) - sorted_item_features = item_features.take(internal_ids) - - dtype = sorted_item_features.values.dtype - n_features = sorted_item_features.values.shape[1] - extra_token_feature_values = sparse.csr_matrix((self.n_item_extra_tokens, n_features), dtype=dtype) - - full_feature_values: sparse.scr_matrix = sparse.vstack( - [extra_token_feature_values, sorted_item_features.values], format="csr" - ) - - item_features = SparseFeatures.from_iterables(values=full_feature_values, names=item_features.names) - - interactions = Interactions.from_raw(interactions, user_id_map, item_id_map) - - dataset = Dataset(user_id_map, item_id_map, interactions, item_features=item_features) - - self.item_id_map = dataset.item_id_map - return dataset def _mask_session(self, ses: List[int]) -> Tuple[List[int], List[int]]: masked_session = ses.copy() @@ -136,66 +78,6 @@ def _collate_fn_train( return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) - def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: - """TODO""" - sequence_dataset = SequenceDataset.from_interactions(processed_dataset.interactions.df) - train_dataloader = DataLoader( - sequence_dataset, - collate_fn=self._collate_fn_train, - batch_size=self.batch_size, - num_workers=self.dataloader_num_workers, - shuffle=self.shuffle_train, - ) - return train_dataloader - - def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: - """ - Filter out interactions and adapt id maps. - Final dataset will consist only of model known items during fit and only of required - (and supported) target users for recommendations. - All users beyond target users for recommendations are dropped. - All target users that do not have at least one known item in interactions are dropped. - Final user_id_map is an enumerated list of supported (filtered) target users - Final item_id_map is model item_id_map constructed during training - """ - # Filter interactions in dataset internal ids - interactions = dataset.interactions.df - users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) - items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) - interactions = interactions[interactions[Columns.User].isin(users_internal)] # todo: fast_isin - interactions = interactions[interactions[Columns.Item].isin(items_internal)] - - # Convert to external ids - interactions[Columns.Item] = dataset.item_id_map.convert_to_external(interactions[Columns.Item]) - interactions[Columns.User] = dataset.user_id_map.convert_to_external(interactions[Columns.User]) - - # Prepare new user id mapping - rec_user_id_map = IdMap.from_values(interactions[Columns.User]) - - # Construct dataset - # TODO: For now features are dropped because model doesn't support them - n_filtered = len(users) - rec_user_id_map.size - if n_filtered > 0: - explanation = f"""{n_filtered} target users were considered cold because of missing known items""" - warnings.warn(explanation) - filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) - filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) - return filtered_dataset - - def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: - """ - Filter out interactions and adapt id maps. - Final dataset will consist only of model known items during fit. - Final user_id_map is the same as dataset original - Final item_id_map is model item_id_map constructed during training - """ - # TODO: optimize by filtering in internal ids - interactions = dataset.get_raw_interactions() - interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] - filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) - filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) - return filtered_dataset - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> torch.LongTensor: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) @@ -205,18 +87,6 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t x[i, -len(ses) :] = ses[-self.session_max_len :] return torch.LongTensor(x) - def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: - """TODO""" - sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) - recommend_dataloader = DataLoader( - sequence_dataset, - batch_size=self.batch_size, - collate_fn=self._collate_fn_recommend, - num_workers=self.dataloader_num_workers, - shuffle=False, - ) - return recommend_dataloader - class PointWiseFeedForward(nn.Module): """TODO""" @@ -277,131 +147,7 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to return seqs -# #### -------------- Session Encoder -------------- #### # - - -class TransformerBasedSessionEncoder(torch.nn.Module): - """TODO""" - - def __init__( - self, - n_blocks: int, - n_factors: int, - n_heads: int, - session_max_len: int, - dropout_rate: float, - use_pos_emb: bool = True, - use_causal_attn: bool = True, - transformer_layers_type: tp.Type[TransformerLayersBase] = BERT4RecTransformerLayers, - item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), - pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, - ) -> None: - super().__init__() - - self.item_model: ItemNetConstructor - self.pos_encoding = pos_encoding_type(use_pos_emb, session_max_len, n_factors) - self.emb_dropout = torch.nn.Dropout(dropout_rate) - self.transformer_layers = transformer_layers_type( - n_blocks=n_blocks, - n_factors=n_factors, - n_heads=n_heads, - dropout_rate=dropout_rate, - ) - self.use_causal_attn = use_causal_attn - self.n_factors = n_factors - self.dropout_rate = dropout_rate - self.n_heads = n_heads - - self.item_net_block_types = item_net_block_types - - def construct_item_net(self, dataset: Dataset) -> None: - """TODO""" - self.item_model = ItemNetConstructor.from_dataset( - dataset, self.n_factors, self.dropout_rate, self.item_net_block_types - ) - - def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - """ - Pass user history through item embeddings and transformer blocks. - - Returns - ------- - torch.Tensor. [batch_size, session_max_len, n_factors] - - """ - session_max_len = sessions.shape[1] - attn_mask = None - if self.use_causal_attn: - attn_mask = ~torch.tril( - torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) - ) - timeline_mask = sessions != 0 - attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1) - timeline_mask = timeline_mask.unsqueeze(-1) - seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] - seqs = self.pos_encoding(seqs, timeline_mask) - seqs = self.emb_dropout(seqs) - seqs = self.transformer_layers(seqs, timeline_mask, attn_mask) - return seqs - - def forward( - self, - sessions: torch.Tensor, # [batch_size, session_max_len] - ) -> torch.Tensor: - """TODO""" - item_embs = self.item_model.get_all_embeddings() # [n_items + 2, n_factors] - session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - logits = session_embs @ item_embs.T # [batch_size, session_max_len, n_items + 2] - return logits - - -class SessionEncoderLightningModule(SessionEncoderLightningModuleBase): - """TODO""" - - def on_train_start(self) -> None: - """TODO""" - self._truncated_normal_init() - - def configure_optimizers(self) -> torch.optim.Adam: - """TODO""" - optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) - return optimizer - - def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: - """TODO""" - x, y, w = batch - logits = self.forward(x) # [batch_size, session_max_len, n_items + 2] - if self.loss == "softmax": - # We are using CrossEntropyLoss with a multi-dimensional case - - # Logits must be passed in form of [batch_size, n_items + 2, session_max_len], - # where n_items + 2 is number of classes - - # Target label indexes must be passed in a form of [batch_size, session_max_len] - # (`0` index for "PAD" ix excluded from loss) - - # Loss output will have a shape of [batch_size, session_max_len] - # and will have zeros for every `0` target label - - loss = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), y, ignore_index=0, reduction="none" - ) # [batch_size, session_max_len] - loss = loss * w - n = (loss > 0).to(loss.dtype) - loss = torch.sum(loss) / torch.sum(n) - return loss - raise ValueError(f"loss {loss} is not supported") - - def _truncated_normal_init(self) -> None: - """TODO""" - for _, param in self.torch_model.named_parameters(): - try: - torch.nn.init.trunc_normal_(param.data) - except ValueError: - pass - - -class BERT4RecModel(ModelBase): +class BERT4RecModel(TransformerModelBase): """TODO""" def __init__( # pylint: disable=too-many-arguments, too-many-locals @@ -410,6 +156,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals n_heads: int = 1, n_factors: int = 128, use_pos_emb: bool = True, + use_causal_attn: bool = False, + use_mlm_attn: bool = True, dropout_rate: float = 0.2, epochs: int = 3, verbose: int = 0, @@ -426,153 +174,37 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, transformer_layers_type: tp.Type[TransformerLayersBase] = BERT4RecTransformerLayers, - data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = BERT4RecDataPreparator, + data_preparator_type: tp.Type[BERT4RecDataPreparator] = BERT4RecDataPreparator, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, device: str = "cpu", # TODO: remove ): - super().__init__(verbose=verbose) - self.data_preparator = data_preparator_type( - session_max_len=session_max_len, - batch_size=batch_size, - dataloader_num_workers=dataloader_num_workers, - train_min_user_interactions=train_min_user_interaction, - mask_prob=mask_prob, - ) - self.torch_model: TransformerBasedSessionEncoder - self._torch_model = TransformerBasedSessionEncoder( + super().__init__( + transformer_layers_type=transformer_layers_type, n_blocks=n_blocks, - n_factors=n_factors, n_heads=n_heads, - session_max_len=session_max_len, - dropout_rate=dropout_rate, + n_factors=n_factors, use_pos_emb=use_pos_emb, - use_causal_attn=False, - transformer_layers_type=transformer_layers_type, + use_causal_attn=use_causal_attn, + use_mlm_attn=use_mlm_attn, + dropout_rate=dropout_rate, + epochs=epochs, + verbose=verbose, + deterministic=deterministic, + cpu_n_threads=cpu_n_threads, + loss=loss, + lr=lr, + session_max_len=session_max_len, + trainer=trainer, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, + lightning_module_type=lightning_module_type, + device=device, # TODO: remove ) - self.lightning_module_type = lightning_module_type - self.trainer: Trainer - if trainer is None: - self._trainer = Trainer( - max_epochs=epochs, - min_epochs=epochs, - deterministic=deterministic, - enable_progress_bar=verbose > 0, - enable_model_summary=verbose > 0, - logger=verbose > 0, - ) - else: - self._trainer = trainer - self.lr = lr - self.loss = loss - self.n_threads = cpu_n_threads - self.u2i_dist = Distance.DOT - self.i2i_dist = Distance.COSINE - self.device = torch.device(device) # TODO: remove - - def _fit( - self, - dataset: Dataset, - ) -> None: - processed_dataset = self.data_preparator.process_dataset_train(dataset) - train_dataloader = self.data_preparator.get_dataloader_train(processed_dataset) - self.torch_model = deepcopy(self._torch_model) # TODO: check that it works - self.torch_model.construct_item_net(processed_dataset) - - lightning_model = self.lightning_module_type(self.torch_model, self.lr, self.loss) - self.trainer = deepcopy(self._trainer) - self.trainer.fit(lightning_model, train_dataloader) - - def _custom_transform_dataset_u2i( - self, dataset: Dataset, users: ExternalIds, on_unsupported_targets: ErrorBehaviour - ) -> Dataset: - return self.data_preparator.transform_dataset_u2i(dataset, users) - - def _custom_transform_dataset_i2i( - self, dataset: Dataset, target_items: ExternalIds, on_unsupported_targets: ErrorBehaviour - ) -> Dataset: - return self.data_preparator.transform_dataset_i2i(dataset) - - def _recommend_u2i( - self, - user_ids: InternalIdsArray, - dataset: Dataset, # [n_rec_users x n_items + 2] - k: int, - filter_viewed: bool, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal - ) -> InternalRecoTriplet: - if sorted_item_ids_to_recommend is None: # TODO: move to _get_sorted_item_ids_to_recommend - sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal - - self.torch_model = self.torch_model.eval() - self.torch_model.to(self.device) - - # Dataset has already been filtered and adapted to known item_id_map - recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) - - session_embs = [] - item_embs = self.torch_model.item_model.get_all_embeddings() # [n_items + 2, n_factors] - with torch.no_grad(): - for x_batch in tqdm.tqdm(recommend_dataloader): # TODO: from tqdm.auto import tqdm. Also check `verbose`` - x_batch = x_batch.to(self.device) # [batch_size, session_max_len] - encoded = self.torch_model.encode_sessions(x_batch, item_embs)[:, -1, :] # [batch_size, n_factors] - encoded = encoded.detach().cpu().numpy() - session_embs.append(encoded) - - user_embs = np.concatenate(session_embs, axis=0) - user_embs = user_embs[user_ids] - item_embs_np = item_embs.detach().cpu().numpy() - - ranker = ImplicitRanker( - self.u2i_dist, - user_embs, # [n_rec_users, n_factors] - item_embs_np, # [n_items + 2, n_factors] - ) - if filter_viewed: - user_items = dataset.get_user_item_matrix(include_weights=False) - ui_csr_for_filter = user_items[user_ids] - else: - ui_csr_for_filter = None - - # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster - - user_ids_indices, all_reco_ids, all_scores = ranker.rank( - subject_ids=np.arange(user_embs.shape[0]), # n_rec_users - k=k, - filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 2] - sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal - num_threads=self.n_threads, - ) - all_target_ids = user_ids[user_ids_indices] - - return all_target_ids, all_reco_ids, all_scores # n_rec_users, model_internal, scores - - def _recommend_i2i( - self, - target_ids: InternalIdsArray, # model internal - dataset: Dataset, - k: int, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], - ) -> InternalRecoTriplet: - if sorted_item_ids_to_recommend is None: - sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() - - self.torch_model = self.torch_model.eval() - item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 2, n_factors] - - # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU - # Should we use torch dot and topk? Should be faster - - ranker = ImplicitRanker( - self.i2i_dist, - item_embs, # [n_items + 2, n_factors] - item_embs, # [n_items + 2, n_factors] - ) - return ranker.rank( - subject_ids=target_ids, # model internal - k=k, - filter_pairs_csr=None, - sorted_object_whitelist=sorted_item_ids_to_recommend, # model internal - num_threads=0, + self.data_preparator = data_preparator_type( + session_max_len=session_max_len, + batch_size=batch_size, + dataloader_num_workers=dataloader_num_workers, + train_min_user_interactions=train_min_user_interaction, + item_extra_tokens=(PADDING_VALUE, MASKING_VALUE), + mask_prob=mask_prob, ) diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index b15f6eeb..c56bf60e 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -287,49 +287,6 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to return seqs -class PreLNTransformerLayers(TransformerLayersBase): - """ - Based on https://arxiv.org/pdf/2002.04745 - On Kion open dataset didn't change metrics, even got a bit worse - But let's keep it for now - """ - - def __init__( - self, - n_blocks: int, - n_factors: int, - n_heads: int, - dropout_rate: float, - ): - super().__init__() - self.n_blocks = n_blocks - self.multi_head_attn = nn.ModuleList( - [torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)] - ) - self.mha_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) - self.mha_dropout = nn.Dropout(dropout_rate) - self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) - self.feed_forward = nn.ModuleList( - [PointWiseFeedForward(n_factors, n_factors, dropout_rate) for _ in range(n_blocks)] - ) - - def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """TODO""" - for i in range(self.n_blocks): - mha_input = self.mha_layer_norm[i](seqs) - mha_output, _ = self.multi_head_attn[i]( - mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False - ) - mha_output = self.mha_dropout(mha_output) - seqs = seqs + mha_output - ff_input = self.ff_layer_norm[i](seqs) - ff_output = self.feed_forward[i](ff_input) - seqs = seqs + ff_output - seqs *= timeline_mask - - return seqs - - class LearnableInversePositionalEncoding(PositionalEncodingBase): """TODO""" @@ -371,6 +328,7 @@ def __init__( dropout_rate: float, use_pos_emb: bool = True, use_causal_attn: bool = True, + use_mlm_attn: bool = False, transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, @@ -387,8 +345,10 @@ def __init__( dropout_rate=dropout_rate, ) self.use_causal_attn = use_causal_attn + self.use_mlm_attn = use_mlm_attn self.n_factors = n_factors self.dropout_rate = dropout_rate + self.n_heads = n_heads self.item_net_block_types = item_net_block_types @@ -413,6 +373,10 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to attn_mask = ~torch.tril( torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) ) + if self.use_mlm_attn: + timeline_mask = sessions != 0 + attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1) + timeline_mask = timeline_mask.unsqueeze(-1) timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1] seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] seqs = self.pos_encoding(seqs, timeline_mask) @@ -471,10 +435,23 @@ def from_interactions( class SessionEncoderDataPreparatorBase: """Base class for data preparator. Used only for type hinting.""" - def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: + def __init__( + self, + session_max_len: int, + batch_size: int, + dataloader_num_workers: int, + item_extra_tokens: tp.Sequence[tp.Hashable], + train_min_user_interactions: int, + shuffle_train: bool = True, + ) -> None: """TODO""" self.item_id_map: IdMap - self.item_extra_tokens: tp.Sequence[tp.Hashable] + self.session_max_len = session_max_len + self.batch_size = batch_size + self.dataloader_num_workers = dataloader_num_workers + self.train_min_user_interactions = train_min_user_interactions + self.item_extra_tokens = item_extra_tokens + self.shuffle_train = shuffle_train def get_known_items_sorted_internal_ids(self) -> np.ndarray: """TODO""" @@ -489,47 +466,6 @@ def n_item_extra_tokens(self) -> int: """TODO""" return len(self.item_extra_tokens) - def process_dataset_train(self, dataset: Dataset) -> Dataset: - """TODO""" - raise NotImplementedError() - - def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: - """TODO""" - raise NotImplementedError() - - def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: - """TODO""" - raise NotImplementedError() - - def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: - """TODO""" - raise NotImplementedError() - - def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: - """TODO""" - raise NotImplementedError() - - -class SASRecDataPreparator(SessionEncoderDataPreparatorBase): - """TODO""" - - def __init__( - self, - session_max_len: int, - batch_size: int, - dataloader_num_workers: int, - item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,), - shuffle_train: bool = True, # not shuffling train dataloader hurts performance - train_min_user_interactions: int = 2, - ) -> None: - super().__init__ - self.session_max_len = session_max_len - self.batch_size = batch_size - self.dataloader_num_workers = dataloader_num_workers - self.item_extra_tokens = item_extra_tokens - self.shuffle_train = shuffle_train - self.train_min_user_interactions = train_min_user_interactions - def process_dataset_train(self, dataset: Dataset) -> Dataset: """TODO""" interactions = dataset.get_raw_interactions() @@ -576,25 +512,6 @@ def process_dataset_train(self, dataset: Dataset) -> Dataset: self.item_id_map = dataset.item_id_map return dataset - def _collate_fn_train( - self, - batch: List[Tuple[List[int], List[float]]], - ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: - """ - Truncate each session from right to keep (session_max_len+1) last items. - Do left padding until (session_max_len+1) is reached. - Split to `x`, `y`, and `yw`. - """ - batch_size = len(batch) - x = np.zeros((batch_size, self.session_max_len)) - y = np.zeros((batch_size, self.session_max_len)) - yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): - x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] - y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] - yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] - return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) - def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: """TODO""" sequence_dataset = SequenceDataset.from_interactions(processed_dataset.interactions.df) @@ -607,6 +524,18 @@ def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: ) return train_dataloader + def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: + """TODO""" + sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) + recommend_dataloader = DataLoader( + sequence_dataset, + batch_size=self.batch_size, + collate_fn=self._collate_fn_recommend, + num_workers=self.dataloader_num_workers, + shuffle=False, + ) + return recommend_dataloader + def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: """ Filter out interactions and adapt id maps. @@ -649,13 +578,49 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: Final item_id_map is model item_id_map constructed during training """ # TODO: optimize by filtering in internal ids - # TODO: For now features are dropped because model doesn't support them interactions = dataset.get_raw_interactions() interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset + def _collate_fn_train( + self, + batch: List[Tuple[List[int], List[float]]], + ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: + """TODO""" + raise NotImplementedError() + + def _collate_fn_recommend( + self, + batch: List[Tuple[List[int], List[float]]], + ) -> torch.LongTensor: + """TODO""" + raise NotImplementedError() + + +class SASRecDataPreparator(SessionEncoderDataPreparatorBase): + """TODO""" + + def _collate_fn_train( + self, + batch: List[Tuple[List[int], List[float]]], + ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: + """ + Truncate each session from right to keep (session_max_len+1) last items. + Do left padding until (session_max_len+1) is reached. + Split to `x`, `y`, and `yw`. + """ + batch_size = len(batch) + x = np.zeros((batch_size, self.session_max_len)) + y = np.zeros((batch_size, self.session_max_len)) + yw = np.zeros((batch_size, self.session_max_len)) + for i, (ses, ses_weights) in enumerate(batch): + x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] + y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] + yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] + return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) + def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> torch.LongTensor: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) @@ -663,18 +628,6 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t x[i, -len(ses) :] = ses[-self.session_max_len :] return torch.LongTensor(x) - def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: - """TODO""" - sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) - recommend_dataloader = DataLoader( - sequence_dataset, - batch_size=self.batch_size, - collate_fn=self._collate_fn_recommend, - num_workers=self.dataloader_num_workers, - shuffle=False, - ) - return recommend_dataloader - # #### -------------- Lightning Model -------------- #### # @@ -753,22 +706,20 @@ def _xavier_normal_init(self) -> None: pass -# #### -------------- SASRec Model -------------- #### # - - -class SASRecModel(ModelBase): +class TransformerModelBase(ModelBase): """TODO""" def __init__( # pylint: disable=too-many-arguments self, + transformer_layers_type: tp.Type[TransformerLayersBase], n_blocks: int = 1, n_heads: int = 1, n_factors: int = 128, use_pos_emb: bool = True, + use_causal_attn: bool = True, + use_mlm_attn: bool = False, dropout_rate: float = 0.2, session_max_len: int = 32, - dataloader_num_workers: int = 0, - batch_size: int = 128, loss: str = "softmax", lr: float = 0.01, epochs: int = 3, @@ -779,11 +730,9 @@ def __init__( # pylint: disable=too-many-arguments trainer: tp.Optional[Trainer] = None, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, - transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net - data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, - ): - super().__init__(verbose=verbose) + ) -> None: + super().__init__(verbose) self.device = torch.device(device) self.n_threads = cpu_n_threads self.torch_model: TransformerBasedSessionEncoder @@ -794,7 +743,8 @@ def __init__( # pylint: disable=too-many-arguments session_max_len=session_max_len, dropout_rate=dropout_rate, use_pos_emb=use_pos_emb, - use_causal_attn=True, + use_causal_attn=use_causal_attn, + use_mlm_attn=use_mlm_attn, transformer_layers_type=transformer_layers_type, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, @@ -812,7 +762,7 @@ def __init__( # pylint: disable=too-many-arguments ) else: self._trainer = trainer - self.data_preparator = data_preparator_type(session_max_len, batch_size, dataloader_num_workers) + self.data_preparator: SessionEncoderDataPreparatorBase self.u2i_dist = Distance.DOT self.i2i_dist = Distance.COSINE self.lr = lr @@ -824,7 +774,6 @@ def _fit( ) -> None: processed_dataset = self.data_preparator.process_dataset_train(dataset) train_dataloader = self.data_preparator.get_dataloader_train(processed_dataset) - self.torch_model = deepcopy(self._torch_model) # TODO: check that it works self.torch_model.construct_item_net(processed_dataset) @@ -845,7 +794,7 @@ def _custom_transform_dataset_i2i( def _recommend_u2i( self, user_ids: InternalIdsArray, - dataset: Dataset, # [n_rec_users x n_items + 1] + dataset: Dataset, # [n_rec_users x n_items + 2] k: int, filter_viewed: bool, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal @@ -860,7 +809,7 @@ def _recommend_u2i( recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) session_embs = [] - item_embs = self.torch_model.item_model.get_all_embeddings() # [n_items + 1, n_factors] + item_embs = self.torch_model.item_model.get_all_embeddings() # [n_items + 2, n_factors] with torch.no_grad(): for x_batch in tqdm.tqdm(recommend_dataloader): # TODO: from tqdm.auto import tqdm. Also check `verbose`` x_batch = x_batch.to(self.device) # [batch_size, session_max_len] @@ -875,7 +824,7 @@ def _recommend_u2i( ranker = ImplicitRanker( self.u2i_dist, user_embs, # [n_rec_users, n_factors] - item_embs_np, # [n_items + 1, n_factors] + item_embs_np, # [n_items + 2, n_factors] ) if filter_viewed: user_items = dataset.get_user_item_matrix(include_weights=False) @@ -888,7 +837,7 @@ def _recommend_u2i( user_ids_indices, all_reco_ids, all_scores = ranker.rank( subject_ids=np.arange(user_embs.shape[0]), # n_rec_users k=k, - filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 1] + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 2] sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal num_threads=self.n_threads, ) @@ -907,15 +856,15 @@ def _recommend_i2i( sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() self.torch_model = self.torch_model.eval() - item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 1, n_factors] + item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 2, n_factors] # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU # Should we use torch dot and topk? Should be faster ranker = ImplicitRanker( self.i2i_dist, - item_embs, # [n_items + 1, n_factors] - item_embs, # [n_items + 1, n_factors] + item_embs, # [n_items + 2, n_factors] + item_embs, # [n_items + 2, n_factors] ) return ranker.rank( subject_ids=target_ids, # model internal @@ -926,6 +875,68 @@ def _recommend_i2i( ) @property - def lightning_model(self) -> SessionEncoderLightningModule: + def lightning_model(self) -> LightningModule: """TODO""" return self.trainer.lightning_module + + +# #### -------------- SASRec Model -------------- #### # + + +class SASRecModel(TransformerModelBase): + """TODO""" + + # pylint: disable=too-many-locals + + def __init__( # pylint: disable=too-many-arguments + self, + n_blocks: int = 1, + n_heads: int = 1, + n_factors: int = 128, + use_pos_emb: bool = True, + use_causal_attn: bool = True, + use_mlm_attn: bool = False, + dropout_rate: float = 0.2, + session_max_len: int = 32, + dataloader_num_workers: int = 0, + batch_size: int = 128, + loss: str = "softmax", + lr: float = 0.01, + epochs: int = 3, + verbose: int = 0, + deterministic: bool = False, + device: str = "cuda:1", + cpu_n_threads: int = 0, + train_min_user_interaction: int = 2, + trainer: tp.Optional[Trainer] = None, + item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), + pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, + transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net + data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = SASRecDataPreparator, + lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, + ): + super().__init__( + transformer_layers_type, # SASRec authors net + n_blocks, + n_heads, + n_factors, + use_pos_emb, + use_causal_attn, + use_mlm_attn, + dropout_rate, + session_max_len, + loss, + lr, + epochs, + verbose, + deterministic, + device, + cpu_n_threads, + trainer, + item_net_block_types, + pos_encoding_type, + lightning_module_type, + ) + self.data_preparator = data_preparator_type( + session_max_len, batch_size, dataloader_num_workers, (PADDING_VALUE,), train_min_user_interaction + ) From fc1b3cda5af462d50750e1ecf7949ffef77beba8 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Sat, 2 Nov 2024 15:38:12 +0300 Subject: [PATCH 3/8] fixed merge --- rectools/models/sasrec.py | 8 ++------ tests/models/test_sasrec.py | 8 +++++++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index fe0371c0..c5555869 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -1111,8 +1111,6 @@ def _recommend_i2i( if sorted_item_ids_to_recommend is None: sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() - self.torch_model = self.torch_model.eval() - item_embs = self.lightning_model.item_embs.detach().cpu().numpy() # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU # Should we use torch dot and topk? Should be faster @@ -1131,9 +1129,9 @@ def _recommend_i2i( ) @property - def lightning_model(self) -> LightningModule: + def torch_model(self) -> TransformerBasedSessionEncoder: """TODO""" - return self.trainer.lightning_module + return self.lightning_model.torch_model # #### -------------- SASRec Model -------------- #### # @@ -1161,7 +1159,6 @@ def __init__( # pylint: disable=too-many-arguments epochs: int = 3, verbose: int = 0, deterministic: bool = False, - device: str = "cuda:1", cpu_n_threads: int = 0, train_min_user_interaction: int = 2, trainer: tp.Optional[Trainer] = None, @@ -1186,7 +1183,6 @@ def __init__( # pylint: disable=too-many-arguments epochs, verbose, deterministic, - device, cpu_n_threads, trainer, item_net_block_types, diff --git a/tests/models/test_sasrec.py b/tests/models/test_sasrec.py index a7af7644..f67e5f8a 100644 --- a/tests/models/test_sasrec.py +++ b/tests/models/test_sasrec.py @@ -433,7 +433,13 @@ def dataset(self) -> Dataset: @pytest.fixture def data_preparator(self) -> SASRecDataPreparator: - return SASRecDataPreparator(session_max_len=3, batch_size=4, dataloader_num_workers=0) + return SASRecDataPreparator( + session_max_len=3, + batch_size=4, + dataloader_num_workers=0, + item_extra_tokens=("PAD",), + train_min_user_interactions=2, + ) @pytest.mark.parametrize( "expected_user_id_map, expected_item_id_map, expected_interactions", From f8c230095e32836944c9b2c73a3036ad04c80a49 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Fri, 8 Nov 2024 13:38:33 +0300 Subject: [PATCH 4/8] changed timeline mask --- rectools/models/bert4rec.py | 36 ++++++++++++++++++--------------- rectools/models/sasrec.py | 40 +++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py index 5cb630ee..1689a3f3 100644 --- a/rectools/models/bert4rec.py +++ b/rectools/models/bert4rec.py @@ -72,9 +72,9 @@ def _collate_fn_train( yw = np.zeros((batch_size, self.session_max_len)) for i, (ses, ses_weights) in enumerate(batch): masked_session, target = self._mask_session(ses) - x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len] - y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len] - yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len] + x[i, -len(ses) + 1:] = masked_session[:-1] # ses: [session_len] -> x[i]: [session_max_len] + y[i, -len(ses) + 1:] = target[:-1] # ses: [session_len] -> y[i]: [session_max_len] + yw[i, -len(ses) + 1:] = ses_weights[:-1] # ses_weights: [session_len] -> yw[i]: [session_max_len] return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) @@ -83,11 +83,10 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t x = np.zeros((len(batch), self.session_max_len)) for i, (ses, _) in enumerate(batch): session = ses.copy() - session = session + [1] - x[i, -len(ses) :] = ses[-self.session_max_len :] + session = session[1:] + [1] + x[i, -len(ses) :] = session[-self.session_max_len:] return torch.LongTensor(x) - class PointWiseFeedForward(nn.Module): """TODO""" @@ -104,8 +103,7 @@ def forward(self, seqs: torch.Tensor) -> torch.Tensor: output = self.ff_gelu(self.ff_linear1(seqs)) fin = self.ff_linear2(self.ff_dropout(output)) return fin - - + class BERT4RecTransformerLayers(TransformerLayersBase): """TODO""" @@ -125,24 +123,30 @@ def __init__( self.dropout1 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) self.layer_norm2 = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) self.feed_forward = nn.ModuleList( - [PointWiseFeedForward(n_factors, n_factors * 4, dropout_rate) for _ in range(n_blocks)] + [PointWiseFeedForward(n_factors, n_factors * 4, dropout_rate, torch.nn.GELU()) for _ in range(n_blocks)] ) self.dropout2 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) - # self.dropout3 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) + self.dropout3 = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(n_blocks)]) - def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + def forward( + self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor, key_padding_mask: torch.Tensor + ) -> torch.Tensor: """TODO""" for i in range(self.n_blocks): mha_input = self.layer_norm1[i](seqs) - # mha_output, _ = - # self.multi_head_attn[i](mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False) - mha_output, _ = self.multi_head_attn[i](mha_input, mha_input, mha_input, need_weights=False) + mha_output, _ = self.multi_head_attn[i]( + mha_input, + mha_input, + mha_input, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + ) seqs = seqs + self.dropout1[i](mha_output) ff_input = self.layer_norm2[i](seqs) ff_output = self.feed_forward[i](ff_input) seqs = seqs + self.dropout2[i](ff_output) - seqs = seqs * timeline_mask - # seqs = self.dropout3[i](seqs) + seqs = self.dropout3[i](seqs) return seqs diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index c5555869..d1229b4b 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -52,7 +52,9 @@ def device(self) -> torch.device: class TransformerLayersBase(nn.Module): """TODO: use Protocol""" - def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + def forward( + self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor, key_padding_mask: torch.Tensor + ) -> torch.Tensor: """Forward pass.""" raise NotImplementedError() @@ -60,7 +62,7 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to class PositionalEncodingBase(torch.nn.Module): """TODO: use Protocol""" - def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.Tensor: + def forward(self, sessions: torch.Tensor) -> torch.Tensor: """Forward pass.""" raise NotImplementedError() @@ -264,7 +266,7 @@ class PointWiseFeedForward(nn.Module): Probability of a hidden unit to be zeroed. """ - def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> None: + def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module) -> None: super().__init__() self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) self.ff_dropout1 = torch.nn.Dropout(dropout_rate) @@ -322,11 +324,14 @@ def __init__( self.q_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) self.feed_forward = nn.ModuleList( - [PointWiseFeedForward(n_factors, n_factors, dropout_rate) for _ in range(n_blocks)] + [PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU()) for _ in range(n_blocks)] ) + self.dropout = nn.ModuleList([torch.nn.Dropout(dropout_rate) for _ in range(n_blocks)]) self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8) - def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + def forward( + self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor, key_padding_mask: torch.Tensor + ) -> torch.Tensor: """ Forward pass through transformer blocks. @@ -344,9 +349,14 @@ def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: to torch.Tensor User sequences passed through transformer layers. """ + # TODO: do we need to fill padding embeds in sessions to all zeros + # or should we use the learnt padding embedding? Should we make it an option for user to decide? + seqs *= timeline_mask # [batch_size, session_max_len, n_factors] for i in range(self.n_blocks): q = self.q_layer_norm[i](seqs) - mha_output, _ = self.multi_head_attn[i](q, seqs, seqs, attn_mask=attn_mask, need_weights=False) + mha_output, _ = self.multi_head_attn[i]( + q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False + ) seqs = q + mha_output ff_input = self.ff_layer_norm[i](seqs) seqs = self.feed_forward[i](ff_input) @@ -376,7 +386,7 @@ def __init__(self, use_pos_emb: bool, session_max_len: int, n_factors: int): super().__init__() self.pos_emb = torch.nn.Embedding(session_max_len, n_factors) if use_pos_emb else None - def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.Tensor: + def forward(self, sessions: torch.Tensor) -> torch.Tensor: """ Forward pass to add learnable positional encoding to sessions and mask padding elements. @@ -402,10 +412,6 @@ def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch. ) # [batch_size, session_max_len] sessions += self.pos_emb(positions.to(sessions.device)) - # TODO: do we need to fill padding embeds in sessions to all zeros - # or should we use the learnt padding embedding? Should we make it an option for user to decide? - sessions *= timeline_mask # [batch_size, session_max_len, n_factors] - return sessions @@ -506,19 +512,18 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to """ session_max_len = sessions.shape[1] attn_mask = None + key_padding_mask = None if self.use_causal_attn: attn_mask = ~torch.tril( torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) ) if self.use_mlm_attn: - timeline_mask = sessions != 0 - attn_mask = ~timeline_mask.unsqueeze(1).repeat(self.n_heads, timeline_mask.squeeze(-1).shape[1], 1) - timeline_mask = timeline_mask.unsqueeze(-1) + key_padding_mask = sessions == 0 timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1] seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] - seqs = self.pos_encoding(seqs, timeline_mask) + seqs = self.pos_encoding(seqs) seqs = self.emb_dropout(seqs) - seqs = self.transformer_layers(seqs, timeline_mask, attn_mask) + seqs = self.transformer_layers(seqs, timeline_mask, attn_mask, key_padding_mask) return seqs def forward( @@ -908,6 +913,7 @@ class SessionEncoderLightningModule(SessionEncoderLightningModuleBase): def on_train_start(self) -> None: """Initialize parameters with values from Xavier normal distribution.""" + # TODO: init padding embedding with zeros self._xavier_normal_init() def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: @@ -1169,7 +1175,7 @@ def __init__( # pylint: disable=too-many-arguments lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, ): super().__init__( - transformer_layers_type, # SASRec authors net + transformer_layers_type, n_blocks, n_heads, n_factors, From 8d15265a3120953944bc01edbabdee8eae0c74e6 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Fri, 8 Nov 2024 13:59:24 +0300 Subject: [PATCH 5/8] changed pointwisefeedforward --- examples/bert4rec.ipynb | 419 +++--------------------------------- rectools/models/bert4rec.py | 25 +-- rectools/models/sasrec.py | 17 +- tests/models/test_sasrec.py | 8 +- 4 files changed, 42 insertions(+), 427 deletions(-) diff --git a/examples/bert4rec.ipynb b/examples/bert4rec.ipynb index 72acebf9..c0b48061 100644 --- a/examples/bert4rec.ipynb +++ b/examples/bert4rec.ipynb @@ -29,7 +29,7 @@ "\n", "from rectools.dataset import Dataset\n", "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", - "from rectools.models.bert4rec import CatFeaturesItemNet, IdEmbeddingsItemNet, BERT4RecModel" + "from rectools.models.bert4rec import IdEmbeddingsItemNet, BERT4RecModel" ] }, { @@ -40,7 +40,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -157,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -203,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -219,7 +219,7 @@ "32" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -271,82 +271,9 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 1.3 M \n", - "---------------------------------------------------------------\n", - "1.3 M Trainable params\n", - "0 Non-trainable params\n", - "1.3 M Total params\n", - "5.291 Total estimated model params size (MB)\n", - "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "221b2c4ffa834b80a47552e1b8ffd21d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00:1\u001b[0m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:69\u001b[0m, in \u001b[0;36mModelBase.fit\u001b[0;34m(self, dataset, *args, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, dataset: Dataset, \u001b[38;5;241m*\u001b[39margs: tp\u001b[38;5;241m.\u001b[39mAny, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: tp\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 57\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;124;03m Fit model.\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;124;03m self\u001b[39;00m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_fitted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:485\u001b[0m, in \u001b[0;36mBERT4RecModel._fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 483\u001b[0m lightning_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module_type(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtorch_model, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss)\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer \u001b[38;5;241m=\u001b[39m deepcopy(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer)\n\u001b[0;32m--> 485\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlightning_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:544\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 544\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 545\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 546\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:580\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 574\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 575\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 576\u001b[0m ckpt_path,\n\u001b[1;32m 577\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 578\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 579\u001b[0m )\n\u001b[0;32m--> 580\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:987\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 982\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 986\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 987\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 989\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 990\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 991\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 992\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1033\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1031\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1032\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1033\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1034\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1035\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:157\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 154\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 157\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 160\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/core/module.py:1303\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1266\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1269\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1270\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1273\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1301\u001b[0m \n\u001b[1;32m 1302\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1303\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py:152\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 152\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:239\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 239\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/optimizer.py:391\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 387\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 388\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 389\u001b[0m )\n\u001b[0;32m--> 391\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 394\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 75\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 78\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/optim/adam.py:148\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 148\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 151\u001b[0m params_with_grad \u001b[38;5;241m=\u001b[39m []\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Optimizer,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py:318\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[1;32m 317\u001b[0m \u001b[38;5;66;03m# manually capture logged metrics\u001b[39;00m\n\u001b[0;32m--> 318\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_result_cls\u001b[38;5;241m.\u001b[39mfrom_training_step_output(training_step_output, trainer\u001b[38;5;241m.\u001b[39maccumulate_grad_batches)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:309\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 309\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 312\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:391\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:386\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule.training_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 373\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(x) \u001b[38;5;66;03m# [batch_size, session_max_len, n_items + 2]\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msoftmax\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 375\u001b[0m \u001b[38;5;66;03m# We are using CrossEntropyLoss with a multi-dimensional case\u001b[39;00m\n\u001b[1;32m 376\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;66;03m# Loss output will have a shape of [batch_size, session_max_len]\u001b[39;00m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;66;03m# and will have zeros for every `0` target label\u001b[39;00m\n\u001b[0;32m--> 386\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 387\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnone\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 388\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch_size, session_max_len]\u001b[39;00m\n\u001b[1;32m 389\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss \u001b[38;5;241m*\u001b[39m w\n\u001b[1;32m 390\u001b[0m n \u001b[38;5;241m=\u001b[39m (loss \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(loss\u001b[38;5;241m.\u001b[39mdtype)\n", - "File \u001b[0;32m/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/torch/nn/functional.py:3086\u001b[0m, in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[1;32m 3084\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3085\u001b[0m reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 3086\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 90.00 MiB. GPU " - ] - } - ], + "outputs": [], "source": [ "%%time\n", "model.fit(dataset_no_features)" @@ -354,32 +281,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/bert4rec.py:181: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:406: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "100%|██████████| 740/740 [00:15<00:00, 49.03it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 17min 5s, sys: 46.6 s, total: 17min 52s\n", - "Wall time: 34.8 s\n" - ] - } - ], + "outputs": [], "source": [ "%%time\n", "recos = model.recommend(\n", @@ -393,18 +297,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "del interactions\n", - "del model\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -415,155 +308,6 @@ "features_results.append(metric_values)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
5755503138650.6415951
5755513152970.2407842
575552344950.0721063
575553378290.0463104
57555437102-0.1499295
...............
22495510975447102-0.3630396
22495610975444151-0.3999357
22495710975447793-0.4406678
22495810975444457-0.6528259
224959109754412995-0.70826310
\n", - "

947050 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "575550 3 13865 0.641595 1\n", - "575551 3 15297 0.240784 2\n", - "575552 3 4495 0.072106 3\n", - "575553 3 7829 0.046310 4\n", - "575554 3 7102 -0.149929 5\n", - "... ... ... ... ...\n", - "224955 1097544 7102 -0.363039 6\n", - "224956 1097544 4151 -0.399935 7\n", - "224957 1097544 7793 -0.440667 8\n", - "224958 1097544 4457 -0.652825 9\n", - "224959 1097544 12995 -0.708263 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# major recommend\n", - "recos.sort_values([\"user_id\", \"rank\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With timeline mask in the end of the block, with attention mask" - ] - }, { "cell_type": "code", "execution_count": 17, @@ -572,15 +316,15 @@ { "data": { "text/plain": [ - "[{'MAP@1': 0.03386095770656615,\n", - " 'MAP@5': 0.059875092311754766,\n", - " 'MAP@10': 0.06626564554123239,\n", + "[{'MAP@1': 0.04153170911358253,\n", + " 'MAP@5': 0.07096106984411608,\n", + " 'MAP@10': 0.07874644762957389,\n", " 'MIUF@1': 18.824620072061013,\n", " 'MIUF@5': 18.824620072061013,\n", " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.06777889234992873,\n", - " 'Serendipity@5': 0.04409114066936074,\n", - " 'Serendipity@10': 0.031205145274404236,\n", + " 'Serendipity@1': 0.08494799640990444,\n", + " 'Serendipity@5': 0.05316937762913509,\n", + " 'Serendipity@10': 0.03892074762532452,\n", " 'model': 'bert4rec_ids'}]" ] }, @@ -593,123 +337,12 @@ "features_results" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Without timeline mask, with attention mask" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'MAP@1': 0.031715044770102244,\n", - " 'MAP@5': 0.058107653322795036,\n", - " 'MAP@10': 0.06400667270068171,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.0633651866321736,\n", - " 'Serendipity@5': 0.04325255649454838,\n", - " 'Serendipity@10': 0.030283831925392017,\n", - " 'model': 'bert4rec_ids'}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_results" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With timeline mask in the end of the block, whithout attention mask" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'MAP@1': 0.03521807589657321,\n", - " 'MAP@5': 0.0635501105108066,\n", - " 'MAP@10': 0.07042686574268418,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.07181247030251835,\n", - " 'Serendipity@5': 0.048066313492978796,\n", - " 'Serendipity@10': 0.03423476251676267,\n", - " 'model': 'bert4rec_ids'}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_results" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With timeline mask, whithout attention mask, 5 and 7 epochs" - ] - }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'MAP@1': 0.03521807589657321,\n", - " 'MAP@5': 0.0635501105108066,\n", - " 'MAP@10': 0.07042686574268418,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.07181247030251835,\n", - " 'Serendipity@5': 0.048066313492978796,\n", - " 'Serendipity@10': 0.03423476251676267,\n", - " 'model': 'bert4rec_ids'},\n", - " {'MAP@1': 0.03613885396129421,\n", - " 'MAP@5': 0.0626756506459862,\n", - " 'MAP@10': 0.06914741192133474,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.07439945092656143,\n", - " 'Serendipity@5': 0.04685051880216868,\n", - " 'Serendipity@10': 0.033803948060973046,\n", - " 'model': 'bert4rec_ids'}]" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_results" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", @@ -735,7 +368,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.8.2" } }, "nbformat": 4, diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py index 1689a3f3..7193de8a 100644 --- a/rectools/models/bert4rec.py +++ b/rectools/models/bert4rec.py @@ -11,6 +11,7 @@ IdEmbeddingsItemNet, ItemNetBase, LearnableInversePositionalEncoding, + PointWiseFeedForward, PositionalEncodingBase, SessionEncoderDataPreparatorBase, SessionEncoderLightningModule, @@ -72,9 +73,9 @@ def _collate_fn_train( yw = np.zeros((batch_size, self.session_max_len)) for i, (ses, ses_weights) in enumerate(batch): masked_session, target = self._mask_session(ses) - x[i, -len(ses) + 1:] = masked_session[:-1] # ses: [session_len] -> x[i]: [session_max_len] - y[i, -len(ses) + 1:] = target[:-1] # ses: [session_len] -> y[i]: [session_max_len] - yw[i, -len(ses) + 1:] = ses_weights[:-1] # ses_weights: [session_len] -> yw[i]: [session_max_len] + x[i, -len(ses) + 1 :] = masked_session[:-1] # ses: [session_len] -> x[i]: [session_max_len] + y[i, -len(ses) + 1 :] = target[:-1] # ses: [session_len] -> y[i]: [session_max_len] + yw[i, -len(ses) + 1 :] = ses_weights[:-1] # ses_weights: [session_len] -> yw[i]: [session_max_len] return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) @@ -84,26 +85,10 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t for i, (ses, _) in enumerate(batch): session = ses.copy() session = session[1:] + [1] - x[i, -len(ses) :] = session[-self.session_max_len:] + x[i, -len(ses) :] = session[-self.session_max_len :] return torch.LongTensor(x) -class PointWiseFeedForward(nn.Module): - """TODO""" - def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> None: - """TODO""" - super().__init__() - self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) - self.ff_gelu = torch.nn.GELU() - self.ff_dropout = torch.nn.Dropout(dropout_rate) - self.ff_linear2 = nn.Linear(n_factors_ff, n_factors) - - def forward(self, seqs: torch.Tensor) -> torch.Tensor: - """TODO""" - output = self.ff_gelu(self.ff_linear1(seqs)) - fin = self.ff_linear2(self.ff_dropout(output)) - return fin - class BERT4RecTransformerLayers(TransformerLayersBase): """TODO""" diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index d1229b4b..3e7bb50e 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -270,9 +270,8 @@ def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, activ super().__init__() self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) self.ff_dropout1 = torch.nn.Dropout(dropout_rate) - self.ff_relu = torch.nn.ReLU() + self.ff_activation = activation self.ff_linear2 = nn.Linear(n_factors_ff, n_factors) - self.ff_dropout2 = torch.nn.Dropout(dropout_rate) def forward(self, seqs: torch.Tensor) -> torch.Tensor: """ @@ -288,8 +287,8 @@ def forward(self, seqs: torch.Tensor) -> torch.Tensor: torch.Tensor User sequence that passed through all layers. """ - output = self.ff_relu(self.ff_dropout1(self.ff_linear1(seqs))) - fin = self.ff_dropout2(self.ff_linear2(output)) + output = self.ff_activation(self.ff_linear1(seqs)) + fin = self.ff_linear2(self.ff_dropout1(output)) return fin @@ -360,6 +359,7 @@ def forward( seqs = q + mha_output ff_input = self.ff_layer_norm[i](seqs) seqs = self.feed_forward[i](ff_input) + seqs = self.dropout[i](seqs) seqs += ff_input seqs *= timeline_mask @@ -632,9 +632,9 @@ def __init__( session_max_len: int, batch_size: int, dataloader_num_workers: int, - item_extra_tokens: tp.Sequence[tp.Hashable], - train_min_user_interactions: int, shuffle_train: bool = True, + item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,), + train_min_user_interactions: int = 2, ) -> None: """TODO""" self.item_id_map: IdMap @@ -1196,5 +1196,8 @@ def __init__( # pylint: disable=too-many-arguments lightning_module_type, ) self.data_preparator = data_preparator_type( - session_max_len, batch_size, dataloader_num_workers, (PADDING_VALUE,), train_min_user_interaction + session_max_len=session_max_len, + batch_size=batch_size, + dataloader_num_workers=dataloader_num_workers, + train_min_user_interactions=train_min_user_interaction, ) diff --git a/tests/models/test_sasrec.py b/tests/models/test_sasrec.py index f67e5f8a..a7af7644 100644 --- a/tests/models/test_sasrec.py +++ b/tests/models/test_sasrec.py @@ -433,13 +433,7 @@ def dataset(self) -> Dataset: @pytest.fixture def data_preparator(self) -> SASRecDataPreparator: - return SASRecDataPreparator( - session_max_len=3, - batch_size=4, - dataloader_num_workers=0, - item_extra_tokens=("PAD",), - train_min_user_interactions=2, - ) + return SASRecDataPreparator(session_max_len=3, batch_size=4, dataloader_num_workers=0) @pytest.mark.parametrize( "expected_user_id_map, expected_item_id_map, expected_interactions", From e7954428988caa9ca6f3fd4c6c32d577147c9403 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Wed, 13 Nov 2024 19:01:24 +0300 Subject: [PATCH 6/8] fixed bert4rec collate_fn --- examples/bert4rec.ipynb | 295 +++++++++++++++++++++++++++++++----- rectools/models/bert4rec.py | 30 ++-- rectools/models/sasrec.py | 45 +++--- 3 files changed, 298 insertions(+), 72 deletions(-) diff --git a/examples/bert4rec.ipynb b/examples/bert4rec.ipynb index c0b48061..f30e17a8 100644 --- a/examples/bert4rec.ipynb +++ b/examples/bert4rec.ipynb @@ -1,15 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append(\"/data/home/maspirina1/tasks/repo/RecTools/\")" - ] - }, { "cell_type": "code", "execution_count": 2, @@ -40,7 +30,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -75,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -157,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -203,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -219,7 +209,7 @@ "32" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -271,9 +261,66 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 1.3 M \n", + "---------------------------------------------------------------\n", + "1.3 M Trainable params\n", + "0 Non-trainable params\n", + "1.3 M Total params\n", + "5.292 Total estimated model params size (MB)\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "031d93173b9b455a8b7e04e4933e13c9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "%%time\n", "model.fit(dataset_no_features)" @@ -281,9 +328,47 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/sasrec.py:786: UserWarning: 91202 target users were considered cold because of missing known items\n", + " warnings.warn(explanation)\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:420: UserWarning: \n", + " Model `` doesn't support recommendations for cold users,\n", + " but some of given users are cold: they are not in the `dataset.user_id_map`\n", + " \n", + " warnings.warn(explanation)\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cdbad90214f24244a0505fbe3955c0d4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
07344677930.9713501
17344678290.9338672
27344637840.6187423
37344697280.6087454
473446121920.2983885
...............
94704585716237341.4075016
94704685716241511.2583127
94704785716286361.2272388
94704885716218441.1099769
94704985716244360.99829510
\n", + "

947050 rows × 4 columns

\n", + "" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 73446 7793 0.971350 1\n", + "1 73446 7829 0.933867 2\n", + "2 73446 3784 0.618742 3\n", + "3 73446 9728 0.608745 4\n", + "4 73446 12192 0.298388 5\n", + "... ... ... ... ...\n", + "947045 857162 3734 1.407501 6\n", + "947046 857162 4151 1.258312 7\n", + "947047 857162 8636 1.227238 8\n", + "947048 857162 1844 1.109976 9\n", + "947049 857162 4436 0.998295 10\n", + "\n", + "[947050 rows x 4 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recos" + ] + }, { "cell_type": "code", "execution_count": 17, @@ -316,15 +542,15 @@ { "data": { "text/plain": [ - "[{'MAP@1': 0.04153170911358253,\n", - " 'MAP@5': 0.07096106984411608,\n", - " 'MAP@10': 0.07874644762957389,\n", + "[{'MAP@1': 0.0457901198911608,\n", + " 'MAP@5': 0.07710723775026486,\n", + " 'MAP@10': 0.08559323634049909,\n", " 'MIUF@1': 18.824620072061013,\n", " 'MIUF@5': 18.824620072061013,\n", " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.08494799640990444,\n", - " 'Serendipity@5': 0.05316937762913509,\n", - " 'Serendipity@10': 0.03892074762532452,\n", + " 'Serendipity@1': 0.09274061559579748,\n", + " 'Serendipity@5': 0.056047439499790956,\n", + " 'Serendipity@10': 0.04129842262611581,\n", " 'model': 'bert4rec_ids'}]" ] }, @@ -337,13 +563,6 @@ "features_results" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py index 7193de8a..9e7d596c 100644 --- a/rectools/models/bert4rec.py +++ b/rectools/models/bert4rec.py @@ -57,7 +57,7 @@ def _mask_session(self, ses: List[int]) -> Tuple[List[int], List[int]]: if random_probs[j] < 0.8: masked_session[j] = 1 elif random_probs[j] < 0.9: - masked_session[j] = np.random.randint(low=2, high=self.item_id_map.size, size=1)[0] + masked_session[j] = np.random.randint(low=self.n_item_extra_tokens, high=self.item_id_map.size) else: target[j] = 0 return masked_session, target @@ -68,24 +68,24 @@ def _collate_fn_train( ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """TODO""" batch_size = len(batch) - x = np.zeros((batch_size, self.session_max_len)) - y = np.zeros((batch_size, self.session_max_len)) - yw = np.zeros((batch_size, self.session_max_len)) + x = np.zeros((batch_size, self.session_max_len + 1)) + y = np.zeros((batch_size, self.session_max_len + 1)) + yw = np.zeros((batch_size, self.session_max_len + 1)) for i, (ses, ses_weights) in enumerate(batch): masked_session, target = self._mask_session(ses) - x[i, -len(ses) + 1 :] = masked_session[:-1] # ses: [session_len] -> x[i]: [session_max_len] - y[i, -len(ses) + 1 :] = target[:-1] # ses: [session_len] -> y[i]: [session_max_len] - yw[i, -len(ses) + 1 :] = ses_weights[:-1] # ses_weights: [session_len] -> yw[i]: [session_max_len] + x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len + 1] + y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len + 1] + yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len + 1] return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> torch.LongTensor: """Right truncation, left padding to session_max_len""" - x = np.zeros((len(batch), self.session_max_len)) + x = np.zeros((len(batch), self.session_max_len + 1)) for i, (ses, _) in enumerate(batch): session = ses.copy() - session = session[1:] + [1] - x[i, -len(ses) :] = session[-self.session_max_len :] + session = session + [1] + x[i, -len(ses) - 1 :] = session[-self.session_max_len - 1 :] return torch.LongTensor(x) @@ -132,7 +132,8 @@ def forward( ff_output = self.feed_forward[i](ff_input) seqs = seqs + self.dropout2[i](ff_output) seqs = self.dropout3[i](seqs) - + # TODO: test with torch.nn.Linear and cross-entropy loss as in + # https://github.com/jaywonchung/BERT4Rec-VAE-Pytorch/blob/f66f2534ebfd937778c7174b5f9f216efdebe5de/models/bert.py#L11C1-L11C67 return seqs @@ -146,7 +147,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals n_factors: int = 128, use_pos_emb: bool = True, use_causal_attn: bool = False, - use_mlm_attn: bool = True, + use_key_padding_mask: bool = True, dropout_rate: float = 0.2, epochs: int = 3, verbose: int = 0, @@ -168,12 +169,13 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals ): super().__init__( transformer_layers_type=transformer_layers_type, + data_preparator_type=data_preparator_type, n_blocks=n_blocks, n_heads=n_heads, n_factors=n_factors, use_pos_emb=use_pos_emb, use_causal_attn=use_causal_attn, - use_mlm_attn=use_mlm_attn, + use_key_padding_mask=use_key_padding_mask, dropout_rate=dropout_rate, epochs=epochs, verbose=verbose, @@ -181,7 +183,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals cpu_n_threads=cpu_n_threads, loss=loss, lr=lr, - session_max_len=session_max_len, + session_max_len=session_max_len + 1, trainer=trainer, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index 3e7bb50e..d0584c4b 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -455,7 +455,7 @@ def __init__( dropout_rate: float, use_pos_emb: bool = True, use_causal_attn: bool = True, - use_mlm_attn: bool = False, + use_key_padding_mask: bool = False, transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, @@ -472,7 +472,7 @@ def __init__( dropout_rate=dropout_rate, ) self.use_causal_attn = use_causal_attn - self.use_mlm_attn = use_mlm_attn + self.use_key_padding_mask = use_key_padding_mask self.n_factors = n_factors self.dropout_rate = dropout_rate self.n_heads = n_heads @@ -513,11 +513,12 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to session_max_len = sessions.shape[1] attn_mask = None key_padding_mask = None + # TODO: att_mask and key_padding_mask together result into NaN scores if self.use_causal_attn: attn_mask = ~torch.tril( torch.ones((session_max_len, session_max_len), dtype=torch.bool, device=sessions.device) ) - if self.use_mlm_attn: + if self.use_key_padding_mask: key_padding_mask = sessions == 0 timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1] seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] @@ -546,9 +547,9 @@ def forward( torch.Tensor Logits. """ - item_embs = self.item_model.get_all_embeddings() # [n_items + 1, n_factors] + item_embs = self.item_model.get_all_embeddings() # [n_items + n_special_tokens, n_factors] session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - logits = session_embs @ item_embs.T # [batch_size, session_max_len, n_items + 1] + logits = session_embs @ item_embs.T # [batch_size, session_max_len, n_items + n_special_tokens] return logits @@ -934,12 +935,12 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: Loss. """ x, y, w = batch - logits = self.forward(x) # [batch_size, session_max_len, n_items + 1] + logits = self.forward(x) # [batch_size, session_max_len, n_items + n_special_tokens] if self.loss == "softmax": # We are using CrossEntropyLoss with a multi-dimensional case - # Logits must be passed in form of [batch_size, n_items + 1, session_max_len], - # where n_items + 1 is number of classes + # Logits must be passed in form of [batch_size, n_items + n_special_tokens, session_max_len], + # where n_items + n_special_tokens is number of classes # Target label indexes must be passed in a form of [batch_size, session_max_len] # (`0` index for "PAD" ix excluded from loss) @@ -978,17 +979,22 @@ def _xavier_normal_init(self) -> None: class TransformerModelBase(ModelBase): - """TODO""" + """ + Base model for all recommender algorithms that work on transformer architecture (e.g. SASRec, Bert4Rec). + To create a custom transformer model it is necessary to inherit from this class + and write self.data_preparator initialization logic. + """ def __init__( # pylint: disable=too-many-arguments self, transformer_layers_type: tp.Type[TransformerLayersBase], + data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase], n_blocks: int = 1, n_heads: int = 1, n_factors: int = 128, use_pos_emb: bool = True, use_causal_attn: bool = True, - use_mlm_attn: bool = False, + use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 32, loss: str = "softmax", @@ -1012,7 +1018,7 @@ def __init__( # pylint: disable=too-many-arguments dropout_rate=dropout_rate, use_pos_emb=use_pos_emb, use_causal_attn=use_causal_attn, - use_mlm_attn=use_mlm_attn, + use_key_padding_mask=use_key_padding_mask, transformer_layers_type=transformer_layers_type, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, @@ -1065,7 +1071,7 @@ def _custom_transform_dataset_i2i( def _recommend_u2i( self, user_ids: InternalIdsArray, - dataset: Dataset, # [n_rec_users x n_items + 2] + dataset: Dataset, # [n_rec_users x n_items + n_special_tokens] k: int, filter_viewed: bool, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal @@ -1084,7 +1090,7 @@ def _recommend_u2i( ranker = ImplicitRanker( self.u2i_dist, user_embs, # [n_rec_users, n_factors] - item_embs_np, # [n_items + 1, n_factors] + item_embs_np, # [n_items + n_special_tokens, n_factors] ) if filter_viewed: user_items = dataset.get_user_item_matrix(include_weights=False) @@ -1123,8 +1129,8 @@ def _recommend_i2i( ranker = ImplicitRanker( self.i2i_dist, - item_embs, # [n_items + 2, n_factors] - item_embs, # [n_items + 2, n_factors] + item_embs, # [n_items + n_special_tokens, n_factors] + item_embs, # [n_items + n_special_tokens, n_factors] ) return ranker.rank( subject_ids=target_ids, # model internal @@ -1146,16 +1152,14 @@ def torch_model(self) -> TransformerBasedSessionEncoder: class SASRecModel(TransformerModelBase): """TODO""" - # pylint: disable=too-many-locals - - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments, too-many-locals self, n_blocks: int = 1, n_heads: int = 1, n_factors: int = 128, use_pos_emb: bool = True, use_causal_attn: bool = True, - use_mlm_attn: bool = False, + use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 32, dataloader_num_workers: int = 0, @@ -1176,12 +1180,13 @@ def __init__( # pylint: disable=too-many-arguments ): super().__init__( transformer_layers_type, + data_preparator_type, n_blocks, n_heads, n_factors, use_pos_emb, use_causal_attn, - use_mlm_attn, + use_key_padding_mask, dropout_rate, session_max_len, loss, From 9c6d3a48730ad09aa46866ec360e9576ea88c02e Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Thu, 14 Nov 2024 13:01:23 +0300 Subject: [PATCH 7/8] sasrec metrics --- examples/sasrec_metrics_comp.ipynb | 327 ++++++++++++++--------------- 1 file changed, 152 insertions(+), 175 deletions(-) diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb index 9ed4963d..b3053f8f 100644 --- a/examples/sasrec_metrics_comp.ipynb +++ b/examples/sasrec_metrics_comp.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -128,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -214,7 +214,7 @@ "32" ] }, - "execution_count": 11, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -227,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -243,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -254,8 +254,7 @@ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n" + "HPU available: False, using: 0 HPUs\n" ] } ], @@ -273,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -290,13 +289,13 @@ "0 Non-trainable params\n", "927 K Total params\n", "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "00f2ec3343d24c3296e3b6e217689b84", + "model_id": "872b6e4e393b469db004bdd889a89533", "version_major": 2, "version_minor": 0 }, @@ -317,10 +316,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -332,29 +331,44 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/sasrec.py:635: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/sasrec.py:786: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/base.py:406: UserWarning: \n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:420: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", - "100%|██████████| 740/740 [00:02<00:00, 251.43it/s]\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ddd4a5fc9400481f98f0f0c0c086b96f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00" + "" ] }, - "execution_count": 19, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -629,31 +637,44 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/sasrec.py:635: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/sasrec.py:786: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/base.py:406: UserWarning: \n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:420: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", - " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n", - "100%|██████████| 740/740 [00:05<00:00, 147.19it/s]\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "612e92b761d741348fd7ec531d2a1964", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00" + "" ] }, - "execution_count": 28, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -780,31 +801,44 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/sasrec.py:635: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/sasrec.py:786: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/models/base.py:406: UserWarning: \n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/models/base.py:420: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", - " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n", - "100%|██████████| 740/740 [00:03<00:00, 190.30it/s]\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "/data/home/maspirina1/tasks/repo/RecTools/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "59de5ff7662c493f8e96359a7c3b2190", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00" ], "text/plain": [ - " target_item_id item_id score rank\n", - "0 13865 15648 1.000000 1\n", - "1 13865 3386 1.000000 2\n", - "2 13865 147 0.898218 3\n", - "3 13865 16194 0.898218 4\n", - "4 13865 12309 0.898218 5\n", - "5 13865 12586 0.898218 6\n", - "6 13865 6661 0.898218 7\n", - "7 13865 2255 0.898218 8\n", - "8 13865 3792 0.898218 9\n", - "9 13865 4130 0.898218 10\n", - "10 4457 5109 1.000000 1\n", - "11 4457 8851 1.000000 2\n", - "12 4457 8486 1.000000 3\n", - "13 4457 12087 1.000000 4\n", - "14 4457 2313 1.000000 5\n", - "15 4457 11977 1.000000 6\n", - "16 4457 7928 1.000000 7\n", - "17 4457 3384 1.000000 8\n", - "18 4457 11513 1.000000 9\n", - "19 4457 6285 1.000000 10\n", - "20 15297 8723 1.000000 1\n", - "21 15297 5926 1.000000 2\n", - "22 15297 4131 1.000000 3\n", - "23 15297 4229 1.000000 4\n", - "24 15297 7005 1.000000 5\n", - "25 15297 10797 1.000000 6\n", - "26 15297 10535 1.000000 7\n", - "27 15297 5400 1.000000 8\n", - "28 15297 4716 1.000000 9\n", - "29 15297 13103 1.000000 10" + " target_item_id item_id score rank\n", + "0 13865 15648 1.000000 1\n", + "1 13865 3386 1.000000 2\n", + "2 13865 147 0.898218 3\n", + "3 13865 16194 0.898218 4\n", + "4 13865 12309 0.898218 5\n", + "5 13865 12586 0.898218 6\n", + "6 13865 6661 0.898218 7\n", + "7 13865 2255 0.898218 8\n", + "8 13865 3792 0.898218 9\n", + "9 13865 4130 0.898218 10\n", + "10 4457 5109 1.000000 1\n", + "11 4457 8851 1.000000 2\n", + "12 4457 8486 1.000000 3\n", + "13 4457 12087 1.000000 4\n", + "14 4457 2313 1.000000 5\n", + "15 4457 11977 1.000000 6\n", + "16 4457 7928 1.000000 7\n", + "17 4457 3384 1.000000 8\n", + "18 4457 11513 1.000000 9\n", + "19 4457 6285 1.000000 10\n", + "20 15297 8723 1.000000 1\n", + "21 15297 5926 1.000000 2\n", + "22 15297 4131 1.000000 3\n", + "23 15297 4229 1.000000 4\n", + "24 15297 7005 1.000000 5\n", + "25 15297 10797 1.000000 6\n", + "26 15297 10535 1.000000 7\n", + "27 15297 5400 1.000000 8\n", + "28 15297 4716 1.000000 9\n", + "29 15297 13103 1.000000 10" ] }, - "execution_count": 34, + "execution_count": 58, "metadata": {}, "output_type": "execute_result" } @@ -1229,7 +1206,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -1238,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -1256,7 +1233,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -1271,7 +1248,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -1296,7 +1273,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -1322,14 +1299,14 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 73, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", + "/data/home/maspirina1/tasks/repo/RecTools/rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n" ] } @@ -1353,7 +1330,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 74, "metadata": {}, "outputs": [ { @@ -1499,7 +1476,7 @@ "sasrec_cat 0.005200 " ] }, - "execution_count": 48, + "execution_count": 74, "metadata": {}, "output_type": "execute_result" } @@ -1523,9 +1500,9 @@ ], "metadata": { "kernelspec": { - "display_name": "rectools_origin", + "display_name": ".venv", "language": "python", - "name": "rectools_origin" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1537,7 +1514,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.8.2" } }, "nbformat": 4, From 3522361bc9808ec50a1f26ac56e60a3dd6666756 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Thu, 14 Nov 2024 13:32:38 +0300 Subject: [PATCH 8/8] added explicit masking value --- rectools/models/bert4rec.py | 4 ++-- rectools/models/sasrec.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/rectools/models/bert4rec.py b/rectools/models/bert4rec.py index 9e7d596c..1a054864 100644 --- a/rectools/models/bert4rec.py +++ b/rectools/models/bert4rec.py @@ -55,7 +55,7 @@ def _mask_session(self, ses: List[int]) -> Tuple[List[int], List[int]]: if random_probs[j] < self.mask_prob: random_probs[j] /= self.mask_prob if random_probs[j] < 0.8: - masked_session[j] = 1 + masked_session[j] = self.extra_token_ids[MASKING_VALUE] elif random_probs[j] < 0.9: masked_session[j] = np.random.randint(low=self.n_item_extra_tokens, high=self.item_id_map.size) else: @@ -84,7 +84,7 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> t x = np.zeros((len(batch), self.session_max_len + 1)) for i, (ses, _) in enumerate(batch): session = ses.copy() - session = session + [1] + session = session + [self.extra_token_ids[MASKING_VALUE]] x[i, -len(ses) - 1 :] = session[-self.session_max_len - 1 :] return torch.LongTensor(x) diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py index d0584c4b..a260275c 100644 --- a/rectools/models/sasrec.py +++ b/rectools/models/sasrec.py @@ -524,6 +524,7 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] seqs = self.pos_encoding(seqs) seqs = self.emb_dropout(seqs) + # TODO: stop passing timeline_mask together with key_padding_mask because they have same information seqs = self.transformer_layers(seqs, timeline_mask, attn_mask, key_padding_mask) return seqs @@ -639,6 +640,7 @@ def __init__( ) -> None: """TODO""" self.item_id_map: IdMap + self.extra_token_ids: tp.Dict self.session_max_len = session_max_len self.batch_size = batch_size self.dataloader_num_workers = dataloader_num_workers @@ -703,6 +705,9 @@ def process_dataset_train(self, dataset: Dataset) -> Dataset: dataset = Dataset(user_id_map, item_id_map, interactions, item_features=item_features) self.item_id_map = dataset.item_id_map + + extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens) + self.extra_token_ids = dict(zip(self.item_extra_tokens, extra_token_ids)) return dataset def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: