# Будем обучать Bert из базового библиотеки Replay

### Модель будет обучаться на решение задачи прогноза следующего фильма каждого пользователя.
### В качестве основы был взят пример из гита RePlay

https://github.com/sb-ai-lab/RePlay/blob/main/examples/10_bert4rec_example.ipynb

## 0. Установка и импорты нужных пакетов, библиотек и данных

Устанавливаем нужные пакеты

In [3]:
!pip install transformers
!pip install torch
!pip install tqdm
!pip install replay-rec
!pip install lightning

Collecting replay-rec
  Downloading replay_rec-0.18.0-py3-none-any.whl.metadata (10 kB)
Collecting fixed-install-nmslib==2.1.2 (from replay-rec)
  Downloading fixed-install-nmslib-2.1.2.tar.gz (196 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.8/196.8 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting hnswlib<0.8.0,>=0.7.0 (from replay-rec)
  Downloading hnswlib-0.7.0.tar.gz (33 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting optuna<3.3.0,>=3.2.0 (from replay-rec)
  Downloading optuna-3.2.0-py3-none-any.whl.metadata (17 kB)
Collecting polars<1.1.0,>=1.0.0 (from replay-rec)
  Downloading polars-1.0.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting psutil<6.1.0,>=6.0.0 (from replay-rec)
  Downloading psutil-6.0.0-cp36-abi3-man

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━

Импортируем библиотеки и фиксируем `seed`

In [4]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter
from replay.utils import get_spark_session
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import Bert4Rec
from replay.models.nn.sequential.bert4rec import (
    Bert4RecPredictionDataset,
    Bert4RecTrainingDataset,
    Bert4RecValidationDataset,
    Bert4RecPredictionBatch,
    Bert4RecModel
)

import pandas as pd

import random
import numpy as np

def set_global_seed(seed: int) -> None:

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_global_seed(42)

import warnings
warnings.filterwarnings("ignore")

In [5]:
spark_session = get_spark_session()



In [6]:
data = pd.read_csv('events.csv')
users = pd.read_csv('user_features.csv')
items = pd.read_csv('item_features.csv')

## 1. Предобработка данных

Переобозначаем `timestamp`, не нарушая порядок: пользователь с одним и тем же timestamp-ом, вероятно, баг - поэтому для поддержания последовательности для модели необходимо переобозначение

In [7]:
data["timestamp"] = data["timestamp"].astype("int64")
data = data.sort_values(by="timestamp")
data["timestamp"] = data.groupby("user_id").cumcount()

После нескольких прогонов моделей с разными гиперпараметрами на основе валидационной выборки был выбран лучший набор, представленный в ячейках скрипта.</br></br>
Итоговую модель обучаем без val выборки.</br>
В качестве тестовой выборки берутся последние интеракции для каждого пользователя.

In [8]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(data)
raw_train_events, raw_train_gt = splitter.split(raw_test_events)

Функция для объединения в одну рабочую структуру данных информации об интерациях, пользователях и фильмах. Накидываем ее на тренировочную и тестовую выборки

In [9]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

In [10]:
train_dataset = Dataset(feature_schema=prepare_feature_schema(is_ground_truth=False),
                        interactions=raw_train_events,
                        query_features=users,
                        item_features=items,
                        check_consistency=True,
                        categorical_encoded=False)
test_dataset = Dataset(feature_schema=prepare_feature_schema(is_ground_truth=False),
                       interactions=raw_test_events,
                       query_features=users,
                       item_features=items,
                       check_consistency=True,
                       categorical_encoded=False)
test_gt = Dataset(feature_schema=prepare_feature_schema(is_ground_truth=True),
                  interactions=raw_test_gt,
                  check_consistency=True,
                  categorical_encoded=False)

Прописываем схему для соотнесения наших выборок с тензорным представлением для работы модели

In [11]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(name=ITEM_FEATURE_NAME,
                      is_seq=True,
                      feature_type=FeatureType.CATEGORICAL,
                      feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
                      feature_hint=FeatureHint.ITEM_ID,
                      embedding_dim=300) #тут пробовал варианты от 100 до 500
                            )

Используем `SequenceTokenizer` для токенизации данных в наших датасетах.</br>
- Токенизация учитывает последовательность интераций пользователя для построения токенов
- Учитывается внутренняя информация пользователей и айтемов
- Классическая BERT-овская токенизация с доп токенами и субтокенами `CLS`, `SEP` для последовательностей, а также `PAD`, `0` и `1` для `attention_mask`

In [12]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(test_dataset)
sequential_validation_gt = tokenizer.transform(test_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt)

test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

## 2. Создание объектов модели, трейнера и даталоадеров

В качестве основы модели выступает архитектура `BERT`:</br>
- Внутренний слой размера `hidden_size=300`
- Количество голов трансформера `head_count=4`
- Количество трансформерных блоков `block_count=2`
- Дропаут `dropout_rate=0.52`</br>
  При начальном обучении лучше всего себя показал вариант с `dropout_rate=0.5`. Для итогового предсказания, так как модель обучается уже на train+val, дропаут немного увеличим для контроля переобучения
- В качестве оптимизатора используем `FatOptimizerFactory` (под капотом `Adam`)

In [13]:
MAX_SEQ_LEN = 100
BATCH_SIZE = 128
NUM_WORKERS = 4

model = Bert4Rec(tensor_schema,
                 block_count=2,
                 head_count=4, # 4, 6, 8
                 max_seq_len=MAX_SEQ_LEN,
                 hidden_size=300, #100, 300, 500
                 dropout_rate=0.52, #0.3, 0.5, 0.7
                 optimizer_factory=FatOptimizerFactory(learning_rate=1e-3)) #1e-3, 1e-4

В качестве основной метрики для коллбеков используем `recall@10` как в итоговой метрике соревнования

In [14]:
checkpoint_callback = ModelCheckpoint(dirpath=".checkpoints",
                                      save_top_k=1,
                                      verbose=True,
                                      monitor="recall@10",
                                      mode="max")

validation_metrics_callback = ValidationMetricsCallback(metrics=["recall"],
                                                        ks=[10],
                                                        item_count=train_dataset.item_count,
                                                        postprocessors=[RemoveSeenItems(sequential_validation_dataset)])

csv_logger = CSVLogger(save_dir=".logs/train", name="RecommenderBERTModelv7")

trainer = L.Trainer(max_epochs=100,
                    callbacks=[checkpoint_callback, validation_metrics_callback],
                    logger=csv_logger)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [15]:
train_dataloader = DataLoader(
    dataset=Bert4RecTrainingDataset(sequential_train_dataset,
                                    max_sequence_length=MAX_SEQ_LEN),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True)

validation_dataloader = DataLoader(
    dataset=Bert4RecValidationDataset(sequential_validation_dataset,
                                      sequential_validation_gt,
                                      sequential_train_dataset,
                                      max_sequence_length=MAX_SEQ_LEN),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True)

## 3. Обучение модели

In [16]:
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | Bert4RecModel    | 4.4 M  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.702    Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | Bert4RecModel    | 4.4 M  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.702    Total estimated model params size (M

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 48: 'recall@10' reached 0.03990 (best 0.03990), saving model to '/content/.checkpoints/epoch=0-step=48.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 48: 'recall@10' reached 0.03990 (best 0.03990), saving model to '/content/.checkpoints/epoch=0-step=48.ckpt' as top 1


k             10
recall  0.039901



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 96: 'recall@10' reached 0.04056 (best 0.04056), saving model to '/content/.checkpoints/epoch=1-step=96.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 96: 'recall@10' reached 0.04056 (best 0.04056), saving model to '/content/.checkpoints/epoch=1-step=96.ckpt' as top 1


k             10
recall  0.040563



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 144: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 144: 'recall@10' was not in top 1


k             10
recall  0.038411



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 192: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 192: 'recall@10' was not in top 1


k             10
recall  0.037417



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 240: 'recall@10' reached 0.04387 (best 0.04387), saving model to '/content/.checkpoints/epoch=4-step=240.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 240: 'recall@10' reached 0.04387 (best 0.04387), saving model to '/content/.checkpoints/epoch=4-step=240.ckpt' as top 1


k             10
recall  0.043874



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 5, global step 288: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 5, global step 288: 'recall@10' was not in top 1


k             10
recall  0.041887



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 6, global step 336: 'recall@10' reached 0.04421 (best 0.04421), saving model to '/content/.checkpoints/epoch=6-step=336.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 6, global step 336: 'recall@10' reached 0.04421 (best 0.04421), saving model to '/content/.checkpoints/epoch=6-step=336.ckpt' as top 1


k             10
recall  0.044205



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 7, global step 384: 'recall@10' reached 0.04934 (best 0.04934), saving model to '/content/.checkpoints/epoch=7-step=384.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 7, global step 384: 'recall@10' reached 0.04934 (best 0.04934), saving model to '/content/.checkpoints/epoch=7-step=384.ckpt' as top 1


k             10
recall  0.049338



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 8, global step 432: 'recall@10' reached 0.05033 (best 0.05033), saving model to '/content/.checkpoints/epoch=8-step=432.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 8, global step 432: 'recall@10' reached 0.05033 (best 0.05033), saving model to '/content/.checkpoints/epoch=8-step=432.ckpt' as top 1


k             10
recall  0.050331



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 9, global step 480: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 9, global step 480: 'recall@10' was not in top 1


k             10
recall  0.046026



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 10, global step 528: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 10, global step 528: 'recall@10' was not in top 1


k             10
recall  0.045364



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 11, global step 576: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 11, global step 576: 'recall@10' was not in top 1


k            10
recall  0.04702



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 12, global step 624: 'recall@10' reached 0.05298 (best 0.05298), saving model to '/content/.checkpoints/epoch=12-step=624.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 12, global step 624: 'recall@10' reached 0.05298 (best 0.05298), saving model to '/content/.checkpoints/epoch=12-step=624.ckpt' as top 1


k            10
recall  0.05298



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 13, global step 672: 'recall@10' reached 0.05546 (best 0.05546), saving model to '/content/.checkpoints/epoch=13-step=672.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 13, global step 672: 'recall@10' reached 0.05546 (best 0.05546), saving model to '/content/.checkpoints/epoch=13-step=672.ckpt' as top 1


k             10
recall  0.055464



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 14, global step 720: 'recall@10' reached 0.06341 (best 0.06341), saving model to '/content/.checkpoints/epoch=14-step=720.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 14, global step 720: 'recall@10' reached 0.06341 (best 0.06341), saving model to '/content/.checkpoints/epoch=14-step=720.ckpt' as top 1


k             10
recall  0.063411



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 15, global step 768: 'recall@10' reached 0.08129 (best 0.08129), saving model to '/content/.checkpoints/epoch=15-step=768.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 15, global step 768: 'recall@10' reached 0.08129 (best 0.08129), saving model to '/content/.checkpoints/epoch=15-step=768.ckpt' as top 1


k             10
recall  0.081291



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 16, global step 816: 'recall@10' reached 0.09371 (best 0.09371), saving model to '/content/.checkpoints/epoch=16-step=816.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 16, global step 816: 'recall@10' reached 0.09371 (best 0.09371), saving model to '/content/.checkpoints/epoch=16-step=816.ckpt' as top 1


k             10
recall  0.093709



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 17, global step 864: 'recall@10' reached 0.10960 (best 0.10960), saving model to '/content/.checkpoints/epoch=17-step=864.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 17, global step 864: 'recall@10' reached 0.10960 (best 0.10960), saving model to '/content/.checkpoints/epoch=17-step=864.ckpt' as top 1


k             10
recall  0.109603



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 18, global step 912: 'recall@10' reached 0.11060 (best 0.11060), saving model to '/content/.checkpoints/epoch=18-step=912.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 18, global step 912: 'recall@10' reached 0.11060 (best 0.11060), saving model to '/content/.checkpoints/epoch=18-step=912.ckpt' as top 1


k             10
recall  0.110596



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 19, global step 960: 'recall@10' reached 0.11424 (best 0.11424), saving model to '/content/.checkpoints/epoch=19-step=960.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 19, global step 960: 'recall@10' reached 0.11424 (best 0.11424), saving model to '/content/.checkpoints/epoch=19-step=960.ckpt' as top 1


k             10
recall  0.114238



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 20, global step 1008: 'recall@10' reached 0.12235 (best 0.12235), saving model to '/content/.checkpoints/epoch=20-step=1008.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 20, global step 1008: 'recall@10' reached 0.12235 (best 0.12235), saving model to '/content/.checkpoints/epoch=20-step=1008.ckpt' as top 1


k             10
recall  0.122351



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 21, global step 1056: 'recall@10' reached 0.12781 (best 0.12781), saving model to '/content/.checkpoints/epoch=21-step=1056.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 21, global step 1056: 'recall@10' reached 0.12781 (best 0.12781), saving model to '/content/.checkpoints/epoch=21-step=1056.ckpt' as top 1


k             10
recall  0.127815



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 22, global step 1104: 'recall@10' reached 0.12930 (best 0.12930), saving model to '/content/.checkpoints/epoch=22-step=1104.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 22, global step 1104: 'recall@10' reached 0.12930 (best 0.12930), saving model to '/content/.checkpoints/epoch=22-step=1104.ckpt' as top 1


k             10
recall  0.129305



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 23, global step 1152: 'recall@10' reached 0.13825 (best 0.13825), saving model to '/content/.checkpoints/epoch=23-step=1152.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 23, global step 1152: 'recall@10' reached 0.13825 (best 0.13825), saving model to '/content/.checkpoints/epoch=23-step=1152.ckpt' as top 1


k             10
recall  0.138245



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 24, global step 1200: 'recall@10' reached 0.14023 (best 0.14023), saving model to '/content/.checkpoints/epoch=24-step=1200.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 24, global step 1200: 'recall@10' reached 0.14023 (best 0.14023), saving model to '/content/.checkpoints/epoch=24-step=1200.ckpt' as top 1


k             10
recall  0.140232



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 25, global step 1248: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 25, global step 1248: 'recall@10' was not in top 1


k             10
recall  0.128311



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 26, global step 1296: 'recall@10' reached 0.14156 (best 0.14156), saving model to '/content/.checkpoints/epoch=26-step=1296.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 26, global step 1296: 'recall@10' reached 0.14156 (best 0.14156), saving model to '/content/.checkpoints/epoch=26-step=1296.ckpt' as top 1


k             10
recall  0.141556



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 27, global step 1344: 'recall@10' reached 0.14570 (best 0.14570), saving model to '/content/.checkpoints/epoch=27-step=1344.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 27, global step 1344: 'recall@10' reached 0.14570 (best 0.14570), saving model to '/content/.checkpoints/epoch=27-step=1344.ckpt' as top 1


k             10
recall  0.145695



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 28, global step 1392: 'recall@10' reached 0.14917 (best 0.14917), saving model to '/content/.checkpoints/epoch=28-step=1392.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 28, global step 1392: 'recall@10' reached 0.14917 (best 0.14917), saving model to '/content/.checkpoints/epoch=28-step=1392.ckpt' as top 1


k             10
recall  0.149172



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 29, global step 1440: 'recall@10' reached 0.15116 (best 0.15116), saving model to '/content/.checkpoints/epoch=29-step=1440.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 29, global step 1440: 'recall@10' reached 0.15116 (best 0.15116), saving model to '/content/.checkpoints/epoch=29-step=1440.ckpt' as top 1


k             10
recall  0.151159



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 30, global step 1488: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 30, global step 1488: 'recall@10' was not in top 1


k             10
recall  0.150166



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 31, global step 1536: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 31, global step 1536: 'recall@10' was not in top 1


k            10
recall  0.14106



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 32, global step 1584: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 32, global step 1584: 'recall@10' was not in top 1


k             10
recall  0.146854



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 33, global step 1632: 'recall@10' reached 0.15662 (best 0.15662), saving model to '/content/.checkpoints/epoch=33-step=1632.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 33, global step 1632: 'recall@10' reached 0.15662 (best 0.15662), saving model to '/content/.checkpoints/epoch=33-step=1632.ckpt' as top 1


k             10
recall  0.156623



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 34, global step 1680: 'recall@10' reached 0.16126 (best 0.16126), saving model to '/content/.checkpoints/epoch=34-step=1680.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 34, global step 1680: 'recall@10' reached 0.16126 (best 0.16126), saving model to '/content/.checkpoints/epoch=34-step=1680.ckpt' as top 1


k             10
recall  0.161258



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 35, global step 1728: 'recall@10' reached 0.16225 (best 0.16225), saving model to '/content/.checkpoints/epoch=35-step=1728.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 35, global step 1728: 'recall@10' reached 0.16225 (best 0.16225), saving model to '/content/.checkpoints/epoch=35-step=1728.ckpt' as top 1


k             10
recall  0.162252



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 36, global step 1776: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 36, global step 1776: 'recall@10' was not in top 1


k             10
recall  0.161093



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 37, global step 1824: 'recall@10' reached 0.16358 (best 0.16358), saving model to '/content/.checkpoints/epoch=37-step=1824.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 37, global step 1824: 'recall@10' reached 0.16358 (best 0.16358), saving model to '/content/.checkpoints/epoch=37-step=1824.ckpt' as top 1


k             10
recall  0.163576



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 38, global step 1872: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 38, global step 1872: 'recall@10' was not in top 1


k             10
recall  0.160265



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 39, global step 1920: 'recall@10' reached 0.16606 (best 0.16606), saving model to '/content/.checkpoints/epoch=39-step=1920.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 39, global step 1920: 'recall@10' reached 0.16606 (best 0.16606), saving model to '/content/.checkpoints/epoch=39-step=1920.ckpt' as top 1


k            10
recall  0.16606



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 40, global step 1968: 'recall@10' reached 0.16639 (best 0.16639), saving model to '/content/.checkpoints/epoch=40-step=1968.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 40, global step 1968: 'recall@10' reached 0.16639 (best 0.16639), saving model to '/content/.checkpoints/epoch=40-step=1968.ckpt' as top 1


k             10
recall  0.166391



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 41, global step 2016: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 41, global step 2016: 'recall@10' was not in top 1


k            10
recall  0.16043



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 42, global step 2064: 'recall@10' reached 0.16705 (best 0.16705), saving model to '/content/.checkpoints/epoch=42-step=2064.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 42, global step 2064: 'recall@10' reached 0.16705 (best 0.16705), saving model to '/content/.checkpoints/epoch=42-step=2064.ckpt' as top 1


k             10
recall  0.167053



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 43, global step 2112: 'recall@10' reached 0.17500 (best 0.17500), saving model to '/content/.checkpoints/epoch=43-step=2112.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 43, global step 2112: 'recall@10' reached 0.17500 (best 0.17500), saving model to '/content/.checkpoints/epoch=43-step=2112.ckpt' as top 1


k          10
recall  0.175



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 44, global step 2160: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 44, global step 2160: 'recall@10' was not in top 1


k             10
recall  0.164238



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 45, global step 2208: 'recall@10' reached 0.17517 (best 0.17517), saving model to '/content/.checkpoints/epoch=45-step=2208.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 45, global step 2208: 'recall@10' reached 0.17517 (best 0.17517), saving model to '/content/.checkpoints/epoch=45-step=2208.ckpt' as top 1


k             10
recall  0.175166



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 46, global step 2256: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 46, global step 2256: 'recall@10' was not in top 1


k             10
recall  0.170199



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 47, global step 2304: 'recall@10' reached 0.17748 (best 0.17748), saving model to '/content/.checkpoints/epoch=47-step=2304.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 47, global step 2304: 'recall@10' reached 0.17748 (best 0.17748), saving model to '/content/.checkpoints/epoch=47-step=2304.ckpt' as top 1


k             10
recall  0.177483



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 48, global step 2352: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 48, global step 2352: 'recall@10' was not in top 1


k             10
recall  0.173344



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 49, global step 2400: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 49, global step 2400: 'recall@10' was not in top 1


k             10
recall  0.171192



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 50, global step 2448: 'recall@10' reached 0.17947 (best 0.17947), saving model to '/content/.checkpoints/epoch=50-step=2448.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 50, global step 2448: 'recall@10' reached 0.17947 (best 0.17947), saving model to '/content/.checkpoints/epoch=50-step=2448.ckpt' as top 1


k            10
recall  0.17947



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 51, global step 2496: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 51, global step 2496: 'recall@10' was not in top 1


k             10
recall  0.174834



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 52, global step 2544: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 52, global step 2544: 'recall@10' was not in top 1


k             10
recall  0.172848



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 53, global step 2592: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 53, global step 2592: 'recall@10' was not in top 1


k             10
recall  0.175497



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 54, global step 2640: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 54, global step 2640: 'recall@10' was not in top 1


k            10
recall  0.17202



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 55, global step 2688: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 55, global step 2688: 'recall@10' was not in top 1


k             10
recall  0.174007



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 56, global step 2736: 'recall@10' reached 0.18560 (best 0.18560), saving model to '/content/.checkpoints/epoch=56-step=2736.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 56, global step 2736: 'recall@10' reached 0.18560 (best 0.18560), saving model to '/content/.checkpoints/epoch=56-step=2736.ckpt' as top 1


k             10
recall  0.185596



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 57, global step 2784: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 57, global step 2784: 'recall@10' was not in top 1


k             10
recall  0.173013



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 58, global step 2832: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 58, global step 2832: 'recall@10' was not in top 1


k             10
recall  0.176987



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 59, global step 2880: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 59, global step 2880: 'recall@10' was not in top 1


k             10
recall  0.182781



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 60, global step 2928: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 60, global step 2928: 'recall@10' was not in top 1


k             10
recall  0.181954



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 61, global step 2976: 'recall@10' reached 0.18825 (best 0.18825), saving model to '/content/.checkpoints/epoch=61-step=2976.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 61, global step 2976: 'recall@10' reached 0.18825 (best 0.18825), saving model to '/content/.checkpoints/epoch=61-step=2976.ckpt' as top 1


k             10
recall  0.188245



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 62, global step 3024: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 62, global step 3024: 'recall@10' was not in top 1


k             10
recall  0.174172



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 63, global step 3072: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 63, global step 3072: 'recall@10' was not in top 1


k             10
recall  0.184934



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 64, global step 3120: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 64, global step 3120: 'recall@10' was not in top 1


k             10
recall  0.180464



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 65, global step 3168: 'recall@10' reached 0.18924 (best 0.18924), saving model to '/content/.checkpoints/epoch=65-step=3168.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 65, global step 3168: 'recall@10' reached 0.18924 (best 0.18924), saving model to '/content/.checkpoints/epoch=65-step=3168.ckpt' as top 1


k             10
recall  0.189238



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 66, global step 3216: 'recall@10' reached 0.19255 (best 0.19255), saving model to '/content/.checkpoints/epoch=66-step=3216.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 66, global step 3216: 'recall@10' reached 0.19255 (best 0.19255), saving model to '/content/.checkpoints/epoch=66-step=3216.ckpt' as top 1


k            10
recall  0.19255



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 67, global step 3264: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 67, global step 3264: 'recall@10' was not in top 1


k             10
recall  0.191225



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 68, global step 3312: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 68, global step 3312: 'recall@10' was not in top 1


k             10
recall  0.183444



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 69, global step 3360: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 69, global step 3360: 'recall@10' was not in top 1


k             10
recall  0.190728



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 70, global step 3408: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 70, global step 3408: 'recall@10' was not in top 1


k             10
recall  0.187086



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 71, global step 3456: 'recall@10' reached 0.19288 (best 0.19288), saving model to '/content/.checkpoints/epoch=71-step=3456.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 71, global step 3456: 'recall@10' reached 0.19288 (best 0.19288), saving model to '/content/.checkpoints/epoch=71-step=3456.ckpt' as top 1


k             10
recall  0.192881



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 72, global step 3504: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 72, global step 3504: 'recall@10' was not in top 1


k             10
recall  0.189073



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 73, global step 3552: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 73, global step 3552: 'recall@10' was not in top 1


k             10
recall  0.187748



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 74, global step 3600: 'recall@10' reached 0.19470 (best 0.19470), saving model to '/content/.checkpoints/epoch=74-step=3600.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 74, global step 3600: 'recall@10' reached 0.19470 (best 0.19470), saving model to '/content/.checkpoints/epoch=74-step=3600.ckpt' as top 1


k             10
recall  0.194702



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 75, global step 3648: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 75, global step 3648: 'recall@10' was not in top 1


k             10
recall  0.193046



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 76, global step 3696: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 76, global step 3696: 'recall@10' was not in top 1


k             10
recall  0.173013



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 77, global step 3744: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 77, global step 3744: 'recall@10' was not in top 1


k             10
recall  0.189073



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 78, global step 3792: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 78, global step 3792: 'recall@10' was not in top 1


k             10
recall  0.177649



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 79, global step 3840: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 79, global step 3840: 'recall@10' was not in top 1


k             10
recall  0.184934



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 80, global step 3888: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 80, global step 3888: 'recall@10' was not in top 1


k             10
recall  0.189238



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 81, global step 3936: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 81, global step 3936: 'recall@10' was not in top 1


k             10
recall  0.194702



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 82, global step 3984: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 82, global step 3984: 'recall@10' was not in top 1


k             10
recall  0.194536



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 83, global step 4032: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 83, global step 4032: 'recall@10' was not in top 1


k            10
recall  0.18394



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 84, global step 4080: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 84, global step 4080: 'recall@10' was not in top 1


k             10
recall  0.169536



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 85, global step 4128: 'recall@10' reached 0.19636 (best 0.19636), saving model to '/content/.checkpoints/epoch=85-step=4128.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 85, global step 4128: 'recall@10' reached 0.19636 (best 0.19636), saving model to '/content/.checkpoints/epoch=85-step=4128.ckpt' as top 1


k             10
recall  0.196358



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 86, global step 4176: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 86, global step 4176: 'recall@10' was not in top 1


k             10
recall  0.168377



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 87, global step 4224: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 87, global step 4224: 'recall@10' was not in top 1


k             10
recall  0.195364



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 88, global step 4272: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 88, global step 4272: 'recall@10' was not in top 1


k             10
recall  0.193543



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 89, global step 4320: 'recall@10' reached 0.19652 (best 0.19652), saving model to '/content/.checkpoints/epoch=89-step=4320.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 89, global step 4320: 'recall@10' reached 0.19652 (best 0.19652), saving model to '/content/.checkpoints/epoch=89-step=4320.ckpt' as top 1


k             10
recall  0.196523



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 90, global step 4368: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 90, global step 4368: 'recall@10' was not in top 1


k             10
recall  0.189238



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 91, global step 4416: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 91, global step 4416: 'recall@10' was not in top 1


k             10
recall  0.190894



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 92, global step 4464: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 92, global step 4464: 'recall@10' was not in top 1


k            10
recall  0.18245



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 93, global step 4512: 'recall@10' reached 0.19735 (best 0.19735), saving model to '/content/.checkpoints/epoch=93-step=4512.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 93, global step 4512: 'recall@10' reached 0.19735 (best 0.19735), saving model to '/content/.checkpoints/epoch=93-step=4512.ckpt' as top 1


k             10
recall  0.197351



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 94, global step 4560: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 94, global step 4560: 'recall@10' was not in top 1


k             10
recall  0.183444



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 95, global step 4608: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 95, global step 4608: 'recall@10' was not in top 1


k             10
recall  0.196192



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 96, global step 4656: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 96, global step 4656: 'recall@10' was not in top 1


k             10
recall  0.179967



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 97, global step 4704: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 97, global step 4704: 'recall@10' was not in top 1


k             10
recall  0.190563



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 98, global step 4752: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 98, global step 4752: 'recall@10' was not in top 1


k             10
recall  0.192881



Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 99, global step 4800: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 99, global step 4800: 'recall@10' was not in top 1
INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


k             10
recall  0.183775



In [17]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

In [18]:
torch.save(best_model.state_dict(), "./RecommenderBERTModelv7.pth")

## 4. Инференс модели

Перезапуск ячеек был вызван сначала получением топ100 скоров для последующего отправления в бустинг. В этом скрипте остается топ10 для формирования итоговых рекомендаций

In [25]:
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True)

csv_logger = CSVLogger(save_dir=".logs/test", name="RecommenderBERTModelv7")

В качестве коллбека используем специальный для `PySpark DataFrame`

In [26]:
TOPK = [10]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

spark_prediction_callback = SparkPredictionCallback(spark_session=spark_session,
                                                    top_k=max(TOPK),
                                                    query_column="user_id",
                                                    item_column="item_id",
                                                    rating_column="score",
                                                    postprocessors=postprocessors)


Прокидываем в инференс тестовую выборку (последнее наблюдение датафрейма)

In [27]:
trainer = L.Trainer(callbacks=[spark_prediction_callback],
                    logger=csv_logger,
                    inference_mode=True)

trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

spark_res = spark_prediction_callback.get_result()

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [28]:
recs = tokenizer.query_and_item_id_encoder.inverse_transform(spark_res)



In [29]:
result = recs.toPandas()
result

Unnamed: 0,score,user_id,item_id
0,2.702602,3,94
1,2.635417,3,2947
2,7.894511,4,755
3,7.267799,4,566
4,7.451857,18,495
...,...,...,...
60395,-0.840430,6016,793
60396,5.580224,6023,1868
60397,5.453481,6023,3035
60398,8.361889,6027,470


In [30]:
# result.to_csv('bert_scores.csv', index=False, header=True)

## 5. Формируем итоговый submission

In [31]:
top_recs = (
    result.sort_values(['user_id', 'score'], ascending=[True, False])
    .groupby('user_id')
    .agg(top_items=('item_id', lambda x: ' '.join(map(str, x.head(10)))))
    .reset_index())

top_recs.columns = ['user_id', 'item_id']
top_recs['user_id'] = top_recs['user_id'].astype(int)
top_recs = top_recs.sort_values('user_id', ascending=True)
top_recs.to_csv('submit_The_boysV6.csv', index=False, header=True)

top_recs

Unnamed: 0,user_id,item_id
0,0,1422 2003 1101 2593 708 1287 2801 1323 1259 3411
1,1,232 1246 350 560 3101 2518 1813 1459 1686 2561
2,2,2774 234 1371 2643 1781 2311 3431 1560 2354 382
3,3,3365 2169 3390 94 1316 3562 2947 3272 3441 3153
4,4,1160 983 755 394 3002 2185 566 3667 3602 3087
...,...,...
6035,6035,3665 2800 674 2366 513 3646 1011 3668 3216 1340
6036,6036,1102 3142 2502 922 2664 1057 3059 1011 3646 2800
6037,6037,1102 2502 1134 1379 1439 2800 3059 3646 3475 12
6038,6038,1514 3278 75 2420 1762 3185 1100 700 776 84
