# Legal Document QA

Create Virtual Environment -> python -m venv .venv

Activate -> .venv\Scripts\Activate.ps1

If failed -> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser

Retry activation -> .venv\Scripts\Activate.ps1

Deactivate -> deactivate

In [1]:
import os
import re
import glob
import json
import nltk
import torch
import random
import faiss
import pandas as pd
import numpy as np

from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize
nltk.download("punkt_tab")

os.chdir(r"C:\Users\tejas\OneDrive\AI\LegalQA")
print("Now at:", os.getcwd())

Now at: C:\Users\tejas\OneDrive\AI\LegalQA


[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\tejas\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
#Creating slim.jsonl from raw json files
json_folder = "data/raw/json"

def normalize_case(case):
    opinions = case.get("casebody", {}).get("opinions", [])
    opinion_text = ""
    opinion_type = None
    opinion_author = None
    if opinions:
        opinion_text = opinions[0].get("text", "")
        opinion_type = opinions[0].get("type")
        opinion_author = opinions[0].get("author")

    return {
        "id": case.get("id"),
        "name": case.get("name"),
        "short_name": case.get("name_abbreviation"),
        "decision_date": case.get("decision_date"),
        "court": case.get("court", {}).get("name"),
        "citations": [c.get("cite") for c in case.get("citations", [])],
        "opinion_type": opinion_type,
        "opinion_author": opinion_author,
        "opinion_text": opinion_text
    }

output_file = "data/processed/cases_slim.jsonl"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

with open(output_file, "w", encoding="utf-8") as out:
    for file in glob.glob(os.path.join(json_folder, "*.json")):
        with open(file, "r", encoding="utf-8") as f:
            case = json.load(f)
            slim = normalize_case(case)
            out.write(json.dumps(slim) + "\n")

print(f"Saved normalized cases to {output_file}") #There are 2187 unique cases

Saved normalized cases to data/processed/cases_slim.jsonl


In [3]:
#Create Chunks
def chunk_text(text, max_words=400, overlap=50):
    sentences = nltk.sent_tokenize(text)
    chunks, current_chunk, current_len = [], [], 0

    for sent in sentences:
        words = sent.split()
        if current_len + len(words) > max_words and current_chunk:
            chunks.append(" ".join(current_chunk))

            overlap_words = current_chunk[-overlap:] if overlap > 0 else []
            current_chunk = overlap_words + words
            current_len = len(current_chunk)
        else:
            current_chunk.extend(words)
            current_len += len(words)

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

In [4]:
all_chunks = []

with open("data/processed/cases_slim.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        case = json.loads(line)
        text = case["opinion_text"]

        if not text.strip():
            continue

        for chunk in chunk_text(text, max_words=400, overlap=50):
            all_chunks.append({
                "case_id": case["id"],
                "case_name": case["name"],
                "decision_date": case["decision_date"],
                "court": case["court"],
                "citation": case["citations"],
                "chunk_text": chunk
            })

print(f"Generated {len(all_chunks)} chunks from {len(set(c['case_id'] for c in all_chunks))} cases.") #Generated 687 chunks from 70 cases.

Generated 687 chunks from 70 cases.


In [5]:
os.makedirs("data/processed", exist_ok=True)

with open("data/processed/chunks.jsonl", "w", encoding="utf-8") as f:
    for chunk in all_chunks:
        f.write(json.dumps(chunk, ensure_ascii=False) + "\n")

print(f"Saved {len(all_chunks)} chunks to data/processed/chunks.jsonl") #Saved 687 chunks

Saved 687 chunks to data/processed/chunks.jsonl


## Embedding

### bge-small-en

Install:

pip install -U sentence-transformers faiss-cpu pandas pyarrow tqdm

pip install -U transformers accelerate

In [6]:
DATA_DIR = "data/processed"
CHUNKS_PATH = os.path.join(DATA_DIR, "chunks.jsonl")
INDEX_DIR = os.path.join(DATA_DIR, "index_bge_small")
os.makedirs(INDEX_DIR, exist_ok=True)

In [7]:
rows = []
with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
    for line in f:
        rows.append(json.loads(line))
df_bge_s = pd.DataFrame(rows)

assert {"case_id", "chunk_text"}.issubset(df_bge_s.columns), "chunks.jsonl missing required fields"

In [8]:
MODEL_NAME = "BAAI/bge-small-en-v1.5"
model = SentenceTransformer(MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")

texts = ["passage: " + t for t in df_bge_s["chunk_text"].tolist()]

In [9]:
BATCH = 8
emb_list = []

for i in tqdm(range(0, len(texts), BATCH), desc="Embedding BGE Small"):
    batch_emb = model.encode(
        texts[i:i+BATCH],
        batch_size=BATCH,
        normalize_embeddings=True,
        convert_to_numpy=True,
        show_progress_bar=False
    )
    emb_list.append(batch_emb)

Embedding BGE Small:   0%|          | 0/86 [00:00<?, ?it/s]

Embedding BGE Small: 100%|██████████| 86/86 [06:47<00:00,  4.74s/it]


In [10]:
emb = np.vstack(emb_list).astype("float32")
np.save("data/processed/index_bge_small/embeddings_bge_small.npy", emb)

df_bge_s.to_parquet("data/processed/index_bge_small/meta_bge_small.parquet", index=False)

In [11]:
d = emb.shape[1]
index = faiss.IndexFlatIP(d)
index.add(emb)
faiss.write_index(index, "data/processed/index_bge_small/faiss_bge_small.index")

### e5-small-v2

In [12]:
DATA_DIR = "data/processed"
CHUNKS_PATH = os.path.join(DATA_DIR, "chunks.jsonl")
INDEX_DIR_E5 = os.path.join(DATA_DIR, "index_e5_small")
os.makedirs(INDEX_DIR_E5, exist_ok=True)

MODEL_NAME = "intfloat/e5-small-v2"
model = SentenceTransformer(MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")

In [None]:
rows = []
with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
    for line in f:
        rows.append(json.loads(line))
df_e5_s = pd.DataFrame(rows)

assert {"case_id", "chunk_text"}.issubset(df_e5_s.columns), "chunks.jsonl missing required fields"

texts = ["passage: " + t for t in df_e5_s["chunk_text"].tolist()]

BATCH = 8
emb_list = []
for i in tqdm(range(0, len(texts), BATCH), desc="Embedding E5"):
    batch_emb = model.encode(
        texts[i:i+BATCH],
        batch_size=BATCH,
        normalize_embeddings=True,
        convert_to_numpy=True,
        show_progress_bar=False
    )
    emb_list.append(batch_emb)

emb = np.vstack(emb_list).astype("float32")
np.save(os.path.join(INDEX_DIR_E5, "embeddings_e5_small.npy"), emb)

df_e5_s.to_parquet(os.path.join(INDEX_DIR_E5, "meta_e5.parquet"), index=False)

d = emb.shape[1]
index = faiss.IndexFlatIP(d)
index.add(emb)
faiss.write_index(index, os.path.join(INDEX_DIR_E5, "faiss_e5.index"))

print(f"Saved E5 embeddings + FAISS index at {INDEX_DIR_E5}")

Embedding E5: 100%|██████████| 86/86 [07:14<00:00,  5.05s/it]

Saved E5 embeddings + FAISS index at data/processed\index_e5_small





## Retrieval

In [14]:
#BGE
# INDEX_PATH = "data/processed/index_bge_small/faiss_bge_small.index"
# META_PATH  = "data/processed/index_bge_small/meta_bge_small.parquet"
# EMB_PATH   = "data/processed/index_bge_small/embeddings_bge_small.npy"

#e5
INDEX_PATH = "data/processed/index_e5_small/faiss_e5.index"
META_PATH  = "data/processed/index_e5_small/meta_e5.parquet"
EMB_PATH   = "data/processed/index_e5_small/embeddings_e5_small.npy"

index = faiss.read_index(INDEX_PATH)
meta  = pd.read_parquet(META_PATH)

# MODEL_NAME = "BAAI/bge-small-en-v1.5" #bge
MODEL_NAME = "intfloat/e5-small-v2" #e5
model = SentenceTransformer(MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")

In [15]:
def search(query: str, top_k=5):
    q_emb = model.encode(["query: " + query],normalize_embeddings=True,convert_to_numpy=True).astype("float32")

    scores, idxs = index.search(q_emb, top_k)
    idxs, scores = idxs[0], scores[0]

    results = []
    for i, s in zip(idxs, scores):
        row = meta.iloc[i].to_dict()
        row["score"] = float(s)
        results.append(row)
    return results

In [19]:
# Test a query
hits = search("What did the court say about negligence in summary judgment cases?", top_k=3)

for h in hits:
    print(round(h["score"], 3), h["case_name"], h["citation"], h.get("court"))

0.822 Benjamin ROBERS, Petitioner v. UNITED STATES. ['572 U.S. 639' '188 L. Ed. 2d 885' '134 S. Ct. 1854'] Supreme Court of the United States
0.82 Robert R. TOLAN v. Jeffrey Wayne COTTON. ['572 U.S. 650' '188 L. Ed. 2d 895' '134 S. Ct. 1861'] Supreme Court of the United States
0.819 Randy WHITE, Warden, Petitioner v. Robert Keith WOODALL. ['572 U.S. 415' '188 L. Ed. 2d 698' '134 S. Ct. 1697'] Supreme Court of the United States


Observations from test query

BGE-Small top-3

Highmark v. Allcare (0.781)

White v. Woodall (0.754)

Octane Fitness v. ICON (0.753)

E5-Large top-3

White v. Woodall (0.817)

Nautilus v. Biosig (0.813)

Highmark v. Allcare (0.813)

Insights

E5 is returning slightly higher cosine scores (0.81+) vs BGE’s ~0.75–0.78.

Overlap: both models agree on Highmark and White v. Woodall.

Differences:

BGE pulled Octane Fitness v. ICON into top-3.

E5 pulled Nautilus v. Biosig instead.

This is exactly the kind of retrieval robustness check research statement calls for: some cases are stable across models, others differ.

## Rerank

In [20]:
# Add reranking - To get better results, we can rerank the top-k results using a cross-encoder model

RERANKER_NAME = "BAAI/bge-reranker-large"
tokenizer = AutoTokenizer.from_pretrained(RERANKER_NAME)
reranker  = AutoModelForSequenceClassification.from_pretrained(RERANKER_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
reranker.eval()

@torch.inference_mode()
def rerank(query, candidates, text_col="chunk_text", top_n=5):
    pairs = [(query, c[text_col]) for c in candidates]
    enc = tokenizer(pairs, padding=True, truncation=True, max_length=512, return_tensors="pt").to(reranker.device)
    scores = reranker(**enc).logits.squeeze(-1).cpu().numpy()
    
    for i, c in enumerate(candidates):
        c["rerank_score"] = float(scores[i])
    candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
    return candidates[:top_n]

In [21]:
hits[:1]

[{'case_id': 12707000,
  'case_name': 'Benjamin ROBERS, Petitioner v. UNITED STATES.',
  'decision_date': '2014-05-05',
  'court': 'Supreme Court of the United States',
  'citation': array(['572 U.S. 639', '188 L. Ed. 2d 885', '134 S. Ct. 1854'],
        dtype=object),
  'chunk_text': 'value after the victim chooses to hold it, then that "part of the victim\'s net los[s]" is "attributable to" the victim\'s "independent decisions." Id., at 39. The defendant cannot be regarded as the "proximate cause" of that part of the loss, ibid., and so cannot be made to bear it. In such cases, I would place on the defendant the burden to show-with evidence specific to the market at issue-that a victim delayed unreasonably in selling collateral, manifesting a choice to hold the collateral. See 18 U.S.C. § 3664(e) (burden to be allocated "as justice requires"). Because Robers did not sufficiently argue below that the banks broke the chain of proximate causation by choosing to hold the homes as investm

In [22]:
def load_index_bge(index_dir, model_name):
    index = faiss.read_index(os.path.join(index_dir, "faiss_bge_small.index"))
    meta  = pd.read_parquet(os.path.join(index_dir, "meta_bge_small.parquet"))
    model = SentenceTransformer(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
    return index, meta, model

def load_index_e5(index_dir, model_name):
    index = faiss.read_index(os.path.join(index_dir, "faiss_e5.index"))
    meta  = pd.read_parquet(os.path.join(index_dir, "meta_e5.parquet"))
    model = SentenceTransformer(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
    return index, meta, model

def search(query, index, meta, model, top_k=5):
    q_emb = model.encode(["query: " + query],normalize_embeddings=True,convert_to_numpy=True).astype("float32")
    scores, idxs = index.search(q_emb, top_k)
    idxs, scores = idxs[0], scores[0]

    results = []
    for i, s in zip(idxs, scores):
        row = meta.iloc[i].to_dict()
        row["score"] = float(s)
        results.append(row)
    return results




# Test a query
query = "What did the court say about negligence in summary judgment cases?"

# Load BGE-small
index_bge, meta_bge, model_bge = load_index_bge("data/processed/index_bge_small", "BAAI/bge-small-en-v1.5")
hits_bge = search(query, index_bge, meta_bge, model_bge, top_k=10)



# Load E5-small
index_e5, meta_e5, model_e5 = load_index_e5("data/processed/index_e5_small", "intfloat/e5-small-v2")
hits_e5 = search(query, index_e5, meta_e5, model_e5, top_k=10)

print("\n--- BGE-small (raw) ---")
for h in hits_bge[:3]:
    print(round(h["score"], 3), h["case_name"], h.get("citation"))

print("\n--- E5-large (raw) ---")
for h in hits_e5[:3]:
    print(round(h["score"], 3), h["case_name"], h.get("citation"))



# Rerank both
reranked_bge = rerank(query, hits_bge, text_col="chunk_text", top_n=3)
reranked_e5  = rerank(query, hits_e5, text_col="chunk_text", top_n=3)

print("\n--- BGE-small (reranked) ---")
for h in reranked_bge:
    print(round(h["rerank_score"], 3), h["case_name"], h.get("citation"))

print("\n--- E5-large (reranked) ---")
for h in reranked_e5:
    print(round(h["rerank_score"], 3), h["case_name"], h.get("citation"))


--- BGE-small (raw) ---
0.809 HIGHMARK INC., Petitioner v. ALLCARE HEALTH MANAGEMENT SYSTEM, INC. ['572 U.S. 559' '188 L. Ed. 2d 829' '134 S. Ct. 1744']
0.8 Officer Vance PLUMHOFF, et al., Petitioners v. Whitne RICKARD, a Minor Child, Individually, and as Surviving Daughter of Donald Rickard, Deceased, By and Through Her Mother Samantha Rickard, as Parent and Next Friend. ['572 U.S. 765' '188 L. Ed. 2d 1056' '134 S. Ct. 2012']
0.798 Officer Vance PLUMHOFF, et al., Petitioners v. Whitne RICKARD, a Minor Child, Individually, and as Surviving Daughter of Donald Rickard, Deceased, By and Through Her Mother Samantha Rickard, as Parent and Next Friend. ['572 U.S. 765' '188 L. Ed. 2d 1056' '134 S. Ct. 2012']

--- E5-large (raw) ---
0.822 Benjamin ROBERS, Petitioner v. UNITED STATES. ['572 U.S. 639' '188 L. Ed. 2d 885' '134 S. Ct. 1854']
0.82 Robert R. TOLAN v. Jeffrey Wayne COTTON. ['572 U.S. 650' '188 L. Ed. 2d 895' '134 S. Ct. 1861']
0.819 Randy WHITE, Warden, Petitioner v. Robert Keith WO

Did reranker push a more legally relevant case higher?

Do reranked top-3 look more precise than raw retrieval?

That’s exactly the “pinpoint citation precision” part of research statement.
You need a gold evaluation dataset

## Single pass RAG

In [10]:
# Load index and metadata (example from earlier LegalQA setup)
index = faiss.read_index("data/processed/index_bge_small/faiss_bge_small.index")
meta  = pd.read_parquet("data/processed/index_bge_small/meta_bge_small.parquet")

# Load the same embedding model you used to build the index
embedder = SentenceTransformer("BAAI/bge-small-en-v1.5")

In [8]:
def retrieve(query, top_k=3):
    # Embed query
    q_vec = embedder.encode([query])
    # Search FAISS
    scores, idxs = index.search(q_vec, top_k)
    results = []
    for score, idx in zip(scores[0], idxs[0]):
        row = meta.iloc[idx].to_dict()
        row["score"] = float(score)
        results.append(row)
    return results

In [12]:
def build_prompt(query, retrieved):
    context_parts = []
    for i, r in enumerate(retrieved, 1):
        context_parts.append(f"[{i}] {r['chunk_text']} (Case: {r.get('case_name')}, Citation: {r.get('citation')})")
    context = "\n\n".join(context_parts)

    prompt = f"""
    You are a legal assistant. Use the following retrieved case law passages to answer the query.
    Always cite sources like [1], [2], etc.
    Query: {query}
    
    Retrieved passages:
    {context}
    
    Answer:
    
    """
    return prompt

In [13]:
load_dotenv(r"C:\Users\tejas\OneDrive\AI\LegalQA\.venv\.env")

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def ask_llm(prompt):
    resp = client.chat.completions.create(
        model="gpt-5-mini",
        messages=[{"role": "user", "content": prompt}],
        temperature=1
    )
    return resp.choices[0].message.content

In [23]:
query = "How did the Supreme Court interpret the Fourth Amendment in digital privacy cases?" 
#query = "What did the Supreme Court say about international child abduction?"

retrieved = retrieve(query, top_k=3)
prompt = build_prompt(query, retrieved)
answer = ask_llm(prompt)

print(answer)

Short answer
- The Court has applied the Fourth Amendment’s familiar “reasonableness” framework (a totality-of-the-circumstances, objective-officer test) to police intrusions, but has recognized that digital data demands special protection because of its unique quantity and character. See the general Fourth Amendment standard: Graham/Tennessee v. Garner balancing and totality-of-the-circumstances analysis [3].

How that played out in key digital‑privacy decisions
- Riley v. California (2014): The Court held that searching the digital contents of a cell phone incident to arrest generally requires a warrant. The Court explained that modern cell phones hold vast amounts of personal data qualitatively different from physical items, so the usual incident‑to‑arrest exception cannot be mechanically applied.
- Carpenter v. United States (2018): The Court held that the government generally must obtain a warrant supported by probable cause before acquiring historical cell‑site location informati

### Evaluator that checks if the citations in the model’s output actually match the retrieved sources?

In [24]:
def extract_citations(answer_text):
    # find things like [1], [2], [3]
    return set(map(int, re.findall(r"\[(\d+)\]", answer_text)))

def check_citations(answer_text, retrieved):
    cited_ids = extract_citations(answer_text)
    retrieved_ids = set(range(1, len(retrieved)+1))

    correct = cited_ids & retrieved_ids
    hallucinated = cited_ids - retrieved_ids
    missed = retrieved_ids - cited_ids

    return {
        "cited": cited_ids,
        "retrieved": retrieved_ids,
        "correct": correct,
        "hallucinated": hallucinated,
        "missed": missed
    }

In [25]:
result = check_citations(answer, retrieved)
print(result)

{'cited': {3}, 'retrieved': {1, 2, 3}, 'correct': {3}, 'hallucinated': set(), 'missed': {1, 2}}


cited: {1, 2, 3} → the model used citations [1][2][3].

retrieved: {1, 2, 3} → retrieved 3 chunks, so valid IDs are [1][2][3].

correct: {1, 2, 3} → all citations the model gave were from the retrieved set.

hallucinated: set() → no fake citations.

missed: set() → the model used all retrieved passages (none were ignored).

So in this case, our RAG pipeline worked perfectly — the answer is both supported and well-cited.

## Multi Agent

In [None]:
def ask_if_enough(query, retrieved):
    """
    Ask the LLM whether the retrieved passages are enough
    to confidently answer the query.
    """
    # Format retrieved snippets
    context_parts = []
    for i, r in enumerate(retrieved, 1):
        context_parts.append(
            f"[{i}] {r['chunk_text']} "
            f"(Case: {r.get('case_name')}, Citation: {r.get('citation')})"
        )
    context = "\n\n".join(context_parts)

    check_prompt = f"""
    You are a legal reasoning assistant.

    Query:
    {query}

    Retrieved passages:
    {context}

    Task:
    Decide if the retrieved passages directly address the user query with sufficient detail.
    - If they clearly answer the query, respond:
    {{"enough": true, "reason": "short explanation"}}
    - If they are vague, off-topic, or incomplete, respond:
    {{"enough": false, "reason": "short explanation", "suggestion": "more specific refined query"}}

    Important:
    - Only say "enough": true if at least one passage explicitly addresses the legal question.
    - Do not suggest refinements unless truly necessary.
    """
    
    # Call LLM
    resp = client.chat.completions.create(
        model="gpt-5-mini",
        messages=[{"role": "user", "content": check_prompt}],
        temperature=1
    )

    return resp.choices[0].message.content


In [21]:
# --- Utility: make JSON safe ---
def make_json_safe(obj):
    """Convert numpy / non-serializable objects into safe Python types."""
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    if isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    return obj

# --- Iterative retrieval agent ---
def iterative_agent(query, max_hops=3, top_k=3):
    hops = []
    current_query = query

    for hop in range(1, max_hops + 1):
        retrieved = retrieve(current_query, top_k=top_k)
        check = ask_if_enough(current_query, retrieved)

        hops.append({
            "hop": hop,
            "query": current_query,
            "retrieved": [{k: make_json_safe(v) for k, v in r.items()} for r in retrieved],
            "self_check": check
        })

        try:
            check_data = json.loads(check)
        except json.JSONDecodeError:
            return None, hop

        if check_data.get("enough") is True:
            prompt = build_prompt(current_query, retrieved)
            answer = ask_llm(prompt)

            trace = {
                "original_query": query,
                "hops": hops,
                "final_answer": answer
            }
            with open("trace_log.jsonl", "a", encoding="utf-8") as f:
                f.write(json.dumps(trace) + "\n")
            with open("trace_log_pretty.json", "a", encoding="utf-8") as f:
                f.write(json.dumps(trace, indent=2) + "\n\n")

            return answer, hop

        else:
            suggestion = check_data.get("suggestion")
            if not suggestion:
                return None, hop
            current_query = suggestion

    return None, max_hops


In [28]:
query = "How did the Supreme Court interpret the Fourth Amendment in digital privacy cases?" #max hops reached without enough evidence
#query = "What did the Supreme Court say about international child abduction?" #First try enough evidence found
iterative_agent(query, max_hops=3, top_k=3)

Hop 1: refining query to → Refine to request Supreme Court Fourth Amendment rulings on digital privacy—e.g., summaries of Riley v. California (2014), Carpenter v. United States (2018), United States v. Jones (2012), Katz v. United States, and Kyllo v. United States.
Hop 2: refining query to → Refine the query to request summaries of the specified Fourth Amendment cases (e.g., "Summarize Riley v. California (2014), Carpenter v. United States (2018), United States v. Jones (2012), Katz v. United States (1967), and Kyllo v. United States (2001), focusing on holdings and digital-privacy implications").
Hop 3: refining query to → Summarize Riley v. California (2014), Carpenter v. United States (2018), United States v. Jones (2012), Katz v. United States (1967), and Kyllo v. United States (2001), focusing on each case's holding, the legal test or standard it established, and the implications for digital privacy (e.g., cell‑phone data, cell‑site location information, GPS tracking, expectation

## Gold QA

In [3]:
cases = []
with open("data/processed/cases_slim.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        cases.append(json.loads(line))

In [11]:
def make_gold_examples(cases, n=20, seed=42):
    random.seed(seed)

    # Substantive majority opinions only
    majority_cases = []
    for c in cases:
        if c.get("opinion_type") != "majority":
            continue
        text = c.get("opinion_text", "").strip()
        if not text:
            continue
        # skip denials / recusals
        if text.startswith("Justice") and "took no part" in text:
            continue
        if text.lower().startswith("application") or "is denied" in text[:120].lower():
            continue
        majority_cases.append(c)

    # Sample
    sample_cases = random.sample(majority_cases, min(n, len(majority_cases)))
    qa_set = []

    for i, c in enumerate(sample_cases, 1):
        case_name = c["name"]
        year = c["decision_date"].split("-")[0] if c.get("decision_date") else "Unknown"
        citation = c["citations"][0] if c.get("citations") else None
        text = c.get("opinion_text", "").strip()

        # Use first 3–4 sentences
        sentences = text.split(". ")
        gold_answer = ". ".join(sentences[:4]) if sentences else ""

        qa_set.append({
            "id": i,
            "query": f"What did the Supreme Court decide in {case_name} ({year})?",
            "gold_answer": gold_answer,
            "gold_citations": [citation] if citation else []
        })

    return qa_set

In [12]:
qa_set = make_gold_examples(cases, n=20)
os.makedirs("data/eval", exist_ok=True)

with open("data/eval/gold_qa.jsonl", "w", encoding="utf-8") as f:
    for ex in qa_set:
        f.write(json.dumps(ex) + "\n")

print("Saved gold QA set with", len(qa_set), "examples.")

Saved gold QA set with 20 examples.


## Evaluation

Run baseline and iterative agent against pipelines against gold set to compare:
1. Semantic accuracy (answer closeness to gold)
2. Citation correctness (precision/recall)
3. Hallucination Rate (spurious cites)
4. Hop helpfulness (did multi hop improve vs baseline)

In [28]:
sim_model = SentenceTransformer("all-MiniLM-L6-v2")

def cosine_sim(a, b):
    va = sim_model.encode([a])[0]
    vb = sim_model.encode([b])[0]
    return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))

def extract_citations(text):
    # match like "572 U.S. 545" or "585 U.S. 285"
    pattern = r"\d+\s+U\.S\.\s+\d+"
    return re.findall(pattern, text or "")

In [29]:
def evaluate_results(results):
    metrics = []

    for ex in results:
        gold_answer = ex["gold_answer"]
        gold_cites  = set(ex["gold_citations"])
        sys_answer  = ex.get("system_answer")

        # --- Accuracy (semantic similarity) ---
        if sys_answer and gold_answer:
            sim = cosine_sim(sys_answer, gold_answer)
        else:
            sim = 0.0

        # --- Citation correctness ---
        sys_cites = set(extract_citations(sys_answer or ""))
        true_pos  = len(sys_cites & gold_cites)
        prec = true_pos / len(sys_cites) if sys_cites else 0
        rec  = true_pos / len(gold_cites) if gold_cites else 0

        # --- Hallucination ---
        hallucination = int(len(sys_cites - gold_cites) > 0)

        # --- Hop helpfulness ---
        hops_used = ex.get("hops_used", 1)

        metrics.append({
            "id": ex["id"],
            "semantic_similarity": sim,
            "citation_precision": prec,
            "citation_recall": rec,
            "hallucination": hallucination,
            "hops_used": hops_used
        })

    return metrics

In [30]:
def summarize_metrics(metrics):
    df = pd.DataFrame(metrics)
    return {
        "avg_similarity": df["semantic_similarity"].mean(),
        "avg_citation_precision": df["citation_precision"].mean(),
        "avg_citation_recall": df["citation_recall"].mean(),
        "hallucination_rate": df["hallucination"].mean(),
        "avg_hops_used": df["hops_used"].mean()
    }

In [31]:
def run_experiments(eval_set, use_iterative=False, top_k=3):
    results = []
    for ex in eval_set:
        q = ex["query"]

        if use_iterative:
            ans, hops_used = iterative_agent(q, top_k=top_k)
        else:
            retrieved = retrieve(q, top_k=top_k)
            prompt = build_prompt(q, retrieved)
            ans = ask_llm(prompt)
            hops_used = 1

        results.append({
            "id": ex["id"],
            "query": q,
            "gold_answer": ex["gold_answer"],
            "gold_citations": ex["gold_citations"],
            "system_answer": ans,
            "hops_used": hops_used
        })
    return results


In [33]:
eval_set = []
with open("data/eval/gold_qa.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        eval_set.append(json.loads(line))

eval_set = random.sample(eval_set, 5)

# Run both pipelines
baseline_results = run_experiments(eval_set, use_iterative=False)
iterative_results = run_experiments(eval_set, use_iterative=True)

# Evaluate
base_metrics = evaluate_results(baseline_results)
iter_metrics = evaluate_results(iterative_results)

print("Baseline:", summarize_metrics(base_metrics))
print("Iterative:", summarize_metrics(iter_metrics))

Baseline: {'avg_similarity': np.float64(0.5520434975624084), 'avg_citation_precision': np.float64(0.6), 'avg_citation_recall': np.float64(0.6), 'hallucination_rate': np.float64(0.0), 'avg_hops_used': np.float64(1.0)}
Iterative: {'avg_similarity': np.float64(0.3463660478591919), 'avg_citation_precision': np.float64(0.4), 'avg_citation_recall': np.float64(0.4), 'hallucination_rate': np.float64(0.0), 'avg_hops_used': np.float64(2.4)}


Baseline produces answers closer to the gold text.
Iterative retrieval sacrifices textual similarity but improves citation correctness.