## Retrieval and Answer Generation (RAG)

This notebook implements the 3 main RAG types. (1) BM25 (2) Vector (3) Hybrid using Weaviate Vector Database.

In [8]:
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List
from tqdm import tqdm
from re import compile
import math
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import word_tokenize
from contractions import fix as fix_contractions
from functools import partial
from sklearn.feature_extraction import text as sk_text

from sentence_transformers import SentenceTransformer

# Weaviate v4 imports
import weaviate
from weaviate.classes.config import Configure, Property, DataType
from weaviate.classes.query import MetadataQuery

In [9]:

sentence_transformer_model = SentenceTransformer(
    "BAAI/bge-large-en-v1.5",
    device="cuda"  # or "cpu"
)

## Dataset preparation

In [10]:
def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:

    with open(fqrels, encoding='utf-8') as f:
        data = json.load(f)

    qrels = {}
    
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = f"{psg['DocumentID']}-{psg['PassageID']}"
            qrels[qid][pid] = 1

    return qrels


file_type = 'test'
qrels = load_qrels("", "QnA_complete_fixed.json.")

with open("./data/qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n") 


with open('../../all_data.json', 'r', encoding='utf-8') as f:
    all_data = json.load(f)

collection = []
seen = set()

for q in all_data:
    for psg in q['Passages']:
        psg_id = f"{psg['DocumentID']}-{psg['PassageID']}"
        if psg_id not in seen:
            passage_text = psg['PassageID'] + " " + psg['Passage']
            if len(passage_text) > 100:
                collection.append(
                    dict(
                        text=passage_text,
                        ID=psg_id,
                        DocumentId=psg['DocumentID'],
                        PassageId=psg['PassageID'],
                    )
                )
                seen.add(psg_id)

In [11]:
stop_words = set(stopwords.words('english'))
stop_words = sk_text.ENGLISH_STOP_WORDS.union(stop_words)
stemmer = SnowballStemmer(language='english')

pattern_newline = compile(r'[\n\t\u200e]') 
pattern_multiple_spaces = compile(r' +') 
pattern_non_alphanumeric = compile(r'[^a-z0-9]') 

def clean_text(text: str) -> str:

    cln_text = fix_contractions(text)
    
    cln_text = cln_text.lower()
    
    cln_text = pattern_newline.sub(' ', cln_text)
    
    cln_text = pattern_non_alphanumeric.sub(' ', cln_text)
    
    tokens = [stemmer.stem(word) for word in word_tokenize(cln_text) if word not in stop_words]
    
    cln_text = ' '.join(tokens)
    
    cln_text = pattern_multiple_spaces.sub(' ', cln_text).strip()
    
    return cln_text

def simple_cleaning(query: str) -> str:

    cln_query = pattern_newline.sub(' ', query)
    cln_query = pattern_multiple_spaces.sub(' ', cln_query).strip()
    return cln_query

def tokenizer(text:str)-> list:
   
    tokens = text.split()
    
    unigrams = tokens
    
    bigrams = [f"{tokens[i]} {tokens[i + 1]}" for i in range(len(tokens) - 1)]
    
    return unigrams + bigrams

In [12]:
tokenized_corpus = [tokenizer(clean_text(doc['text'])) for doc in collection]

In [13]:
collection_array = np.array(collection)

len(tokenized_corpus)

32810

In [14]:
# Connect to local Weaviate (Docker)
client = weaviate.connect_to_local()  # defaults: localhost:8080 / 50051

print("Weaviate ready:", client.is_ready())

# Create the IrishSI collection (only if it doesn't already exist)
if "IrishSIPassage" not in client.collections.list_all():
    irishsi_coll = client.collections.create(
        name="IrishSIPassage",
        vector_config=Configure.Vectors.self_provided(),
        properties=[
            Property(name="text", data_type=DataType.TEXT),
            Property(name="documentId", data_type=DataType.TEXT),
            Property(name="passageId", data_type=DataType.TEXT),
        ],
    )
else:
    irishsi_coll = client.collections.get("IrishSIPassage")

print("Collections in Weaviate:", client.collections.list_all())

