# Transformer Models Advanced Training Guide
This guide is showing advanced features of RecTools transformer models training.

### Table of Contents

* Prepare data
* Advanced training guide
    * Validation fold
    * Validation loss
    * Callback for Early Stopping
    * Callbacks for Checkpoints (+ loading checkpoints)
    * Callbacks for RecSys metrics (+ checkpoints on RecSys metrics)
* Advanced training full example
    * Running full training with all of the described validation features on Kion dataset
* More RecTools features for transformers
    * Saving and loading models
    * Configs for transformer models


In [45]:
import os
import itertools
import typing as tp
import warnings
from collections import Counter
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from lightning_fabric import seed_everything
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback

from rectools import Columns, ExternalIds
from rectools.dataset import Dataset
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics
from rectools.models import BERT4RecModel, SASRecModel, load_model
from rectools.models.nn.item_net import IdEmbeddingsItemNet
from rectools.models.nn.transformer_base import TransformerModelBase

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", FutureWarning)

# Prepare data

In [2]:
%%time
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_en.zip -O data_en.zip
!unzip -o data_en.zip
!rm data_en.zip

Archive:  data_en.zip
  inflating: data_en/items_en.csv    
  inflating: __MACOSX/data_en/._items_en.csv  
  inflating: data_en/interactions.csv  
  inflating: __MACOSX/data_en/._interactions.csv  
  inflating: data_en/users_en.csv    
  inflating: __MACOSX/data_en/._users_en.csv  
CPU times: user 107 ms, sys: 42.1 ms, total: 149 ms
Wall time: 12.4 s


In [3]:
# Download dataset
DATA_PATH = Path("./data_en")
items = pd.read_csv(DATA_PATH / 'items_en.csv', index_col=0)
interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": Columns.Datetime})
)

print(interactions.shape)
interactions.head(2)

(5476251, 5)


Unnamed: 0,user_id,item_id,datetime,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0


In [7]:
interactions[Columns.User].nunique(), interactions[Columns.Item].nunique()

(962179, 15706)

In [8]:
# Process interactions
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
raw_interactions = interactions[["user_id", "item_id", "datetime", "weight"]]
print(raw_interactions.shape)
raw_interactions.head(2)

dataset = Dataset.construct(raw_interactions)

(5476251, 4)


In [9]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

# Advanced Training

## Validation fold

Models do not create validation fold during `fit` by default. However, there is a simple way to force it.

Let's assume that we want to use Leave-One-Out validation for specific set of users. To apply it we need to implement `get_val_mask_func` with required logic and pass it to model during initialization. 

This function should receive interactions with standard RecTools columns and return a binary mask which identifies interactions that should not be used during model training. But instrad should be used for validation loss calculation. They will also be available for Lightning Callbacks to allow RecSys metrics computations.

*Please make sure you do not use `partial` while doing this. Partial functions cannot be by serialized using RecTools.*

In [10]:
# Implement `get_val_mask_func`

N_VAL_USERS = 2048
unique_users = raw_interactions[Columns.User].unique()
VAL_USERS = unique_users[: N_VAL_USERS]

