# Recommender Systems YSDA Course
## Practice №3 Ranking and diversity

Similarly to practice №3, in this notebook we are going to implement recommender system components - late-stage rankers and diversity control modules (or rerankers).

In [3]:
import duckdb
import numpy as np
import polars as pl
from sklearn.model_selection import train_test_split
from catboost import CatBoostRanker, CatBoostClassifier, Pool

In [4]:
from grocery.recommender.primitives import Candidate

In [5]:
FILL_VALUE = -999999999.0
DATA_DIR = '../data/lavka'

In [None]:
duckdb.sql(
f"""
WITH
    actions as (
        SELECT
        user_id,
        source_type,
        product_category,
        product_id as item_id,
        request_id,
        action_type,
        position_in_request,
        make_timestamp(timestamp * 1000000) as timestamp,
        date_trunc('day', make_timestamp(timestamp * 1000000)) as day,
        row_number() over () as idx
    FROM '{DATA_DIR}/train.parquet'
    WHERE action_type IN ('AT_CartUpdate', 'AT_View', 'AT_Click')
    ),
    user_windows as (
        SELECT
            idx,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) over rolling_w as cart_updates,
            sum(case when action_type = 'AT_View' then 1 else 0 end) over rolling_w as views,
            sum(case when action_type = 'AT_Click' then 1 else 0 end) over rolling_w as clicks,
            count(distinct (case when action_type = 'AT_CartUpdate' then request_id end)) over rolling_w as num_requests_with_cart_update,
            count(distinct request_id) over rolling_w as total_requests,
        FROM actions
        WINDOW rolling_w as (
            PARTITION BY "user_id"
            ORDER BY "day" ASC
            RANGE BETWEEN UNBOUNDED PRECEDING AND INTERVAL 1 DAYS PRECEDING
        )
    ),
    user_features as (
        SELECT
            idx,
            cart_updates / IF(views = 0, 1, views) as user_cart_update_turn_rate,
            clicks / IF(views = 0, 1, views) as user_click_turn_rate,
            num_requests_with_cart_update / IF(total_requests = 0, 1, total_requests) as user_conversion_rate,
        FROM user_windows
    ),
    item_windows as (
        SELECT
            idx,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) over rolling_w as cart_updates,
            sum(case when action_type = 'AT_View' then 1 else 0 end) over rolling_w as views,
            sum(case when action_type = 'AT_Click' then 1 else 0 end) over rolling_w as clicks,
            count(distinct (case when action_type = 'AT_CartUpdate' then request_id end)) over rolling_w as num_requests_with_cart_update,
            count(distinct request_id) over rolling_w as total_requests,
        FROM actions
        WINDOW rolling_w as (
            PARTITION BY "item_id"
            ORDER BY "day" ASC
            RANGE BETWEEN UNBOUNDED PRECEDING AND INTERVAL 1 DAYS PRECEDING
        )
    ),
    item_features as (
        SELECT
            idx,
            cart_updates / IF(views = 0, 1, views) as item_cart_update_turn_rate,
            clicks / IF(views = 0, 1, views) as item_click_turn_rate,
            num_requests_with_cart_update / IF(total_requests = 0, 1, total_requests) as item_conversion_rate,
        FROM item_windows
    ),
    user2item_windows as (
        SELECT
            idx,
            timestamp,
            max(case when action_type = 'AT_CartUpdate' then timestamp end) over rolling_w as last_cart_update_timestamp,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) over rolling_w as cart_updates,
            array_agg(timestamp) over rolling_w as cart_update_timestamps,
        FROM actions
        WINDOW rolling_w as (
            PARTITION BY "user_id", "item_id"
            ORDER BY "day" ASC
            RANGE BETWEEN UNBOUNDED PRECEDING AND INTERVAL 1 DAYS PRECEDING
        )
    ),
    user2item_features as (
        SELECT
            idx,
            cart_updates as u2i_cart_updates,
            datepart('day', timestamp - last_cart_update_timestamp) as u2i_days_since_last_cart_update,
            list_avg(list_transform(list_zip(
                array_pop_back(list_sort(cart_update_timestamps)),
                array_pop_front(list_sort(cart_update_timestamps))
            ), x -> datepart('day', x[2] - x[1]))) as u2i_mean_time_between_cartupdates
        FROM user2item_windows
    )
    SELECT
        a.*,
        cast(a.source_type as string) as source_type,
        cast(a.product_category as string) as product_category,
        coalesce(if(isnan(uf.user_cart_update_turn_rate), null, uf.user_cart_update_turn_rate), {FILL_VALUE}) as user_cart_update_turn_rate,
        coalesce(if(isnan(uf.user_click_turn_rate), null, uf.user_click_turn_rate), {FILL_VALUE}) as user_click_turn_rate,
        coalesce(if(isnan(uf.user_conversion_rate), null, uf.user_conversion_rate), {FILL_VALUE}) as user_conversion_rate,
        coalesce(if(isnan(if.item_cart_update_turn_rate), null, if.item_cart_update_turn_rate), {FILL_VALUE}) as item_cart_update_turn_rate,
        coalesce(if(isnan(if.item_click_turn_rate), null, if.item_click_turn_rate), {FILL_VALUE}) as item_click_turn_rate,
        coalesce(if(isnan(if.item_conversion_rate), null, if.item_conversion_rate), {FILL_VALUE}) as item_conversion_rate,
        coalesce(if(isnan(u2if.u2i_cart_updates), null, u2if.u2i_cart_updates), {FILL_VALUE}) as u2i_cart_updates,
        coalesce(if(isnan(u2if.u2i_mean_time_between_cartupdates), null, u2if.u2i_mean_time_between_cartupdates), {FILL_VALUE}) as u2i_mean_time_between_cartupdates,
    FROM actions a
    LEFT JOIN user_features uf
    ON a.idx = uf.idx
    LEFT JOIN item_features if
    ON a.idx = if.idx
    LEFT JOIN user2item_features u2if
    ON a.idx = u2if.idx
"""
).to_parquet(f"{DATA_DIR}/train_with_features.parquet")

