In [1]:
# TODO: will remove
import sys
sys.path.append("../../")

In [2]:
import numpy as np
import os
import pandas as pd
import torch
import typing as tp
import warnings
from copy import deepcopy
from typing import Dict, List, Tuple
from pathlib import Path

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from torch import Tensor
from torch.utils.data import DataLoader
from rectools import Columns
from rectools import ExternalIds
from rectools.dataset import Dataset, Interactions, IdMap
from rectools.metrics import NDCG, Recall, Serendipity, calc_metrics

from rectools.models.sasrec import (
    SASRecModel,
    SASRecDataPreparator,
    SequenceDataset, 
    SessionEncoderLightningModule,
    TransformerBasedSessionEncoder,
    PADDING_VALUE,
    IdEmbeddingsItemNet,
)


# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.simplefilter("ignore", UserWarning)

# Load data

In [3]:
%%time
!wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_en.zip -O data_en.zip
!unzip -o data_en.zip
!rm data_en.zip

In [4]:
# Download dataset
DATA_PATH = Path("data_en")
items = pd.read_csv(DATA_PATH / 'items_en.csv', index_col=0)
interactions = (
    pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=["last_watch_dt"])
    .rename(columns={"last_watch_dt": Columns.Datetime})
)

print(interactions.shape)
interactions.head(2)

(5476251, 5)


Unnamed: 0,user_id,item_id,datetime,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0


In [5]:
interactions[Columns.User].nunique(), interactions[Columns.Item].nunique()

(962179, 15706)

In [6]:
# Process interactions
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)
interactions = interactions[["user_id", "item_id", "datetime", "weight"]]
print(interactions.shape)
interactions.head(2)

(5476251, 4)


Unnamed: 0,user_id,item_id,datetime,weight
0,176549,9506,2021-05-11,3
1,699317,1659,2021-05-29,3


In [7]:
# Process item features
items = items.loc[items[Columns.Item].isin(interactions[Columns.Item])].copy()
items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")
genre_feature = items[["item_id", "genre"]].explode("genre")
genre_feature.columns = ["id", "value"]
genre_feature["feature"] = "genre"
content_feature = items.reindex(columns=[Columns.Item, "content_type"])
content_feature.columns = ["id", "value"]
content_feature["feature"] = "content_type"
item_features = pd.concat((genre_feature, content_feature))

In [8]:
RANDOM_STATE=60
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 60


60

In [9]:
dataset_no_features = Dataset.construct(interactions)
dataset_no_features

Dataset(user_id_map=IdMap(external_ids=array([176549, 699317, 656683, ..., 805174, 648596, 697262])), item_id_map=IdMap(external_ids=array([ 9506,  1659,  7107, ..., 10064, 13019, 10542])), interactions=Interactions(df=         user_id  item_id  weight   datetime
0              0        0     3.0 2021-05-11
1              1        1     3.0 2021-05-29
2              2        2     1.0 2021-05-09
3              3        3     3.0 2021-07-05
4              4        0     3.0 2021-04-30
...          ...      ...     ...        ...
5476246   962177      208     1.0 2021-08-13
5476247   224686     2690     3.0 2021-04-13
5476248   962178       21     3.0 2021-08-20
5476249     7934     1725     3.0 2021-04-19
5476250   631989      157     3.0 2021-08-15

[5476251 rows x 4 columns]), user_features=None, item_features=None)

# **Custome Validation** (Leave-One-Out Strategy)

In [10]:
def create_recos_df_from_logits(
    logits: torch.Tensor, candidates: torch.Tensor, user_ids: tp.List[int], top_k: int
) -> pd.DataFrame:
    _, indexes = logits.topk(k=top_k)
    sorted_recos = candidates.gather(1, indexes).tolist()

    batch_recos = pd.DataFrame(
        {
            Columns.User: user_ids,
            Columns.Item: sorted_recos
        }
    ).explode(column=Columns.Item)
    batch_recos[Columns.Rank] = batch_recos.groupby(Columns.User, sort=False).cumcount() + 1

    return batch_recos


