The

# Libraries

In [None]:
# Standard libraries
import random
from math import floor
from collections import Counter, defaultdict

# Data handling
import pandas as pd
import numpy as np
from datasets import Dataset
from tqdm import tqdm

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# ML libraries
import torch
import torch.nn.functional as F
import tensorflow as tf
from sklearn.model_selection import KFold

# Transformers & Sentence Transformers
from transformers import EarlyStoppingCallback
from sentence_transformers import (
    SentenceTransformer,
    CrossEncoder,
    InputExample,
    util
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData
)


In [None]:
from google.colab import drive #run this if you run this notebook in Colab
drive.mount('/content/drive')

# Data and helper functions

Replace this part with your ways to get the data

In [None]:
my_path = "/content/drive/MyDrive/Colab Notebooks/FUNIX DS LAB/COURSE 5/"

In [None]:
content_df = pd.read_csv(
    my_path + "Data/content.csv",
    index_col = 0).fillna("")

topics_df = pd.read_csv(
    my_path + "Data/topics.csv",
    index_col = 0).fillna({"title": "", "description": ""})

correlations_df = pd.read_csv(
    my_path + "Data/correlations.csv",
    index_col = 0)

In [None]:
import sys
sys.path.append(my_path + 'Helper_functions')
from ultis import print_markdown, Topic, ContentItem, macro_f2, remove_prefix, sem_score
import ultis
ultis.correlations_df = correlations_df
ultis.topics_df = topics_df
ultis.content_df = content_df

# Split data

In [None]:
#Get english data
en_topic_df = topics_df[(topics_df.has_content) & (topics_df.language == "en")]
en_content_df = content_df[content_df.language == "en"]
en_topic_ids = en_topic_df.index.tolist()
en_content_ids = en_content_df.index.tolist()

Split data into train and validation with ratio 8:2

In [None]:
def cv_split(topic_ids, n_folds = 5, seed = 42):
    topic_ids = np.array(topic_ids)
    kf = KFold(n_splits = n_folds,
               shuffle = True,
               random_state = seed)

    folds = [[] for _ in range(n_folds)]
    for fold_idx, (_, val_idx) in enumerate(kf.split(topic_ids)):
        folds[fold_idx] = topic_ids[val_idx].tolist()

    return folds

In [None]:
folds = cv_split(en_topic_ids)
train_en_topic_ids = np.concatenate(folds[:-1])
val_en_topic_ids = folds[-1]

# Data processing

In [None]:
def topic_preprocessing(topic_id):
    topic = Topic(topic_id)
    topic_breadcrumbs = topic.get_breadcrumbs().split(">>")
    topic_text = topic_breadcrumbs[::-1]
    topic_text = ". ".join([remove_prefix(text) for text in topic_text])
    topic_text = " ".join(topic_text.split())
    return topic_text

def topic_input(topic_id):
    topic_text =  topic_preprocessing(topic_id) + ". " + remove_prefix(Topic(topic_id).description)
    topic_text = topic_text.replace("\n", " ")
    topic_text = topic_text.rstrip(". ").strip()
    topic_text = " ".join(topic_text.split())
    return topic_text

def content_input(content_id):
    content_text = (remove_prefix(ContentItem(content_id).title)
                    + ". " + ContentItem(content_id).description
                    + ". " + ContentItem(content_id).text)
    content_text = content_text.replace("\n", " ")
    content_text = content_text.rstrip(". ").strip()
    content_text = " ".join(content_text.split())
    return content_text

# Bi-encoder Training

## Create train dataset

In [None]:
def create_train_dataset(train_topic_ids, over_sampling=None):
    """
    Build a training dataset of (topic, content) pairs.

    Args:
        train_topic_ids (list): List of topic IDs to include.
        over_sampling (bool, optional): Whether to apply oversampling based
        on number of positive pairs.

    Returns:
        Dataset: HuggingFace Dataset object with 'topic' and 'content' columns.
    """
    query_texts = []
    passage_texts = []

    # Count how many valid English content items each topic has
    topic_positive_counts = {
        topic_id: len([c.id for c in Topic(topic_id).content
                       if c.id in en_content_ids])
        for topic_id in train_topic_ids
    }

    # Progress bar description
    oversample_msg = "with oversampling" if over_sampling else "without oversampling"
    for topic_id in tqdm(train_topic_ids,
                         desc=f"Building dataset {oversample_msg}"):
        query_text = topic_input(topic_id)
        pos_content_ids = [c.id for c in Topic(topic_id).content
                           if c.id in en_content_ids]
        count = topic_positive_counts[topic_id]

        # Determine repetition count
        if not over_sampling:
            repeat = 1
        elif count <= 2:
            repeat = 5
        elif count <= 5:
            repeat = 2
        else:
            repeat = 1

        for content_id in pos_content_ids:
            passage_text = content_input(content_id)
            for _ in range(repeat):
                query_texts.append(query_text)
                passage_texts.append(passage_text)

    return Dataset.from_dict({
        "topic": query_texts,
        "content": passage_texts
    })