In [6]:
dataset = pl.read_parquet(f"{DATA_DIR}/train_with_features.parquet")

In [7]:
NUM_FEATURE_COLUMNS = [
    "user_cart_update_turn_rate",
    "user_click_turn_rate",
    "user_conversion_rate",
    "item_cart_update_turn_rate",
    "item_click_turn_rate",
    "item_conversion_rate",
    "u2i_cart_updates",
    "u2i_mean_time_between_cartupdates",
]

CAT_FEATURE_COLUMNS = [
    # "source_type",
    "product_category",
]

FEATURE_COLUMNS = NUM_FEATURE_COLUMNS + CAT_FEATURE_COLUMNS

In [None]:
import functools
conditions = functools.reduce(lambda a, b: a & b, [pl.col(feature) != FILL_VALUE for feature in NUM_FEATURE_COLUMNS])

(
    dataset
    .filter(conditions)
    .select(NUM_FEATURE_COLUMNS)
    .describe()
)

In [8]:
def build_target(action_type: str):
    if action_type == "AT_CartUpdate":
        return 1.0
    elif action_type == "AT_View":
        return 0.0

cbm_dataset = (
    dataset
    .sort("timestamp")
    .filter(pl.col("request_id").is_not_null())
    .filter(pl.col("action_type").is_in(["AT_CartUpdate", "AT_View"]))
    .with_columns(
        pl.col("request_id").cast(str).alias("group_id"),
        pl.col("action_type").map_elements(build_target, return_dtype=float).alias("target")
    )
    .with_columns(
        target=pl.col('target').max().over(partition_by=[pl.col('group_id'), pl.col('item_id')])
    )
    .unique()
)

train, test = train_test_split(cbm_dataset, test_size=0.2, shuffle=False)

train = train.sort("group_id")
test = test.sort("group_id")

