In [1]:
# TODO: will remove
import sys
sys.path.append("../../")

In [2]:
import numpy as np
import os
import pandas as pd
import itertools
import torch
import typing as tp
import warnings
from collections import Counter
from pathlib import Path
from functools import partial

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from rectools import Columns, ExternalIds
from rectools.dataset import Dataset
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics

from rectools.models import BERT4RecModel, SASRecModel
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformer_base import TransformerModelBase

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [3]:
%%time
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_en.zip -O data_en.zip
!unzip -o data_en.zip
!rm data_en.zip

In [4]:
# Download dataset
DATA_PATH = Path("./data_en")
items = pd.read_csv(DATA_PATH / 'items_en.csv', index_col=0)
interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": Columns.Datetime})
)

print(interactions.shape)
interactions.head(2)

(5476251, 5)


Unnamed: 0,user_id,item_id,datetime,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0


In [5]:
interactions[Columns.User].nunique(), interactions[Columns.Item].nunique()

(962179, 15706)

In [6]:
# Process interactions
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
raw_interactions = interactions[["user_id", "item_id", "datetime", "weight"]]
print(raw_interactions.shape)
raw_interactions.head(2)

(5476251, 4)


Unnamed: 0,user_id,item_id,datetime,weight
0,176549,9506,2021-05-11,3
1,699317,1659,2021-05-29,3


In [7]:
# Process item features
# items = items.loc[items[Columns.Item].isin(raw_interactions[Columns.Item])].copy()
# items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")
# genre_feature = items[["item_id", "genre"]].explode("genre")
# genre_feature.columns = ["id", "value"]
# genre_feature["feature"] = "genre"
# content_feature = items.reindex(columns=[Columns.Item, "content_type"])
# content_feature.columns = ["id", "value"]
# content_feature["feature"] = "content_type"
# item_features = pd.concat((genre_feature, content_feature))

In [8]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

In [9]:
dataset_no_features = Dataset.construct(raw_interactions)
dataset_no_features

Dataset(user_id_map=IdMap(external_ids=array([176549, 699317, 656683, ..., 805174, 648596, 697262])), item_id_map=IdMap(external_ids=array([ 9506,  1659,  7107, ..., 10064, 13019, 10542])), interactions=Interactions(df=         user_id  item_id  weight   datetime
0              0        0     3.0 2021-05-11
1              1        1     3.0 2021-05-29
2              2        2     1.0 2021-05-09
3              3        3     3.0 2021-07-05
4              4        0     3.0 2021-04-30
...          ...      ...     ...        ...
5476246   962177      208     1.0 2021-08-13
5476247   224686     2690     3.0 2021-04-13
5476248   962178       21     3.0 2021-08-20
5476249     7934     1725     3.0 2021-04-19
5476250   631989      157     3.0 2021-08-15

[5476251 rows x 4 columns]), user_features=None, item_features=None)

# **Custome Validation** (Leave-One-Out Strategy)

**Functionality for obtaining logged metrics after fitting model:**

In [13]:
def get_log_dir(model: TransformerModelBase) -> Path:
    """
    Get logging directory.
    """
    path = model.fit_trainer.logger.log_dir
    return Path(path) / "metrics.csv"


def get_losses(epoch_metrics_df: pd.DataFrame, is_val: bool) -> pd.DataFrame:
    loss_df = epoch_metrics_df[["epoch", "train/loss"]].dropna()
    if is_val:
        val_loss_df = epoch_metrics_df[["epoch", "val/loss"]].dropna()
        loss_df = pd.merge(loss_df, val_loss_df, how="inner", on="epoch")
    return loss_df.reset_index(drop=True)


def get_val_metrics(epoch_metrics_df: pd.DataFrame) -> pd.DataFrame:
    metrics_df = epoch_metrics_df.drop(columns=["train/loss", "val/loss"]).dropna()
    return metrics_df.reset_index(drop=True)


def get_log_values(model: TransformerModelBase, is_val: bool = False) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]:
    log_path = get_log_dir(model)
    epoch_metrics_df = pd.read_csv(log_path)

    loss_df = get_losses(epoch_metrics_df, is_val)
    val_metrics = None
    if is_val:
        val_metrics = get_val_metrics(epoch_metrics_df)
    return loss_df, val_metrics

