In [1]:
# TODO: will remove
import sys
sys.path.append("../../")

In [2]:
import numpy as np
import os
import pandas as pd
import torch
import typing as tp
import warnings
from pathlib import Path

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from rectools import Columns
from rectools.dataset import Dataset
from rectools.metrics import MAP, Serendipity, MeanInvUserFreq, calc_metrics

from rectools.models import BERT4RecModel, SASRecModel
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformer_base import TransformerModelBase

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [3]:
# %%time
# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip
# !unzip -o data_original.zip
# !rm data_original.zip

In [4]:
DATA_PATH = Path("data_original")

interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": "datetime"})
)

In [5]:
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)

# Split to train / test
max_date = interactions[Columns.Datetime].max()
train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()
test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()
train.drop(train.query("total_dur < 300").index, inplace=True)

# drop items with less than 20 interactions in train
items = train["item_id"].value_counts()
items = items[items >= 20]
items = items.index.to_list()
train = train[train["item_id"].isin(items)]
    
# drop users with less than 2 interactions in train
users = train["user_id"].value_counts()
users = users[users >= 2]
users = users.index.to_list()
train = train[(train["user_id"].isin(users))]

users = train["user_id"].drop_duplicates().to_list()

# drop cold users from test
test_users_sasrec = test[Columns.User].unique()
cold_users = set(test[Columns.User]) - set(train[Columns.User])
test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)
test_users = test[Columns.User].unique()


In [6]:
items = pd.read_csv(DATA_PATH / 'items.csv')

In [7]:
# Process item features to the form of a flatten dataframe
items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()
items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")
genre_feature = items[["item_id", "genre"]].explode("genre")
genre_feature.columns = ["id", "value"]
genre_feature["feature"] = "genre"
content_feature = items.reindex(columns=[Columns.Item, "content_type"])
content_feature.columns = ["id", "value"]
content_feature["feature"] = "content_type"
item_features = pd.concat((genre_feature, content_feature))

candidate_items = interactions['item_id'].drop_duplicates().astype(int)
test["user_id"] = test["user_id"].astype(int)
test["item_id"] = test["item_id"].astype(int)

catalog=train[Columns.Item].unique()

In [8]:
dataset_no_features = Dataset.construct(
    interactions_df=train,
)

dataset_item_features = Dataset.construct(
    interactions_df=train,
    item_features_df=item_features,
    cat_item_features=["genre", "content_type"],
)

In [9]:
metrics_name = {
    'MAP': MAP,
    'MIUF': MeanInvUserFreq,
    'Serendipity': Serendipity
    

}
metrics = {}
for metric_name, metric in metrics_name.items():
    for k in (1, 5, 10):
        metrics[f'{metric_name}@{k}'] = metric(k=k)

# list with metrics results of all models
features_results = []

In [10]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

**Common model params**

In [11]:
EPOCHS = 5
TRAIN_MIN_USER_INTERACTIONS = 5
SESSION_MAX_LEN = 50

N_FACTORS = 256
N_BLOCKS = 4
N_HEADS = 4
USE_POS_EMB = True
LOSS = "softmax"
DROPOUT_RATE = 0.2
BATCH_SIZE = 128
LR = 1e-3


# **Training Objective**

https://arxiv.org/pdf/2205.04507

## **Next Action**

In [12]:
from typing import Dict, List, Tuple

from rectools.models.nn.bert4rec import BERT4RecDataPreparator
from rectools.models.nn.constants import MASKING_VALUE
from rectools.models.nn.transformer_lightning import TransformerLightningModule


# "MASK" token is used for predicting one next token.
class NextItemDataPreparator(BERT4RecDataPreparator):
    
    def _collate_fn_train(
        self,
        batch: List[Tuple[List[int], List[float]]],
    ) -> Dict[str, torch.Tensor]:
        """
        Truncate each session from right to keep `session_max_len` items.
        Do left padding until `session_max_len` is reached.
        Split to `x`, `y`, and `yw`.
        """
        batch_size = len(batch)
        x = np.zeros((batch_size, self.session_max_len))
        y = np.zeros((batch_size, 1))
        yw = np.zeros((batch_size, 1))
        for i, (ses, ses_weights) in enumerate(batch):
            session = ses.copy()
            session[-1] = self.extra_token_ids[MASKING_VALUE]
            x[i, -len(ses) :] = session  # ses: [session_len] -> x[i]: [session_max_len]
            y[i] = ses[-1]  # ses: [session_len] -> y[i]: [1]
            yw[i] = ses_weights[-1]  # ses_weights: [session_len] -> yw[i]: [1]

        batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
        if self.n_negatives is not None:
            negatives = torch.randint(
                low=self.n_item_extra_tokens,
                high=self.item_id_map.size,
                size=(batch_size, 1, self.n_negatives),
            )  # [batch_size, 1, n_negatives]
            batch_dict["negatives"] = negatives
        return batch_dict