class SequenceDatasetValidation(SequenceDataset):
    
    def __init__(self, sessions: List[List[int]], weights: List[List[float]], user_ids: List[int]):
        super().__init__(sessions=sessions, weights=weights)
        self.user_ids = user_ids

    def __len__(self) -> int:
        return len(self.sessions)

    def __getitem__(self, index: int) -> Tuple[List[int], List[float], List[int]]:
        session, weights = super().__getitem__(index=index)
        user_ids = self.user_ids[index]
        return session, weights, user_ids

    @classmethod
    def from_interactions(
        cls,
        interactions: pd.DataFrame,
    ) -> "SequenceDatasetValidation":
        """
        Group interactions by user.
        Construct SequenceDataset from grouped interactions.

        Parameters
        ----------
        interactions: pd.DataFrame
            User-item interactions.
        """
        sessions = (
            interactions.sort_values(Columns.Datetime)
            .groupby(Columns.User, sort=True, as_index=False)[[Columns.Item, Columns.Weight]]
            .agg(list)
        )
        user_ids, sessions, weights = (
            sessions[Columns.User].to_list(),
            sessions[Columns.Item].to_list(),
            sessions[Columns.Weight].to_list(),
        )

        return cls(sessions=sessions, weights=weights, user_ids=user_ids)