In [32]:
train_pool = Pool(
    train.select(FEATURE_COLUMNS).to_numpy(),
    feature_names=FEATURE_COLUMNS,
    cat_features=CAT_FEATURE_COLUMNS,
    label=train["target"].to_numpy(),
    group_id=train["group_id"].to_numpy()
)
test_pool = Pool(
    test.select(FEATURE_COLUMNS).to_numpy(),
    feature_names=FEATURE_COLUMNS,
    cat_features=CAT_FEATURE_COLUMNS,
    label=test["target"].to_numpy(),
    group_id=test["group_id"].to_numpy()
)

In [None]:
test_pool.get_feature_names()

In [None]:
model = CatBoostClassifier(iterations=100, eval_metric="NDCG:top=10")
model.fit(train_pool, eval_set=test_pool, early_stopping_rounds=20, plot=True, verbose=False)

model.save_model("model.cbm")

In [None]:
pl.DataFrame({
    "importance": model.feature_importances_,
    "feature": model.feature_names_
}).sort("importance", descending=True)

In [None]:
metrics = model.eval_metrics(test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))

In [114]:
def category_diversity_at_k(test_dataset, k=10):
    assert "score" in test_dataset.columns, "test_dataset must contain 'score' column"
    assert "group_id" in test_dataset.columns, "test_dataset must contain 'group_id' column"
    assert "product_category" in test_dataset.columns, "test_dataset must contain 'product_category' column"
    return test_dataset.select(
        (pl.col("product_category")
        .sort_by("score", descending=True)
        .head(k)
        .n_unique() / k)
        .over("group_id", mapping_strategy="explode")
    ).mean().item()

In [115]:
def serendipity_at_k(test_dataset, train_dataset, k=10):
    assert "score" in test_dataset.columns, "test_dataset must contain 'score' column"
    assert "target" in test_dataset.columns, "test_dataset must contain 'target' column"
    assert "group_id" in test_dataset.columns, "test_dataset must contain 'group_id' column"
    history = (
        train_dataset
        .select(
            pl.col("user_id"),
            pl.col("item_id"),
            pl.lit(True).alias("is_historic"),
        )
        .unique()
    )
    test_dataset = (
        test_dataset
        .join(history, on=["user_id", "item_id"])
        .with_columns(
            is_historic=pl.col("is_historic").fill_null(False)
        )
        .with_columns(
            is_serendipitous=(pl.col("is_historic") & (pl.col("target") > 0))
        )
    )
    return test_dataset.select(
        pl.col("is_serendipitous")
        .sort_by("score", descending=True)
        .head(k)
        .mean()
        .over("group_id", mapping_strategy="explode")
    ).mean().item()

In [None]:
print("Serendipity@5")
print(f"{serendipity_at_k(test.with_columns(score=model.predict(test_pool)), train, k=5):.5f}")

## Ranking tricks

### Undersampling

In [None]:
(
    dataset
    .group_by("request_id")
    .agg(
        (pl.col("action_type") == "AT_CartUpdate").sum().alias("cart_updates"),
        (pl.col("action_type") == "AT_View").sum().alias("views"),
        (pl.col("action_type") == "AT_Click").sum().alias("clicks"),
    )
    .select(
        pl.col("cart_updates").mean().alias("cart_updates_mean"),
        pl.col("views").mean().alias("views_mean"),
        pl.col("clicks").mean().alias("clicks_mean"),
    )
)

In [118]:
resampled_dataset = dataset.with_row_index()

request_views = (
    resampled_dataset
    .filter(pl.col("action_type") == "AT_View")
    .select(
        pl.col("index").sort_by("position_in_request").head(10).over("request_id", mapping_strategy="explode"),
        pl.lit(True).alias("view_is_ok"),
    )
)

request_clicks = (
    resampled_dataset
    .filter(pl.col("action_type") == "AT_Click")
    .select(
        pl.col("index").sort_by("position_in_request").head(3).over("request_id", mapping_strategy="explode"),
        pl.lit(True).alias("click_is_ok"),
    )
)