# Last logits are used for reducing the number of calculations on training step.
# You could also fill the y with zeros except for the last item in `_collate_fn_train`` and not change the training step 
class NextItemLightningModule(TransformerLightningModule):

    def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step."""
        x, y, w = batch["x"], batch["y"], batch["yw"]
        if self.loss == "softmax":
            logits = self._get_full_catalog_logits(x)[:, -1: :]
            loss = self._calc_softmax_loss(logits, y, w)
        elif self.loss == "BCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_bce_loss(logits, y, w)
        elif self.loss == "gBCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_gbce_loss(logits, y, w, negatives)
        else:
            loss = self._calc_custom_loss(batch, batch_idx)

        self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)

        return loss
    

class NextItemSASRecModel(SASRecModel):

    def _init_data_preparator(self) -> None:
        self.data_preparator = self.data_preparator_type(
            session_max_len=self.session_max_len,
            n_negatives=self.n_negatives if self.loss != "softmax" else None,
            batch_size=self.batch_size,
            dataloader_num_workers=self.dataloader_num_workers,
            train_min_user_interactions=self.train_min_user_interactions,
            mask_prob=0.15,  # Add default `mask_proba`, because SASRec has not this parameter
            get_val_mask_func=self.get_val_mask_func,
        )

In [13]:
bert4rec_nextitem_model = BERT4RecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=USE_POS_EMB,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss=LOSS,
    epochs=EPOCHS,
    verbose=1,
    deterministic=True,
    use_causal_attn=False,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
)

sasrec_nextitem_model = NextItemSASRecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=USE_POS_EMB,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss=LOSS,
    epochs=EPOCHS,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
%%time
bert4rec_nextitem_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 4.6 M 
---------------------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.478    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 7min 9s, sys: 8.27 s, total: 7min 17s
Wall time: 6min 56s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f91fc1d1460>

In [15]:
%%time
recos = bert4rec_nextitem_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "bert_next_action_softmax"
features_results.append(metric_values)

CPU times: user 2min 5s, sys: 9min 37s, total: 11min 43s
Wall time: 25.5 s


In [17]:
%%time
sasrec_nextitem_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 3.0 M 
---------------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.176    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 6min 20s, sys: 8 s, total: 6min 28s
Wall time: 5min 58s


<__main__.NextItemSASRecModel at 0x7f9202293760>

In [18]:
%%time
recos = sasrec_nextitem_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "sasrec_next_action_softmax"
features_results.append(metric_values)

CPU times: user 2min, sys: 9min 35s, total: 11min 35s
Wall time: 24.3 s


**BERT4Rec with use_causal_attn = True**

In [20]:
bert4rec_nextitem_model_with_casual_mask = BERT4RecModel(
    n_factors=N_FACTORS,
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=USE_POS_EMB,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    use_causal_attn=True,
    loss=LOSS,
    epochs=EPOCHS,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=NextItemDataPreparator,
    lightning_module_type=NextItemLightningModule,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
%%time
bert4rec_nextitem_model_with_casual_mask.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 4.6 M 
---------------------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.478    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 7min 2s, sys: 6.52 s, total: 7min 9s
Wall time: 6min 47s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f908487a190>

In [22]:
%%time
recos = bert4rec_nextitem_model_with_casual_mask.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "bert_next_action_softmax_causal"
features_results.append(metric_values)

CPU times: user 1min 56s, sys: 9min 8s, total: 11min 5s
Wall time: 25.7 s


# ALBERT

In [24]:
import typing as tp

import torch
import torch.nn as nn
import typing_extensions as tpe

from rectools.dataset.dataset import Dataset, DatasetSchema
from rectools.models.nn.transformer_data_preparator import TransformerDataPreparatorBase
from rectools.models.nn.item_net import (
    CatFeaturesItemNet,
    IdEmbeddingsItemNet,
    ItemNetBase,
    ItemNetConstructorBase,
    SumOfEmbeddingsConstructor,
)
from rectools.models import BERT4RecModel
from rectools.models.nn.bert4rec import BERT4RecModelConfig, BERT4RecDataPreparator
from rectools.models.nn.transformer_base import ValMaskCallable, TrainerCallable
from rectools.models.nn.transformer_lightning import TransformerLightningModuleBase, TransformerLightningModule
from rectools.models.nn.transformer_net_blocks import (
    LearnableInversePositionalEncoding,
    PreLNTransformerLayer,
    PositionalEncodingBase,
    TransformerLayersBase,
)


class AlBERT4RecSumOfEmbeddingsConstructor(SumOfEmbeddingsConstructor):

    def __init__(
        self,
        n_items: int,
        emb_factors: int,
        n_factors: int,
        item_net_blocks: tp.Sequence[ItemNetBase],
    ) -> None:
        super().__init__(
            n_items=n_items,
            item_net_blocks=item_net_blocks
        )
        self.item_emb_proj = nn.Linear(emb_factors, n_factors)

    @classmethod
    def from_dataset(
        cls,
        dataset: Dataset,
        emb_factors: int,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
    ) -> tpe.Self:
        n_items = dataset.item_id_map.size

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset(dataset, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, emb_factors, n_factors, item_net_blocks)

    @classmethod
    def from_dataset_schema(
        cls,
        dataset_schema: DatasetSchema,
        emb_factors: int,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
    ) -> tpe.Self:
        n_items = dataset_schema.items.n_hot

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset_schema(dataset_schema, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, emb_factors, n_factors, item_net_blocks)

    def forward(self, items: torch.Tensor) -> torch.Tensor:
        item_embs = super().forward(items)
        item_embs = self.item_emb_proj(item_embs)
        return item_embs


class AlBERT4RecPreLNTransformerLayers(TransformerLayersBase):

    def __init__(
        self,
        n_blocks: int,
        n_hidden_groups: int,
        n_inner_groups: int,
        n_factors: int,
        n_heads: int,
        dropout_rate: float,
        ff_factors_multiplier: int = 4,
    ):
        super().__init__()
        self.n_blocks = n_blocks
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        n_fitted_blocks = int(n_hidden_groups * n_inner_groups)
        self.transformer_blocks = nn.ModuleList(
            [
                PreLNTransformerLayer(
                    # number of encoder layer (AlBERTLayers)
                    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L428
                    n_factors,
                    n_heads,
                    dropout_rate,
                    ff_factors_multiplier,
                )
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L469
                for _ in range(n_fitted_blocks)
            ]
        )
        self.n_layers_per_group = n_blocks / n_hidden_groups

    def forward(
        self,
        seqs: torch.Tensor,
        timeline_mask: torch.Tensor,
        attn_mask: tp.Optional[torch.Tensor],
        key_padding_mask: tp.Optional[torch.Tensor],
    ) -> torch.Tensor:
        for block_idx in range(self.n_blocks):
            group_idx = int(block_idx / self.n_layers_per_group)
            for inner_layer_idx in range(self.n_inner_groups):
                layer_idx = group_idx * self.n_inner_groups + inner_layer_idx
                seqs = self.transformer_blocks[layer_idx](seqs, timeline_mask, attn_mask, key_padding_mask)
        return seqs


class AlBERT4RecModelConfig(BERT4RecModelConfig):

    n_hidden_groups: int = 1
    n_inner_groups: int = 1
    emb_factors: int = 64


class AlBERT4RecModel(BERT4RecModel):
    """
    https://arxiv.org/pdf/1909.11942
    """
    
    config_class = AlBERT4RecModelConfig

    def __init__(  # pylint: disable=too-many-arguments, too-many-locals
        self,
        n_blocks: int = 2,
        n_hidden_groups: int = 1,
        n_inner_groups: int = 1,
        n_heads: int = 4,
        n_factors: int = 256,
        emb_factors: int = 64,
        dropout_rate: float = 0.0,
        mask_prob: float = 0.15,
        session_max_len: int = 100,
        train_min_user_interactions: int = 2,
        loss: str = "softmax",
        n_negatives: int = 1,
        gbce_t: float = 0.2,
        lr: float = 0.001,
        batch_size: int = 128,
        epochs: int = 3,
        deterministic: bool = False,
        verbose: int = 0,
        dataloader_num_workers: int = 0,
        use_pos_emb: bool = True,
        use_key_padding_mask: bool = True,
        use_causal_attn: bool = False,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet ),
        item_net_constructor_type: tp.Type[ItemNetConstructorBase] = AlBERT4RecSumOfEmbeddingsConstructor,
        pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
        transformer_layers_type: tp.Type[TransformerLayersBase] = AlBERT4RecPreLNTransformerLayers,
        data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
        lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
        get_val_mask_func: tp.Optional[ValMaskCallable] = None,
        get_trainer_func: tp.Optional[TrainerCallable] = None,
        recommend_batch_size: int = 256,
        recommend_device: tp.Optional[str] = None,
        recommend_n_threads: int = 0,
        recommend_use_gpu_ranking: bool = True,  # TODO: remove after TorchRanker
    ):
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        self.emb_factors = emb_factors

        if n_blocks < n_hidden_groups:
            warnings.warn(
                "When `n_hidden_groups` less than `n_blocks` that will use in the forward only one hidden group."
            ) 

        super().__init__(
            transformer_layers_type=transformer_layers_type,
            data_preparator_type=data_preparator_type,
            n_blocks=n_blocks,
            n_heads=n_heads,
            n_factors=n_factors,
            use_pos_emb=use_pos_emb,
            use_causal_attn=use_causal_attn,
            use_key_padding_mask=use_key_padding_mask,
            dropout_rate=dropout_rate,
            session_max_len=session_max_len,
            dataloader_num_workers=dataloader_num_workers,
            batch_size=batch_size,
            loss=loss,
            n_negatives=n_negatives,
            gbce_t=gbce_t,
            lr=lr,
            epochs=epochs,
            verbose=verbose,
            deterministic=deterministic,
            recommend_device=recommend_device,
            recommend_batch_size=recommend_batch_size,
            recommend_n_threads=recommend_n_threads,
            recommend_use_gpu_ranking=recommend_use_gpu_ranking,
            train_min_user_interactions=train_min_user_interactions,
            mask_prob=mask_prob,
            item_net_block_types=item_net_block_types,
            item_net_constructor_type=item_net_constructor_type,
            pos_encoding_type=pos_encoding_type,
            lightning_module_type=lightning_module_type,
            get_val_mask_func=get_val_mask_func,
            get_trainer_func=get_trainer_func,
        )
    
    def _construct_item_net(self, dataset: Dataset) -> ItemNetConstructorBase:
        return self.item_net_constructor_type.from_dataset(
            dataset, self.emb_factors, self.n_factors, self.dropout_rate, self.item_net_block_types
        )

    def _construct_item_net_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetConstructorBase:
        return self.item_net_constructor_type.from_dataset_schema(
            dataset_schema, self.emb_factors, self.n_factors, self.dropout_rate, self.item_net_block_types
        )

    def _init_transformer_layers(self) -> TransformerLayersBase:
        return self.transformer_layers_type(
            n_blocks=self.n_blocks,
            n_hidden_groups=self.n_hidden_groups,
            n_inner_groups=self.n_inner_groups,
            n_factors=self.n_factors,
            n_heads=self.n_heads,
            dropout_rate=self.dropout_rate,
        )


In [25]:
EMB_FACTORS = 64
N_HIDDEN_GROUPS = 2
N_INNER_GROUPS = 1

In [26]:
albert_model = AlBERT4RecModel(
    n_factors=N_FACTORS,
    emb_factors=EMB_FACTORS,
    n_blocks=N_BLOCKS,
    n_hidden_groups=N_HIDDEN_GROUPS,
    n_inner_groups=N_INNER_GROUPS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=USE_POS_EMB,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    use_causal_attn=False,
    batch_size=BATCH_SIZE,
    loss=LOSS,
    epochs=EPOCHS,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [27]:
%%time
albert_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.0 M 
---------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.884     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 8min 5s, sys: 6.01 s, total: 8min 11s
Wall time: 7min 59s


<__main__.AlBERT4RecModel at 0x7f900c0f3eb0>

In [28]:
%%time
recos = albert_model.recommend(
    users=test_users_sasrec, 
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "albert_softmax"
features_results.append(metric_values)

CPU times: user 2min 7s, sys: 10min 4s, total: 12min 12s
Wall time: 25 s


**AlSASRec**

In [31]:
from rectools.models.nn.sasrec import SASRecModelConfig, SASRecDataPreparator, SASRecTransformerLayer


class AlSASRecTransformerLayers(AlBERT4RecPreLNTransformerLayers):

    def __init__(
        self,
        n_blocks: int,
        n_hidden_groups: int,
        n_inner_groups: int,
        n_factors: int,
        n_heads: int,
        dropout_rate: float,
        ff_factors_multiplier: int = 4,
    ):
        super().__init__(
            n_blocks=n_blocks,
            n_hidden_groups=n_hidden_groups,
            n_inner_groups=n_inner_groups,
            n_factors=n_factors,
            n_heads=n_heads,
            dropout_rate=dropout_rate,
            ff_factors_multiplier=ff_factors_multiplier,
        )
        
        n_fitted_blocks = int(n_hidden_groups * n_inner_groups)
        self.transformer_blocks = nn.ModuleList(
            [
                SASRecTransformerLayer(
                    # number of encoder layer (AlBERTLayers)
                    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L428
                    n_factors,
                    n_heads,
                    dropout_rate,
                )
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L469
                for _ in range(n_fitted_blocks)
            ]
        )
        self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8)

    def forward(
        self,
        seqs: torch.Tensor,
        timeline_mask: torch.Tensor,
        attn_mask: tp.Optional[torch.Tensor],
        key_padding_mask: tp.Optional[torch.Tensor],
    ) -> torch.Tensor:
        seqs *= timeline_mask  # [batch_size, session_max_len, n_factors]
        seqs = super().forward(seqs, timeline_mask, attn_mask, key_padding_mask)
        seqs = self.last_layernorm(seqs)
        return seqs


class AlSASRecModelConfig(SASRecModelConfig):

    n_hidden_groups: int = 1
    n_inner_groups: int = 1
    emb_factors: int = 64


class AlSASRecModel(AlBERT4RecModel):
    """
    https://arxiv.org/pdf/1909.11942
    """
    
    config_class = AlSASRecModelConfig

    def __init__(  # pylint: disable=too-many-arguments, too-many-locals
        self,
        n_blocks: int = 2,
        n_hidden_groups: int = 1,
        n_inner_groups: int = 1,
        n_heads: int = 4,
        n_factors: int = 256,
        emb_factors: int = 64,
        dropout_rate: float = 0.0,
        mask_prob: float = 0.15,
        session_max_len: int = 100,
        train_min_user_interactions: int = 2,
        loss: str = "softmax",
        n_negatives: int = 1,
        gbce_t: float = 0.2,
        lr: float = 0.001,
        batch_size: int = 128,
        epochs: int = 3,
        deterministic: bool = False,
        verbose: int = 0,
        dataloader_num_workers: int = 0,
        use_pos_emb: bool = True,
        use_key_padding_mask: bool = True,
        use_causal_attn: bool = False,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet ),
        item_net_constructor_type: tp.Type[ItemNetConstructorBase] = AlBERT4RecSumOfEmbeddingsConstructor,
        pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
        transformer_layers_type: tp.Type[TransformerLayersBase] = AlSASRecTransformerLayers,  # Change
        data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator,  # Change
        lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
        get_val_mask_func: tp.Optional[ValMaskCallable] = None,
        get_trainer_func: tp.Optional[TrainerCallable] = None,
        recommend_batch_size: int = 256,
        recommend_device: tp.Optional[str] = None,
        recommend_n_threads: int = 0,
        recommend_use_gpu_ranking: bool = True,  # TODO: remove after TorchRanker
    ):
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        self.emb_factors = emb_factors

        if n_blocks < n_hidden_groups:
            warnings.warn(
                "When `n_hidden_groups` less than `n_blocks` that will use in the forward only one hidden group."
            ) 

        super().__init__(
            transformer_layers_type=transformer_layers_type,
            data_preparator_type=data_preparator_type,
            n_blocks=n_blocks,
            n_heads=n_heads,
            n_factors=n_factors,
            use_pos_emb=use_pos_emb,
            use_causal_attn=use_causal_attn,
            use_key_padding_mask=use_key_padding_mask,
            dropout_rate=dropout_rate,
            session_max_len=session_max_len,
            dataloader_num_workers=dataloader_num_workers,
            batch_size=batch_size,
            loss=loss,
            n_negatives=n_negatives,
            gbce_t=gbce_t,
            lr=lr,
            epochs=epochs,
            verbose=verbose,
            deterministic=deterministic,
            recommend_device=recommend_device,
            recommend_batch_size=recommend_batch_size,
            recommend_n_threads=recommend_n_threads,
            recommend_use_gpu_ranking=recommend_use_gpu_ranking,
            train_min_user_interactions=train_min_user_interactions,
            mask_prob=mask_prob,
            item_net_block_types=item_net_block_types,
            item_net_constructor_type=item_net_constructor_type,
            pos_encoding_type=pos_encoding_type,
            lightning_module_type=lightning_module_type,
            get_val_mask_func=get_val_mask_func,
            get_trainer_func=get_trainer_func,
        )
        
    def _init_data_preparator(self) -> None:
        self.data_preparator = self.data_preparator_type(
            session_max_len=self.session_max_len,
            n_negatives=self.n_negatives if self.loss != "softmax" else None,
            batch_size=self.batch_size,
            dataloader_num_workers=self.dataloader_num_workers,
            train_min_user_interactions=self.train_min_user_interactions,
            get_val_mask_func=self.get_val_mask_func,
        )

In [32]:
alsasrec_model = AlSASRecModel(
    n_factors=N_FACTORS,
    emb_factors=EMB_FACTORS,
    n_blocks=N_BLOCKS,
    n_hidden_groups=N_HIDDEN_GROUPS,
    n_inner_groups=N_INNER_GROUPS,
    n_heads=N_HEADS,
    dropout_rate=DROPOUT_RATE,
    use_pos_emb=USE_POS_EMB,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    session_max_len=SESSION_MAX_LEN,
    lr=LR,
    batch_size=BATCH_SIZE,
    loss=LOSS,
    epochs=EPOCHS,
    verbose=1,
    deterministic=True,
    use_causal_attn=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,
    transformer_layers_type=AlSASRecTransformerLayers,
    data_preparator_type=SASRecDataPreparator
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [33]:
%%time
alsasrec_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params
---------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 787 K 
---------------------------------------------------------
787 K     Trainable params
0         Non-trainable params
787 K     Total params
3.150     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


CPU times: user 7min 19s, sys: 7.47 s, total: 7min 26s
Wall time: 7min 10s


<__main__.AlSASRecModel at 0x7f8fcaa35c10>

In [34]:
%%time
recos = alsasrec_model.recommend(
    users=test_users_sasrec,
    dataset=dataset_no_features,
    k=10,
    filter_viewed=True,
    on_unsupported_targets="warn"
)

metric_values = calc_metrics(metrics, recos[["user_id", "item_id", "rank"]], test, train, catalog)
metric_values["model"] = "alsasrec_softmax"
features_results.append(metric_values)

CPU times: user 1min 53s, sys: 8min 28s, total: 10min 22s
Wall time: 25.1 s


# Results

In [36]:
df_metrics = pd.DataFrame(features_results)
df_metrics

Unnamed: 0,MAP@1,MAP@5,MAP@10,MIUF@1,MIUF@5,MIUF@10,Serendipity@1,Serendipity@5,Serendipity@10,model
0,0.041575,0.073377,0.081693,3.770911,4.51456,5.049877,0.000477,0.00046,0.000443,bert_next_action_softmax
1,0.034768,0.066949,0.072362,2.022693,2.805278,3.624555,8e-06,2.5e-05,2.9e-05,sasrec_next_action_softmax
2,0.042906,0.073013,0.081317,3.995624,4.609529,4.996234,0.000449,0.000381,0.000353,bert_next_action_softmax_causal
3,0.038995,0.066442,0.074229,3.670593,4.394567,4.882518,0.000292,0.000275,0.000273,albert_softmax
4,0.046672,0.079716,0.088329,3.798423,4.560263,5.155399,0.000816,0.000745,0.00068,alsasrec_softmax
