In [1]:
import numpy as np
import os
import pandas as pd
import torch
import typing as tp
import warnings
from pathlib import Path

import torch.nn as nn
import typing_extensions as tpe

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from rectools import Columns
from rectools.dataset import Dataset
from rectools.models import BERT4RecModel, SASRecModel


from rectools.dataset.dataset import Dataset, DatasetSchema
from rectools.models.nn.item_net import (
    ItemNetBase,
    SumOfEmbeddingsConstructor,
)
from rectools.models.nn.transformer_net_blocks import (
    PreLNTransformerLayer,
    TransformerLayersBase,
)
from rectools.models.nn.constants import InitKwargs

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [2]:
# %%time
# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip
# !unzip -o data_original.zip
# !rm data_original.zip

In [3]:
DATA_PATH = Path("data_original")

interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": "datetime"})
)
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
dataset_no_features = Dataset.construct(
    interactions_df=interactions,
)

In [4]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

In [5]:
# Function to get custom trainer for quick debugging
def get_debug_trainer() -> Trainer:
    return Trainer(
        accelerator="gpu",
        devices=[1],
        min_epochs=1,
        max_epochs=1,
        deterministic=True,
        limit_train_batches=2,
    )

# **Training Objective**

https://arxiv.org/pdf/2205.04507

## **Next Action**

In [6]:
from typing import Dict, List, Tuple

from rectools.models.nn.bert4rec import BERT4RecDataPreparator
from rectools.models.nn.constants import MASKING_VALUE
from rectools.models.nn.transformer_lightning import TransformerLightningModule


# "MASK" token is used for predicting one next token.
class NextItemDataPreparator(BERT4RecDataPreparator):
    
    def _collate_fn_train(
        self,
        batch: List[Tuple[List[int], List[float]]],
    ) -> Dict[str, torch.Tensor]:
        """
        Truncate each session from right to keep `session_max_len` items.
        Do left padding until `session_max_len` is reached.
        Split to `x`, `y`, and `yw`.
        """
        batch_size = len(batch)
        x = np.zeros((batch_size, self.session_max_len))
        y = np.zeros((batch_size, 1))
        yw = np.zeros((batch_size, 1))
        for i, (ses, ses_weights) in enumerate(batch):
            session = ses.copy()
            session[-1] = self.extra_token_ids[MASKING_VALUE]
            x[i, -len(ses) :] = session  # ses: [session_len] -> x[i]: [session_max_len]
            y[i] = ses[-1]  # ses: [session_len] -> y[i]: [1]
            yw[i] = ses_weights[-1]  # ses_weights: [session_len] -> yw[i]: [1]

        batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
        if self.n_negatives is not None:
            negatives = torch.randint(
                low=self.n_item_extra_tokens,
                high=self.item_id_map.size,
                size=(batch_size, 1, self.n_negatives),
            )  # [batch_size, 1, n_negatives]
            batch_dict["negatives"] = negatives
        return batch_dict


