In [1]:
# TODO: will remove
import sys
sys.path.append("../../")

In [2]:
import numpy as np
import os
import pandas as pd
import itertools
import torch
import typing as tp
import warnings
from pathlib import Path

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from rectools import Columns
from rectools.dataset import Dataset
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics

from rectools.models.sasrec import (
    SASRecModel,
    SASRecDataPreparator,
    SessionEncoderLightningModule,
    IdEmbeddingsItemNet,
)


# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [3]:
%%time
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_en.zip -O data_en.zip
!unzip -o data_en.zip
!rm data_en.zip

In [4]:
# Download dataset
DATA_PATH = Path("data_en")
items = pd.read_csv(DATA_PATH / 'items_en.csv', index_col=0)
interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": Columns.Datetime})
)

print(interactions.shape)
interactions.head(2)

(5476251, 5)


Unnamed: 0,user_id,item_id,datetime,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0


In [5]:
interactions[Columns.User].nunique(), interactions[Columns.Item].nunique()

(962179, 15706)

In [6]:
# Process interactions
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
interactions = interactions[["user_id", "item_id", "datetime", "weight"]]
print(interactions.shape)
interactions.head(2)

(5476251, 4)


Unnamed: 0,user_id,item_id,datetime,weight
0,176549,9506,2021-05-11,3
1,699317,1659,2021-05-29,3


In [7]:
# Process item features
items = items.loc[items[Columns.Item].isin(interactions[Columns.Item])].copy()
items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")
genre_feature = items[["item_id", "genre"]].explode("genre")
genre_feature.columns = ["id", "value"]
genre_feature["feature"] = "genre"
content_feature = items.reindex(columns=[Columns.Item, "content_type"])
content_feature.columns = ["id", "value"]
content_feature["feature"] = "content_type"
item_features = pd.concat((genre_feature, content_feature))

In [8]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

In [9]:
dataset_no_features = Dataset.construct(interactions)
dataset_no_features

Dataset(user_id_map=IdMap(external_ids=array([176549, 699317, 656683, ..., 805174, 648596, 697262])), item_id_map=IdMap(external_ids=array([ 9506,  1659,  7107, ..., 10064, 13019, 10542])), interactions=Interactions(df=         user_id  item_id  weight   datetime
0              0        0     3.0 2021-05-11
1              1        1     3.0 2021-05-29
2              2        2     1.0 2021-05-09
3              3        3     3.0 2021-07-05
4              4        0     3.0 2021-04-30
...          ...      ...     ...        ...
5476246   962177      208     1.0 2021-08-13
5476247   224686     2690     3.0 2021-04-13
5476248   962178       21     3.0 2021-08-20
5476249     7934     1725     3.0 2021-04-19
5476250   631989      157     3.0 2021-08-15

[5476251 rows x 4 columns]), user_features=None, item_features=None)

# **Custome Validation** (Leave-One-Out Strategy)

In [10]:
def get_log_dir(trainer: Trainer) -> Path:
    """
    Get logging directory.
    """
    path = trainer.logger.log_dir
    vesrion = int(path.split("version_")[-1])
    last_path = path.split("version_")[0] + f"version_{vesrion - 1}"
    return Path(last_path) / "metrics.csv"


def get_losses(epoch_metrics_df: pd.DataFrame, is_val: bool) -> pd.DataFrame:
    loss_df = epoch_metrics_df[["epoch", "train/loss"]].dropna()
    if is_val:
        val_loss_df = epoch_metrics_df[["epoch", "val/loss"]].dropna()
        loss_df = pd.merge(loss_df, val_loss_df, how="inner", on="epoch")
    return loss_df.reset_index(drop=True)


def get_val_metrics(epoch_metrics_df: pd.DataFrame) -> pd.DataFrame:
    metrics_df = epoch_metrics_df.drop(columns=["train/loss", "val/loss"]).dropna()
    return metrics_df.reset_index(drop=True)


def get_log_values(trainer: Trainer, is_val: bool = False) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]:
    log_path = get_log_dir(trainer)
    epoch_metrics_df = pd.read_csv(log_path)

    loss_df = get_losses(epoch_metrics_df, is_val)
    val_metrics = None
    if is_val:
        val_metrics = get_val_metrics(epoch_metrics_df)
    return loss_df, val_metrics

In [11]:
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback


