## Load Dataset

In [1]:
from rag.load import load_benchmark_corpus, corpus_to_texts_metadatas


benchmark, corpus = load_benchmark_corpus()
texts, metadatas = corpus_to_texts_metadatas(corpus)

## Split Into Chunks

In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter


text_splitter = RecursiveCharacterTextSplitter(
    separators=['\n\n', '\n', '!', '?', '.', ':', ';', ',', ' ', ''],
    chunk_size=500,
    chunk_overlap=0,
    add_start_index=True,
)

documents = text_splitter.create_documents(texts, metadatas=metadatas)

## Embed Chunks

In [3]:
import torch
from sentence_transformers import SentenceTransformer
from transformers import BitsAndBytesConfig

from rag.embed import compute_similarities, get_query_strings, get_document_contents
from rag.util import cleanup

similarity_cache_path = "data/cache/03_similarities.pt"
try:
    similarities = torch.load(similarity_cache_path)
except:
    similarities = compute_similarities(
        "Qwen/Qwen3-Embedding-8B",
        queries=get_query_strings(benchmark),
        documents=get_document_contents(documents),
    )
    torch.save(similarities, similarity_cache_path)

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from rag.util import cleanup


def format_prompts(query: str, instruction: str, documents: list[str]) -> list[str]:
    """Format query and documents into prompts for reranking."""
    if instruction:
        instruction = f" {instruction}"
    prompts = []
    for doc in documents:
        prompt = f"Check whether a given document contains information helpful to answer the query.\n<Document> {doc}\n<Query> {query}{instruction} ??"
        prompts.append(prompt)
    return prompts

def infer_w_hf(model, tokenizer, query: str, instruction: str, documents: list[str]):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    prompts = format_prompts(query, instruction, documents)
    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask)
    cleanup()

    next_logits = out.logits[:, -1, :]  # [batch, vocab]

    scores_bf16 = next_logits[:, 0].to(torch.bfloat16)
    scores = scores_bf16.float().tolist()

    # Sort by score (descending)
    results = sorted([(s, i, documents[i]) for i, s in enumerate(scores)], key=lambda x: x[0], reverse=True)
    return results

In [5]:
from tqdm.notebook import tqdm

from rag.metrics import similarities_to_ranks
from rag.embed import get_query_strings, get_document_contents


TOP_K = 32
model_path = "ContextualAI/ctxl-rerank-v2-instruct-multilingual-2b"

# Load the reranker
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # so -1 is the real last token for all prompts

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    dtype=dtype,
)
model.eval()

# Run the evaluation
ranks = similarities_to_ranks(similarities)
results = []
for idx, (test, doc_idxs) in tqdm(
    enumerate(zip(benchmark, ranks)),
    total=min(len(benchmark), len(ranks))
):
    top_documents = [documents[doc_idx] for doc_idx in doc_idxs[:TOP_K]]
    result = infer_w_hf(
        model, tokenizer,
        query=get_query_strings([benchmark[idx]])[0],
        instruction='',
        documents=get_document_contents(top_documents),
    )
    results.append(result)

reranks = []
for idx, result in enumerate(results):
    top_documents = ranks[idx, :TOP_K]
    base_document_idxs = [int(top_documents[relative_idx]) for score, relative_idx, content in result]
    reranks.append(base_document_idxs)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]



In [6]:
from rag.metrics import print_evaluations

print("Baseline evaluation")
print_evaluations(benchmark, documents, ranks)

print("\nReranked evaluation")
print_evaluations(benchmark, documents, reranks)

Baseline evaluation
precision @ 1 :  0.2450, recall @ 1 :  0.1492
precision @ 2 :  0.2265, recall @ 2 :  0.2895
precision @ 4 :  0.1758, recall @ 4 :  0.3731
precision @ 8 :  0.1408, recall @ 8 :  0.5243
precision @ 16:  0.1009, recall @ 16:  0.7081
precision @ 32:  0.0685, recall @ 32:  0.8412
precision @ 64:  0.0460, recall @ 64:  0.9538
AUC: 0.10639778453742046

Reranked evaluation
precision @ 1 :  0.3098, recall @ 1 :  0.1903
precision @ 2 :  0.2659, recall @ 2 :  0.2901
precision @ 4 :  0.2127, recall @ 4 :  0.4384
precision @ 8 :  0.1548, recall @ 8 :  0.5639
precision @ 16:  0.1100, recall @ 16:  0.7283
precision @ 32:  0.0685, recall @ 32:  0.8412
precision @ 64:  0.0685, recall @ 64:  0.8412
AUC: 0.12047913911889209
