In [1]:
# TODO: will remove
import sys
sys.path.append("../../")

In [2]:
import numpy as np
import os
import pandas as pd
import torch
import typing as tp
import warnings
from pathlib import Path

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from rectools import Columns
from rectools.dataset import Dataset
from rectools.metrics import MAP, Serendipity, MeanInvUserFreq, calc_metrics

from rectools.models import BERT4RecModel, SASRecModel
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformer_base import TransformerModelBase

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [3]:
# %%time
# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip
# !unzip -o data_original.zip
# !rm data_original.zip

In [4]:
DATA_PATH = Path("data_original")

interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": "datetime"})
)

In [5]:
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)

# Split to train / test
max_date = interactions[Columns.Datetime].max()
train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()
test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()
train.drop(train.query("total_dur < 300").index, inplace=True)

# drop items with less than 20 interactions in train
items = train["item_id"].value_counts()
items = items[items >= 20]
items = items.index.to_list()
train = train[train["item_id"].isin(items)]
    
# drop users with less than 2 interactions in train
users = train["user_id"].value_counts()
users = users[users >= 2]
users = users.index.to_list()
train = train[(train["user_id"].isin(users))]

users = train["user_id"].drop_duplicates().to_list()

# drop cold users from test
test_users_sasrec = test[Columns.User].unique()
cold_users = set(test[Columns.User]) - set(train[Columns.User])
test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)
test_users = test[Columns.User].unique()


In [6]:
items = pd.read_csv(DATA_PATH / 'items.csv')

In [7]:
# Process item features to the form of a flatten dataframe
items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()
items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")
genre_feature = items[["item_id", "genre"]].explode("genre")
genre_feature.columns = ["id", "value"]
genre_feature["feature"] = "genre"
content_feature = items.reindex(columns=[Columns.Item, "content_type"])
content_feature.columns = ["id", "value"]
content_feature["feature"] = "content_type"
item_features = pd.concat((genre_feature, content_feature))

candidate_items = interactions['item_id'].drop_duplicates().astype(int)
test["user_id"] = test["user_id"].astype(int)
test["item_id"] = test["item_id"].astype(int)

catalog=train[Columns.Item].unique()

In [8]:
dataset_no_features = Dataset.construct(
    interactions_df=train,
)

dataset_item_features = Dataset.construct(
    interactions_df=train,
    item_features_df=item_features,
    cat_item_features=["genre", "content_type"],
)

In [9]:
metrics_name = {
    'MAP': MAP,
    'MIUF': MeanInvUserFreq,
    'Serendipity': Serendipity
    

}
metrics = {}
for metric_name, metric in metrics_name.items():
    for k in (1, 5, 10):
        metrics[f'{metric_name}@{k}'] = metric(k=k)

# list with metrics results of all models
features_results = []

In [10]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

In [11]:
def get_log_dir(model: TransformerModelBase) -> Path:
    """
    Get logging directory.
    """
    path = model.fit_trainer.log_dir
    return Path(path) / "metrics.csv"


def get_losses(epoch_metrics_df: pd.DataFrame, is_val: bool) -> pd.DataFrame:
    loss_df = epoch_metrics_df[["epoch", "train_loss"]].dropna()
    if is_val:
        val_loss_df = epoch_metrics_df[["epoch", "val_loss"]].dropna()
        loss_df = pd.merge(loss_df, val_loss_df, how="inner", on="epoch")
    return loss_df.reset_index(drop=True)


def get_val_metrics(epoch_metrics_df: pd.DataFrame) -> pd.DataFrame:
    metrics_df = epoch_metrics_df.drop(columns=["train_loss", "val_loss"]).dropna()
    return metrics_df.reset_index(drop=True)


def get_log_values(model: TransformerModelBase, is_val: bool = False) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]:
    log_path = get_log_dir(model)
    epoch_metrics_df = pd.read_csv(log_path)

    loss_df = get_losses(epoch_metrics_df, is_val)
    val_metrics = None
    if is_val:
        val_metrics = get_val_metrics(epoch_metrics_df)
    return loss_df, val_metrics

**Model common params**

