## Retrieval and Answer Generation (RAG)

This notebook implements the 3 main RAG types. (1) BM25 (2) Vector (3) Hybrid using Weaviate Vector Database.

In [1]:
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List
from tqdm import tqdm
from re import compile
import math
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import word_tokenize
from contractions import fix as fix_contractions
from functools import partial
from sklearn.feature_extraction import text as sk_text
from sentence_transformers import SentenceTransformer
import weaviate
from weaviate.classes.config import Configure, Property, DataType
from weaviate.classes.query import MetadataQuery
import matplotlib.pyplot as plt


In [2]:

sentence_transformer_model = SentenceTransformer(
    "BAAI/bge-large-en-v1.5",
    device="cuda"  # or "cpu"
)

## Dataset preparation

In [3]:
def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:

    with open(fqrels, encoding='utf-8') as f:
        data = json.load(f)

    qrels = {}
    
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = f"{psg['DocumentID']}-{psg['PassageID']}"
            qrels[qid][pid] = 1

    return qrels


file_type = 'test'
qrels = load_qrels("", "QnA_complete.json.")

with open("./data/qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n") 


with open('../all_data.json', 'r', encoding='utf-8') as f:
    all_data = json.load(f)

collection = []
seen = set()

for q in all_data:
    for psg in q['Passages']:
        psg_id = f"{psg['DocumentID']}-{psg['PassageID']}"
        if psg_id not in seen:
            passage_text = psg['PassageID'] + " " + psg['Passage']
            if len(passage_text) > 100:
                collection.append(
                    dict(
                        text=passage_text,
                        ID=psg_id,
                        DocumentId=psg['DocumentID'],
                        PassageId=psg['PassageID'],
                    )
                )
                seen.add(psg_id)

In [4]:
stop_words = set(stopwords.words('english'))
stop_words = sk_text.ENGLISH_STOP_WORDS.union(stop_words)
stemmer = SnowballStemmer(language='english')

pattern_newline = compile(r'[\n\t\u200e]') 
pattern_multiple_spaces = compile(r' +') 
pattern_non_alphanumeric = compile(r'[^a-z0-9]') 

def clean_text(text: str) -> str:

    cln_text = fix_contractions(text)
    
    cln_text = cln_text.lower()
    
    cln_text = pattern_newline.sub(' ', cln_text)
    
    cln_text = pattern_non_alphanumeric.sub(' ', cln_text)
    
    tokens = [stemmer.stem(word) for word in word_tokenize(cln_text) if word not in stop_words]
    
    cln_text = ' '.join(tokens)
    
    cln_text = pattern_multiple_spaces.sub(' ', cln_text).strip()
    
    return cln_text

def simple_cleaning(query: str) -> str:

    cln_query = pattern_newline.sub(' ', query)
    cln_query = pattern_multiple_spaces.sub(' ', cln_query).strip()
    return cln_query

def tokenizer(text:str)-> list:
   
    tokens = text.split()
    
    unigrams = tokens
    
    bigrams = [f"{tokens[i]} {tokens[i + 1]}" for i in range(len(tokens) - 1)]
    
    return unigrams + bigrams

In [5]:
tokenized_corpus = [tokenizer(clean_text(doc['text'])) for doc in collection]

In [6]:
collection_array = np.array(collection)

len(tokenized_corpus)

32810

In [7]:
results = {}

In [3]:
# Connect to local Weaviate (Docker)
client = weaviate.connect_to_local()  # defaults: localhost:8080 / 50051

print("Weaviate ready:", client.is_ready())

# Create the IrishSI collection (only if it doesn't already exist)
if "IrishSIPassage" not in client.collections.list_all():
    irishsi_coll = client.collections.create(
        name="IrishSIPassage",
        vector_config=Configure.Vectors.self_provided(),
        properties=[
            Property(name="text", data_type=DataType.TEXT),
            Property(name="documentId", data_type=DataType.TEXT),
            Property(name="passageId", data_type=DataType.TEXT),
        ],
    )
else:
    irishsi_coll = client.collections.get("IrishSIPassage")

print("Collections in Weaviate:", client.collections.list_all())