class SASRecDataPreparatorValidate(SASRecDataPreparator):

    def __init__(
        self,
        session_max_len: int,
        batch_size: int,
        dataloader_num_workers: int,
        shuffle_train: bool = True,
        item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,),
        train_min_user_interactions: int = 2,
        n_negatives: tp.Optional[int] = None,
    ) -> None:
        super().__init__(
            session_max_len=session_max_len,
            batch_size=batch_size,
            dataloader_num_workers=dataloader_num_workers,
            shuffle_train=shuffle_train,
            item_extra_tokens=item_extra_tokens,
            train_min_user_interactions=train_min_user_interactions,
            n_negatives=n_negatives,
        )
        
        self.val_k_out: int
        self.num_users_interacted_with_item: pd.DataFrame

    def process_dataset_train(self, dataset: Dataset, val_users: ExternalIds, val_k_out: int) -> tp.Dict[str, Dataset]:
        self.val_k_out = val_k_out

        interactions = dataset.get_raw_interactions()

        interactions_train = interactions.copy()
        interactions_val = interactions.copy()

        ### Ctreating train dataset
        interactions_train[f"{Columns.Rank}_inverse"] = (
            interactions_train.sort_values(Columns.Datetime, ascending=False)
            .groupby(Columns.User)
            .cumcount() + 1
        )
        mask_train = ~(
            (interactions_train[Columns.User].isin(val_users))
            & (interactions_train[f"{Columns.Rank}_inverse"].isin(range(1, self.val_k_out + 1)))
        )
        interactions_train.drop(columns=f"{Columns.Rank}_inverse", inplace=True)

        interactions_train = interactions_train[mask_train]

        user_id_map = IdMap.from_values(interactions_train[Columns.User].values)
        item_id_map = IdMap.from_values(interactions_train[Columns.Item].values)
        item_features = None
        if dataset.item_features is not None:
            item_features = dataset.item_features.take(item_id_map.internal_ids)

        interactions_train = Interactions.from_raw(
            interactions_train, user_id_map,  item_id_map, keep_extra_cols=False
        )
        dataset_train = Dataset(user_id_map, dataset.item_id_map, interactions_train, item_features=item_features)
        processed_dataset_train = super().process_dataset_train(dataset_train)
        
        ### Ctreating validation dataset
        interactions_val = (
            interactions_val[
                (interactions_val[Columns.User].isin(val_users))
                & (interactions_val[Columns.User].isin(processed_dataset_train.user_id_map.to_external))
                & (interactions_val[Columns.Item].isin(processed_dataset_train.item_id_map.to_external))
            ]
        )
        interactions_val[f"{Columns.Rank}_inverse"] = (
            interactions_val.sort_values(Columns.Datetime, ascending=False)
            .groupby(Columns.User)
            .cumcount() + 1
        )
        mask_val = interactions_val[f"{Columns.Rank}_inverse"].isin(range(1, self.val_k_out + 1))
        interactions_val.drop(columns=f"{Columns.Rank}_inverse", inplace=True)

        interactions_val.loc[~mask_val, Columns.Weight] = 0
        interactions_val = interactions_val.sort_values(Columns.Datetime).groupby(Columns.User).tail(self.session_max_len + self.val_k_out)
        interactions_val = Interactions.from_raw(
            interactions_val, 
            processed_dataset_train.user_id_map, 
            processed_dataset_train.item_id_map, 
            keep_extra_cols=False,
        )
        processed_dataset_val = Dataset(
            processed_dataset_train.user_id_map, 
            processed_dataset_train.item_id_map, 
            interactions_val,
            item_features=processed_dataset_train.item_features,
        )

        num_users_interacted_with_item = (
            processed_dataset_train.interactions.df.groupby(Columns.Item, sort=False, as_index=False)[
                Columns.User
            ].nunique()
            .rename(columns={Columns.User: "popularity"})
        )
        num_users_interacted_with_item = num_users_interacted_with_item[
            ~num_users_interacted_with_item[Columns.Item].isin(range(self.n_item_extra_tokens))
        ].reset_index(drop=True)
        num_users_interacted_with_item["popularity"] /= num_users_interacted_with_item["popularity"].sum()
        
        self.num_users_interacted_with_item = num_users_interacted_with_item
        return {"train": processed_dataset_train, "val": processed_dataset_val}
    
    def get_dataloader_val(self, processed_dataset: tp.Optional[Dataset], val_neg_candidates: int) -> tp.Optional[DataLoader]:
        sequence_dataset = SequenceDatasetValidation.from_interactions(processed_dataset.interactions.df)
        train_dataloader = DataLoader(
            sequence_dataset,
            collate_fn=lambda batch: self._collate_fn_val(batch, val_neg_candidates),
            batch_size=self.batch_size,
            num_workers=self.dataloader_num_workers,
            shuffle=False,
        )
        return train_dataloader

    def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]], val_neg_candidates: int) -> Dict[str, Tensor]:
        batch_size = len(batch)
        x = np.zeros((batch_size, self.session_max_len))
        y = np.zeros((batch_size, self.val_k_out))
        yw = np.zeros((batch_size, self.val_k_out))
        user_ids = np.zeros((batch_size, self.val_k_out))
        # uniformly_neg_candidates = np.zeros((batch_size, val_neg_candidates))
        popularity_neg_candidates = np.zeros((batch_size, val_neg_candidates))

        for i, (ses, ses_weights, user_id) in enumerate(batch):
            input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
            
            target_idx = [idx for idx, weight in enumerate(ses_weights) if weight != 0]

            targets = list(map(ses.__getitem__, target_idx))
            targets_weights = list(map(ses_weights.__getitem__, target_idx))
            user = [user_id for _ in range(len(target_idx))]

            x[i, -len(input_session) :] = input_session[-self.session_max_len :]  # ses: [session_len] -> x[i]: [session_max_len]
            y[i, :] = targets # y[i]: [val_k_out]
            yw[i, :] = targets_weights  # yw[i]: [val_k_out]
            user_ids[i, :] = user  # u_ids[i]: [val_k_out]
            
            # # Uniformly Sampled
            # low_neg_item_id = self.n_item_extra_tokens
            # high_neg_item_id = self.item_id_map.size
            # neg = torch.randperm(high_neg_item_id - low_neg_item_id) + low_neg_item_id  # [self.item_id_map.size - self.n_item_extra_tokens]
            # uniformly_neg_candidates[i, :] = [n for n in neg.tolist() if n not in targets][: val_neg_candidates]

            # Popular Sampled
            index_neg = torch.multinomial(
                torch.tensor(self.num_users_interacted_with_item["popularity"].tolist()), 
                num_samples=val_neg_candidates + self.val_k_out,
                replacement=False
            )
            neg = self.num_users_interacted_with_item[Columns.Item].iloc[index_neg].tolist()
            popularity_neg_candidates[i, :] = [n for n in neg if n not in targets][: val_neg_candidates]

        batch_dict = {
            "x": torch.LongTensor(x), 
            "y": torch.LongTensor(y), 
            "yw": torch.FloatTensor(yw),
            "user_ids": torch.LongTensor(user_ids),
            # "uniformly_neg_candidates": torch.LongTensor(uniformly_neg_candidates),
            "popularity_neg_candidates": torch.LongTensor(popularity_neg_candidates),
        }
        # TODO: we are sampling negatives for paddings
        if self.n_negatives is not None:
            negatives = torch.randint(
                low=self.n_item_extra_tokens,
                high=self.item_id_map.size,
                size=(batch_size, self.val_k_out, self.n_negatives),
            )  # [batch_size, session_max_len, n_negatives]
            batch_dict["negatives"] = negatives
        return batch_dict


