In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Clearbox for Ranking Tuning

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fsearch%2Fcustom-ranking%2Fclearbox.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/search/custom-ranking/clearbox.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/custom-ranking/clearbox.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| Author |
| --- |
| [Andrei Papou](https://github.com/andrei-papou) |

## Overview

In this notebook we will use the [ClearBox](https://github.com/GoogleCloudPlatform/clearbox) library to improve the default [Vertex AI Search](https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction) ranking for a given query set.

## Get Started

### Install ClearBox from GitHub

In [None]:
%pip install "git+https://github.com/GoogleCloudPlatform/clearbox"

## Ranking Tuning

### Imports

In [None]:
import math
import random
import warnings

from clearbox import features as F
from clearbox.features import signals as S
from clearbox.metrics.recall import RecallAtK
from clearbox.models.probabilistic import (
    BayesianRidgeModel,
    BayesOptLinearModel,
)
from clearbox.models.probabilistic import ExpectedImprovement as EI
from clearbox.models.regression import LinRegModel
from clearbox.training import Trainer
from clearbox.visualization import Visualizer
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning)

### Pin the random seed

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)


seed_everything(42)

### Load and explore the dataset

In [None]:
qs_df = pd.read_csv(
    "https://storage.googleapis.com/github-repo/search/clearbox/beir_fiqa_50.csv"
)

In [None]:
qs_df.head(3)

### Preprocess the dataset and generate features

In [None]:
COL_QUERY_CODE = "query_code"
COL_TARGET = "is_match"

In [None]:
qs_df[COL_QUERY_CODE] = F.encode(qs_df["query"])

In [None]:
def _generate_is_match_col(df: pd.DataFrame) -> pd.DataFrame:
    query_to_matched_titles = {}
    for _, row in df.iterrows():
        query_code = row[COL_QUERY_CODE]
        if query_code in query_to_matched_titles:
            continue
        query_to_matched_titles[query_code] = {
            targets_str.removeprefix("gs://beir_fiqa/").removesuffix(".txt")
            for targets_str in row["targets"].split(" ")
        }
    df[COL_TARGET] = np.zeros((len(df),), dtype=np.uint8)
    for query_code, matched_titles in query_to_matched_titles.items():
        df.loc[
            (df[COL_QUERY_CODE] == query_code)
            & (df["title"].astype(str).isin(matched_titles)),
            COL_TARGET,
        ] = 1


_generate_is_match_col(qs_df)

In [None]:
qs_df.head(3)

In [None]:
len(qs_df[qs_df["is_match"] == 1])

 ### Explore the dataset

