# Transformer Models Advanced Training Guide
This guide is showing advanced features of RecTools transformer models training.

### Table of Contents

* Prepare data
* Advanced training guide
    * Validation fold
    * Validation loss
    * Callback for Early Stopping
    * Callbacks for Checkpoints
        * Loading Checkpoints
    * Callbacks for RecSys metrics
        * RecSys metrics for Early Stopping anf Checkpoints
* Advanced training full example
    * Running full training with all of the described validation features on Kion dataset
* More RecTools features for transformers
    * Saving and loading models
    * Configs for transformer models
        * Classes and function in configs
    * Multi-gpu training


In [1]:
import os
import itertools
import typing as tp
import warnings
from collections import Counter
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from click.core import batch
from lightning_fabric import seed_everything
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback

from rectools import Columns, ExternalIds
from rectools.dataset import Dataset
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics, HitRate
from rectools.models import BERT4RecModel, SASRecModel, load_model
from rectools.models.nn.transformers.hstu import HSTUModel
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformers.base import TransformerModelBase
from rectools import Columns, ExternalIds
from scipy import sparse
# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", FutureWarning)

## Prepare data

%%time
!wget -q https://files.grouplens.org/datasets/movielens/ml-1m.zip -O ml-1m.zip
!unzip -o ml-1m.zip
!rm ml-1m.zip

In [2]:
%%time
ratings = pd.read_csv(
    "ml-1m/ratings.dat", 
    sep="::",
    engine="python",  # Because of 2-chars separators
    header=None,
    names=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
)
print(ratings.shape)
ratings[Columns.Weight] = 1
ratings.head()

(1000209, 4)
CPU times: user 3.39 s, sys: 181 ms, total: 3.57 s
Wall time: 3.57 s


Unnamed: 0,user_id,item_id,weight,datetime
0,1,1193,1,978300760
1,1,661,1,978302109
2,1,914,1,978301968
3,1,3408,1,978300275
4,1,2355,1,978824291


In [3]:
ratings[Columns.Datetime] = ratings[Columns.Datetime].astype("datetime64[s]")

In [4]:
nan_count = ratings.isna().sum()
print(f"Количество NaN в колонке '{Columns.Datetime}': {nan_count}")

Количество NaN в колонке 'datetime': user_id     0
item_id     0
weight      0
datetime    0
dtype: int64


In [5]:
dataset = Dataset.construct(ratings)

         user_id  item_id  weight            datetime
0              0        0     1.0 2000-12-31 22:12:40
1              0        1     1.0 2000-12-31 22:35:09
2              0        2     1.0 2000-12-31 22:32:48
3              0        3     1.0 2000-12-31 22:04:35
4              0        4     1.0 2001-01-06 23:38:11
...          ...      ...     ...                 ...
1000204     6039      772     1.0 2000-04-26 02:35:41
1000205     6039     1106     1.0 2000-04-25 23:21:27
1000206     6039      365     1.0 2000-04-25 23:19:06
1000207     6039      152     1.0 2000-04-26 02:20:48
1000208     6039       26     1.0 2000-04-26 02:19:29

[1000209 rows x 4 columns]
         user_id  item_id  weight            datetime
0              0        0     1.0 2000-12-31 22:12:40
1              0        1     1.0 2000-12-31 22:35:09
2              0        2     1.0 2000-12-31 22:32:48
3              0        3     1.0 2000-12-31 22:04:35
4              0        4     1.0 2001-01-06 23:38:11


In [6]:
print(dataset.interactions.df)

         user_id  item_id  weight            datetime
0              0        0     1.0 2000-12-31 22:12:40
1              0        1     1.0 2000-12-31 22:35:09
2              0        2     1.0 2000-12-31 22:32:48
3              0        3     1.0 2000-12-31 22:04:35
4              0        4     1.0 2001-01-06 23:38:11
...          ...      ...     ...                 ...
1000204     6039      772     1.0 2000-04-26 02:35:41
1000205     6039     1106     1.0 2000-04-25 23:21:27
1000206     6039      365     1.0 2000-04-25 23:19:06
1000207     6039      152     1.0 2000-04-26 02:20:48
1000208     6039       26     1.0 2000-04-26 02:19:29

