In [None]:
# Imports
import os
import re
from collections import Counter
from typing import cast
from datasets import load_dataset, Dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from pydantic import Field, conlist, create_model
import torch
import string
from dotenv import load_dotenv

In [None]:
# Load or set environment variables
# os.environ["OPENAI_API_KEY"] = ""
load_dotenv()

In [None]:
# Config
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
K_RETRIEVE = 5
MAX_HOPS = 3 # realistcally only 2 hops needed for these datasets
MAX_DOCS_PER_HOP = 3
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
CHAT_MODEL = "gpt-4o-mini"

# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


In [None]:
# Load HotpotQA: use TRAIN for corpus, VALIDATION for evaluation
print("Loading HotpotQA...")
ds_train = cast(Dataset, load_dataset("hotpot_qa", "distractor", split="train", streaming=False)) # cast to Dataset to avoid pylance error
ds_val = cast(Dataset, load_dataset("hotpot_qa", "distractor", split="validation", streaming=False))

print("Train size:", len(ds_train))
print("Validation size:", len(ds_val))


In [None]:
# Build a corpus from training and validation sets
corpus_rows = []
for example in ds_train:
    titles = example["context"]["title"]
    sentences_lists = example["context"]["sentences"]
    for title, sents in zip(titles, sentences_lists):
        paragraph_text = " ".join(sents)
        corpus_rows.append({"title": title, "text": paragraph_text})

for example in ds_val:
    titles = example["context"]["title"]
    sentences_lists = example["context"]["sentences"]
    for title, sents in zip(titles, sentences_lists):
        paragraph_text = " ".join(sents)
        corpus_rows.append({"title": title, "text": paragraph_text})

# Remove duplicates
unique_seen = set()
unique_rows = []
for row in corpus_rows:
    clean_text = re.sub(r"\s+", " ", row["text"]).strip().lower()
    key = (row["title"], clean_text)
    if key not in unique_seen:
        unique_seen.add(key)
        unique_rows.append({"title": row["title"], "text": row["text"]})

corpus_rows = unique_rows
print("Paragraphs:", len(corpus_rows))


In [None]:
# Chunk with RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
)

texts, metas = [], []
for r in corpus_rows:
    chunks = text_splitter.split_text(r['text'])
    texts.extend(chunks)
    metas.extend([{"title": r['title']} for _ in chunks])

print("Chunks indexed:", len(texts))


In [None]:
# Build or load FAISS vector store (TODO: move this and the code b4 to a separate script to reuse later)
embedding_model = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL, 
    model_kwargs={"device": device}, # Use GPU if available
    encode_kwargs={"normalize_embeddings": True}, 
    # show_progress=True
)

if os.path.exists("faiss_hotpotqa"):
    print("Loading existing FAISS vector store from faiss_hotpotqa...")
    vector_store = FAISS.load_local("faiss_hotpotqa", embedding_model, allow_dangerous_deserialization=True)
else:
    print("Creating new FAISS vector store...")
    vector_store = FAISS.from_texts(
        texts,
        embedding_model, 
        metadatas=metas
    )

    # Save vector store to disk for future use
    vector_store.save_local("faiss_hotpotqa")
    print("FAISS vector store saved to faiss_hotpotqa")

In [None]:
# LLM question decomposition
# Build a structured LLM that enforces: 2..max_subqs subquestions
def make_decomposer(llm, max_subqs=MAX_HOPS):
    # Dynamic schema so you can set max_subqs at runtime
    DecompSchema = create_model(
        "DecompSchema",
        subquestions=(conlist(str, min_length=2, max_length=max_subqs),
                      Field(description="Ordered sub-questions to solve the original in sequence."))
    )

    # Use strict structured output (no heuristics, no fallback)
    structured_llm = llm.bind_tools(
        tools=[],                     # no tools needed; we just want the schema
        response_format=DecompSchema, # pydantic schema
        strict=True
    )

    SYSTEM = "You have to break complex questions into concise, sequential sub-questions."
    USER_TMPL = (
        f"- Produce BETWEEN 2 and {max_subqs} sub-questions that will help answer the main question.\n"
        "- Each sub-question MUST be under 18 words.\n"
        "- Each sub-question must be specific and answerable using a retrieval system.\n"
        "- Each sub-question must contribute useful information towards answering the main question.\n"
        "- Answers to sub-questions must solve the main question when combined.\n"
        "- Order the list so answering in order solves the original question.\n"
        "- No extra keys. No commentary. No markdown.\n\n"
        "QUESTION: {q}"
    )

    def decompose(question):
        msgs = [
            {"role": "system", "content": SYSTEM},
            {"role": "user", "content": USER_TMPL.format(q=question)},
        ]
        resp = structured_llm.invoke(msgs)
        # LangChain stores the parsed Pydantic object here:
        parsed = resp.additional_kwargs["parsed"]
        # -> parsed is a dict with key "subquestions"
        return parsed.subquestions

    return decompose

In [None]:
# Helper functions for multi-hop QA
def compose_query(llm, question, subq, hops):
    SYSTEM_PROMPT = ("You are to rewrite questions into focused search queries for retrieval.")
    mem_lines = "\n".join([f"{i+1}. {h['subq']} -> {h['answer']}" for i, h in enumerate(hops)]) or "None yet."
    user_prompt = (
        "Rewrite the following sub-question into a concise search query for document retrieval.\n"
        "- use entities/names filled by PRIOR ANSWERS when relevant.\n"
        "- if prior answers don't help, keep the original sub-question details.\n"
        "- Keep it under 18 words. No pronuns like this/that/it.\n"
        "- The query must remain as a question.\n"
        "- Be specific and retrieval-friendly (names, years, titles).\n\n"
        f"ORIGINAL QUESTION: {question}\n\n"
        f"PRIOR ANSWERS:\n{mem_lines}\n\n"
        f"SUB-QUESTION: {subq}\n\n"
        "SEARCH QUERY:"
    )

    resp = llm.invoke([
        {"role": "system", "content": SYSTEM_PROMPT}, 
        {"role": "user", "content": user_prompt}
    ])
    return resp.content.strip()

