# Yambda Dataset: Used Benchmarks


# Libaries import

In [1]:
!apt-get update -qq
!apt-get install -y -qq --no-install-recommends libsuitesparse-dev build-essential git gfortran
!git clone https://github.com/glami/sansa.git
!pip install ./sansa && rm -rf sansa

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Cloning into 'sansa'...
remote: Enumerating objects: 609, done.[K
remote: Counting objects: 100% (143/143), done.[K
remote: Compressing objects: 100% (85/85), done.[K
remote: Total 609 (delta 64), reused 84 (delta 49), pack-reused 466 (from 1)[K
Receiving objects: 100% (609/609), 1.80 MiB | 5.31 MiB/s, done.
Resolving deltas: 100% (358/358), done.
Processing ./sansa
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: sansa
  Building wheel for sansa (pyproject.toml) ... [?25l[?25hdone
  Created wheel for sansa: filename=sansa-1.1.0-py3-none-any.whl size=31025 sha256=0ea61877ebb7429789b76d2bc9ee9c38e91f83df61088ebf3e47d9948a512035
  Stored

In [2]:
from __future__ import annotations

import os

from abc import ABC, abstractmethod
from collections import defaultdict
import dataclasses
import functools
from functools import cached_property
import heapq
import json
from pathlib import Path
from typing import Any, Iterable, Literal

import numpy as np
import polars as pl
import scipy.sparse as sp
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from datasets import Dataset, DatasetDict, load_dataset

In [3]:
from sansa import SANSA, ICFGramianFactorizerConfig, SANSAConfig, UMRUnitLowerTriangleInverterConfig

# Utils & Helpers

In [4]:
class Constants:
    HOUR_SECONDS = 60 * 60
    DAY_SECONDS = 24 * HOUR_SECONDS

    GAP_SIZE = HOUR_SECONDS // 2
    VAL_SIZE = 1 * DAY_SECONDS
    TEST_SIZE = 1 * DAY_SECONDS

    LAST_TIMESTAMP = 26000000
    TEST_TIMESTAMP = LAST_TIMESTAMP - TEST_SIZE

    TRACK_LISTEN_THRESHOLD = 50

    NUM_RANKED_ITEMS = 100

    METRICS = [
        "ndcg@10",
        "ndcg@100",
        "recall@10",
        "recall@100",
        "coverage@10",
        "coverage@100",
    ]

    IDEAL_METRICS = {
        'most_pop': {
            "ndcg@10": 0.0046,
            "ndcg@100": 0.0097,
            "recall@10": 0.0083,
            "recall@100": 0.0222,
            "coverage@10": 0.0001,
            "coverage@100": 0.0006,
        },
        'decay_pop': {
            "ndcg@10": 0.0180,
            "ndcg@100": 0.0269,
            "recall@10": 0.0333,
            "recall@100": 0.0651,
            "coverage@10": 0.0001,
            "coverage@100": 0.0006,
        },
        'item_knn': {
            "ndcg@10": 0.0125,
            "ndcg@100": 0.0251,
            "recall@10": 0.0199,
            "recall@100": 0.0648,
            "coverage@10": 0.1050,
            "coverage@100": 0.3922,
        },
        'sansa': {
            "ndcg@10": 0.0068,
            "ndcg@100": 0.0203,
            "recall@10": 0.0105,
            "recall@100": 0.0616,
            "coverage@10": 0.0194,
            "coverage@100": 0.1601,
        }
    }

In [5]:
def rank_items(users: Embeddings, items: Embeddings, num_items: int, batch_size: int = 128) -> Ranked:
    assert users.device == items.device

    num_users = users.ids.shape[0]

    scores = users.embeddings.new_empty((num_users, num_items))
    item_ids = users.embeddings.new_empty((num_users, num_items), dtype=torch.long)

    for batch_idx in tqdm(range((num_users + batch_size - 1) // batch_size), desc="Calc topk by batches"):
        start_idx = batch_idx * batch_size
        end_idx = (batch_idx + 1) * batch_size

        batch_scores = users.embeddings[start_idx:end_idx, :] @ items.embeddings.T

        sort_indices = batch_scores.topk(num_items, dim=-1).indices
        scores[start_idx:end_idx, :] = torch.gather(batch_scores, dim=-1, index=sort_indices)

        item_ids[start_idx:end_idx, :] = torch.gather(
            items.ids.expand(sort_indices.shape[0], items.ids.shape[0]), dim=-1, index=sort_indices
        )

    return Ranked(user_ids=users.ids, item_ids=item_ids, scores=scores, num_item_ids=items.ids.shape[0])


@dataclasses.dataclass
class Ranked:
    user_ids: torch.Tensor
    item_ids: torch.Tensor
    scores: torch.Tensor | None = None
    num_item_ids: int | None = None  # number of all items. Useful for coverage and etc.

    def __post_init__(self):
        if self.scores is None:
            self.scores = torch.arange(
                self.item_ids.shape[1], 0, -1, device=self.item_ids.device, dtype=torch.float32
            ).expand((self.user_ids.shape[0], self.item_ids.shape[1]))

        assert self.user_ids.dim() == 1
        assert self.scores.dim() == 2
        assert self.scores.shape == self.item_ids.shape
        assert self.user_ids.shape[0] == self.scores.shape[0]

        assert self.user_ids.device == self.scores.device == self.item_ids.device

        assert torch.all(self.scores[:, :-1] >= self.scores[:, 1:]), "scores should be sorted"

        if not torch.all(self.user_ids[:-1] <= self.user_ids[1:]):
            indexes = torch.argsort(self.user_ids, descending=False)
            self.item_ids = self.item_ids[indexes, :]
            self.scores = self.scores[indexes, :]
            self.user_ids = self.user_ids[indexes]

    @property
    def device(self):
        return self.user_ids.device


@dataclasses.dataclass
class Targets:
    user_ids: torch.Tensor
    item_ids: list[torch.Tensor]

    def __post_init__(self):
        assert len(self.item_ids) > 0
        assert self.user_ids.dim() == 1
        assert self.user_ids.shape[0] == len(self.item_ids)
        assert all(x.dim() == 1 for x in self.item_ids), "all ids should be 1D"

        assert all(x.device == self.item_ids[0].device for x in self.item_ids), "all ids should be on the same device"
        assert self.user_ids.device == self.item_ids[0].device

        if not torch.all(self.user_ids[:-1] <= self.user_ids[1:]):
            indexes = torch.argsort(self.user_ids, descending=False)
            self.item_ids = [self.item_ids[i] for i in indexes]
            self.user_ids = self.user_ids[indexes]

        assert torch.all(self.user_ids[:-1] < self.user_ids[1:]), "user_ids should be unique"

    @cached_property
    def lengths(self):
        return torch.tensor([ids.shape[0] for ids in self.item_ids], device=self.item_ids[0].device)

    def __len__(self):
        return len(self.item_ids)

    @property
    def device(self):
        return self.user_ids.device

    @classmethod
    def from_sequential(cls, df: pl.LazyFrame | pl.DataFrame, device: torch.device | str) -> 'Targets':
        return cls(
            df.select("uid")["uid"].to_torch().to(device),
            [torch.tensor(x, device=device) for x in df.select("item_id")["item_id"].to_list()],
        )



@dataclasses.dataclass
class Embeddings:
    ids: torch.Tensor
    embeddings: torch.Tensor

    def __post_init__(self):
        assert self.ids.dim() == 1
        assert self.embeddings.dim() == 2
        assert self.ids.shape[0] == self.embeddings.shape[0]

        assert self.ids.device == self.embeddings.device

        if not torch.all(self.ids[:-1] <= self.ids[1:]):
            indexes = torch.argsort(self.ids, descending=False)
            self.embeddings = self.embeddings[indexes, :]
            self.ids = self.ids[indexes]

        assert torch.all(self.ids[:-1] < self.ids[1:]), "ids should be unique"

    @property
    def device(self):
        return self.ids.device

    def save(self, file_path: str):
        ids_np = self.ids.cpu().numpy()
        embeddings_np = self.embeddings.cpu().numpy()
        np.savez(file_path, ids=ids_np, embeddings=embeddings_np)

    @classmethod
    def load(cls, file_path: str, device: torch.device = torch.device('cpu')) -> 'Embeddings':
        with np.load(file_path) as data:
            ids_np = data['ids']
            embeddings_np = data['embeddings']

        ids = torch.from_numpy(ids_np).to(device)
        embeddings = torch.from_numpy(embeddings_np).to(device)

        return cls(ids=ids, embeddings=embeddings)


# Metrics

In [6]:
def cut_off_ranked(ranked: Ranked, targets: Targets) -> Ranked:
    mask = torch.isin(ranked.user_ids, targets.user_ids)

    assert ranked.scores is not None

    ranked = Ranked(
        user_ids=ranked.user_ids[mask],
        scores=ranked.scores[mask, :],
        item_ids=ranked.item_ids[mask, :],
        num_item_ids=ranked.num_item_ids,
    )

    assert ranked.item_ids.shape[0] == len(targets), "Ranked doesn't contain all targets.user_ids"

    return ranked


class Metric(ABC):
    @abstractmethod
    def __call__(
        self, ranked: Ranked | None, targets: Targets | None, target_mask: torch.Tensor | None, ks: Iterable[int]
    ) -> dict[int, float]:
        pass


class Recall(Metric):
    def __call__(
        self, ranked: Ranked | None, targets: Targets, target_mask: torch.Tensor, ks: Iterable[int]
    ) -> dict[int, float]:
        assert all(0 < k <= target_mask.shape[1] for k in ks)

        values = {}

        for k in ks:
            num_positives = targets.lengths.to(torch.float32)
            num_positives[num_positives == 0] = torch.inf

            values[k] = target_mask[:, :k].to(torch.float32).sum(dim=-1) / num_positives

            values[k] = torch.mean(values[k]).item()

        return values


class DCG(Metric):
    def __call__(
        self, ranked: Ranked | None, targets: Targets | None, target_mask: torch.Tensor, ks: Iterable[int]
    ) -> dict[int, float]:
        assert all(0 < k <= target_mask.shape[1] for k in ks)

        values = {}

        discounts = 1.0 / torch.log2(
            torch.arange(2, target_mask.shape[1] + 2, device=target_mask.device, dtype=torch.float32)
        )

        for k in ks:
            dcg_k = torch.sum(target_mask[:, :k] * discounts[:k], dim=1)
            values[k] = torch.mean(dcg_k).item()

        return values


class NDCG(Metric):
    def __call__(
        self, ranked: Ranked | None, targets: Targets, target_mask: torch.Tensor, ks: Iterable[int]
    ) -> dict[int, float]:
        actual_dcg = DCG()(ranked, targets, target_mask, ks)

        ideal_target_mask = (
            torch.arange(target_mask.shape[1], device=targets.device)[None, :] < targets.lengths[:, None]
        ).to(torch.float32)
        assert target_mask.shape == ideal_target_mask.shape

        ideal_dcg = DCG()(ranked, targets, ideal_target_mask, ks)

        ndcg_values = {k: (actual_dcg[k] / ideal_dcg[k] if ideal_dcg[k] != 0 else 0.0) for k in ks}

        return ndcg_values


class Coverage(Metric):
    def __init__(self, cut_off_ranked: bool = False):
        self.cut_off_ranked = cut_off_ranked

    def __call__(
        self, ranked: Ranked, targets: Targets | None, target_mask: torch.Tensor | None, ks: Iterable[int]
    ) -> dict[int, float]:
        if self.cut_off_ranked:
            assert targets is not None
            ranked = cut_off_ranked(ranked, targets)

        assert all(0 < k <= ranked.item_ids.shape[1] for k in ks)

        assert ranked.num_item_ids is not None

        values = {}
        for k in ks:
            values[k] = ranked.item_ids[:, :k].flatten().unique().shape[0] / ranked.num_item_ids

        return values


# Metrics Utilities

In [7]:
REGISTERED_METRIC_FN = {
    "recall": Recall(),
    "ndcg": NDCG(),
    "coverage": Coverage(cut_off_ranked=False),
}


def calc_metrics(ranked: Ranked, targets: Targets, metrics: list[str]) -> dict[str, Any]:
    grouped_metrics = _parse_metrics(metrics)

    result = {}

    target_mask = create_target_mask(ranked, targets)

    for name, ks in grouped_metrics.items():
        result[name] = REGISTERED_METRIC_FN[name](ranked, targets, target_mask, ks=ks)

    return result


def _parse_metrics(metric_names: list[str]) -> dict[str, list[int]]:
    parsed_metrics = []

    for metric in metric_names:
        parts = metric.split('@')
        name = parts[0]

        assert len(parts) > 1, f"Invalid metric: {metric}, specify @k"

        value = int(parts[1])
        parsed_metrics.append((name, value))

    metrics = defaultdict(list)
    for m in parsed_metrics:
        metrics[m[0]].append(m[1])

    return metrics


def create_target_mask(ranked: Ranked, targets: Targets) -> torch.Tensor:
    ranked = cut_off_ranked(ranked, targets)

    assert ranked.device == targets.device
    assert ranked.item_ids.shape[0] == len(targets)

    target_mask = ranked.item_ids.new_zeros(ranked.item_ids.shape, dtype=torch.float32)

    for i, target in enumerate(tqdm(targets.item_ids, desc="Making target mask")):
        target_mask[i, torch.isin(ranked.item_ids[i], target)] = 1.0

    return target_mask

# Dataset

In [8]:
class YambdaDataset:
    INTERACTIONS = frozenset([
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ])

    def __init__(
        self,
        dataset_type: Literal["flat", "sequential"] = "flat",
        dataset_size: Literal["50m", "500m", "5b"] = "50m"
    ):
        assert dataset_type in {"flat", "sequential"}
        assert dataset_size in {"50m", "500m", "5b"}
        self.dataset_type = dataset_type
        self.dataset_size = dataset_size

    def interaction(self, event_type: Literal[
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ]) -> Dataset:
        assert event_type in YambdaDataset.INTERACTIONS
        return self._download(f"{self.dataset_type}/{self.dataset_size}", event_type)

    def audio_embeddings(self) -> Dataset:
        return self._download("", "embeddings")

    def album_item_mapping(self) -> Dataset:
        return self._download("", "album_item_mapping")

    def artist_item_mapping(self) -> Dataset:
        return self._download("", "artist_item_mapping")

    @staticmethod
    def _download(data_dir: str, file: str) -> Dataset:
        data = load_dataset("yandex/yambda", data_dir=data_dir, data_files=f"{file}.parquet")
        # Returns DatasetDict; extracting the only split
        assert isinstance(data, DatasetDict)
        return data["train"]


In [9]:
# Load data using YambdaDataset
yambda = YambdaDataset(dataset_type="flat", dataset_size="50m")
likes_ds = yambda.interaction("likes")  # load all likes
print(f"Loaded {len(likes_ds):,} records (likes)")

likes_df = likes_ds.to_polars()
likes_df.head()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded 881,456 records (likes)


uid,timestamp,item_id,is_organic
u32,u32,u32,u8
100,44755,732449,1
100,1155860,6568592,0
100,1259125,5411243,1
100,1260005,7371186,0
100,1263935,4943655,0


In [10]:
def flat_split_train_val_test(
    df: pl.LazyFrame,
    test_timestamp: int,
    val_size: int = 0,
    gap_size: int = Constants.GAP_SIZE,
    drop_non_train_items: bool = False,
    engine: str = "streaming",
) -> tuple[pl.LazyFrame, pl.LazyFrame | None, pl.LazyFrame]:

    # Filter out new items
    def drop(df: pl.LazyFrame, unique_train_item_ids) -> pl.LazyFrame:
        if not drop_non_train_items:
            return df

        return (
            df.with_columns(
                pl.col("item_id").is_in(unique_train_item_ids.get_column("item_id").implode()).alias("item_id_in_train")
            )
            .filter("item_id_in_train")
            .drop("item_id_in_train")
        )

    train_timestamp = test_timestamp - gap_size - val_size - (gap_size if val_size != 0 else 0)

    assert gap_size >= 0
    assert val_size >= 0
    assert train_timestamp > 0

    df_lazy = df.lazy()

    # Create train subset
    train = df_lazy.filter(pl.col("timestamp") < train_timestamp)

    # Warm-start setup
    unique_train_uids = train.select("uid").unique().collect(engine=engine)
    unique_train_item_ids = train.select("item_id").unique().collect(engine=engine)

    # Create validation subset
    validation = None
    if val_size != 0:
        validation = (
            df_lazy.filter(
                (pl.col("timestamp") >= test_timestamp - val_size - gap_size)
                & (pl.col("timestamp") < test_timestamp - gap_size)
            )
            #
            .with_columns(
                pl.col("uid").is_in(unique_train_uids.get_column("uid").implode()).alias("uid_in_train")
            )  # to prevent filter reordering
            .filter("uid_in_train")
            .drop("uid_in_train")
        )
        validation = drop(validation, unique_train_item_ids)

    # Create evaluation subset
    test = (
        df_lazy.filter(pl.col("timestamp") >= test_timestamp)
        #
        .with_columns(
            pl.col("uid").is_in(unique_train_uids.get_column("uid").implode()).alias("uid_in_train")
        )  # to prevent filter reordering
        .filter("uid_in_train")
        .drop("uid_in_train")
    )
    test = drop(test, unique_train_item_ids)

    return train, validation, test

In [11]:
train_likes_df, _, test_likes_df = flat_split_train_val_test(
    df=likes_df,
    test_timestamp=Constants.TEST_TIMESTAMP,
    val_size=0,  # No validation is needed
    gap_size=Constants.GAP_SIZE
)

train_likes_df = train_likes_df.collect()
test_likes_df = test_likes_df.collect()

In [12]:
train_likes_df.head()

uid,timestamp,item_id,is_organic
u32,u32,u32,u8
100,44755,732449,1
100,1155860,6568592,0
100,1259125,5411243,1
100,1260005,7371186,0
100,1263935,4943655,0


In [13]:
test_likes_df.head()

uid,timestamp,item_id,is_organic
u32,u32,u32,u8
700,25939545,3383415,0
1400,25993865,7707460,1
1400,25993945,6005845,1
1400,25993945,7539311,0
1400,25993975,8908728,1


### MostPop & DecayPop

In [None]:
def training_popular(
        train_df: pl.DataFrame,
        hour: float,
        max_timestamp: float,
        device: str,
        decay: float = 0.9
    ) -> Embeddings:
    if hour == 0:  # MostPop option
        embeddings = train_df.group_by("item_id").agg(pl.count().alias("item_embedding"))
    else:  # DecayPop option
        tau = decay ** (1 / Constants.DAY_SECONDS / (hour / 24))

        embeddings = (
            train_df.select(
                "item_id",
                (tau ** (max_timestamp - pl.col("timestamp"))).alias("value"),
            )
            .group_by("item_id")
            .agg(pl.col("value").sum().alias("item_embedding"))
        )

    item_ids = embeddings["item_id"].to_torch().to(device)
    item_embeddings = embeddings["item_embedding"].to_torch().to(device)[:, None]

    return Embeddings(item_ids, item_embeddings)


def evaluation_popular(
    train_df: pl.DataFrame,
    valid_df: pl.DataFrame,
    device: str,
    hour: float,
    metrics: list[str]
) -> list[dict[str, Any]]:
    num_ranked_items = max([int(x.split("@")[1]) for x in metrics])

    max_timestamp = train_df.select(pl.col("timestamp").max()).item()
    user_ids = train_df.select("uid").unique()["uid"].to_torch().to(device)

    targets = Targets.from_sequential(
        valid_df.group_by('uid', maintain_order=True).agg("item_id"),
        device,
    )

    item_embeddings = training_popular(
        train_df=train_df,
        hour=hour,
        max_timestamp=max_timestamp,
        device=device,
    )

    ranked = Ranked(
        user_ids=user_ids,
        item_ids=item_embeddings.ids[torch.topk(item_embeddings.embeddings, num_ranked_items, dim=0).indices]
        .ravel()
        .expand((user_ids.shape[0], num_ranked_items)),
        num_item_ids=item_embeddings.ids.shape[0],
    )

    return calc_metrics(ranked, targets, metrics)


def popularity(
    train_df: pl.DataFrame,
    valid_df: pl.DataFrame,
    device: str,
    hour: float,
    report_metrics: list[str],
) -> dict[str, Any]:
    return evaluation_popular(train_df, valid_df, device, hour, report_metrics)

In [None]:
most_pop_metrics = popularity(
    train_df=train_likes_df,
    valid_df=test_likes_df,
    device='cuda',
    hour=0,
    report_metrics=Constants.METRICS
)

In [None]:
decay_pop_metrics = popularity(
    train_df=train_likes_df,
    valid_df=test_likes_df,
    device='cuda',
    hour=3,
    report_metrics=Constants.METRICS
)

In [None]:
for metric_name, values in most_pop_metrics.items():
  for k, value in values.items():
    print(
        f'{metric_name}@{k}',
        round(value, 4),
        Constants.IDEAL_METRICS['most_pop'][f'{metric_name}@{k}']
    )

In [None]:
for metric_name, values in decay_pop_metrics.items():
  for k, value in values.items():
    print(
        f'{metric_name}@{k}',
        round(value, 4),
        Constants.IDEAL_METRICS['decay_pop'][f'{metric_name}@{k}']
    )

### ALS

In [None]:
def process_train_data(train_df: pl.DataFrame) -> tuple[pl.DataFrame, list[int], list[int]]:
    unique_pairs = train_df.select("uid", "item_id").unique()

    unique_uids = train_df.select("uid").unique().sort("uid")["uid"].to_list()
    unique_item_ids = train_df.select("item_id").unique().sort("item_id")["item_id"].to_list()

    return unique_pairs, unique_uids, unique_item_ids


def build_csr_matrix(pairs: pl.DataFrame, unique_uids: list[int], unique_item_ids: list[int]) -> sp.csr_matrix:
    uid_to_idx = {uid: i for i, uid in enumerate(unique_uids)}
    item_id_to_idx = {item_id: i for i, item_id in enumerate(unique_item_ids)}

    pairs = pairs.select(
        pl.col("uid").replace_strict(uid_to_idx, return_dtype=pl.UInt32),
        pl.col("item_id").replace_strict(item_id_to_idx, return_dtype=pl.UInt32),
    )

    rows, cols = pairs["uid"].to_numpy(), pairs["item_id"].to_numpy()
    values = np.ones_like(rows, dtype=np.int32)

    return sp.coo_matrix(
        (values, (rows, cols)), dtype=np.float32, shape=(len(unique_uids), len(unique_item_ids))
    ).tocsr()


def train_embbedings_with_als(
    user_item_interactions: sp.csr_matrix,
    regularization: float = 0.01,
    iterations: int = 100,
    random_state: int = 42,
    factors: int = 64,
) -> tuple[np.ndarray, np.ndarray]:
    from implicit.gpu.als import AlternatingLeastSquares
    als = AlternatingLeastSquares(
        factors=factors,
        regularization=regularization,
        iterations=iterations,
        random_state=random_state,
        calculate_training_loss=False,
    )
    als.fit(user_item_interactions, show_progress=False)
    return als.user_factors.to_numpy(), als.item_factors.to_numpy()


def calc_embeddings_metrics(
    user_emb: np.ndarray,
    item_emb: np.ndarray,
    uid_tensor: torch.Tensor,
    item_id_tensor: torch.Tensor,
    targets: Targets,
    metrics: list[str],
    device: str,
) -> dict[str, dict[int, float]]:
    num_ranked_items = max([int(x.split("@")[1]) for x in metrics])
    user_emb = Embeddings(uid_tensor, torch.from_numpy(user_emb).to(device))
    item_emb = Embeddings(item_id_tensor, torch.from_numpy(item_emb).to(device))

    ranked_items = rank_items(user_emb, item_emb, num_ranked_items)
    return calc_metrics(ranked_items, targets, metrics)


def als(
    train_df: pl.DataFrame,
    valid_df: pl.DataFrame,
    random_seeds: list[int],
    device: str,
    report_metrics: list[str],
    hp: dict[str, Any] = {}
) -> dict[str, dict[int, float]]:
    train_pairs, unique_uids, unique_item_ids = process_train_data(train_df)
    user_item_interactions = build_csr_matrix(train_pairs, unique_uids, unique_item_ids)

    targets = Targets.from_sequential(
        valid_df.group_by("uid").agg("item_id"),
        device,
    )

    metrics_list = []

    for seed in random_seeds:
        user_emb, item_emb = train_embbedings_with_als(
            user_item_interactions=user_item_interactions,
            regularization=hp.get("regularization", 0.01),
            iterations=hp.get("iterations", 100),
            random_state=seed
        )

        metrics = calc_embeddings_metrics(
            user_emb,
            item_emb,
            torch.tensor(unique_uids, device=device),
            torch.tensor(unique_item_ids, device=device),
            targets,
            report_metrics,
            device,
        )
        metrics_list.append(metrics)

    return mean_dicts(metrics_list)

In [None]:
als(
    train_df=train_likes_df,
    test_df=test_likes_df,
    random_seeds=[42, 1337, 15978, 1234321],
    report_metrics=Constants.METRICS,
    device='cuda'
)

### ItemKNN

In [None]:
def eliminate_zeros(x: torch.Tensor, threshold: float = 1e-9) -> torch.Tensor:
    mask = (x._values() > threshold).nonzero()
    nv = x._values().index_select(0, mask.view(-1))
    ni = x._indices().index_select(1, mask.view(-1))
    return torch.sparse_coo_tensor(ni, nv, x.shape)


def create_weighted_sparse_tensor(train: pl.DataFrame, tau: float) -> torch.Tensor:
    uid_mapping = (
        train.select("uid").unique().with_columns(pl.col("uid").rank(method="dense").alias("uid_idx") - 1)
    )

    item_mapping = (
        train.select("item_id")
        .unique()
        .with_columns(pl.col("item_id").rank(method="dense").alias("item_idx") - 1)
    )

    processed = (
        train.with_columns(pl.max("timestamp").over("uid").alias("max_timestamp"))
        .with_columns((pl.col("max_timestamp") - pl.col("timestamp")).alias("delta"))
        .with_columns((tau ** pl.col("delta")).alias("weight"))
        .join(uid_mapping, on="uid", how="inner")
        .join(item_mapping, on="item_id", how="inner")
    )

    coo_data = processed.group_by(["uid_idx", "item_idx"]).agg(pl.sum("weight").alias("total_weight"))

    indices = torch.concat([coo_data["uid_idx"].to_torch()[None, :], coo_data["item_idx"].to_torch()[None, :]], dim=0)
    values = torch.tensor(coo_data["total_weight"].to_numpy(), dtype=torch.float)

    return eliminate_zeros(
        torch.sparse_coo_tensor(
            indices=indices, values=values, size=(uid_mapping["uid_idx"].max() + 1, item_mapping["item_idx"].max() + 1)
        )
    )


def sparse_normalize(sparse_tensor: torch.Tensor, dim=0, eps=1e-12):
    indices = sparse_tensor.coalesce().indices()
    values = sparse_tensor.coalesce().values()

    unique_dim_indices, inverse = torch.unique(indices[dim], return_inverse=True)
    squared_values = values**2
    sum_squared = torch.zeros_like(unique_dim_indices, dtype=torch.float32)
    sum_squared.scatter_add_(0, inverse, squared_values)

    norms = torch.sqrt(sum_squared + eps)
    normalized_values = values / norms[inverse]

    return torch.sparse_coo_tensor(indices, normalized_values, sparse_tensor.size())


def training(
    train: pl.DataFrame,
    hour: float,
    user_item: torch.Tensor,
    user_ids: torch.Tensor,
    device: str,
    decay: float = 0.9
) -> Embeddings:
    tau = 0.0 if hour == 0 else decay ** (1 / 24 / 60 / 60 / (hour / 24))

    user_item_with_tau = create_weighted_sparse_tensor(train, tau)
    user_embeddings = (user_item_with_tau @ user_item.T).to_dense()
    user_embeddings = torch.nn.functional.normalize(user_embeddings, dim=-1)

    return Embeddings(user_ids, user_embeddings.to(device))


def item_knn(
    train_df: pl.DataFrame,
    valid_df: pl.DataFrame,
    device: str,
    hour: float,
    report_metrics: list[str]
) -> list[dict[str, Any]]:
    num_ranked_items = max([int(x.split("@")[1]) for x in report_metrics])

    unique_user_ids = train_df.select("uid").unique().sort("uid")["uid"].to_torch().to(device)
    unique_item_ids = (
        train_df.select("item_id").unique().sort("item_id")["item_id"].to_torch().to(device)
    )

    user_item = create_weighted_sparse_tensor(train_df, 1.0)
    item_embeddings = sparse_normalize(user_item.T.to(device), dim=-1)
    item_embeddings = Embeddings(unique_item_ids, item_embeddings)

    targets = Targets.from_sequential(
        valid_df.group_by("uid").agg("item_id"),
        device,
    )

    user_embeddings = training(
        train=train_df,
        hour=hour,
        user_item=user_item,
        user_ids=unique_user_ids,
        device=device,
    )

    ranked = rank_items(
        users=user_embeddings,
        items=item_embeddings,
        num_items=num_ranked_items,
        batch_size=128,
    )

    return calc_metrics(ranked, targets, report_metrics)

In [None]:
item_knn_metrics = item_knn(
    train_df=train_likes_df,
    valid_df=test_likes_df,
    device='cuda',
    hour=1.5,
    report_metrics=Constants.METRICS
)

In [None]:
for metric_name, values in item_knn_metrics.items():
  for k, value in values.items():
    print(
        f'{metric_name}@{k}',
        round(value, 4),
        Constants.IDEAL_METRICS['item_knn'][f'{metric_name}@{k}']
    )

### EASE (SANSA)

In [17]:
def sansa(
    train_df: pl.DataFrame,
    valid_df: pl.DataFrame,
    device: str,
    report_metrics: list[str],
) -> dict[str, Any]:
    grouped_test, train, test = get_train_val_test_matrices(
        train_df, valid_df
    )

    model = get_sansa_model()
    model.fit(train)

    unique_item_ids = (
        train_df.select("item_id").unique().sort("item_id")["item_id"].to_torch().to(device)
    )

    calculated_metrics = evaluate_sansa(
        num_items_for_metrics=unique_item_ids.shape[0],
        model=model,
        device=device,
        report_metrics=report_metrics,
        grouped_test=grouped_test,
        sparse_train=train,
        sparse_test=test,
    )

    return calculated_metrics


def get_train_val_test_matrices(
    flat_train: pl.DataFrame,
    flat_test: pl.DataFrame
) -> tuple[pl.LazyFrame, pl.LazyFrame, sp.csr_matrix, sp.csr_matrix]:
    all_uids = set(flat_train.get_column("uid").to_list())
    all_items = set(flat_train.get_column("item_id").to_list())

    # Create mapping to create sparse matrix
    uid_to_idx = {uid: i for i, uid in enumerate(all_uids)}
    item_id_to_idx = {item_id: i for i, item_id in enumerate(all_items)}

    sparse_train, _ = get_sparse_data(flat_train, uid_to_idx, item_id_to_idx)
    sparse_test, grouped_test = get_sparse_data(flat_test, uid_to_idx, item_id_to_idx)

    return grouped_test, sparse_train, sparse_test


def get_sparse_data(
    df: pl.LazyFrame, uid_to_idx: dict[int, int], item_id_to_idx: dict[int, int]
) -> tuple[sp.csr_matrix, pl.LazyFrame]:
    df = df.with_columns(
        pl.col("uid").replace_strict(uid_to_idx).alias("uid"),
        pl.col("item_id").replace_strict(item_id_to_idx, default=len(item_id_to_idx)).alias("item_id"),
        pl.lit(1).alias("action"),
    )

    grouped_df = df.group_by('uid', maintain_order=True).agg(
        [pl.col('item_id').alias('item_id'), pl.col('action').alias('actions')]
    )

    rows = []
    cols = []
    values = []

    for user_id, item_ids, actions in tqdm(grouped_df.select('uid', 'item_id', 'actions').rows()):
        rows.extend([user_id] * len(item_ids))
        cols.extend(item_ids)
        values.extend(actions)

    user_item_data = sp.csr_matrix(
        (values, (rows, cols)),
        dtype=np.float32,
        shape=(len(uid_to_idx), len(item_id_to_idx) + 1),  # +1 for default unknown test items
    )

    return user_item_data, grouped_df


def get_sansa_model():
    factorizer_config = ICFGramianFactorizerConfig(
        factorization_shift_step=1e-3,
        factorization_shift_multiplier=2.0,
    )

    inverter_config = UMRUnitLowerTriangleInverterConfig(
        scans=1,  # number of scans through all columns of the matrix
        finetune_steps=5,  # number of finetuning steps, targeting worst columns
    )

    config = SANSAConfig(
        l2=20.0,  # regularization strength
        weight_matrix_density=5e-5,  # desired density of weights
        gramian_factorizer_config=factorizer_config,  # factorizer configuration
        lower_triangle_inverter_config=inverter_config,  # inverter configuration
    )

    model = SANSA(config)

    return model


def evaluate_sansa(
    num_items_for_metrics: int,
    model: SANSA,
    device: str,
    report_metrics: list[str],
    grouped_test: pl.LazyFrame,
    sparse_train: sp.csr_matrix,
    sparse_test: sp.csr_matrix,
) -> dict[str, Any]:

    test_targets = Targets.from_sequential(grouped_test, device=device)

    train_pred_sparse = model.forward(sparse_train)

    A = train_pred_sparse
    num_users = A.shape[0]
    num_items_k = 150

    # 0 if there is no such item
    top_items_idx = np.full((num_users, num_items_k), 0, dtype=int)

    # -1 score if there is no such item
    top_items_score = np.full((num_users, num_items_k), -1, dtype=A.data.dtype)

    for row in tqdm(range(num_users)):
        start, end = A.indptr[row], A.indptr[row + 1]
        row_scores = A.data[start:end]
        row_cols = A.indices[start:end]

        if len(row_scores) == 0:
            continue

        k_here = min(num_items_k, len(row_scores))
        top_k = heapq.nlargest(k_here, zip(row_scores, row_cols), key=lambda x: x[0])

        # Fill in
        for i, (score, idx) in enumerate(top_k):
            top_items_idx[row, i] = idx
            top_items_score[row, i] = score

    user_ids = torch.arange(top_items_idx.shape[0], dtype=torch.int32, device="cpu")
    scores = torch.as_tensor(top_items_score, dtype=torch.float32, device="cpu")
    scores_indices = torch.as_tensor(top_items_idx, dtype=torch.long, device="cpu")
    targets = torch.as_tensor(sparse_test.toarray(), dtype=torch.bool, device="cpu")

    targets = targets.to(dtype=torch.bool, device=device)
    not_zero_user_indices = targets.any(dim=1)

    not_zero_user_indices = not_zero_user_indices.to(dtype=torch.bool, device="cpu")

    user_ids = user_ids[not_zero_user_indices]
    scores = scores[not_zero_user_indices]

    scores_indices = scores_indices[not_zero_user_indices]

    test_ranked = Ranked(
        user_ids=user_ids.to(device),
        scores=scores.to(device),
        item_ids=scores_indices.to(device),
        num_item_ids=num_items_for_metrics,
    )

    calculated_metrics = calc_metrics(test_ranked, test_targets, report_metrics)

    return calculated_metrics

In [18]:
sansa_metrics = sansa(
    train_df=train_likes_df,
    valid_df=test_likes_df,
    device='cuda',
    report_metrics=Constants.METRICS
)

100%|██████████| 8273/8273 [00:00<00:00, 139952.80it/s]
100%|██████████| 1301/1301 [00:00<00:00, 915568.71it/s]
100%|██████████| 8273/8273 [03:01<00:00, 45.57it/s]
Making target mask: 100%|██████████| 1301/1301 [00:00<00:00, 13027.62it/s]


In [19]:
for metric_name, values in sansa_metrics.items():
  for k, value in values.items():
    print(
        f'{metric_name}@{k}',
        round(value, 4),
        Constants.IDEAL_METRICS['sansa'][f'{metric_name}@{k}']
    )

ndcg@10 0.0071 0.0068
ndcg@100 0.0236 0.0203
recall@10 0.0117 0.0105
recall@100 0.0648 0.0616
coverage@10 0.0161 0.0194
coverage@100 0.1165 0.1601