[1000209 rows x 4 columns]


In [7]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

## Advanced Training

### Validation fold

Models do not create validation fold during `fit` by default. However, there is a simple way to force it.

Let's assume that we want to use Leave-One-Out validation for specific set of users. To apply it we need to implement `get_val_mask_func` with required logic and pass it to model during initialization.

This function should receive interactions with standard RecTools columns and return a binary mask which identifies interactions that should not be used during model training. But instrad should be used for validation loss calculation. They will also be available for Lightning Callbacks to allow RecSys metrics computations.

*Please make sure you do not use `partial` while doing this. Partial functions cannot be by serialized using RecTools.*

In [8]:
# Implement `get_val_mask_func`

unique_users = ratings[Columns.User].unique()
VAL_USERS = unique_users

def leave_one_out_mask_for_users(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarray:
    rank = (
        interactions
        .sort_values(Columns.Datetime, ascending=False, kind="stable")
        .groupby(Columns.User, sort=False)
        .cumcount()
    )
    val_mask = (
        (interactions[Columns.User].isin(val_users))
        & (rank == 0)
    )
    return val_mask.values

# We do not use `partial` for correct serialization of the model
def get_val_mask_func(interactions: pd.DataFrame):
    return leave_one_out_mask_for_users(interactions, val_users = VAL_USERS)

In [9]:
class RecallCallback(Callback):  # with filter
    name: str = "recall"

    def __init__(self, k: int, prog_bar: bool = True) -> None:
        self.k = k
        self.name += f"@{k}"
        self.prog_bar = prog_bar

        self.batch_recall_per_users: tp.List[torch.Tensor] = []

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: tp.Dict[str, torch.Tensor],
        batch: tp.Dict[str, torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:

        if "logits" not in outputs:
            session_embs = pl_module.torch_model.encode_sessions(
                batch, pl_module.item_embs
            )[:, -1, :]
            logits = pl_module.torch_model.similarity_module(
                session_embs, pl_module.item_embs
            )
        else:
            logits = outputs["logits"]

        x = batch["x"]
        users = x.shape[0]
        row_ind = np.arange(users).repeat(x.shape[1])
        col_ind = x.flatten().detach().cpu().numpy()
        mask = col_ind != 0
        data = np.ones_like(row_ind[mask])
        filter_csr = sparse.csr_matrix(
            (data, (row_ind[mask], col_ind[mask])),
            shape=(users, pl_module.torch_model.item_model.n_items),
        )
        mask = torch.from_numpy((filter_csr != 0).toarray()).to(logits.device)
        scores = torch.masked_fill(logits, mask, float("-inf"))

        _, batch_recos = scores.topk(k=self.k)

        targets = batch["y"]

        # assume all users have the same amount of TP
        liked = targets.shape[1]
        tp_mask = torch.stack(
            [
                torch.isin(batch_recos[uid], targets[uid])
                for uid in range(batch_recos.shape[0])
            ]
        )
        recall_per_users = tp_mask.sum(dim=1) / liked

        self.batch_recall_per_users.append(recall_per_users)

    def on_validation_epoch_end(
        self, trainer: Trainer, pl_module: LightningModule
    ) -> None:
        recall = float(torch.concat(self.batch_recall_per_users).mean())
        self.log_dict(
            {self.name: recall}, on_step=False, on_epoch=True, prog_bar=self.prog_bar
        )

        self.batch_recall_per_users.clear()

In this guide we are going to use custom Lighhning trainers. We need to implement function that return desired Lightining trainer and pass it to model during initialization.

In [10]:

# Callback for calculating RecSys metrics
recall_callback = RecallCallback(k=10, prog_bar=True)

In [11]:
# Function to get custom trainer

def get_debug_trainer() -> Trainer:
    return Trainer(
        accelerator="gpu",
        devices=1,
        min_epochs=101,
        max_epochs=101,
        deterministic=True,
        enable_model_summary=False,
        enable_progress_bar=True,
        enable_checkpointing=False,
        callbacks=[recall_callback],  # pass our callbacks
        logger = CSVLogger("test_logs"),  # We use CSV logging for this guide but there are many other options
    )

In [12]:
session_max_len = 200
extra_cols_kwargs =  {"extra_cols_kwargs": {"extra_cols": [Columns.Datetime]}}
transformer_layers_kwargs = {
    "attention_dim": 50,
    "linear_hidden_dim":50,
    "attn_dropout_ratio": 0.1,
    "session_max_len": session_max_len,
    "attention_mode" : "rel_pos_bias"
}
lightning_module_kwargs = {
    "temperature": 0.3
}
model  = HSTUModel(session_max_len = session_max_len,
    data_preparator_kwargs=extra_cols_kwargs,
    transformer_layers_kwargs=transformer_layers_kwargs,
    lightning_module_kwargs=lightning_module_kwargs,
    item_net_block_types=(IdEmbeddingsItemNet,),
    get_val_mask_func=get_val_mask_func,  # pass our custom `get_val_mask_func`
    get_trainer_func=get_debug_trainer,  # pass our custom trainer func
    verbose=1,
    loss = 'sampled_softmax',
    n_negatives= 128,
    use_pos_emb = True,
    dropout_rate = 0.2,
    n_factors = 50,
    n_heads = 1,
    n_blocks = 2,
    lr = 0.005,
    batch_size=128,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


{'extra_cols': ['datetime']}


### Validation loss

Let's check how the validation loss is being logged.

In [13]:
# Fit model. Validation fold and validation loss computation will be done under the hood.

model.fit(dataset)

        user_id  item_id  weight            datetime
0             0        1     1.0 2000-04-25 23:25:58
1             0        2     1.0 2000-04-25 23:25:58
2             0        3     1.0 2000-04-25 23:26:18
3             0        4     1.0 2000-04-25 23:26:42
4             0        5     1.0 2000-04-25 23:26:42
...         ...      ...     ...                 ...
658133     6031      689     1.0 2003-02-28 17:45:20
658134     6031     2404     1.0 2003-02-28 17:45:38
658135     6031      634     1.0 2003-02-28 17:47:23
658136     6031      766     1.0 2003-02-28 17:49:08
658137     6031     1943     1.0 2003-02-28 17:49:08

[658138 rows x 4 columns]
        user_id  item_id  weight            datetime
0             0        1     1.0 2000-04-25 23:25:58
1             0        2     1.0 2000-04-25 23:25:58
2             0        3     1.0 2000-04-25 23:26:18
3             0        4     1.0 2000-04-25 23:26:42
4             0        5     1.0 2000-04-25 23:26:42
...         ...    

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


item_model.item_net_blocks.0.ids_emb.weight torch.Size([3646, 50])
pos_encoding_layer._pos_emb.weight torch.Size([200, 50])
transformer_layers.stu_blocks.0._uvqk torch.Size([50, 200])
transformer_layers.stu_blocks.0._rel_attn_bias._ts_w torch.Size([129])
transformer_layers.stu_blocks.0._rel_attn_bias._pos_w torch.Size([401])
transformer_layers.stu_blocks.0._o.weight torch.Size([50, 50])
transformer_layers.stu_blocks.0._o.bias torch.Size([50])
transformer_layers.stu_blocks.1._uvqk torch.Size([50, 200])
transformer_layers.stu_blocks.1._rel_attn_bias._ts_w torch.Size([129])
transformer_layers.stu_blocks.1._rel_attn_bias._pos_w torch.Size([401])
transformer_layers.stu_blocks.1._o.weight torch.Size([50, 50])
transformer_layers.stu_blocks.1._o.bias torch.Size([50])
transformer_layers.last_layernorm.weight torch.Size([50])
transformer_layers.last_layernorm.bias torch.Size([50])


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=101` reached.


<rectools.models.nn.transformers.hstu.HSTUModel at 0x7ced305d30a0>

Let's look at model logs. We can access logs directory with `model.fit_trainer.log_dir`