requests_with_cartupdate = (
    resampled_dataset
    .filter(pl.col("action_type") == "AT_CartUpdate")
    .select("request_id").unique()
)

In [119]:
resampled_dataset = (
    resampled_dataset
    .filter(pl.col("request_id").is_not_null())
    .join(requests_with_cartupdate, on="request_id", how="inner")
    .join(request_views, on="index", how="left")
    .join(request_clicks, on="index", how="left")
    .filter((pl.col("action_type") == "AT_CartUpdate") | pl.col("view_is_ok") | pl.col("click_is_ok"))
)

In [120]:
resampled_cbm_dataset = (
    resampled_dataset
    .filter(pl.col("action_type").is_in(["AT_CartUpdate", "AT_View"]))
    .with_columns(
        pl.col("request_id").cast(str).alias("group_id"),
        pl.col("action_type").map_elements(build_target, return_dtype=float).alias("target")
    )
    .with_columns(
        target=pl.col('target').max().over(partition_by=[pl.col('group_id'), pl.col('item_id')])
    )
    .unique()
    .sort("timestamp")
)

resampled_train, resampled_test = train_test_split(resampled_cbm_dataset, test_size=0.2, shuffle=False)

resampled_train = resampled_train.sort("group_id")
resampled_test = resampled_test.sort("group_id")

In [121]:
resampled_train_pool = Pool(
    resampled_train.select(FEATURE_COLUMNS).to_numpy(),
    feature_names=FEATURE_COLUMNS,
    cat_features=CAT_FEATURE_COLUMNS,
    label=resampled_train["target"].to_numpy(),
    group_id=resampled_train["group_id"].to_numpy()
)

resampled_test_pool = Pool(
    resampled_test.select(FEATURE_COLUMNS).to_numpy(),
    feature_names=FEATURE_COLUMNS,
    cat_features=CAT_FEATURE_COLUMNS,
    label=resampled_test["target"].to_numpy(),
    group_id=resampled_test["group_id"].to_numpy()
)

In [None]:
model = CatBoostClassifier(iterations=200, eval_metric="NDCG:top=10")
model.fit(resampled_train_pool, eval_set=resampled_test_pool, early_stopping_rounds=30, plot=True, verbose=False)

In [None]:
print("Resampled requests")
metrics = model.eval_metrics(resampled_test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))


print("-" * 50)
print("Original requests")
metrics = model.eval_metrics(test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))
print("Serendipity@5")
print(f"{serendipity_at_k(test.with_columns(score=model.predict(test_pool)), train, k=5):.5f}")

In [None]:
model = CatBoostRanker(iterations=200, eval_metric="NDCG:top=10")
model.fit(resampled_train_pool, eval_set=resampled_test_pool, early_stopping_rounds=30, plot=True, verbose=False)

In [None]:
print("Resampled requests")
metrics = model.eval_metrics(resampled_test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))

print("-" * 50)
print("Original requests")
metrics = model.eval_metrics(test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))
print("Serendipity@5")
print(f"{serendipity_at_k(test.with_columns(score=model.predict(test_pool)), train, k=5):.5f}")

### Non-binary target for ranking loss

In [130]:
def build_target(action_type: str):
    if action_type == "AT_CartUpdate":
        return 1.0
    elif action_type == "AT_Click":
        return 0.8
    elif action_type == "AT_View":
        return 0.0

In [131]:
resampled_cbm_dataset = (
    resampled_dataset
    .with_columns(
        pl.col("request_id").cast(str).alias("group_id"),
        pl.col("action_type").map_elements(build_target, return_dtype=float).alias("target")
    )
    .with_columns(
        target=pl.col('target').max().over(partition_by=[pl.col('group_id'), pl.col('item_id')])
    )
    .unique()
    .sort("timestamp")
)

resampled_train, resampled_test = train_test_split(resampled_cbm_dataset, test_size=0.2, shuffle=False)

resampled_train = resampled_train.sort("group_id")
resampled_test = resampled_test.sort("group_id")