In [12]:
MIN_EPOCHS = 5
MAX_EPOCHS = 5
TRAIN_MIN_USER_INTERACTIONS = 5
SESSION_MAX_LEN = 50
N_NEGATIVES = 5

ACCELERATOR = "gpu"
DEVICES = [0]

N_FACTORS = 256
N_BLOCKS = 4
N_HEADS = 4
DROPOUT_RATE = 0.2
BATCH_SIZE = 128
LR = 1e-3


# **Training Objective**

https://arxiv.org/pdf/2205.04507

## **Next Action**

In [13]:
from typing import Dict, List, Tuple

from rectools.models.nn.transformer_data_preparator import TransformerDataPreparatorBase
from rectools.models.nn.transformer_lightning import TransformerLightningModule


# For Bert-like models need add MASK to `_collate_fn` methods.

class NextItemDataPreparator(TransformerDataPreparatorBase):
    """Data preparator for SASRecModel."""

    train_session_max_len_addition: int = 1

    def _collate_fn_train(
        self,
        batch: List[Tuple[List[int], List[float]]],
    ) -> Dict[str, torch.Tensor]:
        """
        Truncate each session from right to keep `session_max_len` items.
        Do left padding until `session_max_len` is reached.
        Split to `x`, `y`, and `yw`.
        """
        batch_size = len(batch)
        x = np.zeros((batch_size, self.session_max_len))
        y = np.zeros((batch_size, 1))
        yw = np.zeros((batch_size, 1))
        for i, (ses, ses_weights) in enumerate(batch):
            x[i, -len(ses) + 1 :] = ses[:-1]  # ses: [session_len] -> x[i]: [session_max_len]
            y[i] = ses[-1]  # ses: [session_len] -> y[i]: [1]
            yw[i] = ses_weights[-1]  # ses_weights: [session_len] -> yw[i]: [1]

        batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
        if self.n_negatives is not None:
            negatives = torch.randint(
                low=self.n_item_extra_tokens,
                high=self.item_id_map.size,
                size=(batch_size, 1, self.n_negatives),
            )  # [batch_size, 1, n_negatives]
            batch_dict["negatives"] = negatives
        return batch_dict
    
    def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
        """Right truncation, left padding to session_max_len"""
        x = np.zeros((len(batch), self.session_max_len))
        for i, (ses, _) in enumerate(batch):
            x[i, -len(ses) :] = ses[-self.session_max_len :]
        return {"x": torch.LongTensor(x)}