**Callback for calculation RecSys metrics on validation step:**

In [14]:
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback


class ValidationMetrics(Callback):
    
    def __init__(self, top_k_saved_val_reco: int, val_metrics: tp.Dict, verbose: int = 0) -> None:
        self.top_k_saved_val_reco = top_k_saved_val_reco
        self.val_metrics = val_metrics
        self.verbose = verbose

        self.epoch_n_users: int = 0
        self.batch_metrics: tp.List[tp.Dict[str, float]] = []

    def on_validation_batch_end(
        self, 
        trainer: Trainer, 
        pl_module: LightningModule, 
        outputs: tp.Dict[str, torch.Tensor], 
        batch: tp.Dict[str, torch.Tensor], 
        batch_idx: int, 
        dataloader_idx: int = 0
    ) -> None:
        logits = outputs["logits"]
        if logits is None:
            logits = pl_module.torch_model.encode_sessions(batch["x"], pl_module.item_embs)[:, -1, :]
        _, sorted_batch_recos = logits.topk(k=self.top_k_saved_val_reco)

        batch_recos = sorted_batch_recos.tolist()
        targets = batch["y"].tolist()

        batch_val_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(recos)) for idx, recos in enumerate(batch_recos)
            )
        )

        batch_target_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(targets)) for idx, targets in enumerate(targets)
            )
        )

        batch_recos_df = pd.DataFrame(
            {
                Columns.User: batch_val_users,
                Columns.Item: list(itertools.chain.from_iterable(batch_recos)),
            }
        )
        batch_recos_df[Columns.Rank] = batch_recos_df.groupby(Columns.User, sort=False).cumcount() + 1

        interactions = pd.DataFrame(
            {
                Columns.User: batch_target_users,
                Columns.Item: list(itertools.chain.from_iterable(targets)),
            }
        )

        prev_interactions = pl_module.data_preparator.train_dataset.interactions.df
        catalog = prev_interactions[Columns.Item].unique()

        batch_metrics = calc_metrics(
            self.val_metrics, 
            batch_recos_df,
            interactions, 
            prev_interactions,
            catalog
        )

        batch_n_users = batch["x"].shape[0]
        self.batch_metrics.append({metric: value * batch_n_users for metric, value in batch_metrics.items()})
        self.epoch_n_users += batch_n_users

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        epoch_metrics = dict(sum(map(Counter, self.batch_metrics), Counter()))
        epoch_metrics = {metric: value / self.epoch_n_users for metric, value in epoch_metrics.items()}

        self.log_dict(epoch_metrics, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)

        self.batch_metrics.clear()
        self.epoch_n_users = 0

**Set up hyperparameters**

In [15]:
VAL_K_OUT = 1
N_VAL_USERS = 2048

unique_users = raw_interactions[Columns.User].unique()
VAL_USERS = unique_users[: N_VAL_USERS]

VAL_METRICS = {
    "NDCG@10": NDCG(k=10),
    "Recall@10": Recall(k=10),
    "Serendipity@10": Serendipity(k=10),
}
VAL_MAX_K = max([metric.k for metric in VAL_METRICS.values()])

MIN_EPOCHS = 2
MAX_EPOCHS = 10

MONITOR_METRIC = "NDCG@10"
MODE_MONITOR_METRIC = "max"

callback_metrics = ValidationMetrics(top_k_saved_val_reco=VAL_MAX_K, val_metrics=VAL_METRICS, verbose=1)
callback_early_stopping = EarlyStopping(monitor=MONITOR_METRIC, patience=MIN_EPOCHS, min_delta=0.0, mode=MODE_MONITOR_METRIC)
CALLBACKS = [callback_metrics, callback_early_stopping]

TRAIN_MIN_USER_INTERACTIONS = 5
SESSION_MAX_LEN = 50

unique_users.shape, VAL_USERS.shape

((962179,), (2048,))

**Custom function for split data on train and validation:**