class ValidationMetrics(Callback):
    
    def __init__(self, val_metrics: tp.Dict) -> None:
        self.val_metrics = val_metrics
        self.catalog = None

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
        prev_interactions = pl_module.train_interactions
        if self.catalog is None:
            self.catalog = prev_interactions[Columns.Item].unique()

        if len(pl_module.epoch_val_recos) == 0:
            warnings.warn("Can not caclulate RecSys metrics, because `epoch_val_recos` and `epoch_targets` from `pl_module` is empty.")
            return None

        epoch_val_recos = []
        for val_recos in pl_module.epoch_val_recos:
            epoch_val_recos.extend(val_recos)

        epoch_targets = [] 
        for batch_targets in pl_module.epoch_targets:
            epoch_targets.extend(batch_targets)

        epoch_val_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(batch_val_recos)) for idx, batch_val_recos in enumerate(pl_module.epoch_val_recos)
            )
        )

        epoch_target_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(batch_targets)) for idx, batch_targets in enumerate(pl_module.epoch_targets)
            )
        )

        epoch_recos_df = pd.DataFrame(
            {
                Columns.User: epoch_val_users,
                Columns.Item: epoch_val_recos,
            }
        )
        epoch_recos_df[Columns.Rank] = epoch_recos_df.groupby(Columns.User, sort=False).cumcount() + 1

        interactions = pd.DataFrame(
            {
                Columns.User: epoch_target_users,
                Columns.Item: epoch_targets,
            }
        )
        result_metrics = calc_metrics(
            self.val_metrics, 
            epoch_recos_df,
            interactions, 
            prev_interactions, 
            self.catalog
        )
        
        self.log_dict(result_metrics)

        pl_module.epoch_val_recos.clear()
        pl_module.epoch_targets.clear() 

In [12]:
VAL_K_OUT = 1
N_VAL_USERS = 2048

unique_users = interactions[Columns.User].unique()
VAL_USERS = unique_users[: N_VAL_USERS]

VAL_METRICS = {
    "NDCG@10": NDCG(k=10),
    "Recall@10": Recall(k=10),
    "Serendipity@10": Serendipity(k=10),
}
VAL_MAX_K = max([metric.k for metric in VAL_METRICS.values()])

MIN_EPOCHS = 5
MAX_EPOCHS = 100

MONITOR_METRIC = "NDCG@10"
MODE_MONITOR_METRIC = "max"

callback_metrics = ValidationMetrics(val_metrics=VAL_METRICS)
callback_early_stopping = EarlyStopping(monitor=MONITOR_METRIC, patience=MIN_EPOCHS, min_delta=0.0, mode=MODE_MONITOR_METRIC)
CALLBACKS = [callback_metrics, callback_early_stopping]

unique_users.shape, VAL_USERS.shape

((962179,), (2048,))

In [13]:
trainer = Trainer(
    accelerator='gpu',
    devices=[1],
    min_epochs=MIN_EPOCHS, 
    max_epochs=MAX_EPOCHS, 
    deterministic=True,
    callbacks=CALLBACKS,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
sasrec_non_default_model = SASRecModel(
    n_factors=64,
    n_blocks=2,
    n_heads=2,
    dropout_rate=0.2,
    use_pos_emb=True,
    train_min_user_interaction=5,
    session_max_len=50,
    lr=1e-3,
    batch_size=128,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=SASRecDataPreparator,  # SASRecDataPreparator, 
    lightning_module_type=SessionEncoderLightningModule,  # SessionEncoderLightningModule,
    trainer=trainer,
    val_max_k=VAL_MAX_K
)

In [15]:
%%time
sasrec_non_default_model.fit(dataset_no_features, VAL_K_OUT, VAL_USERS)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                           | Params
---------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 988 K 
---------------------------------------------------------------
988 K     Trainable params
0         Non-trainable params
988 K     Total params
3.952     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

CPU times: user 13min 5s, sys: 20.5 s, total: 13min 26s
Wall time: 13min 17s


<rectools.models.sasrec.SASRecModel at 0x7fb6a613ab80>

In [16]:
loss_df, val_metrics_df = get_log_values(trainer, is_val=True)

In [17]:
loss_df

Unnamed: 0,epoch,train/loss,val/loss
0,0,16.38846,21.328363
1,1,15.719435,21.243851
2,2,15.568491,21.38291
3,3,15.500002,21.374329
4,4,15.457561,21.435516
5,5,15.432266,21.541876
6,6,15.409526,21.426628
7,7,15.390734,21.506046


In [18]:
val_metrics_df

Unnamed: 0,NDCG@10,Recall@10,Serendipity@10,epoch,step
0,0.001894,0.011827,3.870708e-07,0,2362
1,0.00191,0.01117,5.202952e-06,1,4725
2,0.002181,0.013141,5.207453e-06,2,7088
3,0.001423,0.009855,5.180448e-06,3,9451
4,0.001829,0.012484,5.432494e-06,4,11814
5,0.002036,0.015112,6.062609e-06,5,14177
6,0.001941,0.014455,5.837568e-06,6,16540
7,0.001724,0.013141,5.79256e-06,7,18903