class SessionEncoderLightningModuleValidate(SessionEncoderLightningModule):

    def __init__(
        self,
        torch_model: TransformerBasedSessionEncoder,
        lr: float,
        gbce_t: float,
        n_item_extra_tokens: int,
        val_metrics: tp.Dict,
        val_dataset: Dataset,
        loss: str = "softmax",
        adam_betas: Tuple[float, float] = (0.9, 0.98),
        val_k_out: int = 1,
    ):
        super().__init__(
            torch_model=torch_model,
            lr=lr,
            gbce_t=gbce_t,
            n_item_extra_tokens=n_item_extra_tokens,
            loss=loss,
            adam_betas=adam_betas
        )        
        
        interactions = val_dataset.interactions.df

        self.val_dataset = val_dataset
        self.interactions = interactions[interactions[Columns.Weight] != 0]
        self.prev_interactions = interactions[interactions[Columns.Weight] == 0]
        self.catalog = val_dataset.item_id_map.to_internal
        
        self.val_metrics = val_metrics
        self.metric_max_k = max([metric.k for metric in val_metrics.values()])

        self.validation_step_recos_sampled = []
        self.validation_step_recos_unsampled = []
        self.val_metrics_result = []
        self.val_k_out = val_k_out

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        train_loss = super().training_step(batch, batch_idx)
        self.log("train/loss", train_loss, on_step=False, on_epoch=True, prog_bar=True)
        return train_loss

    def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor:
        """Validation step."""
        x, y, w = batch["x"], batch["y"], batch["yw"]
        sampled_neg_candidates = batch["popularity_neg_candidates"]

        user_ids = []
        for uid in batch["user_ids"].tolist():
            user_ids.extend(uid)

        last_logits_for_metrics = None
        if self.loss == "softmax":
            logits = self._get_full_catalog_logits(x)
            last_logits_for_metrics = logits[:, -1:, :]
            val_loss = self._calc_softmax_loss(last_logits_for_metrics, y, w)
        elif self.loss == "BCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)
            last_logits = logits[:, -1:, :]
            val_loss = self._calc_bce_loss(last_logits, y, w)
        elif self.loss == "gBCE":
            negatives = batch["negatives"]
            logits = self._get_pos_neg_logits(x, y, negatives)
            last_logits = logits[:, -1:, :]
            val_loss = self._calc_gbce_loss(last_logits, y, w, negatives)
        else:
            raise ValueError(f"loss {self.loss} is not supported")

        if last_logits_for_metrics is not None:
            last_logits_for_metrics = last_logits_for_metrics[:, -1, :].detach().cpu()

            # Popularity sampled metrics
            sampled_candidates = torch.concat([y, sampled_neg_candidates], dim=1).detach().cpu()

            logits_candidates = last_logits_for_metrics.gather(1, sampled_candidates)
            batch_recos_sampled = create_recos_df_from_logits(
                logits=logits_candidates, 
                candidates=sampled_candidates, 
                user_ids=user_ids, 
                top_k=logits_candidates.shape[1]
            )
            
            self.validation_step_recos_sampled.append(batch_recos_sampled)

            # Unsampled metrics
            unsampled_candidates = torch.arange(0, self.torch_model.item_model.n_items + 1)
            batch_recos_unsampled = create_recos_df_from_logits(
                logits=last_logits_for_metrics,
                candidates=unsampled_candidates.repeat(last_logits_for_metrics.shape[0], 1), 
                user_ids=user_ids,
                top_k=self.metric_max_k
            )

            self.validation_step_recos_unsampled.append(batch_recos_unsampled)

        else:
            warnings.warn("Can not caclulate RecSys metrics with `BCE` and `gBCE` losses")

        self.log("val/loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
        return val_loss
    
    def on_validation_epoch_end(self) -> None:
        if self.loss not in ["BCE", "gBCE"]:
            sampled_recos = pd.concat(self.validation_step_recos_sampled, ignore_index=True)
            unsampled_recos = pd.concat(self.validation_step_recos_unsampled, ignore_index=True)

            sampled_metrics = calc_metrics(self.val_metrics, sampled_recos, self.interactions, self.prev_interactions, self.catalog)
            unsampled_metrics = calc_metrics(self.val_metrics, unsampled_recos, self.interactions, self.prev_interactions, self.catalog)

            sampled_metrics = {f"{metric_name}_sampled": metric_value for metric_name, metric_value in sampled_metrics.items()}
            unsampled_metrics = {f"{metric_name}_unsampled": metric_value for metric_name, metric_value in unsampled_metrics.items()}

            metrics = {**sampled_metrics, **unsampled_metrics}

            self.log_dict(metrics)
            self.validation_step_recos_sampled.clear()


class SASRecModelValidateLeaveOneOut(SASRecModel):
    val_k_out: int = 1

    def _fit(self, dataset: Dataset, val_users: ExternalIds, val_metrics: tp.Dict, val_neg_candidates: int) -> None:
        processed_datasets = self.data_preparator.process_dataset_train(dataset, val_users, self.val_k_out)
        train_dataloader = self.data_preparator.get_dataloader_train(processed_datasets["train"])
        val_dataloader = self.data_preparator.get_dataloader_val(processed_datasets["val"], val_neg_candidates)

        torch_model = deepcopy(self._torch_model)  # TODO: check that it works
        torch_model.construct_item_net(processed_datasets["train"])

        n_item_extra_tokens = self.data_preparator.n_item_extra_tokens
        self.lightning_model = self.lightning_module_type(
            torch_model=torch_model,
            lr=self.lr,
            loss=self.loss,
            gbce_t=self.gbce_t,
            n_item_extra_tokens=n_item_extra_tokens,
            val_metrics=val_metrics,
            val_dataset=processed_datasets["val"],
            val_k_out=self.val_k_out,
        )
        
        self.fit_trainer = deepcopy(self._trainer)
        self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)