In [None]:
model = CatBoostRanker(iterations=200, eval_metric="NDCG:top=10")
model.fit(resampled_train_pool, eval_set=resampled_test_pool, early_stopping_rounds=30, plot=True, verbose=False)

In [None]:
print("Original requests")
metrics = model.eval_metrics(test_pool, ["AUC:type=Ranking", "NDCG", "NDCG:top=10"])
for metric in metrics:
    print(metric)
    print(np.mean(metrics[metric]))
print("Serendipity@5")
print(f"{serendipity_at_k(test.with_columns(score=model.predict(test_pool)), train, k=5):.5f}")

## Reranker class

In [35]:
last_day = '2023-11-02'

In [None]:
duckdb.sql(
f"""
WITH
    actions as (
        SELECT
            user_id,
            product_id as item_id,
            product_category,
            request_id,
            action_type,
            position_in_request,
            make_timestamp(timestamp * 1000000) as timestamp,
            date_trunc('day', make_timestamp(timestamp * 1000000)) as day,
        FROM '{DATA_DIR}/train.parquet'
        WHERE action_type IN ('AT_CartUpdate', 'AT_View', 'AT_Click')
        AND date_trunc('day', make_timestamp(timestamp * 1000000)) < date '{last_day}'
    ),
    user_windows as (
        SELECT
            user_id,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) as cart_updates,
            sum(case when action_type = 'AT_View' then 1 else 0 end) as views,
            sum(case when action_type = 'AT_Click' then 1 else 0 end) as clicks,
            count(distinct (case when action_type = 'AT_CartUpdate' then request_id end)) as num_requests_with_cart_update,
            count(distinct request_id) as total_requests,
        FROM actions
        GROUP BY user_id
    ),
    user_features as (
        SELECT
            user_id,
            cart_updates / views as user_cart_update_turn_rate,
            clicks / views as user_click_turn_rate,
            num_requests_with_cart_update / total_requests as user_conversion_rate,
        FROM user_windows
    ),
    item_windows as (
        SELECT
            item_id,
            cast(max(product_category) as string) as product_category,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) as cart_updates,
            sum(case when action_type = 'AT_View' then 1 else 0 end) as views,
            sum(case when action_type = 'AT_Click' then 1 else 0 end) as clicks,
            count(distinct (case when action_type = 'AT_CartUpdate' then request_id end)) as num_requests_with_cart_update,
            count(distinct request_id) as total_requests,
        FROM actions
        GROUP BY item_id
    ),
    item_features as (
        SELECT
            item_id,
            product_category,
            cart_updates / views as item_cart_update_turn_rate,
            clicks / views as item_click_turn_rate,
            num_requests_with_cart_update / total_requests as item_conversion_rate,
        FROM item_windows
    ),
    user2item_windows as (
        SELECT
            user_id,
            item_id,
            max(case when action_type = 'AT_CartUpdate' then timestamp end) as last_cart_update_timestamp,
            sum(case when action_type = 'AT_CartUpdate' then 1 else 0 end) as cart_updates,
            array_agg(timestamp) as cart_update_timestamps,
        FROM actions
        GROUP BY item_id, user_id
    ),
    user2item_features as (
        SELECT
            user_id,
            item_id,
            cart_updates as u2i_cart_updates,
            list_avg(list_transform(list_zip(
                array_pop_back(list_sort(cart_update_timestamps)),
                array_pop_front(list_sort(cart_update_timestamps))
            ), x -> datepart('day', x[2] - x[1]))) as u2i_mean_time_between_cartupdates
        FROM user2item_windows
    )
    SELECT
        u2if.user_id as user_id,
        u2if.item_id as item_id,
        if.product_category as product_category,
        coalesce(if(isnan(uf.user_cart_update_turn_rate), null, uf.user_cart_update_turn_rate), -999999999.0) as user_cart_update_turn_rate,
        coalesce(if(isnan(uf.user_click_turn_rate), null, uf.user_click_turn_rate), -999999999.0) as user_click_turn_rate,
        coalesce(if(isnan(uf.user_conversion_rate), null, uf.user_conversion_rate), -999999999.0) as user_conversion_rate,
        coalesce(if(isnan(if.item_cart_update_turn_rate), null, if.item_cart_update_turn_rate), -999999999.0) as item_cart_update_turn_rate,
        coalesce(if(isnan(if.item_click_turn_rate), null, if.item_click_turn_rate), -999999999.0) as item_click_turn_rate,
        coalesce(if(isnan(if.item_conversion_rate), null, if.item_conversion_rate), -999999999.0) as item_conversion_rate,
        coalesce(if(isnan(u2if.u2i_cart_updates), null, u2if.u2i_cart_updates), -999999999.0) as u2i_cart_updates,
        coalesce(if(isnan(u2if.u2i_mean_time_between_cartupdates), null, u2if.u2i_mean_time_between_cartupdates), -999999999.0) as u2i_mean_time_between_cartupdates
    FROM user2item_features u2if
    LEFT JOIN user_features uf
    ON u2if.user_id = uf.user_id
    FULL JOIN item_features if
    ON u2if.item_id = if.item_id
""").to_parquet("latest_features.parquet")