def leave_one_out_mask_for_users(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarray:
    rank = (
        interactions
        .sort_values(Columns.Datetime, ascending=False, kind="stable")
        .groupby(Columns.User, sort=False)
        .cumcount()
    )
    val_mask = (
        (interactions[Columns.User].isin(val_users))
        & (rank == 0)
    )
    return val_mask.values

# We do not use `partial` for correct serialization of the model
def get_val_mask_func(interactions: pd.DataFrame):
    return leave_one_out_mask_for_users(interactions, val_users = VAL_USERS)

In [11]:
model = SASRecModel(
    n_factors=64,
    n_blocks=2,
    n_heads=2,
    dropout_rate=0.2,
    train_min_user_interactions=5,
    session_max_len=50,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet,),
    get_val_mask_func=get_val_mask_func,  # pass our custom `get_val_mask_func`
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


## Validation loss

Let's check how the validation loss is being logged.
We just want to quickly check functionality for now so let's create a custom Lightning trainer for that.

In [12]:
trainer = Trainer(
    accelerator='cpu',  # TODO: change
    devices=1,
    min_epochs=2,
    max_epochs=2, 
    deterministic=True,
    limit_train_batches=2,  # use only 2 batches for each epoch for a test run
    enable_checkpointing=False,
    logger = CSVLogger("test_logs")
)

# Replace default trainer with our custom one
model._trainer = trainer

# Fit model. Validation fold and validation loss computation will be done under the hood.
model.fit(dataset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type                           | Params | Mode 
-----------------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 987 K  | train
-----------------------------------------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.951     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


<rectools.models.nn.sasrec.SASRecModel at 0x296553d60>

Let's look at model logs. We can access logs directory with `model.fit_trainer.log_dir`

In [13]:
# What's inside the logs directory?
!ls $model.fit_trainer.log_dir

hparams.yaml metrics.csv


In [14]:
# Losses and metrics are in the `metrics.csv`
# Let's look at logs

!tail $model.fit_trainer.log_dir/metrics.csv

epoch,step,train_loss,val_loss
0,1,,22.41293716430664
0,1,22.974777221679688,
1,3,,22.27031898498535
1,3,22.650423049926758,


## Callback for Early Stopping

Now that we have validation loss logged, let's use ot for model Early Stopping. It will ensure that model will not resume training if validation loss (or any other custom metric) doesn't impove. We have Lightning Callbacks for that.

In [15]:
early_stopping_callback = EarlyStopping(
    monitor=SASRecModel.val_loss_name,   # or just pass "val_loss" here
    mode="min",
    min_delta=1.  # just for a quick test of functionality
)

In [16]:
trainer = Trainer(
    accelerator='cpu',  # TODO: change
    devices=1,
    min_epochs=1,  # minimum number of epochs to train before early stopping
    max_epochs=20,  # maximum number of epochs to train
    deterministic=True,
    limit_train_batches=2,  # use only 2 batches for each epoch for a test run
    enable_checkpointing=False,
    logger = CSVLogger("test_logs"),
    callbacks=early_stopping_callback,  # pass our callback
)

# Replace default trainer with our custom one
model._trainer = trainer

# Fit model. Everything will happen under the hood
model.fit(dataset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type                           | Params | Mode 
-----------------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 987 K  | train
-----------------------------------------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.951     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

<rectools.models.nn.sasrec.SASRecModel at 0x296553d60>

Here model stopped training after 4 epochs because validation loss wasn't improving by our specified `min_delta`

In [17]:
# Let's check out logs
!tail $model.fit_trainer.log_dir/metrics.csv

epoch,step,train_loss,val_loss
0,1,,22.35995864868164
0,1,22.873361587524414,
1,3,,22.200777053833008
1,3,22.538841247558594,
2,5,,21.98937225341797
2,5,22.36414909362793,
3,7,,21.726999282836914
3,7,22.734487533569336,


## Callback for Checkpoints

In [31]:
# Checkpoint after last epoch
last_epoch_ckpt = ModelCheckpoint(filename="last_epoch")

# Checkpoints based on validation loss
least_val_loss_ckpt = ModelCheckpoint(
    monitor=SASRecModel.val_loss_name,   # or just pass "val_loss" here,
    mode="min",
    filename="{epoch}-{val_loss:.2f}",
    save_top_k=2,  # Let's save top 2 checkpoints for validation loss
)

In [32]:
trainer = Trainer(
    accelerator='cpu',  # TODO: change
    devices=1,
    min_epochs=1,
    max_epochs=6,
    deterministic=True,
    limit_train_batches=2,  # use only 2 batches for each epoch for a test run
    logger = CSVLogger("test_logs"),
    callbacks=[last_epoch_ckpt, least_val_loss_ckpt],  # pass our callbacks for checkpoints
)

# Replace default trainer with our custom one
model._trainer = trainer

# Fit model. Everything will happen under the hood
model.fit(dataset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type                           | Params | Mode 
-----------------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 987 K  | train
-----------------------------------------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.951     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=6` reached.


<rectools.models.nn.sasrec.SASRecModel at 0x296553d60>

Let's look at model checkpoints that were saved. By default they are neing saved to `checkpoints` directory in  `model.fit_trainer.log_dir`

In [35]:
# We have 2 checkpoints for 2 best validation loss values and one for last epoch
!ls $model.fit_trainer.log_dir/checkpoints

epoch=4-val_loss=21.46.ckpt last_epoch.ckpt
epoch=5-val_loss=21.16.ckpt


Loading checkpoints is very simple with `load_model` function.

In [41]:
model.save("temp.ckpt")

11938674

In [42]:
loaded = SASRecModel.load("temp.ckpt")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [40]:
loaded = SASRecModel.load(model.fit_trainer.log_dir + "/checkpoints/last_epoch.ckpt")

UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.

## Callbacks for RecSys metrics during training

Monitoring RecSys metrics (or any other custom things) on validation fold is not available out of the box, but we can create a custom Lightning Callback for that.

Below is an example of calculating standard RecTools metrics on validation fold during training. We use it as an explicit example that any customization is possible. But it is recommend to implement metrics calculation using `torch` for faster computations.

Please look at PyTorch Lightning documentation for more details on custom callbacks.

In [46]:
# Implement custom Callback for RecTools metrics computation within validation epochs during training.

class ValidationMetrics(Callback):
    
    def __init__(self, top_k: int, val_metrics: tp.Dict, verbose: int = 0) -> None:
        self.top_k = top_k
        self.val_metrics = val_metrics
        self.verbose = verbose

        self.epoch_n_users: int = 0
        self.batch_metrics: tp.List[tp.Dict[str, float]] = []

    def on_validation_batch_end(
        self, 
        trainer: Trainer, 
        pl_module: LightningModule, 
        outputs: tp.Dict[str, torch.Tensor], 
        batch: tp.Dict[str, torch.Tensor], 
        batch_idx: int, 
        dataloader_idx: int = 0
    ) -> None:
        logits = outputs["logits"]
        if logits is None:
            logits = pl_module.torch_model.encode_sessions(batch["x"], pl_module.item_embs)[:, -1, :]
        _, sorted_batch_recos = logits.topk(k=self.top_k)

        batch_recos = sorted_batch_recos.tolist()
        targets = batch["y"].tolist()

        batch_val_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(recos)) for idx, recos in enumerate(batch_recos)
            )
        )

        batch_target_users = list(
            itertools.chain.from_iterable(
                itertools.repeat(idx, len(targets)) for idx, targets in enumerate(targets)
            )
        )

        batch_recos_df = pd.DataFrame(
            {
                Columns.User: batch_val_users,
                Columns.Item: list(itertools.chain.from_iterable(batch_recos)),
            }
        )
        batch_recos_df[Columns.Rank] = batch_recos_df.groupby(Columns.User, sort=False).cumcount() + 1

        interactions = pd.DataFrame(
            {
                Columns.User: batch_target_users,
                Columns.Item: list(itertools.chain.from_iterable(targets)),
            }
        )

        prev_interactions = pl_module.data_preparator.train_dataset.interactions.df
        catalog = prev_interactions[Columns.Item].unique()

        batch_metrics = calc_metrics(
            self.val_metrics, 
            batch_recos_df,
            interactions, 
            prev_interactions,
            catalog
        )

        batch_n_users = batch["x"].shape[0]
        self.batch_metrics.append({metric: value * batch_n_users for metric, value in batch_metrics.items()})
        self.epoch_n_users += batch_n_users

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        epoch_metrics = dict(sum(map(Counter, self.batch_metrics), Counter()))
        epoch_metrics = {metric: value / self.epoch_n_users for metric, value in epoch_metrics.items()}

        self.log_dict(epoch_metrics, on_step=False, on_epoch=True, prog_bar=self.verbose > 0)

        self.batch_metrics.clear()
        self.epoch_n_users = 0

When custom metrics callback is implemented, we can use the values of these metrics for both Early Stopping and Checkpoints.

In [49]:
# Initialize callbacks for metrics calculation and checkpoint based on NDCG value

metrics = {
    "NDCG@10": NDCG(k=10),
    "Recall@10": Recall(k=10),
    "Serendipity@10": Serendipity(k=10),
}
top_k = max([metric.k for metric in metrics.values()])

# Callback for calculating RecSys metrics
val_metrics_callback = ValidationMetrics(top_k=top_k, val_metrics=metrics, verbose=1)

# Callback for checkpoint based on maximization of NDCG@10
best_ndcg_ckpt = ModelCheckpoint(
    monitor="NDCG@10",
    mode="max",
    filename="{epoch}-{NDCG@10:.2f}",
)

In [50]:
trainer = Trainer(
    accelerator='cpu',  # TODO: change
    devices=1,
    min_epochs=1,
    max_epochs=6,
    deterministic=True,
    limit_train_batches=2,  # use only 2 batches for each epoch for a test run
    logger = CSVLogger("test_logs"),
    callbacks=[val_metrics_callback, best_ndcg_ckpt],  # pass our callbacks
)

# Replace default trainer with our custom one
model._trainer = trainer

# Fit model. Everything will happen under the hood
model.fit(dataset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type                           | Params | Mode 
-----------------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 987 K  | train
-----------------------------------------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.951     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=6` reached.


<rectools.models.nn.sasrec.SASRecModel at 0x296553d60>

We have checkpoint for best NDCG@10 model in the usual directory for checkpoints

In [51]:
!ls $model.fit_trainer.log_dir/checkpoints

epoch=5-NDCG@10=0.01.ckpt


We also now have metrics in our logs

In [53]:
!head $model.fit_trainer.log_dir/metrics.csv

NDCG@10,Recall@10,Serendipity@10,epoch,step,train_loss,val_loss
0.0006136037991382182,0.005259697791188955,4.036495283799013e-06,0,1,,22.36232566833496
,,,0,1,22.85256004333496,
0.00378932966850698,0.04470742866396904,5.4826059567858465e-06,1,3,,22.194259643554688
,,,1,3,22.471229553222656,
0.004971458576619625,0.048652201890945435,5.865532330062706e-06,2,5,,21.967544555664062
,,,2,5,22.728843688964844,
0.008074083365499973,0.04996712505817413,5.288889951771125e-06,3,7,,21.701507568359375
,,,3,7,22.52100372314453,
0.010768753476440907,0.0788954645395279,3.748174322026898e-06,4,9,,21.411954879760742


Let's load them to read more easily

In [86]:
def get_logs(model: TransformerModelBase) -> tp.Tuple[pd.DataFrame, ...]:
    log_path = Path(model.fit_trainer.log_dir) / "metrics.csv"
    epoch_metrics_df = pd.read_csv(log_path)
    
    loss_df = epoch_metrics_df[["epoch", "train_loss"]].dropna()
    val_loss_df = epoch_metrics_df[["epoch", "val_loss"]].dropna()
    loss_df = pd.merge(loss_df, val_loss_df, how="inner", on="epoch")
    loss_df.reset_index(drop=True, inplace=True)
    
    metrics_df = epoch_metrics_df.drop(columns=["train_loss", "val_loss"]).dropna()
    metrics_df.reset_index(drop=True, inplace=True)

    return loss_df, metrics_df

loss_df, metrics_df = get_logs(model)

loss_df.head()

Unnamed: 0,epoch,train_loss,val_loss
0,0,22.85256,22.362326
1,1,22.47123,22.19426
2,2,22.728844,21.967545
3,3,22.521004,21.701508
4,4,22.202381,21.411955


In [87]:
metrics_df.head()

Unnamed: 0,NDCG@10,Recall@10,Serendipity@10,epoch,step
0,0.000614,0.00526,4e-06,0,1
1,0.003789,0.044707,5e-06,1,3
2,0.004971,0.048652,6e-06,2,5
3,0.008074,0.049967,5e-06,3,7
4,0.010769,0.078895,4e-06,4,9


In [88]:
del model
torch.cuda.empty_cache()

# Advanced training full example
Running full training with all of the described validation features on Kion dataset

In [79]:
model = SASRecModel(
    n_factors=64,
    n_blocks=2,
    n_heads=2,
    dropout_rate=0.2,
    train_min_user_interactions=5,
    session_max_len=50,
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet,),
    get_val_mask_func=get_val_mask_func,  # pass our custom `get_val_mask_func`
)

early_stopping_callback = EarlyStopping(
    monitor=SASRecModel.val_loss_name,   # or just pass "val_loss" here
    mode="min",
)

trainer = Trainer(
    accelerator='cpu',  # TODO: change
    devices=1,
    min_epochs=1,
    max_epochs=100,
    deterministic=True,
    logger = CSVLogger("sasrec_logs"),
    callbacks=[
        val_metrics_callback,  # calculate RecSys metrics
        best_ndcg_ckpt,  # save best NDCG model checkpoint
        last_epoch_ckpt,  # save model checkpoint after the last train epoch
        early_stopping_callback,  # early stopping on validation loss
    ],
)

# Replace default trainer with our custom one
model._trainer = trainer

# Fit model. Everything will happen under the hood
model.fit(dataset)

In [None]:
loss_df, metrics_df = get_logs(model)

In [None]:
!ls $model.fit_trainer.log_dir/checkpoints