## Load Dataset

In [1]:
from lib.load import extract_data, load_benchmark_corpus


extract_data()
benchmark, corpus = load_benchmark_corpus()

## Split Into Chunks

In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter


text_splitter = RecursiveCharacterTextSplitter(
    separators=['\n\n', '\n', '!', '?', '.', ':', ';', ',', ' ', ''],
    chunk_size=500,
    chunk_overlap=0,
    add_start_index=True,
)

names, texts = zip(*corpus.items())
metadatas = [
    {"source_file": name}
    for idx, name in enumerate(names)
]

documents = text_splitter.create_documents(corpus.values(), metadatas=metadatas)
documents[:3]

[Document(metadata={'source_file': 'Fiverr.txt', 'start_index': 2}, page_content='At Fiverr we care about your privacy.\nWe do not sell or rent your personal information to third parties for their direct marketing purposes without your explicit consent.'),
 Document(metadata={'source_file': 'Fiverr.txt', 'start_index': 173}, page_content='We do not disclose it to others except as disclosed in this Policy or required to provide you with the services of the Site and mobile applications, meaning - to allow you to buy, sell, share the information you want to share on the Site; to contribute on the forum; pay for products; post reviews and so on; or where we have a legal obligation to do so.'),
 Document(metadata={'source_file': 'Fiverr.txt', 'start_index': 530}, page_content='We collect information that you provide us or voluntarily share with other users, and also some general technical information that is automatically gathered by our systems, such as IP address, browser information and 

## Embed Chunks

In [3]:
import gc
import torch

from sentence_transformers import SentenceTransformer
from transformers import BitsAndBytesConfig


def compute_similarities(benchmark, documents):
    # Load model
    model = SentenceTransformer(
        "Qwen/Qwen3-Embedding-8B",
        model_kwargs={"quantization_config": BitsAndBytesConfig(load_in_8bit=True)}
    )
    # Compute embeddings
    document_embeddings = model.encode(
        [f"{document.metadata["source_file"]}: {document.page_content}" for document in documents],
        show_progress_bar=True,
    )
    query_embeddings = model.encode(
        [test['query'] for test in benchmark],
        prompt_name="query",
        show_progress_bar=True,
    )
    # Compute similarity
    similarities = model.similarity(query_embeddings, document_embeddings)
    # Cleanup
    del model
    cleanup()

    return similarities

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

if True:
    similarities = torch.load("sim_cache")
else:
    similarities = compute_similarities(benchmark, documents)

In [4]:
import random

random.seed(1996)
idxs = random.sample(range(len(benchmark)), 20)

benchmark = [benchmark[idx] for idx in idxs]
similarities = similarities[idxs]

In [5]:
import gc

from tqdm.notebook import trange

import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


def rerank(queries, documents, scores, topk=32, batchsize=32):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-8B", padding_side='left')
    
    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen3-Reranker-0.6B",
        quantization_config=BitsAndBytesConfig(load_in_8bit=True),
        attn_implementation="flash_attention_2"
    ).eval()


    pairs = []
    extras = []
    for idx, (query, document_idxs) in enumerate(zip(queries, torch.argsort(scores, descending=True)[:, :topk])):
        pairs += [(query, documents[document_idx]) for document_idx in document_idxs]
        extras += [(idx, document_idx) for document_idx in document_idxs]

    new_scores = []
    new_extras = []
    for i in trange(0, len(pairs), batchsize):    
        new_scores += rescore_qwen(tokenizer, model, pairs[i:i+batchsize])
        new_extras += extras[i:i+batchsize]

    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return new_scores, new_extras
    

def rescore_qwen(tokenizer, model, pairs):
    def format_instruction(instruction, query, doc):
        if instruction is None:
            instruction = 'Given a web search query, retrieve relevant passages that answer the query'
        output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)
        return output
    
    def process_inputs(pairs):
        inputs = tokenizer(
            pairs, padding=False, truncation='longest_first',
            return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
        )
        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
        inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
        for key in inputs:
            inputs[key] = inputs[key].to(model.device)
        return inputs
    
    @torch.no_grad()
    def compute_logits(inputs, **kwargs):
        batch_scores = model(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, token_true_id]
        false_vector = batch_scores[:, token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores
    
    token_false_id = tokenizer.convert_tokens_to_ids("no")
    token_true_id = tokenizer.convert_tokens_to_ids("yes")
    max_length = 8192
    
    prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
    suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
    prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
    suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
            
    task = 'Given a web search query, retrieve relevant passages that answer the query'
    
    pairs = [format_instruction(task, query, doc) for query, doc in pairs]
    
    # Tokenize the input texts
    inputs = process_inputs(pairs)
    scores = compute_logits(inputs)
    
    return scores


new_score, new_extras = rerank(
    [test['query'] for test in benchmark],
    [f"{document.metadata["source_file"]}: {document.page_content}" for document in documents],
    similarities,
)

  0%|          | 0/20 [00:00<?, ?it/s]

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [7]:
y = new_score
document_idxs = torch.argsort(similarities, descending=True)[:, :32].numpy()

blocks = []
for i in range(0, len(y), 64):
    block = list(zip(y[i:i+64], document_idxs[i // 64]))
    block.sort(reverse=True)
    if i == 0:
        print(block)
    blocks.append([int(idx) for _, idx in block])

[(0.9990234375, np.int64(188)), (0.99853515625, np.int64(233)), (0.99853515625, np.int64(187)), (0.998046875, np.int64(225)), (0.998046875, np.int64(192)), (0.99755859375, np.int64(250)), (0.99755859375, np.int64(218)), (0.99755859375, np.int64(191)), (0.99755859375, np.int64(189)), (0.9970703125, np.int64(258)), (0.9970703125, np.int64(240)), (0.9970703125, np.int64(222)), (0.99658203125, np.int64(249)), (0.99658203125, np.int64(174)), (0.99609375, np.int64(176)), (0.99560546875, np.int64(267)), (0.99560546875, np.int64(256)), (0.99560546875, np.int64(251)), (0.9951171875, np.int64(186)), (0.994140625, np.int64(273)), (0.99365234375, np.int64(163)), (0.9931640625, np.int64(184)), (0.99169921875, np.int64(223)), (0.9912109375, np.int64(195)), (0.98876953125, np.int64(190)), (0.98681640625, np.int64(246)), (0.98583984375, np.int64(169)), (0.9853515625, np.int64(175)), (0.984375, np.int64(198)), (0.97314453125, np.int64(199)), (0.96435546875, np.int64(211)), (0.93310546875, np.int64(161)

In [8]:
from lib.metrics import precision_recall

def evaluate_rag(benchmark, documents, similarities, topk):
    document_idxs_by_rank = torch.argsort(similarities, descending=True)[:, :topk]
    return evaluate_rag_reranked(benchmark, documents, document_idxs_by_rank, topk)


def evaluate_rag_reranked(benchmark, documents, document_idxs_by_rank, topk):
    precision = recall = 0
    count = 0
    for test, document_idxs in zip(benchmark, document_idxs_by_rank):
        document_idxs = document_idxs[:topk]
        # Compute spans
        spans_true = []
        for snippet in test["snippets"]:
            spans_true.append(snippet["span"])
        spans_pred = []
        for idx in document_idxs:
            document = documents[idx]
            start = document.metadata["start_index"]
            length = len(document.page_content)
            spans_pred.append((start, start + length))
        # Compute precision and recall
        p, r = precision_recall(spans_true, spans_pred)
        # Update accumulators
        precision += p
        recall += r
        count += 1
    return precision / count, recall / count

In [9]:
evaluate_rag_reranked(benchmark, documents, torch.argsort(similarities, descending=True), 4)

(0.24589244289064832, 0.3285857284815846)

In [11]:
evaluate_rag_reranked(benchmark, documents, blocks, 4)

(0.1364112818056034, 0.09770221601774717)

In [12]:
blocks[0]

[188,
 233,
 187,
 225,
 192,
 250,
 218,
 191,
 189,
 258,
 240,
 222,
 249,
 174,
 176,
 267,
 256,
 251,
 186,
 273,
 163,
 184,
 223,
 195,
 190,
 246,
 169,
 175,
 198,
 199,
 211,
 161]

In [16]:
len(set(torch.argsort(similarities, descending=True)[0].tolist()) & set(blocks[0]))

32