In [9]:
from grocery.recommender.features import FeatureManager, StaticFeatureExtractor, FeatureStorage

In [10]:
user_feature_names = ["user_cart_update_turn_rate", "user_click_turn_rate", "user_conversion_rate"]
item_feature_names = ["item_cart_update_turn_rate", "item_click_turn_rate", "item_conversion_rate", "product_category"]
user2item_feature_names = ["u2i_cart_updates", "u2i_mean_time_between_cartupdates"]

features = pl.read_parquet("latest_features.parquet")
user_features = features.select("user_id", *user_feature_names).unique()
item_features = features.select("item_id", *item_feature_names).unique()
user2item_features = features.select("user_id", "item_id", *user2item_feature_names).unique()

user_storage = FeatureStorage()
item_storage = FeatureStorage()
user2item_storage = FeatureStorage()


for feature_key in user_feature_names:
    user_feature_dict = {
        row["user_id"]: row[feature_key]
        for row in user_features
        .select("user_id", feature_key)
        .to_dicts()
    }
    user_storage.add_feature(feature_key, user_feature_dict, -999999999.0)

for feature_key in ["item_cart_update_turn_rate", "item_click_turn_rate", "item_conversion_rate"]:
    item_feature_dict = {
        row["item_id"]: row[feature_key]
        for row in item_features
        .select("item_id", feature_key)
        .to_dicts()
    }
    item_storage.add_feature(feature_key, item_feature_dict, -999999999.0)
    
feature_key = "product_category"
item_feature_dict = {
    row["item_id"]: row[feature_key]
    for row in item_features
    .select("item_id", feature_key)
    .to_dicts()
}
item_storage.add_feature(feature_key, item_feature_dict, "EMPTY")

for feature_key in ["u2i_cart_updates", "u2i_mean_time_between_cartupdates"]:
    user2item_feature_dict = {
        (row["user_id"], row["item_id"]): row[feature_key]
        for row in user2item_features
        .select("user_id", "item_id", feature_key)
        .to_dicts()
    }
    user2item_storage.add_feature(feature_key, user2item_feature_dict, -999999999.0)

def user_key(user_id, item_id):
    return user_id

def item_key(user_id, item_id):
    return item_id

def user_item_key(user_id, item_id):
    return (user_id, item_id)

manager = FeatureManager([
    StaticFeatureExtractor(user_feature_names, user_storage, key=user_key),
    StaticFeatureExtractor(item_feature_names, item_storage, key=item_key),
    StaticFeatureExtractor(user2item_feature_names, user2item_storage, key=user_item_key),
])

In [None]:
from abc import ABC, abstractmethod
import heapq