Weaviate ready: True
Collections in Weaviate: {'IrishSIPassage': _CollectionConfigSimple(name='IrishSIPassage', description=None, generative_config=None, properties=[_Property(name='text', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectorizer_configs={}), _Property(name='documentId', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectorizer_configs={}), _Property(name='passageId', description=None, data_type=<DataType.TEXT: 'text'>, index_filterable=True, index_range_filters=False, index_searchable=True, nested_properties=None, tokenization=<Tokenization.WORD: 'word'>, vectorizer_config=None, vectorizer=None, vectori

Only run the following cells (cells 15 and 16) once each on intital run or dataset change 

In [15]:
# # Delete existing collection completely
# if "IrishSIPassage" in client.collections.list_all():
#     client.collections.delete("IrishSIPassage")

# # Recreate the collection (same schema)
# irishsi_coll = client.collections.create(
#     name="IrishSIPassage",
#     vector_config=Configure.Vectors.self_provided(),
#     properties=[
#         Property(name="text", data_type=DataType.TEXT),
#         Property(name="documentId", data_type=DataType.TEXT),
#         Property(name="passageId", data_type=DataType.TEXT),
#     ],
# )

In [16]:
# collection_data = collection  # if you haven't already aliased it

# irishsi_coll = client.collections.get("IrishSIPassage")

# def to_str_id(value):
#     """
#     Safely convert any ID to a string for Weaviate.
#     Handles floats/ints/None/NaN from pandas/JSON.
#     """
#     if value is None:
#         return "None"
#     # Handle numpy / pandas NaN
#     try:
#         import math
#         if isinstance(value, float) and math.isnan(value):
#             return "NaN"
#     except Exception:
#         pass

#     # If it's numeric like 123.0 -> "123"
#     if isinstance(value, (int, float)):
#         # strip .0 if it's an int-ish float
#         if isinstance(value, float) and value.is_integer():
#             return str(int(value))
#         return str(value)

#     return str(value)

# with irishsi_coll.batch.dynamic() as batch:
#     for doc in tqdm(collection_data, desc="Ingesting IrishSI into Weaviate"):
#         text = doc["text"]

#         raw_document_id = doc.get("DocumentId")
#         raw_passage_id = doc.get("PassageId")

#         document_id = to_str_id(raw_document_id)
#         passage_id = to_str_id(raw_passage_id)

#         # Compute dense vector with your fine-tuned model
#         vector = sentence_transformer_model.encode(
#             text,
#             normalize_embeddings=True
#         ).tolist()

#         batch.add_object(
#             properties={
#                 "text": text,
#                 "documentId": document_id,
#                 "passageId": passage_id,
#             },
#             vector=vector,
#         )

# print("Ingestion complete.")

In [17]:
irishsi_coll = client.collections.get("IrishSIPassage")

result = irishsi_coll.query.bm25(
    query="test"
)

print("Hits:", len(result.objects))
for obj in result.objects:
    print(obj.properties)

Hits: 10
{'text': 'schedule-2 SCHEDULE 2 \n\nNew Schedule 2 to Principal Regulations \n\nRegulation 2(e) \n\n“SCHEDULE 2 \n\nRegulation 6(1) \n\nAPPLICATION FEES \n\nDescription of fee \n\nPrescribed Fee \n\n(1) \n\nApplication for authorisation as a \nCVR test operator (including 1 heavy \nCVR vehicle test lane and 1 light \nCVR vehicle test lane only) \n\nApplication for authorisation as a \nCVR test operator with more than \none 1 heavy CVR vehicle test lane \nand 1 light CVR vehicle test \n\nlane \n\n(2) \n\n€8,500 \n\n€6,000 for each additional test lane \n\nApplication for renewal of \nauthorisation as a CVR test operator \n\n€500 \n\nApplication for amendment of \nauthorisation in relation to ADR \ntesting \n\n€500 \n\nApplication for amendment of \nauthorisation as a CVR test operator \nto increase number of test lanes \n\n€6,000 for each additional test lane \n\n” \n\n \n \n \n \n \n \n \n \n \n \n \n\x0c[475] 73 \n\nRegulation 2(f)', 'documentId': 'si-2022-0475', 'passageId':

In [19]:
# Load queries from IrishSI QnA file

with open("QnA_complete_fixed.json", "r", encoding="utf-8") as f:
    qna_data = json.load(f)

queries = []
for item in qna_data:
    queries.append({
        "query_id": item["QuestionID"],   
        "query": item["Question"]         
    })

print("Number of queries:", len(queries))
print("Example:", queries[0])

Number of queries: 240
Example: {'query_id': '8d9c9c4a-2a66-4d1e-8b0b-0bbf3e3f1c2e', 'query': 'When does a person who purchases a dwelling on or after the commencement of these Regulations cease to qualify as a “relevant owner” for the scheme?'}


In [20]:
def dedupe_list(seq):
    seen = set()
    output = []
    for x in seq:
        if x not in seen:
            output.append(x)
            seen.add(x)
    return output

In [21]:
def average_precision_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    score = 0.0
    hits = 0
    for i, doc_id in enumerate(retrieved_ids[:k]):
        if doc_id in relevant_ids:
            hits += 1
            score += hits / (i + 1)
    return score / max(1, len(relevant_ids))

def recall_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    retrieved_relevant = len([d for d in retrieved_ids[:k] if d in relevant_ids])
    return retrieved_relevant / max(1, len(relevant_ids))

def ndcg_at_k(retrieved_ids: List[str], relevant_ids: List[str], k: int = 20):
    dcg = 0.0
    for i, doc_id in enumerate(retrieved_ids[:k]):
        if doc_id in relevant_ids:
            dcg += 1 / math.log2(i + 2)
    # ideal DCG
    ideal_hits = min(len(relevant_ids), k)
    idcg = sum([1 / math.log2(i + 2) for i in range(ideal_hits)])
    if idcg == 0:
        return 0.0
    return dcg / idcg

In [22]:
def evaluate_retriever(retriever_function, queries, qrels, k=20, show_progress=True):
    ap_scores = []
    recall_scores = []
    ndcg_scores = []

    iterator = tqdm(queries, desc=f"Evaluating {retriever_function.__name__}") if show_progress else queries
    
    for q in iterator:
        qid = q["query_id"]
        relevant = list(qrels.get(qid, {}).keys())

        results = retriever_function(q["query"], top_n=k)

        retrieved_ids = dedupe_list([r["ID"] for r in results])[:k]

        ap_scores.append(average_precision_at_k(retrieved_ids, relevant, k))
        recall_scores.append(recall_at_k(retrieved_ids, relevant, k))
        ndcg_scores.append(ndcg_at_k(retrieved_ids, relevant, k))

    return {
        "MAP@{}".format(k): sum(ap_scores) / len(ap_scores),
        "Recall@{}".format(k): sum(recall_scores) / len(recall_scores),
        "nDCG@{}".format(k): sum(ndcg_scores) / len(ndcg_scores),
    }

## Weaviate Retriever: BM25

Let us evaluate the weaviate retriever using BM25

In [23]:
def weaviate_bm25_search(query: str, top_n: int = 20):
    irishsi_coll = client.collections.get("IrishSIPassage")

    try:
        result = irishsi_coll.query.bm25(
            query=query
        )
    except Exception as e:
        print("BM25 error:", e)
        return []

    hits = []
    for obj in result.objects[:top_n]:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
        })

    return hits

In [24]:
# Evaluate BM25
print("BM25 Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_bm25_search, queries, qrels, k=k, show_progress=(i==0))
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

BM25 Results:


Evaluating weaviate_bm25_search: 100%|██████████| 240/240 [00:00<00:00, 246.91it/s]


  k=10: MAP=0.3092, Recall=0.5083, nDCG=0.3573
  k=20: MAP=0.3092, Recall=0.5083, nDCG=0.3573


## Weaviate Retriever: Vector

Let us evaluate the weaviate retriever using Vector search

In [25]:
def weaviate_vector_search(query: str, top_n: int = 20):

    irishsi_coll = client.collections.get("IrishSIPassage")

    query_vec = sentence_transformer_model.encode(
        query,
        normalize_embeddings=True
    ).tolist()

    result = irishsi_coll.query.near_vector(
        near_vector=query_vec,
        limit=top_n,
        return_metadata=MetadataQuery(score=True, distance=True),
    )

    hits = []
    for obj in result.objects:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
            "score": obj.metadata.score,
            "distance": obj.metadata.distance,
        })

    return hits

In [26]:
# Evaluate Vector
print("Vector Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_vector_search, queries, qrels, k=k, show_progress=(i==0))
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

Vector Results:


Evaluating weaviate_vector_search: 100%|██████████| 240/240 [00:07<00:00, 30.83it/s]


  k=10: MAP=0.3703, Recall=0.5958, nDCG=0.4247
  k=20: MAP=0.3738, Recall=0.6479, nDCG=0.4378


## Weaviate Hybrid Retriever (BM25 + Vector)

Hybrid retriever using a combination of BM25 + Vector

In [27]:
def weaviate_hybrid_search(query: str, top_n: int = 20, alpha: float = 0.5):

    irishsi_coll = client.collections.get("IrishSIPassage")

    query_vec = sentence_transformer_model.encode(
        query,
        normalize_embeddings=True
    ).tolist()

    result = irishsi_coll.query.hybrid(
        query=query,
        vector=query_vec,
        alpha=alpha,
        limit=top_n,
        return_metadata=MetadataQuery(score=True),
    )

    hits = []
    for obj in result.objects:
        props = obj.properties
        hits.append({
            "text": props["text"],
            "DocumentId": props["documentId"],
            "PassageId": props["passageId"],
            "ID": f"{props['documentId']}-{props['passageId']}",
            "score": obj.metadata.score,
        })

    return hits

In [28]:
# Evaluate Hybrid
print("Hybrid Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(weaviate_hybrid_search, queries, qrels, k=k, show_progress=(i==0))
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

Hybrid Results:


Evaluating weaviate_hybrid_search: 100%|██████████| 240/240 [00:07<00:00, 33.13it/s]


  k=10: MAP=0.3618, Recall=0.5750, nDCG=0.4130
  k=20: MAP=0.3669, Recall=0.6479, nDCG=0.4316


In [29]:
# Strong reranker
from sentence_transformers import CrossEncoder


reranker = CrossEncoder("BAAI/bge-reranker-large", device="cuda")

In [30]:
def rerank_results(query: str, candidates: list, top_k: int = 20):
    if len(candidates) == 0:
        return []

    pairs = [[query, c["text"]] for c in candidates]
    scores = reranker.predict(pairs)

    reranked = sorted(
        zip(candidates, scores),
        key=lambda x: x[1],
        reverse=True
    )

    result = []
    for c, s in reranked[:top_k]:
        c["rerank_score"] = float(s)
        result.append(c)

    return result

In [31]:
def hybrid_plus_rerank_search(query: str, top_n: int = 20, pool_size: int = 50, alpha: float = 0.7):
    candidates = weaviate_hybrid_search(query, top_n=pool_size)

    reranked = rerank_results(query, candidates, top_k=top_n)

    return reranked

In [32]:
def rerank_retriever(retriever_fn, query: str, top_n: int = 20, pool_size: int = 50):
 
    candidates = retriever_fn(query, top_n=pool_size)

    return rerank_results(query, candidates, top_k=top_n)

In [33]:
def bm25_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_bm25_search, query, top_n=top_n)

def vector_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_vector_search, query, top_n=top_n)

def hybrid_plus_rerank(query: str, top_n: int = 20):
    return rerank_retriever(weaviate_hybrid_search, query, top_n=top_n)

In [34]:
# Evaluate BM25 + Rerank
print("BM25 + Rerank Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(bm25_plus_rerank, queries, qrels, k=k, )
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

# Evaluate Vector + Rerank
print("Vector + Rerank Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(vector_plus_rerank, queries, qrels, k=k, )
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

# Evaluate Hybrid + Rerank
print("Hybrid + Rerank Results:")
for i, k in enumerate([10, 20]):
    metrics = evaluate_retriever(hybrid_plus_rerank, queries, qrels, k=k, )
    print(f"  k={k}: MAP={metrics[f'MAP@{k}']:.4f}, Recall={metrics[f'Recall@{k}']:.4f}, nDCG={metrics[f'nDCG@{k}']:.4f}")

BM25 + Rerank Results:


Evaluating bm25_plus_rerank: 100%|██████████| 240/240 [02:19<00:00,  1.72it/s]


  k=10: MAP=0.4002, Recall=0.5083, nDCG=0.4266


Evaluating bm25_plus_rerank: 100%|██████████| 240/240 [02:16<00:00,  1.75it/s]


  k=20: MAP=0.4002, Recall=0.5083, nDCG=0.4266
Vector + Rerank Results:


Evaluating vector_plus_rerank: 100%|██████████| 240/240 [11:41<00:00,  2.92s/it]


  k=10: MAP=0.4375, Recall=0.6500, nDCG=0.4882


Evaluating vector_plus_rerank: 100%|██████████| 240/240 [11:24<00:00,  2.85s/it]


  k=20: MAP=0.4397, Recall=0.6813, nDCG=0.4963
Hybrid + Rerank Results:


Evaluating hybrid_plus_rerank: 100%|██████████| 240/240 [12:38<00:00,  3.16s/it]


  k=10: MAP=0.4458, Recall=0.6667, nDCG=0.4983


Evaluating hybrid_plus_rerank: 100%|██████████| 240/240 [12:36<00:00,  3.15s/it]

  k=20: MAP=0.4482, Recall=0.6979, nDCG=0.5066





In [35]:
def save_rankings_to_trec(queries, retriever_function, output_file, method_name="retriever"):

    with open(output_file, 'w') as f:
        for query in tqdm(queries, desc=f"Generating TREC rankings for {method_name}"):
            qid = query["query_id"]
            results = retriever_function(query["query"], top_n=50)  # Get top 50
            
            for rank, result in enumerate(results, start=1):
                passage_id = result["ID"]
                score = result.get("score", result.get("rerank_score", 0.0))  
                
                
                line = f"{qid} Q0 {passage_id} {rank} {score:.6f} {method_name}\n"
                f.write(line)
    
    print(f"Rankings saved to {output_file}")


print("Generating BM25 rankings...")
save_rankings_to_trec(queries, weaviate_bm25_search, "data/rankings_bm25.trec", "bm25")


print("Generating Vector rankings...")
save_rankings_to_trec(queries, weaviate_vector_search, "data/rankings_vector.trec", "vector")


print("Generating Hybrid rankings...")
save_rankings_to_trec(queries, weaviate_hybrid_search, "data/rankings_hybrid.trec", "hybrid")


print("Generating BM25 + Rerank rankings...")
save_rankings_to_trec(queries, bm25_plus_rerank, "data/rankings_bm25_rerank.trec", "bm25_rerank")


print("Generating Vector + Rerank rankings...")
save_rankings_to_trec(queries, vector_plus_rerank, "data/rankings_vector_rerank.trec", "vector_rerank")

print("Generating Hybrid + Rerank rankings...")
save_rankings_to_trec(queries, hybrid_plus_rerank, "data/rankings_hybrid_rerank.trec", "hybrid_rerank")


print("\n" + "="*60)
print("TREC FILES GENERATED:")
print("="*60)
for filename in ["rankings_bm25.trec", "rankings_vector.trec", "rankings_hybrid.trec", 
                 "rankings_bm25_rerank.trec", "rankings_vector_rerank.trec", "rankings_hybrid_rerank.trec"]:
    filepath = f"data/{filename}"
    if os.path.exists(filepath):
        with open(filepath, 'r') as f:
            num_lines = sum(1 for _ in f)
        print(f"{filename}: {num_lines} lines")
    else:
        print(f"{filename}: NOT FOUND")

Generating BM25 rankings...


Generating TREC rankings for bm25: 100%|██████████| 240/240 [00:00<00:00, 276.57it/s]


Rankings saved to data/rankings_bm25.trec
Generating Vector rankings...


Generating TREC rankings for vector: 100%|██████████| 240/240 [00:07<00:00, 30.83it/s]


Rankings saved to data/rankings_vector.trec
Generating Hybrid rankings...


Generating TREC rankings for hybrid: 100%|██████████| 240/240 [00:08<00:00, 29.78it/s]


Rankings saved to data/rankings_hybrid.trec
Generating BM25 + Rerank rankings...


Generating TREC rankings for bm25_rerank: 100%|██████████| 240/240 [02:16<00:00,  1.76it/s]


Rankings saved to data/rankings_bm25_rerank.trec
Generating Vector + Rerank rankings...


Generating TREC rankings for vector_rerank: 100%|██████████| 240/240 [11:41<00:00,  2.92s/it]


Rankings saved to data/rankings_vector_rerank.trec
Generating Hybrid + Rerank rankings...


Generating TREC rankings for hybrid_rerank: 100%|██████████| 240/240 [12:44<00:00,  3.18s/it]

Rankings saved to data/rankings_hybrid_rerank.trec

TREC FILES GENERATED:
rankings_bm25.trec: 2400 lines
rankings_vector.trec: 12000 lines
rankings_hybrid.trec: 12000 lines
rankings_bm25_rerank.trec: 2400 lines
rankings_vector_rerank.trec: 12000 lines
rankings_hybrid_rerank.trec: 12000 lines