In [12]:
def get_log_dir(trainer: Trainer) -> Path:
    """
    Get logging directory.
    """
    path = trainer.logger.log_dir
    vesrion = int(path.split("version_")[-1])
    last_path = path.split("version_")[0] + f"version_{vesrion - 1}"
    return Path(last_path) / "metrics.csv"


def get_losses(epoch_metrics_df: pd.DataFrame) -> pd.DataFrame:
    train_loss_df = epoch_metrics_df[["epoch", "train/loss"]].dropna()
    val_loss_df = epoch_metrics_df[["epoch", "val/loss"]].dropna()
    loss_df = pd.merge(train_loss_df, val_loss_df, how="inner", on="epoch")
    return loss_df.reset_index(drop=True)


def get_val_metrics(epoch_metrics_df: pd.DataFrame) -> pd.DataFrame:
    metrics_df = epoch_metrics_df.drop(columns=["train/loss", "val/loss"]).dropna()
    return metrics_df.reset_index(drop=True)


def get_log_values(trainer: Trainer) -> tp.Tuple[pd.DataFrame, pd.DataFrame]:
    log_path = get_log_dir(trainer)
    epoch_metrics_df = pd.read_csv(log_path)

    loss_df = get_losses(epoch_metrics_df)
    val_metrics = get_val_metrics(epoch_metrics_df)
    return loss_df, val_metrics