# Last logits are used for reducing the number of calculations on training step.
# You could also fill the y with zeros except for the last item in `_collate_fn_train`` and not change the training step 
class NextItemLightningModule(TransformerLightningModule):

    def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step."""
        x, y, w = batch["x"], batch["y"], batch["yw"]
        if self.loss == "softmax":
            logits = self._get_full_catalog_logits(x)[:, -1: :]
            loss = self._calc_softmax_loss(logits, y, w)
        elif self.loss == "BCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_bce_loss(logits, y, w)
        elif self.loss == "gBCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)[:, -1: :]
            loss = self._calc_gbce_loss(logits, y, w, negatives)
        else:
            loss = self._calc_custom_loss(batch, batch_idx)

        self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)

        return loss

In [7]:
nextitem_transformer_bidirectional = BERT4RecModel(
    data_preparator_type=NextItemDataPreparator,  # "NextItem" training objective data preparator
    lightning_module_type=NextItemLightningModule,  # "NextItem" lightning module
    get_trainer_func = get_debug_trainer,
)

nextitem_transformer_unidirectional = SASRecModel(
    data_preparator_type=NextItemDataPreparator,   # "NextItem" training objective data preparator
    lightning_module_type=NextItemLightningModule,  # "NextItem" lightning module
    use_causal_attn=True,  # Apply causal attention mask
    get_trainer_func = get_debug_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [8]:
%%time
nextitem_transformer_bidirectional.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 5.5 M  | train
-----------------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.040    Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 30.9 s, sys: 3.3 s, total: 34.2 s
Wall time: 27.8 s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f3fd903bac0>

In [9]:
%%time
nextitem_transformer_unidirectional.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 4.7 M  | train
-----------------------------------------------------------------
4.7 M     Trainable params
0         Non-trainable params
4.7 M     Total params
18.890    Total estimated model params size (MB)
34        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 28.9 s, sys: 3.31 s, total: 32.2 s
Wall time: 27 s


<rectools.models.nn.sasrec.SASRecModel at 0x7f3fd903b550>

# ALBERT
Albert has 2 main innovations which can be used together or separately:
1. Learning embeddings of smaller size and then projecting them to the required size through a Liner projection
2. Sharing weights between transformer layers

In [28]:
# ### ---------- Special Albert logic for Embeddings ---------- ### #

class AlBERT4RecSumOfEmbeddingsConstructor(SumOfEmbeddingsConstructor):

    def __init__(
        self,
        n_items: int,
        n_factors: int,
        item_net_blocks: tp.Sequence[ItemNetBase],
        emb_factors: int = 16,  # accept new kwargs
    ) -> None:
        super().__init__(
            n_items=n_items,
            item_net_blocks=item_net_blocks,
        )
        self.item_emb_proj = nn.Linear(emb_factors, n_factors)

    @classmethod
    def from_dataset(
        cls,
        dataset: Dataset,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
        emb_factors: int,  # accept new kwargs
    ) -> tpe.Self:
        n_items = dataset.item_id_map.size

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset(dataset, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, n_factors, item_net_blocks, emb_factors)

    @classmethod
    def from_dataset_schema(
        cls,
        dataset_schema: DatasetSchema,
        n_factors: int,
        dropout_rate: float,
        item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]],
        emb_factors: int,  # accept new kwargs
    ) -> tpe.Self:
        n_items = dataset_schema.items.n_hot

        item_net_blocks: tp.List[ItemNetBase] = []
        for item_net in item_net_block_types:
            item_net_block = item_net.from_dataset_schema(dataset_schema, emb_factors, dropout_rate)
            if item_net_block is not None:
                item_net_blocks.append(item_net_block)

        return cls(n_items, n_factors, item_net_blocks, emb_factors)

    def forward(self, items: torch.Tensor) -> torch.Tensor:
        item_embs = super().forward(items)
        item_embs = self.item_emb_proj(item_embs)
        return item_embs

In [29]:
# ### ---------- Special Albert logic for Transfromer Layers ---------- ### #
    
class AlBERT4RecPreLNTransformerLayers(TransformerLayersBase):

    def __init__(
        self,
        n_blocks: int,
        n_factors: int,
        n_heads: int,
        dropout_rate: float,
        ff_factors_multiplier: int = 4,
        n_hidden_groups: int=1,  # accept new kwarg
        n_inner_groups: int=1,  # accept new kwarg
        
    ):
        super().__init__()
        
        self.n_blocks = n_blocks
        self.n_hidden_groups = n_hidden_groups
        self.n_inner_groups = n_inner_groups
        n_fitted_blocks = int(n_hidden_groups * n_inner_groups)
        self.transformer_blocks = nn.ModuleList(
            [
                PreLNTransformerLayer(
                    # number of encoder layer (AlBERTLayers)
                    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L428
                    n_factors,
                    n_heads,
                    dropout_rate,
                    ff_factors_multiplier,
                )
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/albert/modeling_albert.py#L469
                for _ in range(n_fitted_blocks)
            ]
        )
        self.n_layers_per_group = n_blocks / n_hidden_groups

    def forward(
        self,
        seqs: torch.Tensor,
        timeline_mask: torch.Tensor,
        attn_mask: tp.Optional[torch.Tensor],
        key_padding_mask: tp.Optional[torch.Tensor],
    ) -> torch.Tensor:
        for block_idx in range(self.n_blocks):
            group_idx = int(block_idx / self.n_layers_per_group)
            for inner_layer_idx in range(self.n_inner_groups):
                layer_idx = group_idx * self.n_inner_groups + inner_layer_idx
                seqs = self.transformer_blocks[block_idx](seqs, attn_mask, key_padding_mask)
        return seqs


In [30]:
ALBERT_ITEM_NET_CONSTRUCTOR_KWARGS = {  # these arguments are obligatory for our custom classes
    "emb_factors": 32,
}

ALBERT_TRANSFORMER_LAYERS_KWARGS = {  # these arguments are obligatory for our custom classes
    "n_hidden_groups": 2,
    "n_inner_groups": 1,
}

In [31]:
albert_model = BERT4RecModel(
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,  # custom item net constructor
    item_net_constructor_kwargs=ALBERT_ITEM_NET_CONSTRUCTOR_KWARGS,  # kwargs for custom constructor
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,  # custom transformer layers
    transformer_layers_kwargs=ALBERT_TRANSFORMER_LAYERS_KWARGS, # kwargs for custom transformer layers
    get_trainer_func = get_debug_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [32]:
%%time
albert_model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.1 M  | train
-----------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.407     Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 31.6 s, sys: 4.36 s, total: 36 s
Wall time: 29.3 s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f3cce308970>

In [33]:
alsasrec = SASRecModel(
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,  # custom item net constructor
    item_net_constructor_kwargs=ALBERT_ITEM_NET_CONSTRUCTOR_KWARGS,  # kwargs for custom constructor
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,  # custom transformer layers
    transformer_layers_kwargs=ALBERT_TRANSFORMER_LAYERS_KWARGS, # kwargs for custom transformer layers
    get_trainer_func = get_debug_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [34]:
%%time
alsasrec.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.1 M  | train
-----------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.407     Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 57 s, sys: 5.67 s, total: 1min 2s
Wall time: 39.9 s


<rectools.models.nn.sasrec.SASRecModel at 0x7f3d36ce42b0>

# How about NextTokenTransformer with Albert logic and causal attention?
# Just because we can!

In [36]:
next_action_albert_causal = BERT4RecModel(
    item_net_constructor_type=AlBERT4RecSumOfEmbeddingsConstructor,  # custom item net constructor
    item_net_constructor_kwargs=ALBERT_ITEM_NET_CONSTRUCTOR_KWARGS,  # kwargs for custom constructor
    transformer_layers_type=AlBERT4RecPreLNTransformerLayers,  # custom transformer layers
    transformer_layers_kwargs=ALBERT_TRANSFORMER_LAYERS_KWARGS, # kwargs for custom transformer layers
    data_preparator_type=NextItemDataPreparator,  # custom data preparator
    lightning_module_type=NextItemLightningModule,  # custom lightning module
    use_causal_attn=True,  # Apply causal attention mask
    get_trainer_func = get_debug_trainer,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [37]:
%%time
next_action_albert_causal.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.1 M  | train
-----------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.407     Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 47.5 s, sys: 4.16 s, total: 51.6 s
Wall time: 37.5 s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f3c86767220>

## What about configs?

In [38]:
params = next_action_albert_causal.get_params(simple_types=True)

In [39]:
params

{'cls': 'BERT4RecModel',
 'verbose': 0,
 'data_preparator_type': '__main__.NextItemDataPreparator',
 'n_blocks': 2,
 'n_heads': 4,
 'n_factors': 256,
 'use_pos_emb': True,
 'use_causal_attn': True,
 'use_key_padding_mask': True,
 'dropout_rate': 0.2,
 'session_max_len': 100,
 'dataloader_num_workers': 0,
 'batch_size': 128,
 'loss': 'softmax',
 'n_negatives': 1,
 'gbce_t': 0.2,
 'lr': 0.001,
 'epochs': 3,
 'deterministic': False,
 'recommend_batch_size': 256,
 'recommend_device': None,
 'recommend_n_threads': 0,
 'recommend_use_torch_ranking': True,
 'train_min_user_interactions': 2,
 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',
  'rectools.models.nn.item_net.CatFeaturesItemNet'],
 'item_net_constructor_type': '__main__.AlBERT4RecSumOfEmbeddingsConstructor',
 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',
 'transformer_layers_type': '__main__.AlBERT4RecPreLNTransformerLayers',
 'lightning_module_type':

In [40]:
model = BERT4RecModel.from_params(params)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [41]:
%%time
model.fit(dataset_no_features)

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 2.1 M  | train
-----------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.407     Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 36.4 s, sys: 4.11 s, total: 40.5 s
Wall time: 35 s


<rectools.models.nn.bert4rec.BERT4RecModel at 0x7f3cfffecee0>