## Load Dataset

In [1]:
import json
import os
import random

from lib.load import extract_data, LEGALBENCH_RAG_PATH


extract_data()

def load_benchmark_corpus(subset="maud"):
    with open(os.path.join(LEGALBENCH_RAG_PATH, "benchmarks", f"{subset}.json")) as f:
        benchmark = json.load(f)['tests']
    
    corpus = {}
    corpus_path = os.path.join(LEGALBENCH_RAG_PATH, "corpus", subset)

    random.seed(42)
    
    for document in random.sample(os.listdir(corpus_path), 5):
        with open(os.path.join(corpus_path, document)) as f:
            corpus[document] = f.read()

    benchmark_sample = []
    for test in benchmark:
        file_path = test["snippets"][0]["file_path"]
        filename = os.path.basename(file_path)
        if filename in corpus:
            benchmark_sample.append(test)
            
    return benchmark_sample, corpus
benchmark, corpus = load_benchmark_corpus()

## Split Into Chunks

In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter


text_splitter = RecursiveCharacterTextSplitter(
    separators=['\n\n', '\n', '!', '?', '.', ':', ';', ',', ' ', ''],
    chunk_size=500,
    chunk_overlap=0,
    add_start_index=True,
)

names, texts = zip(*corpus.items())
metadatas = [
    {"source_file": name}
    for idx, name in enumerate(names)
]

documents = text_splitter.create_documents(corpus.values(), metadatas=metadatas)
documents[:3]

[Document(metadata={'source_file': 'Domtar Corporation_Paper Excellence Canada Group.txt', 'start_index': 0}, page_content='\ufeffExhibit 2.1 \n\n\nExecution Version     AGREEMENT AND PLAN OF MERGER \n\n\namong \n\n\nDOMTAR CORPORATION, \n\n\nKARTA HALTEN B.V., \n\n\nand \n\n\nPEARL MERGER SUB INC. \n\n\nand \n\n\nPAPER EXCELLENCE B.V. \n\n\nand \n\n\nHERVEY INVESTMENTS B.V. \n\n\nDated as of May 10, 2021    \n\n\n\n\n\n\n\n\n________________'),
 Document(metadata={'source_file': 'Domtar Corporation_Paper Excellence Canada Group.txt', 'start_index': 278}, page_content='TABLE OF CONTENTS         Page  ARTICLE I    DEFINITIONS    Section 1.1   Definitions    6  Section 1.2   Table of Definitions    20  Section 1.3   Other Definitional and Interpretative Provisions    22  ARTICLE II    THE MERGER; EFFECT ON THE CAPITAL STOCK; PAYMENT    Section 2.1   The Merger    23  Section 2.2   Closing    23  Section 2.3   Effective Time    23  Section 2.4   Surviving Corporation Matters    24  Sectio

## Embed Chunks

In [3]:
import gc
import os
import torch

from sentence_transformers import SentenceTransformer
from transformers import BitsAndBytesConfig


def compute_similarities(benchmark, documents):
    # Load model
    model = SentenceTransformer(
        "Qwen/Qwen3-Embedding-8B",
        model_kwargs={"quantization_config": BitsAndBytesConfig(load_in_8bit=True)}
    )
    # Compute embeddings
    document_embeddings = model.encode(
        [f"{document.metadata["source_file"]}: {document.page_content}" for document in documents],
        show_progress_bar=True,
    )
    query_embeddings = model.encode(
        [test['query'] for test in benchmark],
        prompt_name="query",
        show_progress_bar=True,
    )
    # Compute similarity
    similarities = model.similarity(query_embeddings, document_embeddings)
    # Cleanup
    del model
    cleanup()

    return similarities

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()


similarity_cache_path = "data/cache/04_similarities_maud.pt"
try:
    similarities = torch.load(similarity_cache_path)
except:
    similarities = compute_similarities(benchmark, documents)
    torch.save(similarities, similarity_cache_path)

In [4]:
ranks = torch.argsort(similarities, descending=True)

In [5]:
from tqdm.notebook import tqdm

from lib.rerank import Reranker


model_path = "ContextualAI/ctxl-rerank-v2-instruct-multilingual-2b"
TOP_K = 32

def rerank(benchmark, documents, ranks, model_path=model_path, topk=TOP_K):
    reranker = Reranker(model_path)
    
    results = []
    for idx, (test, doc_idxs) in tqdm(
        enumerate(zip(benchmark, ranks)),
        total=min(len(benchmark), len(ranks))
    ):
        result = reranker(
            query=benchmark[idx]['query'],
            instruction='',
            documents=[
                f"{documents[document_idx].metadata["source_file"]}: {documents[document_idx].page_content}" 
                for document_idx in ranks[idx, :TOP_K]
            ],
        )
        results.append(result)
    
        if idx % 8 == 0:
            cleanup()
    
    reranks = []
    for idx, result in enumerate(results):
        top_documents = ranks[idx, :TOP_K]
        base_document_idxs = [int(top_documents[relative_idx]) for score, relative_idx, content in result]
        reranks.append(base_document_idxs)
    return reranks

reranks = rerank(benchmark, documents, ranks)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]



In [6]:
from lib.metrics import print_evaluations

print("Baseline Evaluation")
print_evaluations(benchmark, documents, ranks)

print("\nReranked evaluation")
print_evaluations(benchmark, documents, reranks)

Baseline Evaluation
precision @ 1 :  0.1935, recall @ 1 :  0.0918
precision @ 2 :  0.1687, recall @ 2 :  0.1278
precision @ 4 :  0.1368, recall @ 4 :  0.2127
precision @ 8 :  0.1173, recall @ 8 :  0.3382
precision @ 16:  0.0766, recall @ 16:  0.4198
precision @ 32:  0.0522, recall @ 32:  0.5222
precision @ 64:  0.0352, recall @ 64:  0.6287
AUC: 0.05023092285649827

Reranked evaluation
precision @ 1 :  0.2991, recall @ 1 :  0.1622
precision @ 2 :  0.1922, recall @ 2 :  0.1920
precision @ 4 :  0.1533, recall @ 4 :  0.2744
precision @ 8 :  0.1145, recall @ 8 :  0.3628
precision @ 16:  0.0754, recall @ 16:  0.4222
precision @ 32:  0.0522, recall @ 32:  0.5222
precision @ 64:  0.0522, recall @ 64:  0.5222
AUC: 0.06668000103732792