In [None]:
train_dataset = create_train_dataset(train_en_topic_ids, over_sampling = True)

Building dataset with oversampling: 100%|██████████| 22443/22443 [01:23<00:00, 269.42it/s]


## Evaluator

In [None]:
def dev_evaluator_custom(valuation_ids):
    """
    Create an Information Retrieval evaluator for a given list of topic IDs.

    Args:
        valuation_ids (list): List of topic IDs used for validation.

    Returns:
        InformationRetrievalEvaluator: Evaluator for IR metrics (precision@k, recall@k).
    """
    # Map topic_id to query text
    queries = {
        topic_id: topic_input(topic_id)
        for topic_id in valuation_ids
    }

    # Build content corpus: content_id -> content text
    corpus = {
        content_id: content_input(content_id)
        for content_id in en_content_ids
    }

    # Map topic_id to set of relevant content IDs
    relevant_docs = {
        topic_id: set(content.id for content in Topic(topic_id).content)
        for topic_id in valuation_ids
    }

    # Create evaluator
    dev_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        batch_size=1024,
        name='eval-ir',
        show_progress_bar=True,
        precision_recall_at_k=[5, 10, 25, 50, 100]
    )

    return dev_evaluator


In [None]:
dev_evaluator = dev_evaluator_custom(val_en_topic_ids)

## Training parameteres

In [None]:
model_src = ""  # path or model name
run_name = ""

# Load model on GPU and set max token length
bi_encoder = SentenceTransformer(model_src, device="cuda")
bi_encoder.max_seq_length = 128

# Training configuration
batch_size = 128
num_epoch = 3
loss = MultipleNegativesRankingLoss(model=bi_encoder)

# === Training arguments ===
args = SentenceTransformerTrainingArguments(
    output_dir="your output dir",             # path to save model checkpoints
    num_train_epochs=num_epoch,               # total training epochs
    per_device_train_batch_size=batch_size,   # batch size for training
    per_device_eval_batch_size=batch_size,    # batch size for evaluation
    learning_rate=2e-5,                       # AdamW learning rate
    warmup_ratio=0.05,                        # linear warmup
    fp16=True,                                # use FP16 mixed precision
    bf16=False,                               # disable bfloat16
    batch_sampler="no_duplicates",            # avoid duplicate pairs in batch

    # Evaluation strategy
    eval_strategy="steps",
    eval_steps=0.2,                           # relative to len(train_dataloader)
    load_best_model_at_end=True,
    metric_for_best_model="eval_eval-ir_cosine_recall@50",
    greater_is_better=True,

    # Saving strategy
    save_strategy="steps",
    save_steps=0.2,                           # relative steps like eval_steps
    save_total_limit=2,                       # keep only last 2 checkpoints

    # Logging
    logging_steps=0.2,
    logging_first_step=True,

    # Optional logging backend
    run_name=run_name,
    report_to="none"                          # set to 'wandb' or 'tensorboard' if needed
)


