# Cortex Search Reranker Fine-tuning

At the heart of every modern search system is a powerful two-stage process. First, retrieval acts like a speedy scout, rapidly sifting through millions of documents to gather a set of promising candidates. But this initial step prioritizes speed over perfect accuracy. That's where the **reranker** takes the spotlight. It's a more sophisticated model that closely examines the top candidates, intelligently reordering them to push the absolute best results to the top. This ensures we see the most relevant information first. However, a generic, "one-size-fits-all" reranker can often miss crucial nuances. In this tutorial, we'll discover how fine-tuning transforms a good reranker into a great one, delivering pinpoint accuracy for our specific needs.

We will be using a complex search dataset called [TREC Clinical Trials](https://www.trec-cds.org/2021.html). This dataset represents an interesting and challenging case for fine-tuning because:

- The search queries are paragraph-long patient case descriptions (PCDs), in contrast to keyword-based or single-sentence queries.
- It contains a large amount of domain-specific language and terminology, on which general search models have limited expertise.
- The notion of relevance differs significantly from that in general web search.

In [None]:
# install ir_datasets for downloading sample data
!pip install ir_datasets --quiet

In [None]:
# most python libraries we need in this tutorial

import os
import json
import math
import torch
import datetime
import ir_datasets
import pandas as pd

from tqdm.auto import tqdm
from dataclasses import dataclass
from snowflake.snowpark.context import get_active_session
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from snowflake.ml.modeling.distributors.pytorch import (
    PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig,
)
from snowflake.ml.modeling.distributors.pytorch.context import get_context


session = get_active_session()

# Print the current role, warehouse, and database/schema
print(f"role: {session.get_current_role()} | WH: {session.get_current_warehouse()} | DB.SCHEMA: {session.get_fully_qualified_current_schema()}")

# Stage 1: Data Preparation

In this stage, we will

- Collect documents in the corpus and store in a Snowflake table (if not there already);
- Create a cortex search service for evaluation and data mining (if not there already);
- Collect some test queries.

In [None]:
# download data and store into a Snowflake table

CORPUS_TABLE_NAME = "TRECCT_DOCUMENT_CORPUS"

dataset = ir_datasets.load("clinicaltrials/2021/trec-ct-2021")
corpus_rows = []
for doc in tqdm(dataset.docs_iter()):  # iterate through the corpus
    # document processing logic
    docid = doc.doc_id
    title = doc.title
    condition = doc.condition
    summary = doc.summary
    detailed_description = doc.detailed_description
    eligibility = doc.eligibility

    # for simplicity we concatenate four fields together to form document texts
    doc_text = f"TITLE: {title}\nELIGIBILITY: {eligibility}\nSUMMARY: {summary}\nDETAILED_DESCRIPTION: {detailed_description}"
    corpus_rows.append({
        "DOCID": docid,
        "TEXT": doc_text
    })

df = pd.DataFrame(corpus_rows)

session.write_pandas(
    df,
    table_name=CORPUS_TABLE_NAME,
    auto_create_table=True,   # infer column names/types from pandas
    overwrite=True           # set True to replace the table contents
)

In [None]:
-- quickly check if all docs are uploaded successfully
SELECT COUNT(*) FROM TRECCT_DOCUMENT_CORPUS;

In [None]:
# collect test queries and write to table
# though we have test queries from two years (2021 and 2022), we only use 2022 ones for evaluation
# queries from 2021 will be lightly used to adjusting the prompts for data generation and evaluation

PROMPT_QUERIES_TABLE_NAME = "TRECCT_QUERIES_FOR_PROMPT"
TEST_QUERIES_TABLE_NAME = "TRECCT_TEST_QUERIES"
test_queries_2021, test_queries_2022 = [], []

for query in ir_datasets.load("clinicaltrials/2021/trec-ct-2021").queries_iter():
    qid, query_text = query.query_id, query.text
    qid = "2021_" + qid
    test_queries_2021.append({
        "QID": qid,
        "TEXT": query_text
    })

for query in ir_datasets.load("clinicaltrials/2021/trec-ct-2022").queries_iter():
    qid, query_text = query.query_id, query.text
    qid = "2022_" + qid
    test_queries_2022.append({
        "QID": qid,
        "TEXT": query_text
    })

df_prompt = pd.DataFrame(test_queries_2021)
df_test = pd.DataFrame(test_queries_2022)

session.write_pandas(
    df_prompt,
    table_name=PROMPT_QUERIES_TABLE_NAME,
    auto_create_table=True,
    overwrite=True
)

session.write_pandas(
    df_test,
    table_name=TEST_QUERIES_TABLE_NAME,
    auto_create_table=True,
    overwrite=True
)

print(f"Collected {len(df_test)} test queries!")

In [None]:
-- Let's also create a cortex search service 
-- for evaluation and to help us efficiently build finetuning data

CREATE CORTEX SEARCH SERVICE IF NOT EXISTS CSS_TRECCT
  ON TEXT
  ATTRIBUTES DOCID, TEXT
  WAREHOUSE = 'SEARCH_L'  -- replace this with your available warehouse
  TARGET_LAG = '30 days'
  EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0-8k'
  INITIALIZE = ON_CREATE
  AS SELECT * FROM TRECCT_DOCUMENT_CORPUS;

# Stage 2: Defining Relevance

In this stage, we need to establish a robust, coherent and reusable concept of “relevance.” **This is arguably the most important step in customizing the search model for our needs.** The essence of finetuning lies in translating the concept of relevance into data for model training. If the concept is incorrect or unclear, we risk training the model in unwanted directions.

Specifically, in this stage, we will:

- Start with a scoring criterion;
- Iterate on the scoring criterion until it aligns with our empirical expectations.

*Note: this step is best accomplished by or under the consultance of domain experts.*

If the search task is similar to generic web search, a good starting point of the scoring criterion can be found in Figure 1 of [this paper](https://arxiv.org/abs/2406.06519). We won't be using this criterion as the definition of relevance is intricate.

Based on materials in [TREC Clinical Trials](https://www.trec-cds.org/2021.html), it's clear that there are three levels of relevance in this search task: eligible (2), excludes (1) and not relevant (0). A good scoring criterion must be interpretable **by both humans and our data agent** (in this case, an LLM). We can validate this by checking if the agent's relevance judgments align with those of domain experts.

Now suppose we know from [here](https://trec.nist.gov/data/trials/qrels2021.txt) that for query 1 from the 2021 split, document NCT00526812 is judged as 'excludes' (1), whereas NCT00003466 is judged as 'not relevant' (0). Let's quickly test whether a scoring criterion works with data agents.

In [None]:
SELECT 
    b.QID AS QID, 
    a.DOCID AS DOCID,
    SNOWFLAKE.CORTEX.COMPLETE('llama3.3-70b',
'[Instruction]

Determine how well the given patient case description (PCD) matches the given clinal trial description (CTD), based on the scoring criteria provided below. Output score must be wrapped in [Score] [/Score] tag pair for easy parsing. Explain the level of relevance in separate [EXP] [/EXP] tag pair.

[Scoring Criteria] 

- Score 2 (eligible): Disease type/location matches trial intent. Patient meets all major eligibility criteria (age, relapse status, prior therapy limits, etc.). 
- Score 1 (conceptually relevant but excluded) The trial is truly for the same disease entity in the same anatomical context (e.g., brain anaplastic astrocytoma if patient has brain anaplastic astrocytoma). The trial treatment could apply in principle to the patient’s disease type, but the patient fails on technical eligibility restrictions, such as: Too many prior lines of therapy; Performance status or lab values outside range; Excluded comorbidity (e.g., uncontrolled hypertension, liver failure). In other words: same disease, same setting, but patient disqualified. 
- Score 0 (not relevant): The trial is not truly applicable to the patient’s disease setting, even if the words overlap. Examples: Trial requires intracranial astrocytoma, but patient’s is spinal astrocytoma; Trial is only for first relapse, but patient is on third-line or beyond and the mechanism is specific to early relapse; Trial is targeting related but different histologies (e.g., oligoastrocytoma only, glioblastoma only, pediatric setting); Trial uses a local delivery approach (e.g., stereotactic catheter infusion) that is anatomically impossible for the patient’s tumor site.

[Input PCD]
' || b.TEXT || '

[Input CTD]
' || a.TEXT) as llm_output,TRY_TO_NUMBER(
      TRIM(
        REGEXP_SUBSTR(
          LLM_JUDGE_OUTPUT,
          '\\[\\s*score\\s*\\][\\s\\S]*?(\\d+)[\\s\\S]*?\\[\\s*/\\s*score\\s*\\]',
          1, 1, 'is', 1
        )
      )
    ) AS SCORE,
    TRIM(
      REGEXP_SUBSTR(
        REGEXP_REPLACE(llm_output, '.*\\[\\s*exp\\s*\\]', '', 1, 1, 'is'),
        '([\\s\\S]*?)\\[\\s*/\\s*exp\\s*\\]', 1, 1, 'is', 1
      )
    ) AS EXPLANATION,
    b.text as PCD,
    a.text as CTD
    from RAW_CORPUS_TREC_CT a
    join TRECCT_QUERIES_FOR_PROMPT b
    -- where a.DOCID = 'NCT00526812' and b.QID = '2021_1';  --> should be 1 mostly
    where a.DOCID = 'NCT00003466' and b.QID = '2021_1';  --> should be 0 mostly

It seems that the data agent can comprehend the scoring criterion quite accurately! Since we will be generating more data and evaluating search quality using this criterion repeatedly, it’s worthwhile to spend additional time refining it on more query–document samples until we achieve satisfactory performance.

# Stage 3: Synthetic Data Generation

Now it's time to take our refined criterion of relevance to synthetically generate search queries. Specifically, we will in this stage:

- Sample 10K documents and generate 1 synthetic query from each;
- Automatically annotate the top results searching those synthetic queries as well as our test queries;
- Split synthetic data into training and validation set.

In [None]:
-- To make queries *grounded*, typically we need to generate queries *based on* certain documents.
-- Let's first randomly sample some documents:
CREATE OR REPLACE TABLE TRECCT_DOCUMENT_CORPUS_RANDOM_10K AS 
SELECT * FROM TRECCT_DOCUMENT_CORPUS ORDER BY RANDOM() LIMIT 10000;

In [None]:
-- Before generating the whole 10K queries,
-- let's first try a few samples and eye-ball if they are realistic; 
-- if not, adjust prompts accordingly

SELECT
    'TRAINQ_' || row_number() over (order by docid) as QID,
    DOCID,
    TEXT,
    SNOWFLAKE.CORTEX.COMPLETE('llama3.3-70b', 
'[Instruction]

Given a clinical trial description (CTD), write a synthetic one-paragraph patient case description (PCD) that is deemed "eligible" (score 2) or "conceptually relevant but excluded" (score 1) according to the scoring criteria below. Follow the format of three sample PCDs I provide below. Output PCD must be wrapped in [PCD] [/PCD] tag pair for easy parsing. Explain the level of relevance in separate [EXP] [/EXP] tag pair.

[Scoring Criteria] 

- Score 2 (eligible): Disease type/location matches trial intent. Patient meets all major eligibility criteria (age, relapse status, prior therapy limits, etc.). 
- Score 1 (conceptually relevant but excluded) The trial is truly for the same disease entity in the same anatomical context (e.g., brain anaplastic astrocytoma if patient has brain anaplastic astrocytoma). The trial treatment could apply in principle to the patient’s disease type, but the patient fails on technical eligibility restrictions, such as: Too many prior lines of therapy; Performance status or lab values outside range; Excluded comorbidity (e.g., uncontrolled hypertension, liver failure). In other words: same disease, same setting, but patient disqualified. 
- Score 0 (not relevant): The trial is not truly applicable to the patient’s disease setting, even if the words overlap. Examples: Trial requires intracranial astrocytoma, but patient’s is spinal astrocytoma; Trial is only for first relapse, but patient is on third-line or beyond and the mechanism is specific to early relapse; Trial is targeting related but different histologies (e.g., oligoastrocytoma only, glioblastoma only, pediatric setting); Trial uses a local delivery approach (e.g., stereotactic catheter infusion) that is anatomically impossible for the patient’s tumor site.

[Sample PCD 1]

Patient is a 55yo woman with h/o ESRD on HD and peritoneal dialysis who presented with watery, non bloody diarrhea and weakness. She has a history of 2 prior C diff infections, the most recent just 1 month ago. Recent antibx use in the last month on prior admission. Was also txd for Cdiff at that time for 14 d. course with po vanco. Pt was initially admitted to the ICU and was septic on pressors (levophed) until the morning of [**8-26**] with leukocytosis but no fever. C diff assay positive on admission, and pt had leukocytosis consistent with C diff. Patient was placed on Vanco po, Flagyl IV and Flagyl po initially, and when patient improved she was transitioned to Vanco oral and Flagyl oral on [**8-29**]. Patient was treated with Vanco for an extended course of 6 weeks given her recurrent C diff. Pt was also encouraged to take probiotics and to bleach her home when she was discharged.

[Sample PCD 2]

A 45-year-old woman was referred to the emergency department with abdominal pain lasting about 4 days accompanied by nausea and 2 episodes of vomiting. The pain is localized to the epigastric region and radiates to the right upper quadrant. The pain is worsening after eating fatty food. The patient experienced similar pain twice in the past year. Her past medical history is remarkable for hypercholesterolemia and two C/sections. She has 2 children, and she is menopausal. She doesn\'t smoke, drink alcohol, or use illicit drugs. She is mildly febrile. Her BP is 150/85, HR 115, RR 15, T 38.2, SpO2 98% on RA. On palpation, she experiences epigastric tenderness and tenderness in the right upper quadrant without rebound. Bowel sounds are normal. Laboratory analysis is remarkable for elevated ESR and leukocytosis with a left shift. The ultrasound revealed several gallstones and biliary sludge. The largest gallstone is 0.7cm. Surgery consultation recommends elective cholecystectomy.

[Sample PCD 3]

The patient is a 33-year-old woman complained of fatigue, weight gain and abnormal spotting between menses. No hirsutism or nipple discharge was detected. Her BMI was 34. Her lab results were remarkable for high TSH level (13 mU/L) and low free T4 level (0.2 ng/dl). Her anti-TPO levels were extremely high (120 IU/ml). She was diagnosed with Hashimoto\'s thyroiditis. Her aunt, brother and mother have the same disease. After starting 250 mcg Levothyroxine per day, her symptoms improved significantly and her periods are normal. She is still overweight with BMI of 31. Her most recent thyroid profile revealed all results except for anti-TPO within the normal range: TSH: 2.35 mU/L Free T4: 2.7 ng/dl Anti-TPO: 75 IU/ml

[Input CTD]

' || TEXT) as llm_output,
REGEXP_SUBSTR(llm_output, '\\[pcd\\]([\\s\\S]*?)\\[/pcd\\]', 1, 1, 'ie') as QUERY,
REGEXP_SUBSTR(llm_output, '\\[exp\\]([\\s\\S]*?)\\[/exp\\]', 1, 1, 'ie') as EXPLANATION,
FROM TRECCT_DOCUMENT_CORPUS_RANDOM_10K 
LIMIT 5;

In [None]:
-- Now we can go ahead and generate more synthetic queries and store them in a table

CREATE OR REPLACE TABLE TRECCT_SYNTHETIC_QUERIES_RANDOM_10K AS
    SELECT
        'TRAINQ_' || row_number() over (order by docid) as QID,
        DOCID,
        TEXT,
        SNOWFLAKE.CORTEX.COMPLETE('llama3.3-70b', 
'[Instruction]

Given a clinical trial description (CTD), wrtie a synthetic one-paragraph patient case description (PCD) that is deemed "eligible" (score 2) or "conceptually relevant but excluded" (score 1) according to the scoring criteria below. Follow the format of three sample PCDs I provide below. Output PCD must be wrapped in [PCD] [/PCD] tag pair for easy parsing. Explain the level of relevance in separate [EXP] [/EXP] tag pair.

[Scoring Criteria] 

- Score 2 (eligible): Disease type/location matches trial intent. Patient meets all major eligibility criteria (age, relapse status, prior therapy limits, etc.). 
- Score 1 (conceptually relevant but excluded) The trial is truly for the same disease entity in the same anatomical context (e.g., brain anaplastic astrocytoma if patient has brain anaplastic astrocytoma). The trial treatment could apply in principle to the patient’s disease type, but the patient fails on technical eligibility restrictions, such as: Too many prior lines of therapy; Performance status or lab values outside range; Excluded comorbidity (e.g., uncontrolled hypertension, liver failure). In other words: same disease, same setting, but patient disqualified. 
- Score 0 (not relevant): The trial is not truly applicable to the patient’s disease setting, even if the words overlap. Examples: Trial requires intracranial astrocytoma, but patient’s is spinal astrocytoma; Trial is only for first relapse, but patient is on third-line or beyond and the mechanism is specific to early relapse; Trial is targeting related but different histologies (e.g., oligoastrocytoma only, glioblastoma only, pediatric setting); Trial uses a local delivery approach (e.g., stereotactic catheter infusion) that is anatomically impossible for the patient’s tumor site.

[Sample PCD 1]

Patient is a 55yo woman with h/o ESRD on HD and peritoneal dialysis who presented with watery, non bloody diarrhea and weakness. She has a history of 2 prior C diff infections, the most recent just 1 month ago. Recent antibx use in the last month on prior admission. Was also txd for Cdiff at that time for 14 d. course with po vanco. Pt was initially admitted to the ICU and was septic on pressors (levophed) until the morning of [**8-26**] with leukocytosis but no fever. C diff assay positive on admission, and pt had leukocytosis consistent with C diff. Patient was placed on Vanco po, Flagyl IV and Flagyl po initially, and when patient improved she was transitioned to Vanco oral and Flagyl oral on [**8-29**]. Patient was treated with Vanco for an extended course of 6 weeks given her recurrent C diff. Pt was also encouraged to take probiotics and to bleach her home when she was discharged.

[Sample PCD 2]

A 45-year-old woman was referred to the emergency department with abdominal pain lasting about 4 days accompanied by nausea and 2 episodes of vomiting. The pain is localized to the epigastric region and radiates to the right upper quadrant. The pain is worsening after eating fatty food. The patient experienced similar pain twice in the past year. Her past medical history is remarkable for hypercholesterolemia and two C/sections. She has 2 children, and she is menopausal. She doesn\'t smoke, drink alcohol, or use illicit drugs. She is mildly febrile. Her BP is 150/85, HR 115, RR 15, T 38.2, SpO2 98% on RA. On palpation, she experiences epigastric tenderness and tenderness in the right upper quadrant without rebound. Bowel sounds are normal. Laboratory analysis is remarkable for elevated ESR and leukocytosis with a left shift. The ultrasound revealed several gallstones and biliary sludge. The largest gallstone is 0.7cm. Surgery consultation recommends elective cholecystectomy.

[Sample PCD 3]

The patient is a 33-year-old woman complained of fatigue, weight gain and abnormal spotting between menses. No hirsutism or nipple discharge was detected. Her BMI was 34. Her lab results were remarkable for high TSH level (13 mU/L) and low free T4 level (0.2 ng/dl). Her anti-TPO levels were extremely high (120 IU/ml). She was diagnosed with Hashimoto\'s thyroiditis. Her aunt, brother and mother have the same disease. After starting 250 mcg Levothyroxine per day, her symptoms improved significantly and her periods are normal. She is still overweight with BMI of 31. Her most recent thyroid profile revealed all results except for anti-TPO within the normal range: TSH: 2.35 mU/L Free T4: 2.7 ng/dl Anti-TPO: 75 IU/ml

[Input CTD]

' || TEXT) as llm_output,
    REGEXP_SUBSTR(llm_output, '\\[pcd\\]([\\s\\S]*?)\\[/pcd\\]', 1, 1, 'ie') as QUERY,
    REGEXP_SUBSTR(llm_output, '\\[exp\\]([\\s\\S]*?)\\[/exp\\]', 1, 1, 'ie') as EXPLANATION,
    FROM TRECCT_DOCUMENT_CORPUS_RANDOM_10K
    WHERE QUERY IS NOT NULL;

In [None]:
-- Check how many queries we get.
-- It's normal we get less than the number of inputs,
-- because sometimes LLMs don't follow format exactly and we have to disgard them.
SELECT COUNT(*) FROM TRECCT_SYNTHETIC_QUERIES_RANDOM_10K;

When fine-tuning a search model, each training query requires both positive (relevant) and negative documents. A common assumption is that the document used to generate the query can serve as the positive example. However, this approach can lead to suboptimal results, as these query-document pairs are often trivially relevant and provide little learning value for the model.

To create a more effective dataset, we will instead use Cortex Search to find seemingly relevant documents for our training queries. A data agent will then carefully label these documents as either relevant or not relevant. This process yields more challenging hard positives and crucial hard negatives, which are essential for building a robust model.

In [None]:
-- We will use batch cortex search (https://docs.snowflake.com/LIMITEDACCESS/cortex-search/batch-cortex-search) to get top results for those queries more efficiently
-- For 8k-10k queries on an CSS with ~300K documents, this will take under 5 minutes

CREATE OR REPLACE TABLE CSS_TOP30_SYNTHETIC_QUERIES AS
    SELECT
        q.QID as QID,
        r.DOCID as DOCID,
        r.METADATA$RANK as RANK,
        r.METADATA$REQUEST_ID
    FROM TRECCT_SYNTHETIC_QUERIES_RANDOM_10K AS q,
    LATERAL CORTEX_SEARCH_BATCH(
        service_name => 'CORTEX_SEARCH_DB.PYU.CSS_TRECCT',
        query => q.QUERY,
        limit => 30
    ) AS r;

In [None]:
-- Let's judge those results!
-- Note that this step might take hours depending on the amount of data and compute traffic!

CREATE or REPLACE TABLE CSS_TOP30_SYNTHETIC_QUERIES_JUDGED as
SELECT
    b.QID as QID,
    b.DOCID as DOCID,
    SNOWFLAKE.CORTEX.COMPLETE('llama3.3-70b', '
[Instruction]
Determine how well the given patient case description (PCD) matches the given clinal trial description (CTD), based on the scoring criteria provided below. Output score must be wrapped in [Score] [/Score] tag pair for easy parsing. Explain the level of relevance in separate [EXP] [/EXP] tag pair.

[Scoring Criteria]
- Score 2 (eligible): Disease type/location matches trial intent. Patient meets all major eligibility criteria (age, relapse status, prior therapy limits, etc.).
- Score 1 (conceptually relevant but excluded) The trial is truly for the same disease entity in the same anatomical context (e.g., brain astrocytoma if patient has brain anaplastic astrocytoma). The trial treatment could apply in principle to the patient’s disease type, but the patient fails on technical eligibility restrictions, such as: Too many prior lines of therapy; Performance status or lab values outside range; Excluded comorbidity (e.g., uncontrolled hypertension, liver failure). In other words: same disease, same setting, but patient disqualified.
- Score 0 (not relevant): The trial is not truly applicable to the patient’s disease setting, even if the words overlap. Examples: Trial requires intracranial astrocytoma, but patient’s is spinal astrocytoma; Trial is only for first relapse, but patient is on third-line or beyond and the mechanism is specific to early relapse; Trial is targeting related but different histologies (e.g., oligoastrocytoma only, glioblastoma only, pediatric setting); Trial uses a local delivery approach (e.g., stereotactic catheter infusion) that is anatomically impossible for the patient’s tumor site.

[Input PCD]
' || c.QUERY || '

[Input CTD]
' || a.TEXT) as LLM_JUDGE_OUTPUT,
    TRY_TO_NUMBER(
      TRIM(
        REGEXP_SUBSTR(
          LLM_JUDGE_OUTPUT,
          '\\[\\s*score\\s*\\][\\s\\S]*?(\\d+)[\\s\\S]*?\\[\\s*/\\s*score\\s*\\]',
          1, 1, 'is', 1
        )
      )
    ) AS SCORE,
    TRIM(
      REGEXP_SUBSTR(
        REGEXP_REPLACE(LLM_JUDGE_OUTPUT, '.*\\[\\s*exp\\s*\\]', '', 1, 1, 'is'),
        '([\\s\\S]*?)\\[\\s*/\\s*exp\\s*\\]', 1, 1, 'is', 1
      )
    ) AS EXPLANATION
from TRECCT_DOCUMENT_CORPUS a
join CSS_TOP30_SYNTHETIC_QUERIES b 
join TRECCT_SYNTHETIC_QUERIES_RANDOM_10K c on a.DOCID = b.DOCID and b.QID = c.QID;

In [None]:
-- Let's also label our test data (reusable).
-- First we search top 100 per query using cortex search.
CREATE OR REPLACE TABLE CSS_TOP100_TEST_QUERIES AS
    SELECT
        q.QID as QID,
        r.DOCID as DOCID,
        r.METADATA$RANK as RANK,
        r.METADATA$REQUEST_ID  --> in case need for batch search debugging
    FROM TRECCT_TEST_QUERIES AS q,
    LATERAL CORTEX_SEARCH_BATCH(
        service_name => 'CORTEX_SEARCH_DB.PYU.CSS_TRECCT',
        query => q.TEXT,
        limit => 100
    ) AS r;

-- Finally we label these using an LLM the same way we label training data.
CREATE or REPLACE TABLE CSS_TOP100_TEST_QUERIES_JUDGED as
SELECT
    b.QID as QID,
    b.DOCID as DOCID,
    SNOWFLAKE.CORTEX.COMPLETE('llama3.3-70b', '
[Instruction]
Determine how well the given patient case description (PCD) matches the given clinal trial description (CTD), based on the scoring criteria provided below. Output score must be wrapped in [Score] [/Score] tag pair for easy parsing. Explain the level of relevance in separate [EXP] [/EXP] tag pair.

[Scoring Criteria]
- Score 2 (eligible): Disease type/location matches trial intent. Patient meets all major eligibility criteria (age, relapse status, prior therapy limits, etc.).
- Score 1 (conceptually relevant but excluded) The trial is truly for the same disease entity in the same anatomical context (e.g., brain astrocytoma if patient has brain anaplastic astrocytoma). The trial treatment could apply in principle to the patient’s disease type, but the patient fails on technical eligibility restrictions, such as: Too many prior lines of therapy; Performance status or lab values outside range; Excluded comorbidity (e.g., uncontrolled hypertension, liver failure). In other words: same disease, same setting, but patient disqualified.
- Score 0 (not relevant): The trial is not truly applicable to the patient’s disease setting, even if the words overlap. Examples: Trial requires intracranial astrocytoma, but patient’s is spinal astrocytoma; Trial is only for first relapse, but patient is on third-line or beyond and the mechanism is specific to early relapse; Trial is targeting related but different histologies (e.g., oligoastrocytoma only, glioblastoma only, pediatric setting); Trial uses a local delivery approach (e.g., stereotactic catheter infusion) that is anatomically impossible for the patient’s tumor site.

[Input PCD]
' || c.TEXT || '

[Input CTD]
' || a.TEXT) as LLM_JUDGE_OUTPUT,
    TRY_TO_NUMBER(
      TRIM(
        REGEXP_SUBSTR(
          LLM_JUDGE_OUTPUT,
          '\\[\\s*score\\s*\\][\\s\\S]*?(\\d+)[\\s\\S]*?\\[\\s*/\\s*score\\s*\\]',
          1, 1, 'is', 1
        )
      )
    ) AS SCORE,
    TRIM(
        REGEXP_SUBSTR(
          LLM_JUDGE_OUTPUT,
          '\\[\\s*exp\\s*\\]([\\s\\S]*?)\\[\\s*/\\s*exp\\s*\\]',
          1, 1, 'is', 1
        )
      ) AS EXP_TEXT
from TRECCT_DOCUMENT_CORPUS a
join CSS_TOP100_TEST_QUERIES b 
join TRECCT_TEST_QUERIES c on a.DOCID = b.DOCID and b.QID = c.QID;

In [None]:
# Let's define an evaluation function see how cortex search without reranker perform!

def compute_ndcg_at_k(session,
                      test_table,
                      judged_table="CSS_TOP100_TEST_QUERIES_JUDGED",
                      k=10):
    """
    Compute nDCG@k per query and the mean across queries.

    Definition:
      - DCG@k: use the predicted top-k (from test_table ordered by RANK) with judged scores.
      - IDCG@k: use all labeled documents for the query (from judged_table), sorted by SCORE desc, take top-k.
      - Unlabeled documents are treated as SCORE=0 for DCG.

    Returns:
      mean_ndcg (float), ndcg_per_query (dict {qid: ndcg_at_k})
    """
    # Predicted list (used for DCG)
    test_df = (
        session.table(test_table)
        .select("QID", "DOCID", "RANK")
        .to_pandas()
    )

    # All labels (used for IDCG and to supply scores to predicted docs)
    judged_df = (
        session.table(judged_table)
        .select("QID", "DOCID", "SCORE")
        .to_pandas()
    )
    judged_df["SCORE"] = pd.to_numeric(judged_df["SCORE"], errors="coerce").fillna(0)

    def _dcg(scores):
        # Exponential gain: 2^rel - 1
        return sum((2**rel - 1) / math.log2(i + 2) for i, rel in enumerate(scores))

    # Precompute IDCG@k for each QID from all labeled docs
    idcg_by_qid = {}
    for qid, g in judged_df.groupby("QID", sort=False):
        ideal_scores = g["SCORE"].sort_values(ascending=False).tolist()[:k]
        idcg_by_qid[qid] = _dcg(ideal_scores) if ideal_scores else 0.0

    # Merge predicted with labels (unlabeled -> 0)
    merged = test_df.merge(judged_df, on=["QID", "DOCID"], how="left")
    merged["SCORE"] = merged["SCORE"].fillna(0)

    ndcg_per_query = {}
    for qid, g in merged.groupby("QID", sort=False):
        # DCG@k from predicted top-k
        topk_scores = g.sort_values("RANK").head(k)["SCORE"].tolist()
        dcg = _dcg(topk_scores)
        idcg = idcg_by_qid.get(qid, 0.0)
        ndcg_per_query[qid] = (dcg / idcg) if idcg > 0 else 0.0

    # Mean across all queries in test_table (including those with no labels -> nDCG=0)
    mean_ndcg = (sum(ndcg_per_query.values()) / len(ndcg_per_query)) if ndcg_per_query else 0.0
    return mean_ndcg, ndcg_per_query

mean_ndcg_for_css = compute_ndcg_at_k(session=session, test_table="CSS_TOP100_TEST_QUERIES", k=10)[0]
print(f"Cortex Search (without reranker) nDCG@10 on test set: {mean_ndcg_for_css:.4f}")

# Stage 4: Training and Evaluation

Cortex Search offers a decent out-of-the-box solution for most search tasks on which it wasn't specifically optimized on. Let's see how much a fine-tuned reranker can help lift the search quality to another level!

In this stage, we will:

- Transform our generated training data into the format fit for Snowflake's distributed training framework (Ray);
- Sample from training data for efficient demonstration (optional & for demonstration purpose only);
- Run training job with multiple GPUs and store model weights to a stage;
- Evaluate all checkpoints on our test data and pick the best one.


In [None]:
-- Here we convert our training data into format compatible with our distributed training framework

-- Clean up
DROP VIEW IF EXISTS TMP_VAL_QIDS;
DROP VIEW IF EXISTS TMP_TRAIN_QIDS;
DROP VIEW IF EXISTS RERANKER_VAL;
DROP VIEW IF EXISTS RERANKER_TRAIN;

-- 1) 50 queries for validation, rest for training
CREATE TEMP VIEW TMP_VAL_QIDS AS
SELECT QID
FROM (
  SELECT DISTINCT QID
  FROM TRECCT_SYNTHETIC_QUERIES_RANDOM_10K
)
QUALIFY ROW_NUMBER() OVER (ORDER BY RANDOM()) <= 50;

CREATE TEMP VIEW TMP_TRAIN_QIDS AS
SELECT DISTINCT QID
FROM TRECCT_SYNTHETIC_QUERIES_RANDOM_10K
WHERE QID NOT IN (SELECT QID FROM TMP_VAL_QIDS);

-- 2) Materialize compact 3-column views (fixed order)
--    LABEL rule: SCORE >= 1 -> 1 else 0
CREATE TEMP VIEW RERANKER_VAL AS
SELECT
  Q.QUERY  AS QUERY_TEXT,
  D.TEXT   AS DOC_TEXT,
  IFF(L.SCORE >= 1, 1, 0) AS LABEL
FROM CSS_TOP30_SYNTHETIC_QUERIES_JUDGED L
JOIN TMP_VAL_QIDS V ON L.QID = V.QID
JOIN TRECCT_SYNTHETIC_QUERIES_RANDOM_10K Q ON L.QID = Q.QID
JOIN TRECCT_DOCUMENT_CORPUS D ON L.DOCID = D.DOCID;

CREATE TEMP VIEW RERANKER_TRAIN AS
SELECT
  Q.QUERY  AS QUERY_TEXT,
  D.TEXT   AS DOC_TEXT,
  IFF(L.SCORE >= 1, 1, 0) AS LABEL
FROM CSS_TOP30_SYNTHETIC_QUERIES_JUDGED L
JOIN TMP_TRAIN_QIDS T ON L.QID = T.QID
JOIN TRECCT_SYNTHETIC_QUERIES_RANDOM_10K Q ON L.QID = Q.QID
JOIN TRECCT_DOCUMENT_CORPUS D ON L.DOCID = D.DOCID;

In [None]:
# [IMPORTANT] Here for demonstration purposes we only use 50K (less than 20%) of the whole training set
# Feel free to remove this limit to unlock more gains
train_df = session.table("RERANKER_TRAIN").select("QUERY_TEXT", "DOC_TEXT", "LABEL").limit(50000)
val_df   = session.table("RERANKER_VAL").select("QUERY_TEXT", "DOC_TEXT", "LABEL")

train_connector = ShardedDataConnector.from_dataframe(train_df)
val_connector   = ShardedDataConnector.from_dataframe(val_df)

# Schema constants (fixed)
Q_COL = "QUERY_TEXT"
D_COL = "DOC_TEXT"
Y_COL = "LABEL"

print(len(train_df.to_pandas()), len(val_df.to_pandas()))

In [None]:
# Define wrapper classes for registry and SPCS serving

import tempfile
import torch.nn as nn
from typing import List, Optional
from snowflake.ml.model import custom_model
from snowflake.snowpark import Session


class RerankerModule(nn.Module):
    """
    Wraps tokenizer + HF model.
    - forward(queries, docs) -> list[float] for service inference
    - encode(q_list, d_list) for training batches
    - save(path) / load(path) to persist/load one cohesive artifact
    """
    def __init__(self, model_name: str, max_length: int = 1024, device: Optional[torch.device] = None):
        super().__init__()
        self.model_name = model_name
        self.max_length = int(max_length)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self._device = device or torch.device("cpu")
        self.to(self._device)

    @torch.inference_mode()
    def forward(self, queries: List[str], docs: List[str]) -> List[float]:
        if len(queries) != len(docs):
            raise ValueError("queries and docs must have the same length")
        if len(queries) == 0:
            return []
        self.model.eval()
        scores: List[float] = []
        bs = min(len(queries), 128)
        if isinstance(queries, pd.Series):
            queries = queries.to_list()
        if isinstance(docs, pd.Series):
            docs = docs.to_list()
        for i in range(0, len(queries), bs):
            enc = self.tokenizer(
                queries[i:i+bs], docs[i:i+bs],
                padding=True, truncation=True, max_length=self.max_length, return_tensors="pt"
            )
            enc = {k: v.to(self.model.device) for k, v in enc.items()}
            logits = self.model(**enc).logits.view(-1)
            scores.extend(logits.detach().cpu().tolist())
        return scores

    def encode(self, q_list: List[str], d_list: List[str]) -> dict:
        enc = self.tokenizer(
            q_list, d_list,
            padding=True, truncation=True, max_length=self.max_length, return_tensors="pt"
        )
        return {k: v.to(self._device) for k, v in enc.items()}

    def save(self, path: str):
        os.makedirs(path, exist_ok=True)
        torch.save(self.model.state_dict(), os.path.join(path, "weights.pt"))
        self.tokenizer.save_pretrained(path)
        with open(os.path.join(path, "model_config.json"), "w") as f:
            json.dump({"model_name": self.model_name, "max_length": self.max_length}, f)


    @classmethod
    def load(cls, session: Session, stage_path: str, device: Optional[torch.device] = None) -> "RerankerModule":
        """
        Loads the reranker module from a Snowflake stage.

        Args:
            session: The active Snowpark session object.
            stage_path: The path to the model artifact folder in a Snowflake stage (e.g., '@~/.../final').
            device: The torch device to load the model onto.
        
        Returns:
            An instance of RerankerModule.
        """
        with tempfile.TemporaryDirectory() as temp_dir: 
            session.file.get(stage_path, temp_dir)
            with open(os.path.join(temp_dir, "model_config.json"), "r") as f:
                cfg = json.load(f)
            
            obj = cls(cfg["model_name"], int(cfg["max_length"]), device=device)
            obj.tokenizer = AutoTokenizer.from_pretrained(temp_dir)
            state = torch.load(os.path.join(temp_dir, "weights.pt"), map_location=device or "cpu")
            obj.model.load_state_dict(state)
            obj.to(device or torch.device("cpu"))
            obj.model.eval()
            return obj

class RerankerCustomModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.reranker = self.context["reranker"]

    @custom_model.inference_api
    def forward(self, input: pd.DataFrame) -> pd.DataFrame:
        scores = self.reranker.forward(
            input["queries"],
            input["docs"],
        )
        return pd.DataFrame({'scores': scores})

In [None]:
from snowflake.ml.runtime_cluster import get_ray_dashboard_url
dashboard_url = get_ray_dashboard_url('"Cortex Search Reranker Finetuning"')
print(f"Access the Ray Dashboard here: {dashboard_url}")

# A good set of default configs on a 4-GPU machine
@dataclass
class TrainConfig:
    model_name: str = "BAAI/bge-reranker-v2-m3"
    max_length: int = 1024
    batch_size: int = 4
    num_epochs: int = 1
    lr: float = 1e-4
    warmup_ratio: float = 0.05
    weight_decay: float = 0.01
    gradient_accumulation_steps: int = 4
    # Distributed resources
    num_nodes: int = 1
    num_workers_per_node: int = 4
    num_gpus_per_worker: int = 1

CFG = TrainConfig()

CKPT_STAGE = "@~/bge_reranker_ckpts"  # replace this if necessary
RUN_ID = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
print(f"Find artifacts at {CKPT_STAGE}/run_{RUN_ID}")

def train_func():
    # --- Required imports ---
    import os
    import torch
    import time, datetime
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import DataLoader
    from transformers import get_linear_schedule_with_warmup

    # --- DDP INITIALIZATION ---
    backend = "nccl" if torch.cuda.is_available() and CFG.num_gpus_per_worker > 0 else "gloo"
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() and CFG.num_gpus_per_worker > 0 else torch.device("cpu")
    dist.init_process_group(backend=backend)

    ctx = get_context()
    dmap = ctx.get_dataset_map()
    model_dir = ctx.get_model_dir()
    session = get_active_session()

    train_shard = dmap["train"].get_shard()
    val_shard   = dmap.get("val").get_shard() if dmap.get("val") else None

    # Replace to_torch_datapipe() with to_torch_dataset() and DataLoader
    def make_dataloader(shard, shuffle, drop):
        if shard is None:
            return None
        dataset = shard.to_torch_dataset() 
        return DataLoader(
            dataset, 
            batch_size=CFG.batch_size, 
            shuffle=shuffle,
            drop_last=drop, 
            num_workers=0 
        )

    train_dl = make_dataloader(train_shard, shuffle=False, drop=True)
    val_dl   = make_dataloader(val_shard, shuffle=False, drop=False)

    # --- Pre-count steps ---
    total_steps_this_epoch = 0
    
    # Heler function
    def make_count_dataloader(shard, drop):
        if shard is None:
            return None
        dataset = shard.to_torch_dataset() 
        return DataLoader(
            dataset, 
            batch_size=CFG.batch_size, 
            shuffle=False,
            drop_last=drop, # Keep drop_last=True for accurate batch counting
            num_workers=0 
        )

    # Use a count_dl instance to iterate and count the batches.
    count_dl = make_count_dataloader(train_shard, drop=True)

    if count_dl is not None:
        for _b in count_dl:
            total_steps_this_epoch += 1

    milestones = set()
    if total_steps_this_epoch > 0:
        for r in (0.2, 0.4, 0.6, 0.8, 1.0):
            milestones.add(max(1, int(total_steps_this_epoch * r)))
    
    # --- Encapsulated model/tokenizer ---
    reranker = RerankerModule(model_name=CFG.model_name, max_length=CFG.max_length, device=device)
    
    # Explicitly use device_ids in barrier
    if dist.is_initialized():
        if device.type == "cuda":
            dist.barrier(device_ids=[device.index])
        else:
            dist.barrier()
            
    reranker.model = DDP(reranker.model, device_ids=[local_rank] if device.type == "cuda" else None)

    optimizer = torch.optim.AdamW(reranker.model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    total_steps_est = max(CFG.num_epochs * max(total_steps_this_epoch,1) // max(CFG.gradient_accumulation_steps,1), 1)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(CFG.warmup_ratio * total_steps_est),
        num_training_steps=total_steps_est
    )

    Q_COL, D_COL, Y_COL = "QUERY_TEXT", "DOC_TEXT", "LABEL"

    def encode_batch(batch):
        q_list = list(batch[Q_COL])
        d_list = list(batch[D_COL])
        targets = batch[Y_COL].to(device=device, dtype=torch.float32).view(-1)
    
        enc = reranker.encode(q_list, d_list)
        return enc, targets, len(q_list)

    def evaluate_on_val():
        if val_dl is None: 
            return {"val_loss": None, "val_acc": None}
        reranker.model.eval()
        total_loss = 0.0
        total_n = 0
        correct = 0
        with torch.no_grad():
            for b in val_dl: 
                inputs, targets, n = encode_batch(b)
                logits = reranker.model(**inputs).logits.view(-1)
                loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, targets)
                total_loss += float(loss.item()) * n
                total_n += n
                preds = (logits > 0).long()
                correct += (preds == targets.long()).sum().item()
        reranker.model.train()
        if total_n == 0:
            return {"val_loss": None, "val_acc": None}
        return {"val_loss": total_loss / total_n, "val_acc": correct / total_n}

    # Enhanced synchronization in save_and_upload
    def save_and_upload(tag: str, extra_metrics: dict):
        if dist.is_initialized():
            if device.type == "cuda":
                dist.barrier(device_ids=[device.index])
            else:
                dist.barrier()

        if dist.is_initialized() and dist.get_rank() != 0:
            pass 
        else:
            export_dir = os.path.join(model_dir, f"export_{tag}")
            os.makedirs(export_dir, exist_ok=True)

            unwrapped_model = reranker.model.module if hasattr(reranker.model, "module") else reranker.model
            torch.save(unwrapped_model.state_dict(), os.path.join(export_dir, "weights.pt"))
            reranker.tokenizer.save_pretrained(export_dir)

            with open(os.path.join(export_dir, "model_config.json"), "w") as f:
                json.dump({"model_name": reranker.model_name, "max_length": reranker.max_length}, f)

            meta = {
                "run_id": RUN_ID,
                "tag": tag,
                "model_name": CFG.model_name,
                "max_length": CFG.max_length,
                "batch_size_per_gpu": CFG.batch_size,
                "num_workers": CFG.num_nodes * CFG.num_workers_per_node,
                "grad_accum": CFG.gradient_accumulation_steps,
                **(extra_metrics or {}),
            }
            with open(os.path.join(export_dir, "metrics.json"), "w") as f:
                json.dump(meta, f)

            stage_prefix = f"{CKPT_STAGE}/run_{RUN_ID}/{tag}"
            for fn in os.listdir(export_dir):
                p = os.path.join(export_dir, fn)
                if os.path.isfile(p):
                    session.file.put(f"file://{p}", stage_prefix, auto_compress=False, overwrite=True)

        if dist.is_initialized():
            if device.type == "cuda":
                dist.barrier(device_ids=[device.index])
            else:
                dist.barrier()

    # --- training loop ---
    start_time = time.time()
    print(f"[Rank{dist.get_rank()}]: Training started!")
    global_step = 0
    for _ in range(CFG.num_epochs):
        if train_dl is not None:
            step_in_epoch = 0
            for batch_rows in train_dl:
                inputs, targets, _ = encode_batch(batch_rows)
                logits = reranker.model(**inputs).logits.view(-1)
                loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, targets)
                (loss / CFG.gradient_accumulation_steps).backward()

                if (global_step + 1) % CFG.gradient_accumulation_steps == 0:
                    optimizer.step(); scheduler.step(); optimizer.zero_grad()

                step_in_epoch += 1
                global_step += 1

                if step_in_epoch in milestones:
                    tag = f"step{global_step}"
                    eval_metrics = evaluate_on_val()
                    time_elapsed = int(time.time() - start_time)
                    print(f"[Rank{dist.get_rank()}]: Glocal step {global_step}, time elapsed {datetime.timedelta(seconds=time_elapsed)}")
                    save_and_upload(tag, {"train_loss": float(loss.item()), **eval_metrics})

    # final save + upload
    if dist.is_initialized():
        if device.type == "cuda":
            dist.barrier(device_ids=[device.index])
        else:
            dist.barrier()
            
    save_and_upload("final", evaluate_on_val())

    # DDP finalize
    if dist.is_initialized():
        if device.type == "cuda":
            dist.barrier(device_ids=[device.index]);
        else:
            dist.barrier()
        dist.destroy_process_group()

In [None]:
# Launch actual training jobs!
# Note: training 50K samples with default configs takes between 1-2 hours.

scaling = PyTorchScalingConfig(
    num_nodes=CFG.num_nodes,
    num_workers_per_node=CFG.num_workers_per_node,
    resource_requirements_per_worker=WorkerResourceConfig(
        num_cpus=4,
        num_gpus=CFG.num_gpus_per_worker,
    ),
)

dist_trainer = PyTorchDistributor(train_func=train_func, scaling_config=scaling)

# Assuming train_connector and val_connector are defined elsewhere
result = dist_trainer.run(
    dataset_map={"train": train_connector, "val": val_connector},
)

In [None]:
-- See artifacts saved during training
LIST @~/bge_reranker_ckpts/run_20251015_234022;

In [None]:
def rerank_and_eval(
    session,
    qd_pairs_table: str,
    document_text_table: str,
    query_text_table: str,
    judged_table: str,
    reranker_module: RerankerModule,
    device: str = "cuda:0",
    batch_size: int = 256,
    k: int = 10,
    output_temp_table: str = "TMP_RERANK_RESULTS"
):
    """
    Rerank candidate documents per query using a reranker model,
    write results to a temp table (QID, DOCID, RANK),
    then compute mean nDCG@k and per-query nDCG using compute_ndcg_at_k().
    The temp table is dropped automatically at the end.

    Returns:
        mean_ndcg (float), ndcg_per_query (dict)
    """

    # Prepare model and tokenizer
    reranker_module.model.to(device)

    # Join query-doc pairs with text
    from snowflake.snowpark.functions import col

    qd = session.table(qd_pairs_table).select("QID", "DOCID")
    qtbl = session.table(query_text_table).select("QID", col("TEXT").alias("QUERY_TEXT"))
    dtbl = session.table(document_text_table).select("DOCID", col("TEXT").alias("DOC_TEXT"))

    joined = (
        qd.join(qtbl, qd["QID"] == qtbl["QID"])
          .join(dtbl, qd["DOCID"] == dtbl["DOCID"])
          .select(qd["QID"].alias("QID"), qd["DOCID"].alias("DOCID"), qtbl["QUERY_TEXT"], dtbl["DOC_TEXT"])
    )

    # Count total pairs for progress bar
    total_pairs = joined.count()

    # Stream scoring
    results = []
    def flush_batch(buf_qid, buf_docid, buf_query, buf_doc):
        if not buf_qid:
            return
        with torch.no_grad():
            scores = reranker_module.forward(buf_query, buf_doc)
        for qid, docid, s in zip(buf_qid, buf_docid, scores):
            results.append((qid, docid, float(s)))
        return len(buf_qid)

    buf_qid, buf_docid, buf_query, buf_doc = [], [], [], []
    with tqdm(total=total_pairs, desc="Scoring pairs", dynamic_ncols=True, mininterval=0.5) as pbar:
        for row in joined.to_local_iterator():
            qid, docid, qtxt, dtxt = row["QID"], row["DOCID"], row["QUERY_TEXT"], row["DOC_TEXT"]
            buf_qid.append(qid); buf_docid.append(docid); buf_query.append(qtxt); buf_doc.append(dtxt)
            if len(buf_qid) >= batch_size:
                n = flush_batch(buf_qid, buf_docid, buf_query, buf_doc)
                pbar.update(n)
                buf_qid.clear(); buf_docid.clear(); buf_query.clear(); buf_doc.clear()
        # final flush
        n = flush_batch(buf_qid, buf_docid, buf_query, buf_doc)
        pbar.update(n)
    # Rank per QID
    df = pd.DataFrame(results, columns=["QID", "DOCID", "score"])
    df["RANK"] = (
        df.sort_values(["QID", "score"], ascending=[True, False])
          .groupby("QID")
          .cumcount() + 1
    )
    out_df = pd.DataFrame({
        "QID": df["QID"],
        "DOCID": df["DOCID"].astype(str),
        "RANK": df["RANK"],
    })
    sf_out = session.create_dataframe(out_df)
    sf_out.write.save_as_table(output_temp_table, mode="overwrite", table_type="temporary")

    # Compute nDCG@k
    mean_ndcg, ndcg_per_query = compute_ndcg_at_k(
        session,
        test_table=output_temp_table,
        judged_table=judged_table,
        k=k
    )

    # Drop temp table
    session.sql(f"DROP TABLE IF EXISTS {output_temp_table}").collect()

    return mean_ndcg, ndcg_per_query

In [None]:
# evaluate all checkpoints!

import altair as alt
import streamlit as st

eval_results = {}

for steps in [0, 625, 1250, 1875, 2500, 3125]:
    if steps == 0:
        # un-tuned reranker
        reranker = RerankerModule(model_name="BAAI/bge-reranker-v2-m3")
    else:
        checkpoint_path = f"@~/bge_reranker_ckpts/run_20251015_234022/step{steps}/"
        reranker = RerankerModule.load(session=session, stage_path=checkpoint_path)
        
    
    mean_ndcg, ndcg_per_query = rerank_and_eval(
        session=session,
        qd_pairs_table="CSS_TOP100_TEST_QUERIES",
        document_text_table="TRECCT_DOCUMENT_CORPUS",
        query_text_table="TRECCT_TEST_QUERIES",
        judged_table="CSS_TOP100_TEST_QUERIES_JUDGED",
        reranker_module=reranker,
        device="cuda:0",
        batch_size=256,
        k=10
    )
    print(f"step {steps} nDCG@10: {mean_ndcg:.4f}")
    eval_results[steps] = (mean_ndcg, ndcg_per_query)



mean_ndcg_vals = {k: v[0] for k, v in eval_results.items()}

train_res_df = pd.DataFrame(
    sorted(mean_ndcg_vals.items()),
    columns=["STEP", "NDCG"]
)

# Add the baseline line as an extra row for plotting convenience
baseline_df = pd.DataFrame({
    "STEP": train_res_df["STEP"],
    "NDCG": [mean_ndcg_for_css] * len(mean_ndcg_vals)
})

df_all = pd.concat([
    train_res_df.assign(Type="Reranker"),
    baseline_df.assign(Type="No Reranking")
])

ymin = max(0, mean_ndcg_for_css - 0.02)
ymax = min(1, train_res_df["NDCG"].max() + 0.02)

st.subheader("nDCG@10 over Training Steps 📈")
st.caption("Tracking reranker performance vs. baseline (no reranking)")

# Build Altair chart with axis range
chart = (
    alt.Chart(df_all)
    .mark_line(point=True)
    .encode(
        x=alt.X("STEP:Q", title="Training Step"),
        y=alt.Y("NDCG:Q", scale=alt.Scale(domain=[ymin, ymax]), title="Mean nDCG@10"),
        color=alt.Color("Type:N", title="Legend")
    )
    .properties(width=600, height=400)
)

st.altair_chart(chart, use_container_width=True)


# Show numeric summary
st.subheader("Summary 📊")
st.metric("Baseline -- no reranking", f"{mean_ndcg_for_css:.4f}")
st.metric("Baseline -- non-finetuned reranking", f"{train_res_df['NDCG'][0]:.4f}")
st.metric("Best Reranker nDCG", f"{train_res_df['NDCG'].max():.4f}")
st.metric("Best Reranker improvement", f"+{train_res_df['NDCG'].max()/train_res_df['NDCG'][0]-1:.2%}")

# Stage 5: Register the model and create a service

Now that we have a great reranker that improves significantly over no-reranking and a general-purpose reranker, in the final stage of the process, we can go ahead and register this model and create a service - so that we can use it repeatedly within or even outside Snowflake.

*For demonstration purposes, we separate search (with Cortex Search) and reranking into two steps/service. Note that it's possible to bundle the new reranker into your Cortex Search Service*.

In [None]:
from snowflake.ml.registry.registry import Registry
registry = Registry(session=session, database_name=f"CORTEX_SEARCH_DB", schema_name="PYU")

checkpoint_path = f"@~/bge_reranker_ckpts/run_20251015_234022/final/"
reranker = RerankerModule.load(session=session, stage_path=checkpoint_path)

reranker_custom_model = RerankerCustomModel(
    context=custom_model.ModelContext(
        reranker=reranker,
    )
)

sample_input_data = pd.DataFrame(
    {
        "queries": ["query1", "query1"],
        "docs": ["doc1", "doc2"],
    },
)

ref = registry.log_model(
    model=reranker_custom_model,
    model_name="bgem3_reranker",
    sample_input_data=sample_input_data,
    pip_requirements=[
        "torch",
        "transformers",
    ],
    version_name="finetuned_trecct_50k"
)

In [None]:
ref.create_service(
    service_name="bge_reranker_service_finetuned_50k",
    service_compute_pool="GPU_NV_S_COMPUTE_POOL",  # Replace with a GPU compute pool you created
    ingress_enabled=False,
    gpu_requests="1", # Model fits in GPU memory
    max_instances=4,
)

In [None]:
-- Let's review how relevant top-10 docs are without our reranker

SELECT
    a.QID,
    a.DOCID,
    a.RANK,
    b.SCORE
FROM
    CSS_TOP100_TEST_QUERIES a
JOIN
    CSS_TOP100_TEST_QUERIES_JUDGED b ON a.QID = b.QID AND a.DOCID = b.DOCID
WHERE
    a.QID = '2022_5' AND a.RANK <= 10
ORDER BY
    a.RANK;

In [None]:
# Now let's see how the new reranker service makes the top-10 results much better! 

from snowflake.snowpark.functions import col

# get query text and document texts for test samples
df_a = session.table('CSS_TOP100_TEST_QUERIES')
df_b = session.table('TRECCT_TEST_QUERIES')
df_c = session.table('TRECCT_DOCUMENT_CORPUS')

# 2. Build the transformation plan with renamed columns
final_df = df_a.join(df_b, df_a['qid'] == df_b['qid']) \
               .join(df_c, df_a['docid'] == df_c['docid']) \
               .where(df_a['qid'] == '2022_5') \
               .select(
                   df_b['text'].alias('QUERY_TEXT'),
                   df_c['docid'].alias('DOCID'),
                   df_c['text'].alias('DOC_TEXT')
               ).collect()

full_pandas_df = pd.DataFrame(final_df)

# transform into reranker format
queries_docs_df = full_pandas_df[['QUERY_TEXT', 'DOC_TEXT']].rename(columns={'QUERY_TEXT': 'queries', 'DOC_TEXT': 'docs'})
queries = full_pandas_df['QUERY_TEXT'].tolist()
docs = full_pandas_df['DOC_TEXT'].tolist()
docids_list = full_pandas_df['DOCID'].tolist()

# send top 100 docs to the reranker service
scores = ref.run(
    pd.DataFrame(
        {
            "queries": queries,
            "docs": docs,
        },
    ),
    function_name="forward",
    service_name="bge_reranker_service_finetuned_50k"
)

# parse scores and display top 10 docs
scores['DOCID'] = docids_list
df_sorted = scores.sort_values(by='scores', ascending=False)
top_10 = df_sorted.head(10).copy()
top_10['RANK'] = range(1, 11)
judge_df = session.table('CSS_TOP100_TEST_QUERIES_JUDGED')
final_pandas_df = judge_df.join(session.create_dataframe(top_10), on='DOCID') \
                       .where(col('QID') == '2022_5') \
                       .select(
                           col('QID'),
                           col('DOCID'),
                           col('RANK'),
                           col('SCORE')
                       ) \
                       .sort(col('RANK')) \
                       .to_pandas()

final_pandas_df