def answer_subq(llm, question, subq, passages, hops):
    SYSTEM_PROMPT = ("You are a precise QA assistant. Return only the short answer phrase. No explanation, no full sentences.")
    mem_lines = "\n".join([f"{i+1}. {h['subq']} -> {h['answer']}" for i, h in enumerate(hops)]) or "None."
    ctx = "\n\n".join([f"PASSAGE {i+1}:\n{p}" for i, p in enumerate(passages)])
    user_prompt = (
        "Answer the CURRENT SUB-QUESTION using the CONTEXT passages provided.\n"
        f"ORIGINAL QUESTION:\n{question}\n\n"
        f"PRIOR ANSWERS:\n{mem_lines}\n\n"
        f"CURRENT SUB-QUESTION:\n{subq}\n\n"
        f"CONTEXT:\n{ctx}\n\n"
        "Answer (short phrase only):"
    )

    resp = llm.invoke([
        {"role": "system", "content": SYSTEM_PROMPT}, 
        {"role": "user", "content": user_prompt}
    ])
    return resp.content.strip()

def get_final_answer(llm, question, hops):
    SYSTEM_PROMPT = ("You are a precise QA assistant. Return only the short answer phrase (in some cases 1-2 words will suffice). No explanation, no full sentences.")
    mem_lines = "\n".join([f"{i+1}. {h['subq']} -> {h['answer']}" for i, h in enumerate(hops)]) or "None."
    support = []
    for i, h in enumerate(hops):
        for j, p in enumerate(h['passages']):
            support.append(f"HOP {i+1} PASSAGE {j+1}:\n{p}")
    ctx = "\n\n".join(support)
    user_prompt = (
        "Using the prior sub-questions and their answers along with the supporting context, answer the ORIGINAL QUESTION.\n"
        f"ORIGINAL QUESTION:\n{question}\n\n"
        f"SUB-QUESTION ANSWERS:\n{mem_lines}\n\n"
        f"SUPPORTING CONTEXT:\n{ctx}\n\n"
        "Final answer (short phrase only):"
    )

    resp = llm.invoke([
        {"role": "system", "content": SYSTEM_PROMPT}, 
        {"role": "user", "content": user_prompt}
    ])
    return resp.content.strip()

In [None]:
# Multi-hop QA Pipeline
llm = ChatOpenAI(model=CHAT_MODEL, temperature=0)
decomposer = make_decomposer(llm, MAX_HOPS)

def multi_hop_qa(question, max_docs_per_hop=MAX_DOCS_PER_HOP, k=K_RETRIEVE):
    subquestions = decomposer(question)
    hops = [] # each hop will be {"subq":..., "composed":..., "passages":..., "answer":...}
    
    for subq in subquestions:
        # Recompose query using previous answers if any
        composed_q = ""
        if hops:
            composed_q = compose_query(llm, question, subq, hops)
        else:
            composed_q = subq

        # Retrieve documents for this sub-question
        retrieved = vector_store.similarity_search(composed_q, k=k)
        passages = [doc.page_content for doc in retrieved][:max_docs_per_hop]

        # Answer current hop
        ans = answer_subq(llm, question, subq, passages, hops)
        hops.append({
            "subq": subq,
            "composed": composed_q,
            "passages": passages,
            "answer": ans
        })

    # get the final answer using passages and answers from all hops
    final_answer = get_final_answer(llm, question, hops)
    return final_answer, hops

In [None]:
# Let's test on one example
query = "What is the nickname of the city where Darling's Waterfront Pavilion is located?"
pred, hops = multi_hop_qa(query, MAX_DOCS_PER_HOP, K_RETRIEVE)
print("Question:", query)
for i, h in enumerate(hops):
    print(f"Sub-question {i + 1}: {h['subq']}")
    print(f"Composed Query: {h['composed']}")
    for j, p in enumerate(h['passages']):
        print(f"Passage {j + 1}:\n{p}\n")
    print(f"Answer: {h['answer']}\n")

print("Original Question:", query)
print("Predicted Answer:", pred)

In [None]:
# EM/F1 evaluation
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    if normalized_prediction in ['yes', 'no'] and normalized_prediction != normalized_ground_truth:
        return 0
    if normalized_ground_truth in ['yes', 'no'] and normalized_prediction != normalized_ground_truth:
        return 0

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return 1.0 if (normalize_answer(prediction) == normalize_answer(ground_truth)) else 0.0

In [None]:
def eval(ds, n, k=K_RETRIEVE):
    idxs = list(range(min(n, len(ds)))) # first n examples

    ems, f1s = [], []

    for i in idxs:
        ex = ds[i]
        q = ex["question"]
        ground_truth = ex["answer"]

        # Predictions from your singlehop system
        pred, _ = multi_hop_qa(q, MAX_DOCS_PER_HOP, k=k)
        print(f"Q: {q}")
        print(f"Pred: {pred}")
        print(f"Ground Truth: {ground_truth}")

        ems.append(exact_match_score(pred, ground_truth))
        f1s.append(f1_score(pred, ground_truth))

    m = len(idxs) if idxs else 1
    return {
        "n": len(idxs),
        "k": k,
        "EM": sum(ems)/m,
        "F1": sum(f1s)/m,
    }

# Run eval
metrics = eval(ds_val, 100, k=K_RETRIEVE) # TODO: change N to ds_val size for full eval later
print("Metrics:", metrics)