In [None]:
# finetune model with defiend trainer, set early stopping to prevent overfitting
trainer = SentenceTransformerTrainer(
    model = bi_encoder,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    evaluator=dev_evaluator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
trainer.train()

# Retrieval Inference

In [None]:
# install and import faiss for fast retrieving
%pip install faiss-cpu
import faiss

In [None]:
# === Load bi-encoder model ===
bi_encoder = SentenceTransformer("GphaHoa/academic_all_mpnet_base_v2",
                                 device='cuda')
bi_encoder.max_seq_length = 128  # Truncate long inputs

# === Prepare raw text for encoding ===
corpus = [content_input(content_id) for content_id in en_content_ids]
train_queries = [topic_input(topic_id) for topic_id in train_en_topic_ids]
val_queries = [topic_input(topic_id) for topic_id in val_en_topic_ids]

# === Encode corpus and queries into dense vectors ===
corpus_emb = bi_encoder.encode(
    corpus,
    convert_to_numpy=True,
    normalize_embeddings=True,  # Required for cosine similarity
    show_progress_bar=True
)

train_query_emb = bi_encoder.encode(
    train_queries,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True
)

val_query_emb = bi_encoder.encode(
    val_queries,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True
)


In [None]:
# Create FAISS index for inner product
index_flat = faiss.IndexFlatIP(corpus_emb.shape[1])
index_flat.add(corpus_emb)

top_k = 50  # Number of top results to retrieve

# === Retrieve top-k for validation queries ===
scores, indices = index_flat.search(val_query_emb, top_k)
retrieval_val_dict = {}

for i, topic_id in tqdm(enumerate(val_en_topic_ids),
                        desc=f"Retrieving top {top_k} with Exact Faiss"):
    query_results = []
    for j in range(top_k):
        content_idx = indices[i][j]
        content_id = en_content_ids[content_idx]
        score = float(scores[i][j])
        query_results.append((content_id, score))
    retrieval_val_dict[topic_id] = query_results

# === Retrieve top-k for training queries ===
scores, indices = index_flat.search(train_query_emb, top_k)
retrieval_train_dict = {}

for i, topic_id in tqdm(enumerate(train_en_topic_ids),
                        desc=f"Retrieving top {top_k} with Exact Faiss"):
    query_results = []
    for j in range(top_k):
        content_idx = indices[i][j]
        content_id = en_content_ids[content_idx]
        score = float(scores[i][j])
        query_results.append((content_id, score))
    retrieval_train_dict[topic_id] = query_results


# Reranking Training

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from sentence_transformers.cross_encoder import (
    CrossEncoder,
    CrossEncoderTrainer,
    CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
    CrossEncoderClassificationEvaluator,
    CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.evaluation import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives

## Create truncated dataset

In [None]:
# Truncate topic/content to 64 tokens to fit within 128-token input of cross-encoder

model_name = "cross-encoder/ms-marco-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def truncate_text(text, max_tokens=64):
    """
    Tokenizes and truncates input text to a specified max token length,
    then converts tokens back to a string.

    Args:
        text (str): The input text.
        max_tokens (int): Maximum number of tokens to retain.

    Returns:
        str: Truncated text.
    """
    tokens = tokenizer.tokenize(text)[:max_tokens]
    return tokenizer.convert_tokens_to_string(tokens)


def truncate_batch(batch):
    """
    Truncates both 'topic' and 'content' fields in a batch to 64 tokens each.

    Args:
        batch (dict): Batch with 'topic' and 'content' lists.

    Returns:
        dict: Batch with truncated 'topic' and 'content'.
    """
    return {
        "topic": [truncate_text(t, 64) for t in batch["topic"]],
        "content": [truncate_text(c, 64) for c in batch["content"]],
    }


def truncate_topic_batch(batch):
    """
    Truncates only the 'topic' field in a batch.

    Args:
        batch (dict): Batch with 'topic' list.

    Returns:
        dict: Batch with truncated 'topic'.
    """
    return {
        "topic": [truncate_text(t) for t in batch["topic"]]
    }


def truncate_content_batch(batch):
    """
    Truncates only the 'content' field in a batch.

    Args:
        batch (dict): Batch with 'content' list.

    Returns:
        dict: Batch with truncated 'content'.
    """
    return {
        "content": [truncate_text(c) for c in batch["content"]]
    }


In [None]:
# === Build HuggingFace Datasets for topics and contents ===
topic_dict = {
    "topic_id": en_topic_ids,
    "topic": [topic_input(tid) for tid in en_topic_ids]
}
content_dict = {
    "content_id": en_content_ids,
    "content": [content_input(cid) for cid in en_content_ids]
}

# Convert to HuggingFace Dataset
topic_dataset = Dataset.from_dict(topic_dict)
content_dataset = Dataset.from_dict(content_dict)

# Truncate long texts to fit max_seq_length=128 in cross-encoder
topic_dataset = topic_dataset.map(
    truncate_topic_batch,
    batched=True,
    batch_size=2048
)
content_dataset = content_dataset.map(
    truncate_content_batch,
    batched=True,
    batch_size=2048
)

# Build lookup dictionaries for later use (e.g., formatting cross-encoder input)
topic_lookup = {
    tid: text for tid, text in zip(topic_dataset["topic_id"],
                                   topic_dataset["topic"])
}
content_lookup = {
    cid: text for cid, text in zip(content_dataset["content_id"],
                                   content_dataset["content"])
}


## Prepare train dataset

In [None]:
all_train_dataset = []

# Dictionary to store oversampling repeat count per topic
topic_repeat_counts = {}

# Determine repeat factor based on number of positive contents per topic
for topic_id in retrieval_train_dict.keys():
    pos_content_ids = [c.id for c in Topic(topic_id).content
                       if c.id in en_content_ids]
    count = len(pos_content_ids)

    if count < 5:
        repeat = 5  # Oversample heavily when positives are very few
    elif count < 10:
        repeat = 3
    else:
        repeat = 1  # No oversampling when there are enough positives

    topic_repeat_counts[topic_id] = repeat

# Generate labeled training examples with oversampling
for topic_id, content_score in retrieval_train_dict.items():
    topic_text = topic_lookup.get(topic_id, "")
    top_content_ids = [cid for cid, _ in content_score]
    pos_content_ids = [content.id for content in Topic(topic_id).content]

    repeat = topic_repeat_counts.get(topic_id, 1)

    for content_id in top_content_ids:
        label = 1 if content_id in pos_content_ids else 0
        reps = repeat if label == 1 else 1  # Oversample only positive pairs
        for _ in range(reps):
            all_train_dataset.append({
                "topic": topic_text,
                "content": content_lookup.get(content_id, ""),
                "label": label
            })

# Convert to HuggingFace Dataset
all_train_dataset = Dataset.from_list(all_train_dataset)

## Prepare dev set

In [None]:
# === Build lookup tables from topic/content datasets ===
topic_lookup = {tid: text for tid, text in zip(topic_dataset["topic_id"],
                                               topic_dataset["topic"])}
content_lookup = {cid: text for cid, text in zip(content_dataset["content_id"],
                                                 content_dataset["content"])}

samples = []

# Build evaluation samples for CrossEncoder reranking
for topic_id, retrieved in retrieval_val_dict.items():
    # Get top-50 retrieved content IDs from bi-encoder
    top_50_ids = [cid for cid, _ in retrieved[:50]]
    top_50_texts = [(cid, content_lookup.get(cid, "")) for cid in top_50_ids]

    # Get ground-truth positive content IDs for this topic
    pos_ids = [content.id for content in Topic(topic_id).content]
    pos_texts = {content_lookup.get(cid, "") for cid in pos_ids}

    # Create a reranking sample: query, positives, and full list of documents
    sample = {
        "query": topic_lookup.get(topic_id, ""),
        "positive": [content_lookup.get(cid, "") for cid in pos_ids],
        "documents": [text for _, text in top_50_texts]
    }

    samples.append(sample)

# === Create reranking evaluator for CrossEncoder ===
reranking_evaluator = CrossEncoderRerankingEvaluator(
    samples=samples,
    batch_size=4096,
    name="reranking-dev",
    always_rerank_positives=False,  # Only rerank retrieved docs
    show_progress_bar=True,
    at_k=50
)

# Final dev evaluator (can be passed to .evaluate())
dev_evaluator = reranking_evaluator

## Cross encoder training

In [None]:
from collections import Counter

In [None]:
from sentence_transformers import CrossEncoder, losses
from torch import tensor
from collections import Counter
from math import floor

# === Config ===
model_name = "cross-encoder/ms-marco-MiniLM-L6-v2"
run_name = "cross_encoder_ms_macro"
train_batch_size = 128
num_epochs = 1

# === Class balance info ===
label_counts = Counter(all_train_dataset["label"])
num_neg = label_counts[0]
num_pos = label_counts[1]

# === Compute eval step based on total batch steps per epoch ===
eval_step = floor(len(all_train_dataset) / train_batch_size)

# === Initialize CrossEncoder ===
cross_encoder = CrossEncoder(
    model_name,
    num_labels=1,          # Regression output for binary classification
    max_length=128
)

# === Weighted BCE loss to handle class imbalance ===
loss = losses.BinaryCrossEntropyLoss(
    model=cross_encoder,
    pos_weight=tensor(num_neg / num_pos)  # Higher weight for positive class
)

In [None]:
# Define the training arguments
args = CrossEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"content/{run_name}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.05,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        dataloader_num_workers=4,
        load_best_model_at_end=True,
        # Optional tracking/debugging parameters:
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        metric_for_best_model = 'eval_reranking-dev_ndcg@50',
        logging_strategy="epoch",
        logging_first_step=True,
        seed=12,
        report_to="none"
)

#Create the trainer & start training
trainer = CrossEncoderTrainer(
        model=cross_encoder,
        args=args,
        train_dataset=all_train_dataset,
        loss=loss,
        evaluator=dev_evaluator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
    )
trainer.train()

# Full inference and result

In [None]:
import time

In [None]:
def dense_search(model_name, query_ids, corpus_ids, top_k=100):
    """
    Perform dense retrieval using a bi-encoder and FAISS exact search (cosine similarity).

    Args:
        model_name (str): Pretrained SentenceTransformer model name or path.
        query_ids (list): List of topic IDs (queries).
        corpus_ids (list): List of content IDs (documents).
        top_k (int): Number of top documents to retrieve for each query.

    Returns:
        dict: Mapping from query_id to list of (content_id, score) tuples.
    """
    # Load bi-encoder
    bi_encoder = SentenceTransformer(model_name, device="cuda")
    bi_encoder.to("cuda")
    bi_encoder.max_length = 128

    print("Embedding topics and contents...")

    # Encode queries
    query_emb = bi_encoder.encode(
        [topic_input(topic_id) for topic_id in query_ids],
        convert_to_tensor=True,
        normalize_embeddings=True,
        show_progress_bar=True,
        batch_size=256
    )

    # Encode corpus
    corpus_emb = bi_encoder.encode(
        [content_input(content_id) for content_id in corpus_ids],
        convert_to_tensor=True,
        normalize_embeddings=True,
        show_progress_bar=True,
        batch_size=256
    )

    # Convert to numpy for FAISS
    query_emb = query_emb.cpu().numpy()
    corpus_emb = corpus_emb.cpu().numpy()

    print("Searching with exact FAISS (IndexFlatIP)...")
    d = corpus_emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(corpus_emb)

    # Perform search
    import time
    start = time.time()
    scores, indices = index.search(query_emb, top_k)
    print("Search time: ", time.time() - start)

    # Build retrieval dict
    retrieval_dict = {}
    print("Retrieving with FAISS...")
    for i, topic_id in enumerate(query_ids):
        query_results = []
        for j in range(top_k):
            content_idx = indices[i][j]
            content_id = corpus_ids[content_idx]
            score = float(scores[i][j])
            query_results.append((content_id, score))
        retrieval_dict[topic_id] = query_results

    return retrieval_dict


In [None]:
def ndcg_at_k(y_true, k):
    """
    Compute NDCG@k.

    Args:
        y_true (list): Binary relevance list (1=relevant, 0=not).
        k (int): Cutoff rank.

    Returns:
        float: NDCG score at rank k.
    """
    gains = y_true[:k]
    discounts = np.log2(np.arange(2, len(gains) + 2))
    dcg = np.sum(gains / discounts)
    ideal = sorted(y_true, reverse=True)[:k]
    idcg = np.sum(ideal / discounts)
    return dcg / idcg if idcg > 0 else 0.0


def mrr_at_k(y_true, k):
    """
    Compute MRR@k.

    Returns:
        float: Reciprocal rank at first relevant hit.
    """
    for rank, rel in enumerate(y_true[:k], start=1):
        if rel > 0:
            return 1 / rank
    return 0.0


def precision_at_k(y_true, k):
    """
    Compute precision@k.

    Returns:
        float: Number of relevant items in top-k / k.
    """
    return sum(y_true[:k]) / k


def recall_at_k(y_true, total_relevant, k):
    """
    Compute recall@k.

    Args:
        total_relevant (int): Total relevant items in ground truth.

    Returns:
        float: Number of relevant retrieved / total relevant.
    """
    return sum(y_true[:k]) / total_relevant if total_relevant > 0 else 0.0


def metrics_at_k(topic_ids, retrieval_results, list_k=[100]):
    """
    Evaluate IR metrics at multiple cutoff ranks (k).

    Args:
        topic_ids (list): List of topic IDs to evaluate.
        retrieval_results (dict): Mapping topic_id->list of (content_id, score).
        list_k (list): List of cutoff values (e.g. [10, 50, 100]).

    Returns:
        dict: Mean precision, recall, ndcg, and mrr at each k.
    """
    results = {}

    # Build topic -> set of relevant content IDs
    topic_content_dict = {
        tid: set(content.id for content in Topic(tid).content)
        for tid in topic_ids
    }

    for k in list_k:
        ndcgs, mrrs, p_at_k, r_at_k = [], [], [], []

        for tid in tqdm(topic_ids, desc=f"Evaluating@{k}"):
            relevant = topic_content_dict[tid]
            preds = retrieval_results.get(tid, [])

            pred_ids = [cid for cid, _ in preds]
            y_true = [1 if cid in relevant else 0 for cid in pred_ids]

            ndcgs.append(ndcg_at_k(y_true, k))
            mrrs.append(mrr_at_k(y_true, k))
            p_at_k.append(precision_at_k(y_true, k))
            r_at_k.append(recall_at_k(y_true, len(relevant), k))

        results[f"precision@{k}"] = np.mean(p_at_k)
        results[f"recall@{k}"] = np.mean(r_at_k)
        results[f"ndcg@{k}"] = np.mean(ndcgs)
        results[f"mrr@{k}"] = np.mean(mrrs)

    return results


In [None]:
def rerank_from_retrieval(cross_encoder,
                          val_result_dict,
                          topic_lookup,
                          content_lookup,
                          top_k=50):
    """
    Rerank bi-encoder retrieval results using a cross-encoder.

    Args:
        cross_encoder (CrossEncoder): Trained SentenceTransformers CrossEncoder.
        val_result_dict (dict): Dict mapping topic_id -> list of (content_id, score) from bi-encoder.
        topic_lookup (dict): topic_id -> topic text.
        content_lookup (dict): content_id -> content text.
        top_k (int): Number of top documents to rerank per query.

    Returns:
        dict: topic_id -> list of (content_id, cross-encoder score), sorted by score desc.
    """
    reranked_dict = {}

    for topic_id, retrieved in tqdm(val_result_dict.items(),
                                    desc="CrossEncoder rerank"):
        query_text = topic_lookup.get(topic_id, "")

        # Get top-k content IDs from bi-encoder results
        top_ids = [cid for cid, _ in retrieved[:top_k]]
        doc_texts = [content_lookup.get(cid, "") for cid in top_ids]

        # Use cross-encoder to score relevance
        results = cross_encoder.rank(query_text, doc_texts)

        # Rebuild (content_id, score) list using original IDs
        reranked_dict[topic_id] = [
            (top_ids[res["corpus_id"]], float(res["score"]))
            for res in results
        ]

    return reranked_dict


## Retrieving with bi-encoder

In [None]:
bi_encoder = ""       #link to your saved finetuned bi-encoder
retrieval_val_dict = dense_search(bi_encoder, val_en_topic_ids,
                                  en_content_ids,
                                  top_k = 50)

In [None]:
#View the result metrics if wanted
params_list = [50,25,10,5]
result_metrics = metrics_at_k(val_en_topic_ids,
                              retrieval_val_dict,
                              list_k = params_list)
result_metrics

## Reranking

In [None]:
cross_encoder = ""        #link to your saved finetuned cross-encoder
reranked_val_dict = rerank_from_retrieval(cross_encoder,
                                          retrieval_val_dict,
                                          topic_lookup,
                                          content_lookup)

## Final inference

Set up your inference with different threshold for final results

In [None]:
def inference_by_threshold(reranked_result, k=0.7):
    """
    Filter reranked results by a dynamic score threshold based on top-1 score.

    Args:
        reranked_result (dict): Mapping from query_id -> list of (content_id, score), sorted by score desc.
        k (float): Threshold ratio relative to max score (e.g., 0.7 keeps all scores >= 70% of max).

    Returns:
        dict: Filtered results for each query_id.
    """
    filtered_result = {}

    for qid, cid_score_list in reranked_result.items():
        if not cid_score_list:
            continue

        max_score = cid_score_list[0][1]
        min_required = k * max_score  # dynamic threshold based on top-1

        # Keep only content with score above threshold
        filtered_result[qid] = [
            (cid, score) for cid, score in cid_score_list if score >= min_required
        ]

    return filtered_result


In [None]:
thresholds = np.arange(0.4, 0.9, 0.025)
f2_scores = []
for k in thresholds:
    print(f"Reranking with threshold = {k} \n")
    final_results = inference_by_threshold(reranked_val_dict, k = k)
    _,_,_,f2 = macro_f2(val_en_topic_ids, final_results)
    f2_scores.append(np.mean(f2))