In [13]:
N_VAL_USERS = 2048
VAL_NEG_CANDIDATES = 100

unique_users = interactions[Columns.User].unique()
VAL_USERS = unique_users[: N_VAL_USERS]
unique_users.shape, VAL_USERS.shape


((962179,), (2048,))

In [14]:
VAL_METRICS = {
    "NDCG@10": NDCG(k=10),
    "Recall@10": Recall(k=10),
    "Serendipity@10": Serendipity(k=10),
}

In [16]:
PATIENCE = 20
MAX_EPOCHS = 100
early_stopping = EarlyStopping(monitor="val/loss", patience=PATIENCE, min_delta=0.0)

In [17]:
trainer = Trainer(
    accelerator='gpu',
    devices=[0],
    min_epochs=PATIENCE, 
    max_epochs=MAX_EPOCHS, 
    deterministic=True,
    callbacks=[early_stopping],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
sasrec_non_default_model = SASRecModelValidateLeaveOneOut(
    n_factors=64,
    n_blocks=2,
    n_heads=2,
    dropout_rate=0.2,
    use_pos_emb=True,
    train_min_user_interaction=5,
    session_max_len=50,
    lr=1e-3,
    batch_size=128,
    # epochs=5,
    loss="softmax",
    verbose=1,
    deterministic=True,
    item_net_block_types=(IdEmbeddingsItemNet, ),  # Use only item ids in ItemNetBlock
    data_preparator_type=SASRecDataPreparatorValidate, 
    lightning_module_type=SessionEncoderLightningModuleValidate,
    trainer=trainer,
)

In [19]:
%%time
sasrec_non_default_model.fit(dataset_no_features, VAL_USERS, VAL_METRICS, VAL_NEG_CANDIDATES)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                           | Params
---------------------------------------------------------------
0 | torch_model | TransformerBasedSessionEncoder | 988 K 
---------------------------------------------------------------
988 K     Trainable params
0         Non-trainable params
988 K     Total params
3.952     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

CPU times: user 6h 26min 18s, sys: 1min 41s, total: 6h 28min
Wall time: 1h 30min 41s


<__main__.SASRecModelValidateLeaveOneOut at 0x7f1d780cd220>

In [20]:
loss_df, val_metrics_df = get_log_values(trainer)

In [21]:
loss_df

Unnamed: 0,epoch,train/loss,val/loss
0,0,16.38846,21.328363
1,1,15.723532,21.301912
2,2,15.568251,21.398142
3,3,15.498148,21.538115
4,4,15.45577,21.400196
5,5,15.430855,21.358656
6,6,15.408437,21.334433
7,7,15.392024,21.515078
8,8,15.381497,21.288002
9,9,15.367785,21.428801


In [22]:
val_metrics_df

Unnamed: 0,NDCG@10_sampled,NDCG@10_unsampled,Recall@10_sampled,Recall@10_unsampled,Serendipity@10_sampled,Serendipity@10_unsampled,epoch,step
0,0.003752,0.001894,0.031537,0.011827,1.1e-05,2e-06,0,2362
1,0.003227,0.003946,0.026938,0.014455,7.3e-05,2e-06,1,4725
2,0.003435,0.005737,0.030223,0.015769,7.6e-05,2e-06,2,7088
3,0.002978,0.007539,0.026281,0.017083,8e-06,3e-06,3,9451
4,0.002953,0.009345,0.02431,0.01774,7e-06,3e-06,4,11814
5,0.002747,0.011063,0.024967,0.019054,7e-06,3e-06,5,14177
6,0.003015,0.013068,0.025624,0.019711,8e-06,3e-06,6,16540
7,0.00279,0.014837,0.022996,0.019711,7e-06,3e-06,7,18903
8,0.002847,0.016229,0.027595,0.019711,8e-06,3e-06,8,21266
9,0.002923,0.018152,0.024967,0.019711,7e-06,3e-06,9,23629