In [16]:
def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> pd.Series:
    rank = (
        interactions
        .sort_values(Columns.Datetime, ascending=False, kind="stable")
        .groupby(Columns.User, sort=False)
        .cumcount()
        + 1
    )
    val_mask = (
        (interactions[Columns.User].isin(val_users))
        & (rank <= VAL_K_OUT)
    )
    return val_mask


GET_VAL_MASK = partial(
    get_val_mask, 
    val_users=VAL_USERS,
)

# SASRec

In [17]:
sasrec_trainer = Trainer(
    accelerator='gpu',
    devices=[0],
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS, 
    deterministic=True,
    callbacks=CALLBACKS,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
sasrec_non_default_model = SASRecModel(
    n_factors=64,
    n_blocks=2,
    n_heads=2,
    dropout_rate=0.2,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=1e-3,
    batch_size=128,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    trainer=sasrec_trainer,
    get_val_mask_func=GET_VAL_MASK,
)

In [19]:
%%time
sasrec_non_default_model.fit(dataset_no_features)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                           | Params
---------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 987 K 
---------------------------------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.951     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


CPU times: user 17min 30s, sys: 28.3 s, total: 17min 58s
Wall time: 17min 46s


<rectools.models.nn.sasrec.SASRecModel at 0x7fe11c518700>

In [20]:
loss_df, val_metrics_df = get_log_values(sasrec_non_default_model, is_val=True)

In [21]:
loss_df

Unnamed: 0,epoch,train/loss,val/loss
0,0,16.390102,15.514286
1,1,15.722713,15.147015
2,2,15.560143,15.003609
3,3,15.493325,14.91841
4,4,15.450736,14.874678
5,5,15.421854,14.841123
6,6,15.405242,14.814446
7,7,15.390318,14.782287
8,8,15.374591,14.762179
9,9,15.367148,14.763201


In [22]:
val_metrics_df

Unnamed: 0,NDCG@10,Recall@10,Serendipity@10,epoch,step
0,0.021677,0.175542,4.3e-05,0,2362
1,0.0234,0.188692,7.8e-05,1,4725
2,0.024569,0.196581,0.000103,2,7088
3,0.02486,0.197896,0.0001,3,9451
4,0.0261,0.207758,0.000121,4,11814
5,0.026255,0.206443,0.000139,5,14177
6,0.026567,0.207101,0.000131,6,16540
7,0.026694,0.203156,0.00013,7,18903
8,0.027346,0.205786,0.000147,8,21266
9,0.026959,0.205128,0.000139,9,23629


In [23]:
del sasrec_non_default_model
torch.cuda.empty_cache()

# BERT4Rec

In [24]:
bert_trainer = Trainer(
    accelerator='gpu',
    devices=[1],
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS, 
    deterministic=True,
    callbacks=CALLBACKS,
)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
bert4rec_id_softmax_model = BERT4RecModel(
    mask_prob=0.5,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),
    trainer=bert_trainer,
    get_val_mask_func=GET_VAL_MASK,
    verbose=1,
)

In [26]:
%%time
bert4rec_id_softmax_model.fit(dataset_no_features)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                           | Params
---------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 2.1 M 
---------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.202     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

CPU times: user 8min 2s, sys: 21.3 s, total: 8min 23s
Wall time: 8min 5s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7fe0c9656a60>

In [27]:
loss_df, val_metrics_df = get_log_values(bert4rec_id_softmax_model, is_val=True)

In [28]:
loss_df

Unnamed: 0,epoch,train/loss,val/loss
0,0,16.926027,16.13648
1,1,17.872343,16.835089
2,2,18.278784,16.187969
3,3,18.398537,16.172079


In [29]:
val_metrics_df

Unnamed: 0,NDCG@10,Recall@10,Serendipity@10,epoch,step
0,0.021723,0.172037,1.4e-05,0,4741
1,0.022332,0.177499,1e-05,1,9483
2,0.021432,0.169306,1.1e-05,2,14225
3,0.021572,0.170945,2.1e-05,3,18967


In [30]:
del bert4rec_id_softmax_model
torch.cuda.empty_cache()