In [None]:
def _plot_signal_distribution(
    df: pd.DataFrame, signal_list: list[str], n_per_row: int = 3
):
    n = len(signal_list)
    n_rows = math.ceil(n / n_per_row)
    f, axes = plt.subplots(n_rows, n_per_row, squeeze=False)
    f.set_figheight(5 * n_rows)
    f.set_figwidth(7 * n_per_row)
    for i, signal in enumerate(signal_list):
        ax = axes[i // n_per_row][i % n_per_row]
        ax.set_title(signal)
        ax.hist(df[signal], bins=50)


_plot_signal_distribution(
    qs_df,
    [
        "gecko_score",
        "bm25_score",
        "jetstream_score",
        "freshness_rank",
        "base_rank",
    ],
)

In [None]:
for signal_name in [
    "gecko_score",
    "bm25_score",
    "jetstream_score",
    "freshness_rank",
    "base_rank",
]:
    if qs_df[signal_name].isna().any():
        print(
            f"{signal_name}: {len(qs_df[qs_df[signal_name].isna()]) / len(qs_df) * 100:.2f}% are NaNs"
        )

Looks like we have some NaN values in `bm25_score` column. We'll handle that later.

In [None]:
plt.hist(qs_df["is_match"], bins=50)

### Train the model

First we will create instances of `Trainer` and `Visualizer` classes. We will use `trainer` to train and validate the model and `visualizer` will help us explore the results.

The `Trainer` has a number of parameters to customize the training procedure, let's look at some of them.

First of all, under the hood `Trainer` does a number
of random cross-validation splits and trains and validates a new model on each of them. `seeds` argument specifies how many CV splits we should do and which random seed we should use for each. `n_folds` specifies the number of folds we generate for each split.

Once a new model is trained for a given seed and fold combination, the `Trainer` will calculate metric value on both train and validation parts of the fold for each metric object from `metrics` list. All those metric values are then aggregated and added to the return value of `train` method, you can access them using `.metrics` attribute.

`target_col` specifies the column of the `qs_df` data frame to use as a target. `query_col` is used to identify the query column, please note that we use the integer encoded version of the query we've generated above.

In [None]:
trainer = Trainer(
    df=qs_df,
    seeds=[7, 15, 21, 42, 81],
    n_folds=3,
    metrics=[RecallAtK(k) for k in [1, 3, 5]],
    target_col=COL_TARGET,
    query_col=COL_QUERY_CODE,
)
visualizer = Visualizer(metrics=trainer.metrics)

It's a good practice in ranking to compare the model against the individual signal, just to make sure our model is an improvement compared to the features we train on. Also, as the ranking signals Vertex AI exposes take different aspects of query-document relevance into account (e.g. topicality, relevance, semantic similarity, etc.), looking at the baselines is a good way to understand which if the signals are more important for your query set.

In [None]:
baseline_list = [
    ("base_rank", trainer.get_feature_baseline(feature=-S.base_rank)),
    ("gecko_score", trainer.get_feature_baseline(feature=S.gecko_score)),
    (
        "bm25_score",
        trainer.get_feature_baseline(feature=F.FillNaN(S.bm25_score, F.Constant(0.0))),
    ),
    ("jetstream_score", trainer.get_feature_baseline(feature=S.jetstream_score)),
    ("freshness_rank", trainer.get_feature_baseline(feature=-S.freshness_rank)),
]

Now we will train a linear regression model on reciprocal ranks of input signals we've explored above.

`features` argument of the `train` method accepts a list of input features to train the model on. Notice how the features are created:

- We use `clearbox.features` module (conveniently imported as `F` alias) to do all the feature engineering we need. Doing it this way guarantees we'll be able to seamlessly deploy the formula to production.

- We use `S` utility object to create a signal node. `S.gecko_score` tells the trainer that we should read the `gecko_score` column of the input data frame. Under the hood, it's just a syntactic sugar over `F.Signal("gecko_score")`.

- We use `F.RR` class to generate reciprocal ranks. Although for some of the signals (like `gecko_score` or `bm25_score`) we have raw values, not just ranks, it's always recommended to train the model on reciprocal ranks. Doing it this way guarantees better stability of the formula across signal updates (e.g. a new version of Jetstream model is released and the distribution of the score changes slightly).

- When computing reciprocal ranks, we make sure the input signal is monotonically increasing (e.g. better match corresponds to higher signal value). As the ranks are usually monotonically decreasing, we preprocess them using negation operation: `-S.base_rank`.

- We use `F.FillNaN` class to replace `NaN` values with `0`. It's important to do that here using `F` module utilities as this way we can guarantee same logic will be used for any NaNs we encounter while serving the formula in production.

In [None]:
reg_training_results = trainer.train(
    LinRegModel(),
    features=[
        F.RR(-S.base_rank, 40.0, group_by=S.query_code),
        F.RR(S.gecko_score, 40.0, group_by=S.query_code),
        F.RR(F.FillNaN(S.bm25_score, F.Constant(0.0)), 40.0, group_by=S.query_code),
        F.RR(S.jetstream_score, 40.0, group_by=S.query_code),
        F.RR(-S.freshness_rank, 40.0, group_by=S.query_code),
    ],
    num_parallel_workers=4,
    print_progress=True,
)
visualizer.visualize_training_results(reg_training_results, baseline_list);

Looks like we've improved over the baselines quite a bit!

`training_results` object which is returned by the `train` method provides a number of useful properties:

- `metrics` data frame contains all the metric values computed on train and validation parts for each seed & fold pair.
- `ranking_formula` represents the final ranking model we've trained.

The ranking formula object has a number of useful methods, but the most important one is `serialize_to_ranking_expression`. It serializes the formula into the deployment format. To deploy the formula, put the string returned by the method into the `ranking_expression` field of the request and make sure `ranking_expression_backend` field is set to `CLEARBOX`. And that's it, the formula is now used for reranking the results of this particular request.

Now let's look at the formula we've just tuned.

In [None]:
print(reg_training_results.ranking_formula.serialize_to_ranking_expression())

Now let's try a different kind of model and see if we can beat the current result. Intuitively it does seem like optimizing against a 0/1 target using a regression model isn't a great idea.

What we'll do instead is optimize the recall directly. Of course, the metric is non-differentiable, but we still can try different combinations of weights and see which of those give us better results. But even if we are going to try just 10 values from 0.0 to 1.0 for each of the 5 signal, we'll have to wait a lot as we'll need to compute recall 5 seed x 3 folds x 10 ** 5 models = 1500000 times. Instead, we gonna use Bayesian Optimization technique to pick the most promising combinations of weights. `BayesOptLinearModel` does just that.

In [None]:
bayes_training_results = trainer.train(
    BayesOptLinearModel(
        surrogate_model=BayesianRidgeModel(),
        acquisition_function=EI(xi=1e-6),
        metric=RecallAtK(1),
        grid_size=21,
        seed_batch_size=512,
        batch_size=32,
        n_opt_steps=8,
    ),
    features=[
        F.RR(-S.base_rank, 40.0, group_by=S.query_code),
        F.RR(S.gecko_score, 40.0, group_by=S.query_code),
        F.RR(F.FillNaN(S.bm25_score, F.Constant(0.0)), 40.0, group_by=S.query_code),
        F.RR(S.jetstream_score, 40.0, group_by=S.query_code),
        F.RR(-S.freshness_rank, 40.0, group_by=S.query_code),
    ],
    num_parallel_workers=4,
    print_progress=True,
)
visualizer.visualize_training_results(
    bayes_training_results,
    [
        *baseline_list,
        ("regression_model", reg_training_results.metrics),
    ],
)

In [None]:
print(bayes_training_results.ranking_formula.serialize_to_ranking_expression())

The intuition worked and we were able to improve all 3 recall metrics, `~0.04` improvement in `recall@1` looks decent. Notice how the formula is very close to the previous one, the only substantial difference is that it uses `base_rank` instead of `freshness_rank`.