class NextItemLightningModule(TransformerLightningModule):

    def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step."""
        x, y, w = batch["x"], batch["y"], batch["yw"]
        if self.loss == "softmax":
            logits = self._get_full_catalog_logits(x)[:, -1: :]
            loss = self._calc_softmax_loss(logits, y, w)
        elif self.loss == "BCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_bce_loss(logits, y, w)
        elif self.loss == "gBCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_gbce_loss(logits, y, w, negatives)
        else:
            loss = self._calc_custom_loss(batch, batch_idx)

        self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)

        return loss 

In [14]:
def get_nextitem_trainer():
    return Trainer(
        accelerator=ACCELERATOR,
        devices=DEVICES,
        min_epochs=MIN_EPOCHS,
        max_epochs=MAX_EPOCHS, 
        deterministic=True,
    )

In [15]:
sasrec_nextitem_model = SASRecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
    get_trainer_func=get_nextitem_trainer,
)

N_NEGATIVES = 5

sasrec_nextitem_bce_model = SASRecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss="BCE",
    n_negatives=N_NEGATIVES,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
    get_trainer_func=get_nextitem_trainer,
)

sasrec_nextitem_gbce_model = SASRecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss="gBCE",
    n_negatives=N_NEGATIVES,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
    get_trainer_func=get_nextitem_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
%%time
sasrec_nextitem_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 3.0 M 
---------------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.175    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 4min 21s, sys: 6.17 s, total: 4min 27s
Wall time: 4min 15s


<rectools.models.nn.sasrec.SASRecModel at 0x7fc82bf640d0>

In [17]:
%%time
recos = sasrec_nextitem_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "sasrec_next_action_softmax"
features_results.append(metric_values)
features_results

CPU times: user 1min 50s, sys: 8min 36s, total: 10min 26s
Wall time: 20.3 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'}]

In [18]:
%%time
sasrec_nextitem_bce_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 3.0 M 
---------------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.175    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 4min 22s, sys: 6.2 s, total: 4min 28s
Wall time: 4min 15s


<rectools.models.nn.sasrec.SASRecModel at 0x7fc9fd8f9af0>

In [19]:
%%time
recos = sasrec_nextitem_bce_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "sasrec_next_action_bce"
features_results.append(metric_values)
features_results

CPU times: user 2min 15s, sys: 11min 4s, total: 13min 20s
Wall time: 23.4 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'}]

In [20]:
%%time
sasrec_nextitem_gbce_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 3.0 M 
---------------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.175    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 4min 34s, sys: 5.86 s, total: 4min 39s
Wall time: 4min 20s


<rectools.models.nn.sasrec.SASRecModel at 0x7fc82bf64a60>

In [21]:
%%time
recos = sasrec_nextitem_gbce_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "sasrec_next_action_gbce"
features_results.append(metric_values)
features_results

CPU times: user 2min 16s, sys: 10min 4s, total: 12min 21s
Wall time: 23.3 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'},
 {'MAP@1': 0.03477402284752523,
  'MAP@5': 0.06698163027041105,
  'MAP@10': 0.07277966423896895,
  'MIUF@1': 2.020424706174193,
  'MIUF@5': 2.823113684972394,
  'MIUF@10': 3.7625225192486647,
  'Serendipity@1': 7.626538834693788e-06,
  'Serendipity@5': 2.4311904032477694e-05,
 

In [22]:
softmax_loss_df, _ = get_log_values(sasrec_nextitem_model, is_val=False)
softmax_loss_df["loss_type"] = "softmax"
bce_loss_df, _ = get_log_values(sasrec_nextitem_bce_model, is_val=False)
bce_loss_df["loss_type"] = "bce"
gbce_loss_df, _ = get_log_values(sasrec_nextitem_gbce_model, is_val=False)
gbce_loss_df["loss_type"] = "gbce"
pd.concat([softmax_loss_df, bce_loss_df, gbce_loss_df], axis=1)

Unnamed: 0,epoch,train_loss,loss_type,epoch.1,train_loss.1,loss_type.1,epoch.2,train_loss.2,loss_type.2
0,0,18.603785,softmax,0,0.76706,bce,0,0.6705,gbce
1,1,18.375935,softmax,1,0.721544,bce,1,0.627388,gbce
2,2,18.355379,softmax,2,0.715669,bce,2,0.622248,gbce
3,3,18.337034,softmax,3,0.711058,bce,3,0.620583,gbce
4,4,18.291851,softmax,4,0.709537,bce,4,0.617849,gbce


**use_causal_attn=True**

In [23]:
nextitem_model_with_casual_mask = SASRecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    use_causal_attn=True,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
    get_trainer_func=get_nextitem_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [24]:
%%time
nextitem_model_with_casual_mask.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 3.0 M 
---------------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.175    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 4min 24s, sys: 5.98 s, total: 4min 30s
Wall time: 4min 19s


<rectools.models.nn.sasrec.SASRecModel at 0x7fc743ca8040>

In [25]:
%%time
recos = nextitem_model_with_casual_mask.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "sasrec_next_action_softmax_casual"
features_results.append(metric_values)
features_results

CPU times: user 2min 34s, sys: 10min 11s, total: 12min 46s
Wall time: 22.3 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'},
 {'MAP@1': 0.03477402284752523,
  'MAP@5': 0.06698163027041105,
  'MAP@10': 0.07277966423896895,
  'MIUF@1': 2.020424706174193,
  'MIUF@5': 2.823113684972394,
  'MIUF@10': 3.7625225192486647,
  'Serendipity@1': 7.626538834693788e-06,
  'Serendipity@5': 2.4311904032477694e-05,
 

In [26]:
loss_df, _ = get_log_values(nextitem_model_with_casual_mask, is_val=False)
loss_df

Unnamed: 0,epoch,train_loss
0,0,18.612335
1,1,18.424408
2,2,18.361567
3,3,18.333748
4,4,18.335825


# ALBERT

In [27]:
import typing as tp

import torch
import torch.nn as nn
import typing_extensions as tpe
from pytorch_lightning import Trainer

from rectools.dataset.dataset import Dataset, DatasetSchema
from rectools.models.nn.transformer_data_preparator import TransformerDataPreparatorBase
from rectools.models.nn.item_net import (
    CatFeaturesItemNet,
    IdEmbeddingsItemNet,
    ItemNetBase,
    ItemNetConstructorBase,
    SumOfEmbeddingsConstructor,
)
from rectools.models import BERT4RecModel
from rectools.models.nn.bert4rec import BERT4RecModelConfig, BERT4RecDataPreparator
from rectools.models.nn.transformer_base import ValMaskCallable, TrainerCallable
from rectools.models.nn.transformer_lightning import TransformerLightningModuleBase, TransformerLightningModule
from rectools.models.nn.transformer_net_blocks import (
    LearnableInversePositionalEncoding,
    PreLNTransformerLayer,
    PositionalEncodingBase,
    TransformerLayersBase,
)


class AlBERT4RecSumOfEmbeddingsConstructor(SumOfEmbeddingsConstructor):

    def __init__(
        self,
        n_items: int,
        emb_factors: int,
        n_factors: int,
        item_net_blocks: tp.Sequence[ItemNetBase],
    ) -> None:
        super().__init__(
            n_items=n_items,
            item_net_blocks=item_net_blocks
        )
        self.item_emb_proj = nn.Linear(emb_factors, n_factors)

    @classmethod
    def from_dataset(
        cls,
        dataset: Dataset,
        emb_factors: int,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
    ) -> tpe.Self:
        n_items = dataset.item_id_map.size

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset(dataset, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, emb_factors, n_factors, item_net_blocks)

    @classmethod
    def from_dataset_schema(
        cls,
        dataset_schema: DatasetSchema,
        emb_factors: int,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
    ) -> tpe.Self:
        n_items = dataset_schema.items.n_hot

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset_schema(dataset_schema, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, emb_factors, n_factors, item_net_blocks)

    def forward(self, items: torch.Tensor) -> torch.Tensor:
        item_embs = super().forward(items)
        item_embs = self.item_emb_proj(item_embs)
        return item_embs


class AlBERT4RecPreLNTransformerLayers(TransformerLayersBase):

    def __init__(
        self,
        n_blocks: int,
        n_hidden_groups: int,
        n_inner_groups: int,
        n_factors: int,
        n_heads: int,
        dropout_rate: float,
        ff_factors_multiplier: int = 4,
    ):
        super().__init__()
        self.n_blocks = n_blocks
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        n_fitted_blocks = int(n_hidden_groups * n_inner_groups)
        self.transformer_layers = nn.ModuleList(
            [
                PreLNTransformerLayer(
                    # number of encoder layer (AlBERTLayers)
                    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L428
                    n_factors,
                    n_heads,
                    dropout_rate,
                    ff_factors_multiplier,
                ) 
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L469
                for _ in range(n_fitted_blocks)
            ]
        )
        self.n_layers_per_group = n_blocks / n_hidden_groups

    def forward(
        self,
        seqs: torch.Tensor,
        timeline_mask: torch.Tensor,
        attn_mask: tp.Optional[torch.Tensor],
        key_padding_mask: tp.Optional[torch.Tensor],
    ) -> torch.Tensor:
        for layer_idx_in_group in range(self.n_blocks):
            group_idx = int(layer_idx_in_group / self.n_layers_per_group)
            seqs = self.transformer_layers[group_idx](seqs, timeline_mask, attn_mask, key_padding_mask)
        return seqs


class AlBERT4RecModelConfig(BERT4RecModelConfig):

    n_hidden_groups: int = 1
    n_inner_groups: int = 1
    emb_factors: int = 64


class AlBERT4RecModel(BERT4RecModel):
    """
    https://arxiv.org/pdf/1909.11942
    """
    
    config_class = AlBERT4RecModelConfig

    def __init__(  # pylint: disable=too-many-arguments, too-many-locals
        self,
        n_blocks: int = 2,
        n_hidden_groups: int = 1,
        n_inner_groups: int = 1,
        n_heads: int = 4,
        n_factors: int = 256,
        emb_factors: int = 64,
        dropout_rate: float = 0.0,
        mask_prob: float = 0.15,
        session_max_len: int = 100,
        train_min_user_interactions: int = 2,
        loss: str = "softmax",
        n_negatives: int = 1,
        gbce_t: float = 0.2,
        lr: float = 0.001,
        batch_size: int = 128,
        epochs: int = 3,
        deterministic: bool = False,
        verbose: int = 0,
        dataloader_num_workers: int = 0,
        use_pos_emb: bool = True,
        use_key_padding_mask: bool = True,
        use_causal_attn: bool = False,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet ),
        item_net_constructor_type: tp.Type[ItemNetConstructorBase] = AlBERT4RecSumOfEmbeddingsConstructor,
        pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
        transformer_layers_type: tp.Type[TransformerLayersBase] = AlBERT4RecPreLNTransformerLayers,
        data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
        lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
        get_val_mask_func: tp.Optional[ValMaskCallable] = None,
        get_trainer_func: tp.Optional[TrainerCallable] = None,
        recommend_batch_size: int = 256,
        recommend_device: tp.Optional[str] = None,
        recommend_n_threads: int = 0,
        recommend_use_gpu_ranking: bool = True,  # TODO: remove after TorchRanker
    ):
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        self.emb_factors = emb_factors

        if n_blocks < n_hidden_groups:
            warnings.warn(
                "When `n_hidden_groups` less than `n_blocks` that will use in the forward only one hidden group."
            ) 

        super().__init__(
            transformer_layers_type=transformer_layers_type,
            data_preparator_type=data_preparator_type,
            n_blocks=n_blocks,
            n_heads=n_heads,
            n_factors=n_factors,
            use_pos_emb=use_pos_emb,
            use_causal_attn=use_causal_attn,
            use_key_padding_mask=use_key_padding_mask,
            dropout_rate=dropout_rate,
            session_max_len=session_max_len,
            dataloader_num_workers=dataloader_num_workers,
            batch_size=batch_size,
            loss=loss,
            n_negatives=n_negatives,
            gbce_t=gbce_t,
            lr=lr,
            epochs=epochs,
            verbose=verbose,
            deterministic=deterministic,
            recommend_device=recommend_device,
            recommend_batch_size=recommend_batch_size,
            recommend_n_threads=recommend_n_threads,
            recommend_use_gpu_ranking=recommend_use_gpu_ranking,
            train_min_user_interactions=train_min_user_interactions,
            mask_prob=mask_prob,
            item_net_block_types=item_net_block_types,
            item_net_constructor_type=item_net_constructor_type,
            pos_encoding_type=pos_encoding_type,
            lightning_module_type=lightning_module_type,
            get_val_mask_func=get_val_mask_func,
            get_trainer_func=get_trainer_func,
        )
    
    def init_item_net_from_dataset(self, dataset: Dataset) -> ItemNetConstructorBase:
        return self.item_net_constructor_type.from_dataset(
            dataset, self.emb_factors, self.n_factors, self.dropout_rate, self.item_net_block_types
        )

    def init_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetConstructorBase:
        return self.item_net_constructor_type.from_dataset_schema(
            dataset_schema, self.emb_factors, self.n_factors, self.dropout_rate, self.item_net_block_types
        )

    def init_transformer_layers(self) -> TransformerLayersBase:
        return self.transformer_layers_type(
            n_blocks=self.n_blocks,
            n_hidden_groups=self.n_hidden_groups,
            n_inner_groups=self.n_inner_groups,
            n_factors=self.n_factors,
            n_heads=self.n_heads,
            dropout_rate=self.dropout_rate,
        )


In [28]:
def get_albert_trainer():
    return Trainer(
        accelerator=ACCELERATOR,
        devices=DEVICES,
        min_epochs=MIN_EPOCHS,
        max_epochs=MAX_EPOCHS, 
        deterministic=True,
    )

In [29]:
EMB_FACTORS = 64
N_HIDDEN_GROUPS = 2
N_INNER_GROUPS = 1

In [30]:
albert_model = AlBERT4RecModel(
    n_factors=N_FACTORS,
    emb_factors=EMB_FACTORS,
    n_blocks=N_BLOCKS,
    n_hidden_groups=N_HIDDEN_GROUPS,
    n_inner_groups=N_INNER_GROUPS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    use_causal_attn=False,
    batch_size=BATCH_SIZE,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,
    get_trainer_func=get_albert_trainer,
)


N_NEGATIVES = 5

albert_model_bce = AlBERT4RecModel(
    n_factors=N_FACTORS,
    emb_factors=EMB_FACTORS,
    n_blocks=N_BLOCKS,
    n_hidden_groups=N_HIDDEN_GROUPS,
    n_inner_groups=N_INNER_GROUPS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    use_causal_attn=False,
    batch_size=BATCH_SIZE,
    loss="BCE",
    n_negatives=N_NEGATIVES,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,
    get_trainer_func=get_albert_trainer,
)

albert_model_gbce = AlBERT4RecModel(
    n_factors=N_FACTORS,
    emb_factors=EMB_FACTORS,
    n_blocks=N_BLOCKS,
    n_hidden_groups=N_HIDDEN_GROUPS,
    n_inner_groups=N_INNER_GROUPS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=True,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    use_causal_attn=False,
    batch_size=BATCH_SIZE,
    loss="gBCE",
    n_negatives=N_NEGATIVES,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,
    get_trainer_func=get_albert_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [31]:
%%time
albert_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.0 M 
---------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.884     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 5min 57s, sys: 4.42 s, total: 6min 1s
Wall time: 5min 48s


<__main__.AlBERT4RecModel at 0x7fc7205daeb0>

In [32]:
%%time
recos = albert_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "albert_softmax"
features_results.append(metric_values)
features_results

CPU times: user 2min 33s, sys: 10min 25s, total: 12min 59s
Wall time: 25.1 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'},
 {'MAP@1': 0.03477402284752523,
  'MAP@5': 0.06698163027041105,
  'MAP@10': 0.07277966423896895,
  'MIUF@1': 2.020424706174193,
  'MIUF@5': 2.823113684972394,
  'MIUF@10': 3.7625225192486647,
  'Serendipity@1': 7.626538834693788e-06,
  'Serendipity@5': 2.4311904032477694e-05,
 

In [33]:
%%time
albert_model_bce.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.0 M 
---------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.884     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 4min 59s, sys: 3.91 s, total: 5min 3s
Wall time: 4min 52s


<__main__.AlBERT4RecModel at 0x7fc7205da4f0>

In [34]:
%%time
recos = albert_model_bce.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "albert_bce"
features_results.append(metric_values)
features_results

CPU times: user 2min 46s, sys: 10min 19s, total: 13min 6s
Wall time: 22.7 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'},
 {'MAP@1': 0.03477402284752523,
  'MAP@5': 0.06698163027041105,
  'MAP@10': 0.07277966423896895,
  'MIUF@1': 2.020424706174193,
  'MIUF@5': 2.823113684972394,
  'MIUF@10': 3.7625225192486647,
  'Serendipity@1': 7.626538834693788e-06,
  'Serendipity@5': 2.4311904032477694e-05,
 

In [35]:
%%time
albert_model_gbce.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.0 M 
---------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.884     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 5min 26s, sys: 5.2 s, total: 5min 31s
Wall time: 5min


<__main__.AlBERT4RecModel at 0x7fc7205dad90>

In [36]:
%%time
recos = albert_model_gbce.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "albert_gbce"
features_results.append(metric_values)
features_results

CPU times: user 2min 10s, sys: 10min 27s, total: 12min 38s
Wall time: 22.5 s


[{'MAP@1': 0.0347556169412836,
  'MAP@5': 0.0670784020904029,
  'MAP@10': 0.07305790205735346,
  'MIUF@1': 2.026684699022298,
  'MIUF@5': 2.809957769256009,
  'MIUF@10': 3.7957813994754876,
  'Serendipity@1': 8.831319032679852e-06,
  'Serendipity@5': 2.9488162768383078e-05,
  'Serendipity@10': 4.3597166218751995e-05,
  'model': 'sasrec_next_action_softmax'},
 {'MAP@1': 0.030642412703662827,
  'MAP@5': 0.06354385798198152,
  'MAP@10': 0.06939375803777235,
  'MIUF@1': 2.1760997525872234,
  'MIUF@5': 2.8881699041484103,
  'MIUF@10': 3.6832272807506397,
  'Serendipity@1': 1.0021156220458501e-05,
  'Serendipity@5': 2.479979331575515e-05,
  'Serendipity@10': 3.303582017928996e-05,
  'model': 'sasrec_next_action_bce'},
 {'MAP@1': 0.03477402284752523,
  'MAP@5': 0.06698163027041105,
  'MAP@10': 0.07277966423896895,
  'MIUF@1': 2.020424706174193,
  'MIUF@5': 2.823113684972394,
  'MIUF@10': 3.7625225192486647,
  'Serendipity@1': 7.626538834693788e-06,
  'Serendipity@5': 2.4311904032477694e-05,
 

In [37]:
softmax_loss_df, _ = get_log_values(albert_model, is_val=False)
softmax_loss_df["loss_type"] = "softmax"
bce_loss_df, _ = get_log_values(albert_model_bce, is_val=False)
bce_loss_df["loss_type"] = "bce"
gbce_loss_df, _ = get_log_values(albert_model_gbce, is_val=False)
gbce_loss_df["loss_type"] = "gbce"
pd.concat([softmax_loss_df, bce_loss_df, gbce_loss_df], axis=1)

Unnamed: 0,epoch,train_loss,loss_type,epoch.1,train_loss.1,loss_type.1,epoch.2,train_loss.2,loss_type.2
0,0,19.219992,softmax,0,0.882489,bce,0,8.39836,gbce
1,1,18.369999,softmax,1,0.724247,bce,1,8.37926,gbce
2,2,17.939146,softmax,2,0.69663,bce,2,8.383474,gbce
3,3,17.678568,softmax,3,0.677156,bce,3,8.377734,gbce
4,4,17.43996,softmax,4,0.664593,bce,4,8.377232,gbce


In [38]:
df_metrics = pd.DataFrame(features_results)
df_metrics

Unnamed: 0,MAP@1,MAP@5,MAP@10,MIUF@1,MIUF@5,MIUF@10,Serendipity@1,Serendipity@5,Serendipity@10,model
0,0.034756,0.067078,0.073058,2.026685,2.809958,3.795781,9e-06,2.9e-05,4.4e-05,sasrec_next_action_softmax
1,0.030642,0.063544,0.069394,2.1761,2.88817,3.683227,1e-05,2.5e-05,3.3e-05,sasrec_next_action_bce
2,0.034774,0.066982,0.07278,2.020425,2.823114,3.762523,8e-06,2.4e-05,3.6e-05,sasrec_next_action_gbce
3,0.034764,0.067061,0.072467,2.029366,2.814538,3.616758,9e-06,2.7e-05,2.9e-05,sasrec_next_action_softmax_casual
4,0.04002,0.066784,0.074854,4.295447,5.091198,5.485631,0.000512,0.000448,0.000419,albert_softmax
5,0.038324,0.065402,0.073086,2.710455,3.571952,4.316234,6e-05,9.6e-05,0.000122,albert_bce
6,0.010698,0.016777,0.019512,4.308966,7.650483,7.218645,2.8e-05,5.1e-05,4.4e-05,albert_gbce


Всего прогонов через слои `n_blocks * n_inner_groups`, при этом обучаемых слоев `n_hidden_groups * n_inner_groups`

In [None]:
# n_blocks = 11
# n_hidden_groups = 2
# n_inner_groups = 3
# nb nhg nig
# 0   0   0
#         1
#         2
# 1   0   0
#         1
#         2
# 2   0   0
#         1
#         2
# 3   0   0
#         1
#         2
# 4   0   0
#         1
#         2
# 5   0   3
#         4
#         5
# 6   1   3
#         4
#         5
# 7   1   3
#         4
#         5
# 8   1   3
#         4
#         5
# 9   1   3
#         4
#         5
# 10  1   3
#         4
#         5