class Ranker(ABC):
    @abstractmethod
    def __init__(self):
        pass
        
    @abstractmethod
    def rank(self, object_id: int, candidates: list[Candidate], n: int) -> list[Candidate]:
        pass

    @staticmethod
    def select_top_n(candidates: list[Candidate], feature: str, n: int, descending: bool = True):
        if descending:
            return heapq.nlargest(n, candidates, key=lambda x: x.features[feature])
        else:
            return heapq.nsmallest(n, candidates, key=lambda x: x.features[feature])

In [14]:
from catboost import FeaturesData


class GroceryCatboostRanker(Ranker):
    def __init__(self,
                 model_path: str,
                 num_feature_schema: list[str],
                 cat_feature_schema: list[str] | None = None,
                 score_feature_name: str = "cbm_relevance"
                 ):
        super().__init__()
        self.model = CatBoostRanker()
        self.model.load_model(fname=model_path)
        self.num_feature_schema = num_feature_schema
        self.cat_feature_schema = cat_feature_schema or []
        self.score_feature_name = score_feature_name
        self.fill_value = -9999999.0

    def build_cbm_features(self, candidates: list[Candidate]) -> FeaturesData:
        num_feature_array = np.array([
            [candidate.features.get(feature, self.fill_value) for feature in self.num_feature_schema]
            for candidate in candidates
        ], dtype=np.float32)
        cat_feature_array = np.array([
            [candidate.features.get(feature, "EMPTY") for feature in self.cat_feature_schema]
            for candidate in candidates
        ], dtype=object)
        return FeaturesData(
            num_feature_data=num_feature_array,
            cat_feature_data=cat_feature_array,
            num_feature_names=self.num_feature_schema,
            cat_feature_names=self.cat_feature_schema,
        )

    def rank(self, object_id: int, candidates: list[Candidate], n: int) -> list[Candidate]:
        features = self.build_cbm_features(candidates)
        scores = self.model.predict(features)
        for candidate, score in zip(candidates, scores):
            candidate.features[self.score_feature_name] = score
        return self.select_top_n(candidates, self.score_feature_name, n)

In [None]:
ranker = GroceryCatboostRanker(
    model_path="model.cbm",
    num_feature_schema=NUM_FEATURE_COLUMNS,
    cat_feature_schema=CAT_FEATURE_COLUMNS,
)

In [29]:
sample_object_id = train["user_id"].shuffle()[-1]
sample_history = (
    train
    .filter(pl.col("user_id") == sample_object_id)
    ["item_id"]
    .unique()
    .to_list()
)
candidates = [Candidate(id=item_id) for item_id in sample_history]
candidates = list(manager.extract(sample_object_id, candidates))
ranked = ranker.rank(sample_object_id, candidates, 10)

### Softmax Sampling

In [151]:
class SoftmaxSampler(Ranker):
    def __init__(self,
                 temperature: float = 0.1,
                 relevance_feature_name: str = "cbm_relevance",
                 sampled_rank_feature_name: str = "sampled_cbm_relevance",
                 random_state: int | np.random.RandomState | None = 42
                 ):
        super().__init__()
        self.relevance_feature_name = relevance_feature_name
        self.sampled_rank_feature_name = sampled_rank_feature_name
        self.temperature = temperature
        self.rng = np.random.default_rng(seed=random_state)
    
    def gumbel_max_trick(self, relevances: np.ndarray) -> np.ndarray:
        noise = self.rng.gumbel(size=relevances.shape)
        relevances = relevances + noise * self.temperature
        return relevances

    def rank(self, object_id: int, candidates: list[Candidate], n: int) -> list[Candidate]:
        relevances = np.array([candidate.features[self.relevance_feature_name] for candidate in candidates])
        probs = self.gumbel_max_trick(relevances)
        for candidate, prob in zip(candidates, probs):
            candidate.features[self.sampled_rank_feature_name] = prob
        return self.select_top_n(candidates, self.sampled_rank_feature_name, n)

In [152]:
sampler = SoftmaxSampler()

In [None]:
sampler.rank(sample_object_id, ranked, 10)