Weaviate ready: True
Collections in Weaviate: {'IrishSIPassage': _CollectionConfigSimple(name='IrishSIPassage', description=None, generative_config=None, properties=[_Property(name='text', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectorizer_configs={}), _Property(name='documentId', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectorizer_configs={}), _Property(name='passageId', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectori

Only run the following cells (cells 15 and 16) once each on intital run or dataset change 

In [None]:
# # Delete existing collection completely
# if "IrishSIPassage" in client.collections.list_all():
#     client.collections.delete("IrishSIPassage")

# # Recreate the collection (same schema)
# irishsi_coll = client.collections.create(
#     name="IrishSIPassage",
#     vector_config=Configure.Vectors.self_provided(),
#     properties=[
#         Property(name="text", data_type=DataType.TEXT),
#         Property(name="documentId", data_type=DataType.TEXT),
#         Property(name="passageId", data_type=DataType.TEXT),
#     ],
# )

In [None]:
# collection_data = collection  # if you haven't already aliased it

# irishsi_coll = client.collections.get("IrishSIPassage")

# def to_str_id(value):
#     """
#     Safely convert any ID to a string for Weaviate.
#     Handles floats/ints/None/NaN from pandas/JSON.
#     """
#     if value is None:
#         return "None"
#     # Handle numpy / pandas NaN
#     try:
#         import math
#         if isinstance(value, float) and math.isnan(value):
#             return "NaN"
#     except Exception:
#         pass

#     # If it's numeric like 123.0 -> "123"
#     if isinstance(value, (int, float)):
#         # strip .0 if it's an int-ish float
#         if isinstance(value, float) and value.is_integer():
#             return str(int(value))
#         return str(value)

#     return str(value)

# with irishsi_coll.batch.dynamic() as batch:
#     for doc in tqdm(collection_data, desc="Ingesting IrishSI into Weaviate"):
#         text = doc["text"]

#         raw_document_id = doc.get("DocumentId")
#         raw_passage_id = doc.get("PassageId")

#         document_id = to_str_id(raw_document_id)
#         passage_id = to_str_id(raw_passage_id)

#         # Compute dense vector with your fine-tuned model
#         vector = sentence_transformer_model.encode(
#             text,
#             normalize_embeddings=True
#         ).tolist()

#         batch.add_object(
#             properties={
#                 "text": text,
#                 "documentId": document_id,
#                 "passageId": passage_id,
#             },
#             vector=vector,
#         )

# print("Ingestion complete.")

Ingesting IrishSI into Weaviate: 100%|██████████| 32810/32810 [20:40<00:00, 26.44it/s]


Ingestion complete.


In [11]:
irishsi_coll = client.collections.get("IrishSIPassage")

result = irishsi_coll.query.bm25(
    query="test"
)

print("Hits:", len(result.objects))
for obj in result.objects:
    print(obj.properties)

Hits: 10
{'text': 'schedule-2 SCHEDULE 2 \n\nNew Schedule 2 to Principal Regulations \n\nRegulation 2(e) \n\n“SCHEDULE 2 \n\nRegulation 6(1) \n\nAPPLICATION FEES \n\nDescription of fee \n\nPrescribed Fee \n\n(1) \n\nApplication for authorisation as a \nCVR test operator (including 1 heavy \nCVR vehicle test lane and 1 light \nCVR vehicle test lane only) \n\nApplication for authorisation as a \nCVR test operator with more than \none 1 heavy CVR vehicle test lane \nand 1 light CVR vehicle test \n\nlane \n\n(2) \n\n€8,500 \n\n€6,000 for each additional test lane \n\nApplication for renewal of \nauthorisation as a CVR test operator \n\n€500 \n\nApplication for amendment of \nauthorisation in relation to ADR \ntesting \n\n€500 \n\nApplication for amendment of \nauthorisation as a CVR test operator \nto increase number of test lanes \n\n€6,000 for each additional test lane \n\n” \n\n \n \n \n \n \n \n \n \n \n \n \n\x0c[475] 73 \n\nRegulation 2(f)', 'documentId': 'si-2022-0475', 'passageId':

In [3]:
# Load queries from IrishSI QnA file

with open("QnA_complete.json", "r", encoding="utf-8") as f:
    qna_data = json.load(f)

queries = []
for item in qna_data:
    queries.append({
        "query_id": item["QuestionID"],   
        "query": item["Question"]         
    })

print("Number of queries:", len(queries))
print("Example:", queries[0])

Number of queries: 240
Example: {'query_id': '8d9c9c4a-2a66-4d1e-8b0b-0bbf3e3f1c2e', 'query': 'When does a person who purchases a dwelling on or after the commencement of these Regulations cease to qualify as a “relevant owner” for the scheme?'}


In [4]:
def dedupe_list(seq):
    seen = set()
    output = []
    for x in seq:
        if x not in seen:
            output.append(x)
            seen.add(x)
    return output

In [5]:
def average_precision_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    score = 0.0
    hits = 0
    for i, doc_id in enumerate(retrieved_ids[:k]):
        if doc_id in relevant_ids:
            hits += 1
            score += hits / (i + 1)
    return score / max(1, len(relevant_ids))

def recall_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    retrieved_relevant = len([d for d in retrieved_ids[:k] if d in relevant_ids])
    return retrieved_relevant / max(1, len(relevant_ids))

def ndcg_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    dcg = 0.0
    for i, doc_id in enumerate(retrieved_ids[:k]):
        if doc_id in relevant_ids:
            dcg += 1 / math.log2(i + 2)
    # ideal DCG
    ideal_hits = min(len(relevant_ids), k)
    idcg = sum([1 / math.log2(i + 2) for i in range(ideal_hits)])
    if idcg == 0:
        return 0.0
    return dcg / idcg

In [6]:
def evaluate_retriever(retriever_function, queries, qrels, k=20, show_progress=True):
    ap_scores = []
    recall_scores = []
    ndcg_scores = []

    iterator = tqdm(queries, desc=f"Evaluating {retriever_function.__name__}") if show_progress else queries
    
    for q in iterator:
        qid = q["query_id"]
        relevant = list(qrels.get(qid, {}).keys())

        results = retriever_function(q["query"], top_n=k)

        retrieved_ids = dedupe_list([r["ID"] for r in results])[:k]

        ap_scores.append(average_precision_at_k(retrieved_ids, relevant, k))
        recall_scores.append(recall_at_k(retrieved_ids, relevant, k))
        ndcg_scores.append(ndcg_at_k(retrieved_ids, relevant, k))

    return {
        "MAP@{}".format(k): sum(ap_scores) / len(ap_scores),
        "Recall@{}".format(k): sum(recall_scores) / len(recall_scores),
        "nDCG@{}".format(k): sum(ndcg_scores) / len(ndcg_scores),
    }

## Weaviate Retriever: BM25

Let us evaluate the weaviate retriever using BM25

In [7]:
def weaviate_bm25_search(query: str, top_n: int = 20):
    irishsi_coll = client.collections.get("IrishSIPassage")

    try:
        result = irishsi_coll.query.bm25(
            query=query
        )
    except Exception as e:
        print("BM25 error:", e)
        return []

    hits = []
    for obj in result.objects[:top_n]:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
        })

    return hits

In [None]:
query = "If an eligible applicant later receives a payment outside this scheme for the same damage, what are their notification and repayment obligations to the local authority?"
results = weaviate_bm25_search(query, top_n=5)

df = pd.DataFrame(results)
df['text_preview'] = df['text'].str[:100] + '...'
df = df[['ID', 'DocumentId', 'PassageId', 'text_preview']]

print(f"Query: '{query}'")
display(df)

Query: 'If an eligible applicant later receives a payment outside this scheme for the same damage, what are their notification and repayment obligations to the local authority?'


Unnamed: 0,ID,DocumentId,PassageId,text_preview
0,si-2020-0025-reg-12,si-2020-0025,reg-12,reg-12 12. (1) Where an application under Regu...
1,si-2020-0025-reg-11,si-2020-0025,reg-11,reg-11 11. (1) Without prejudice to the genera...
2,si-2020-0025-reg-9,si-2020-0025,reg-9,reg-9 9. (1) An applicant who has received a c...
3,si-2024-0103-reg-26,si-2024-0103,reg-26,reg-26 26. (1) Where an awarding authority mak...
4,si-2025-0070-reg-26,si-2025-0070,reg-26,reg-26 26. (1) Where the awarding authority ma...


In [17]:
print("BM25 Results:")
results["bm25"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_bm25_search, queries, qrels, k=k, show_progress=(i==0))
    results["bm25"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["bm25"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["bm25"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")


BM25 Results:


Evaluating weaviate_bm25_search: 100%|██████████| 240/240 [00:01<00:00, 141.57it/s]


  k=10: MAP=0.2942, Recall=0.5000, nDCG=0.3438
  k=20: MAP=0.2942, Recall=0.5000, nDCG=0.3438


## Weaviate Retriever: Vector

Let us evaluate the weaviate retriever using Vector search

In [10]:
def weaviate_vector_search(query: str, top_n: int = 20):

    irishsi_coll = client.collections.get("IrishSIPassage")

    query_vec = sentence_transformer_model.encode(
        query,
        normalize_embeddings=True
    ).tolist()

    result = irishsi_coll.query.near_vector(
        near_vector=query_vec,
        limit=top_n,
        return_metadata=MetadataQuery(score=True, distance=True),
    )

    hits = []
    for obj in result.objects:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
            "score": obj.metadata.score,
            "distance": obj.metadata.distance,
        })

    return hits

In [14]:
query = "If an eligible applicant later receives a payment outside this scheme for the same damage, what are their notification and repayment obligations to the local authority?"
results = weaviate_vector_search(query, top_n=5)

df = pd.DataFrame(results)
df['text_preview'] = df['text'].str[:100] + '...'
df = df[['ID', 'DocumentId', 'PassageId',  'distance', 'text_preview']]

print(f"Query: '{query}'")
display(df)


Query: 'If an eligible applicant later receives a payment outside this scheme for the same damage, what are their notification and repayment obligations to the local authority?'


Unnamed: 0,ID,DocumentId,PassageId,distance,text_preview
0,si-2023-0347-reg-54,si-2023-0347,reg-54,0.243071,reg-54 54. Where I have received or receive in...
1,si-2023-0347-reg-53,si-2023-0347,reg-53,0.249695,reg-53 53. Where a payment otherwise than unde...
2,si-2023-0347-reg-41,si-2023-0347,reg-41,0.251202,reg-41 41. Where I have received or receive in...
3,si-2023-0347-reg-40,si-2023-0347,reg-40,0.2531,reg-40 40. Where a payment otherwise than unde...
4,si-2020-0025-reg-11,si-2020-0025,reg-11,0.260702,reg-11 11. (1) Without prejudice to the genera...


In [19]:
print("Vector Results:")
results["vector"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_vector_search, queries, qrels, k=k, show_progress=(i==0))
    results["vector"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["vector"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["vector"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")


Vector Results:


Evaluating weaviate_vector_search: 100%|██████████| 240/240 [00:06<00:00, 35.27it/s]


  k=10: MAP=0.3752, Recall=0.6000, nDCG=0.4294
  k=20: MAP=0.3834, Recall=0.6562, nDCG=0.4472


## Weaviate Hybrid Retriever (BM25 + Vector)

Hybrid retriever using a combination of BM25 + Vector

In [4]:
def weaviate_hybrid_search(query: str, top_n: int = 20, alpha: float = 0.5):

    irishsi_coll = client.collections.get("IrishSIPassage")

    query_vec = sentence_transformer_model.encode(
        query,
        normalize_embeddings=True
    ).tolist()

    result = irishsi_coll.query.hybrid(
        query=query,
        vector=query_vec,
        alpha=alpha,
        limit=top_n,
        return_metadata=MetadataQuery(score=True),
    )

    hits = []
    for obj in result.objects:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
            "score": obj.metadata.score,
        })

    return hits

In [18]:
# Strong reranker
from sentence_transformers import CrossEncoder


reranker = CrossEncoder("BAAI/bge-reranker-large", device="cuda")

# query = "If an eligible applicant later receives a payment outside this scheme for the same damage, what are their notification and repayment obligations to the local authority?"
query = "According to the Explanatory Note, what is the substantive effect of S.I. No. 128 of 2020 on the operation period of the Covid-19 Temporary Restrictions Regulations?"
alpha = 0.5

results = weaviate_hybrid_search(query, top_n=5, alpha=alpha)


pd.set_option('display.max_colwidth', None)  
pd.set_option('display.max_rows', None)    

df = pd.DataFrame(results)
df['text_preview'] = df['text'].str[:200] + '...'
df['rank'] = range(1, len(df) + 1)
df['original_score'] = df['score']  # Keep original score for comparison

print(f"Query: '{query}'")
print(f"Alpha: {alpha} (Balance between BM25 and Vector)")

print("Re-ranking with Cross-Encoder...")
query_passage_pairs = [[query, text] for text in df['text']]
rerank_scores = reranker.predict(query_passage_pairs)

df['rerank_score'] = rerank_scores
df = df.sort_values(by='rerank_score', ascending=False).reset_index(drop=True)
df['rank'] = range(1, len(df) + 1)

df_display = df[['rank', 'ID', 'DocumentId', 'PassageId', 'original_score', 'rerank_score', 'text_preview']]
display(df_display)

Query: 'According to the Explanatory Note, what is the substantive effect of S.I. No. 128 of 2020 on the operation period of the Covid-19 Temporary Restrictions Regulations?'
Alpha: 0.5 (Balance between BM25 and Vector)
Re-ranking with Cross-Encoder...


Unnamed: 0,rank,ID,DocumentId,PassageId,original_score,rerank_score,text_preview
0,1,si-2020-0128-explanatory-note,si-2020-0128,explanatory-note,0.748943,0.94951,explanatory-note EXPLANATORY NOTE \n\n(This note is not part of the Instrument and does not purport to be a legal \ninterpretation). \n\nThese Regulations amend the Health Act 1947 (Section 31A-Temporary \n...
1,2,si-2021-0273-explanatory-note,si-2021-0273,explanatory-note,0.718136,0.850439,explanatory-note EXPLANATORY NOTE \n\n(This note is not part of the Instrument and does not purport to be a legal \ninterpretation.) \n\nThese Regulations amend the Health Act 1947 (Section 31A – Temporary...
2,3,si-2022-0048-explanatory-note,si-2022-0048,explanatory-note,0.727485,0.842393,explanatory-note EXPLANATORY NOTE \n\n(This note is not part of the Instrument and does not purport to be a legal \ninterpretation.) \n\nThese Regulations amend the Health Act 1947 (Section 31A – Temporary...
3,4,si-2021-0513-explanatory-note,si-2021-0513,explanatory-note,0.699032,0.5286,explanatory-note EXPLANATORY NOTE \n\n(This note is not part of the Instrument and does not purport to be a legal \ninterpretation.) \n\nThese Regulations amend the Health Act 1947 (Section 31A - Temporary...
4,5,si-2021-0276-explanatory-note,si-2021-0276,explanatory-note,0.480185,0.334487,explanatory-note EXPLANATORY NOTE \n\n(This note is not part of the Instrument and does not purport to be a legal \ninterpretation.) \n\nThese Regulations extend until 19th July 2021 the effective date of ...


In [21]:
print("Hybrid Results:")
results["hybrid"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_hybrid_search, queries, qrels, k=k, show_progress=(i==0))
    results["hybrid"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["hybrid"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["hybrid"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")


Hybrid Results:


Evaluating weaviate_hybrid_search: 100%|██████████| 240/240 [00:08<00:00, 28.09it/s]


  k=10: MAP=0.3572, Recall=0.5708, nDCG=0.4085
  k=20: MAP=0.3630, Recall=0.6521, nDCG=0.4293


In [22]:
# Strong reranker
from sentence_transformers import CrossEncoder


reranker = CrossEncoder("BAAI/bge-reranker-large", device="cuda")

In [23]:
def rerank_results(query: str, candidates: list, top_k: int = 20):
    if len(candidates) == 0:
        return []

    pairs = [[query, c["text"]] for c in candidates]
    scores = reranker.predict(pairs)

    reranked = sorted(
        zip(candidates, scores),
        key=lambda x: x[1],
        reverse=True
    )

    result = []
    for c, s in reranked[:top_k]:
        c["rerank_score"] = float(s)
        result.append(c)

    return result

In [24]:
def hybrid_plus_rerank_search(query: str, top_n: int = 20, pool_size: int = 50, alpha: float = 0.7):
    candidates = weaviate_hybrid_search(query, top_n=pool_size)

    reranked = rerank_results(query, candidates, top_k=top_n)

    return reranked

In [25]:
def rerank_retriever(retriever_fn, query: str, top_n: int = 20, pool_size: int = 50):
 
    candidates = retriever_fn(query, top_n=pool_size)

    return rerank_results(query, candidates, top_k=top_n)

In [26]:
def bm25_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_bm25_search, query, top_n=top_n)

def vector_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_vector_search, query, top_n=top_n)

def hybrid_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_hybrid_search, query, top_n=top_n)

In [27]:
print("BM25 + Rerank Results:")
results["bm25_rerank"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(bm25_plus_rerank, queries, qrels, k=k)
    results["bm25_rerank"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["bm25_rerank"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["bm25_rerank"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

print("Vector + Rerank Results:")
results["vector_rerank"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(vector_plus_rerank, queries, qrels, k=k)
    results["vector_rerank"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["vector_rerank"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["vector_rerank"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

print("Hybrid + Rerank Results:")
results["hybrid_rerank"] = {}
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(hybrid_plus_rerank, queries, qrels, k=k)
    results["hybrid_rerank"][f"MAP@{k}"] = metrics[f"MAP@{k}"]
    results["hybrid_rerank"][f"Recall@{k}"] = metrics[f"Recall@{k}"]
    results["hybrid_rerank"][f"nDCG@{k}"] = metrics[f"nDCG@{k}"]
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")


BM25 + Rerank Results:


Evaluating bm25_plus_rerank: 100%|██████████| 240/240 [02:21<00:00,  1.70it/s]


  k=10: MAP=0.3957, Recall=0.5042, nDCG=0.4220


Evaluating bm25_plus_rerank: 100%|██████████| 240/240 [02:16<00:00,  1.75it/s]


  k=20: MAP=0.3957, Recall=0.5042, nDCG=0.4220
Vector + Rerank Results:


Evaluating vector_plus_rerank: 100%|██████████| 240/240 [11:49<00:00,  2.95s/it]


  k=10: MAP=0.4458, Recall=0.6583, nDCG=0.4965


Evaluating vector_plus_rerank: 100%|██████████| 240/240 [11:35<00:00,  2.90s/it]


  k=20: MAP=0.4481, Recall=0.6896, nDCG=0.5047
Hybrid + Rerank Results:


Evaluating hybrid_plus_rerank: 100%|██████████| 240/240 [12:40<00:00,  3.17s/it]


  k=10: MAP=0.4458, Recall=0.6667, nDCG=0.4983


Evaluating hybrid_plus_rerank: 100%|██████████| 240/240 [12:47<00:00,  3.20s/it]

  k=20: MAP=0.4483, Recall=0.7021, nDCG=0.5075





In [28]:
order = ["bm25", "vector", "hybrid", "bm25_rerank", "vector_rerank", "hybrid_rerank"]
df = pd.DataFrame(results).T.loc[order]

retrievers = df.index.tolist()
x = np.arange(len(retrievers))
width = 0.35

recall_10_vals = df["Recall@10"].values
recall_20_vals = df["Recall@20"].values

plt.figure(figsize=(12,6))
plt.bar(x - width/2, recall_10_vals, width, label="Recall@10")
plt.bar(x + width/2, recall_20_vals, width, label="Recall@20")
plt.xticks(x, retrievers, rotation=45)
plt.ylabel("Recall")
plt.xlabel("retriever")
plt.title("Weaviate RAG Recall Comparison (Simple vs Rerank)")
plt.ylim(0, 1.0)
plt.grid(axis="y", linestyle="--", alpha=0.5)
for i, v in enumerate(recall_10_vals):
    plt.text(x[i] - width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=9)
for i, v in enumerate(recall_20_vals):
    plt.text(x[i] + width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=9)
plt.legend()
plt.tight_layout()
plt.savefig("weaviate_comparison_recall.png")
plt.close()

map_10_vals = df["MAP@10"].values
map_20_vals = df["MAP@20"].values

plt.figure(figsize=(12,6))
plt.bar(x - width/2, map_10_vals, width, label="MAP@10")
plt.bar(x + width/2, map_20_vals, width, label="MAP@20")
plt.xticks(x, retrievers, rotation=45)
plt.ylabel("MAP")
plt.xlabel("retriever")
plt.title("Weaviate RAG MAP Comparison (Simple vs Rerank)")
plt.ylim(0, 1.0)
plt.grid(axis="y", linestyle="--", alpha=0.5)
for i, v in enumerate(map_10_vals):
    plt.text(x[i] - width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=9)
for i, v in enumerate(map_20_vals):
    plt.text(x[i] + width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=9)
plt.legend()
plt.tight_layout()
plt.savefig("weaviate_comparison_map.png")
plt.close()


In [29]:
def save_rankings_to_trec(queries, retriever_function, output_file, method_name="retriever"):

    with open(output_file, 'w') as f:
        for query in tqdm(queries, desc=f"Generating TREC rankings for {method_name}"):
            qid = query["query_id"]
            results = retriever_function(query["query"], top_n=100)
            
            for rank, result in enumerate(results, start=1):
                passage_id = result["ID"]
                score = result.get("score", result.get("rerank_score", 0.0))  
                
                
                line = f"{qid} Q0 {passage_id} {rank} {score:.6f} {method_name}\n"
                f.write(line)
    
    print(f"Rankings saved to {output_file}")


print("Generating BM25 rankings...")
save_rankings_to_trec(queries, weaviate_bm25_search, "data/rankings_bm25.trec", "bm25")


print("Generating Vector rankings...")
save_rankings_to_trec(queries, weaviate_vector_search, "data/rankings_vector.trec", "vector")


print("Generating Hybrid rankings...")
save_rankings_to_trec(queries, weaviate_hybrid_search, "data/rankings_hybrid.trec", "hybrid")


print("Generating BM25 + Rerank rankings...")
save_rankings_to_trec(queries, bm25_plus_rerank, "data/rankings_bm25_rerank.trec", "bm25_rerank")


print("Generating Vector + Rerank rankings...")
save_rankings_to_trec(queries, vector_plus_rerank, "data/rankings_vector_rerank.trec", "vector_rerank")

print("Generating Hybrid + Rerank rankings...")
save_rankings_to_trec(queries, hybrid_plus_rerank, "data/rankings_hybrid_rerank.trec", "hybrid_rerank")



Generating BM25 rankings...


Generating TREC rankings for bm25: 100%|██████████| 240/240 [00:01<00:00, 167.94it/s]


Rankings saved to data/rankings_bm25.trec
Generating Vector rankings...


Generating TREC rankings for vector: 100%|██████████| 240/240 [00:08<00:00, 29.78it/s]


Rankings saved to data/rankings_vector.trec
Generating Hybrid rankings...


Generating TREC rankings for hybrid: 100%|██████████| 240/240 [00:09<00:00, 25.14it/s]


Rankings saved to data/rankings_hybrid.trec
Generating BM25 + Rerank rankings...


Generating TREC rankings for bm25_rerank: 100%|██████████| 240/240 [02:16<00:00,  1.76it/s]


Rankings saved to data/rankings_bm25_rerank.trec
Generating Vector + Rerank rankings...


Generating TREC rankings for vector_rerank: 100%|██████████| 240/240 [11:34<00:00,  2.89s/it]


Rankings saved to data/rankings_vector_rerank.trec
Generating Hybrid + Rerank rankings...


Generating TREC rankings for hybrid_rerank: 100%|██████████| 240/240 [12:23<00:00,  3.10s/it]

Rankings saved to data/rankings_hybrid_rerank.trec



