# Corrective RAG (CRAG) Implementation

In this notebook, I implement Corrective RAG - an advanced approach that dynamically evaluates retrieved information and corrects the retrieval process when necessary, using web search as a fallback.

CRAG improves on traditional RAG by:

- Evaluating retrieved content before using it
- Dynamically switching between knowledge sources based on relevance
- Correcting the retrieval with web search when local knowledge is insufficient
- Combining information from multiple sources when appropriate

## Setting Up the Environment
We begin by importing necessary libraries.

In [51]:
import os
import numpy as np
import json
import google.generativeai as genai
import re
import pickle
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PyPDF2 import PdfReader

In [52]:

import fitz
import os
import google.generativeai as genai
from dotenv import load_dotenv


## Enitre Code

In [50]:
import os
import numpy as np
import json
import google.generativeai as genai
import re
import pickle
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PyPDF2 import PdfReader

# Configure your API key here
# genai.configure(api_key="YOUR_GOOGLE_API_KEY")

class SimpleVectorStore:
    def __init__(self) -> None:
        self.embeddings: List[np.ndarray] = []
        self.documents: List[Dict[str, Any]] = []
    
    def add_item(self, *, text: str, embedding: Union[np.ndarray, List[float]], metadata: Optional[Dict[str, Any]] = None) -> None:
        doc = {"text": text, "metadata": (metadata or {})}
        self._add_core(embedding, doc)

    def add(self, embedding: Union[np.ndarray, List[float]], document: Union[Dict[str, Any], str]) -> None:
        if isinstance(document, str):
            doc = {"text": document, "metadata": {}}
        else:
            doc = {"text": document.get("text", ""), "metadata": document.get("metadata", {})}
        self._add_core(embedding, doc)

    def similarity_search(self, query_embedding: Union[np.ndarray, List[float]], k: int = 5, filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, return_scores: bool = False):
        if not self.embeddings:
            return []
        
        q = self._as_unit_vector(query_embedding)
        sims: List[Tuple[int, float]] = []
        for i, emb in enumerate(self.embeddings):
            doc = self.documents[i]
            if filter_func is not None and not filter_func(doc):
                continue
            sim = float(np.dot(q, emb))
            sims.append((i, sim))

        if not sims:
            return []
        
        sims.sort(key=lambda t: t[1], reverse=True)
        top = sims[:max(0, int(k))]
        
        if return_scores:
            return [(self.documents[i], score) for i, score in top]
        else:
            return [self.documents[i] for i, _ in top]
    
    def save(self, path: str) -> None:
        with open(path, "wb") as f:
            pickle.dump({"embeddings": self.embeddings, "documents": self.documents}, f)

    @classmethod
    def load(cls, path: str) -> "SimpleVectorStore":
        with open(path, "rb") as f:
            data = pickle.load(f)
        store = cls()
        store.embeddings = [np.asarray(e, dtype=float).reshape(-1) for e in data.get("embeddings", [])]
        store.documents = data.get("documents", [])
        store.embeddings = [store._as_unit_vector(e) for e in store.embeddings]
        return store

    def _add_core(self, embedding: Union[np.ndarray, List[float]], doc: Dict[str, Any]) -> None:
        emb = self._as_unit_vector(embedding)
        self.embeddings.append(emb)
        self.documents.append(doc)

    def _as_unit_vector(self, x: Union[np.ndarray, List[float]]) -> np.ndarray:
        v = np.asarray(x, dtype=float).reshape(-1)
        n = np.linalg.norm(v)
        if n == 0:
            return v
        return v / n

def _safe_get_page(doc: Dict[str, Any]):
    meta = doc.get("metadata") or {}
    for k in ("page", "page_num", "page_number", "pg"):
        if k in meta:
            return meta[k]
    return None

def extract_text_pages(pdf_path: str) -> List[str]:
    pages = []
    try:
        with open(pdf_path, "rb") as f:
            reader = PdfReader(f)
            for p in reader.pages:
                pages.append(p.extract_text() or "")
    except FileNotFoundError:
        print(f"Error: The file at {pdf_path} was not found.")
        return []
    return pages

def chunk_text(text: str, chunk_size=1000, overlap=200) -> List[Tuple[str, int]]:
    chunks = []
    start = 0
    n = len(text)
    while start < n:
        end = min(n, start + chunk_size)
        chunks.append((text[start:end], start))
        if end == n: break
        start = max(end - overlap, start + 1)
    return chunks

def create_embeddings(texts: Union[str, List[str]], model="models/embedding-001", max_retries=3, sleep_sec=1.5):
    if not texts:
        return [] if isinstance(texts, list) else None

    input_texts = texts if isinstance(texts, list) else [texts]
    vectors = []
    for t in input_texts:
        if t is None:
            vectors.append([])
            continue
        text = str(t)
        
        last_err = None
        for attempt in range(1, max_retries + 1):
            try:
                resp = genai.embed_content(model=model, content=text)
                if isinstance(resp, dict):
                    emb = resp.get("embedding", {})
                    values = emb.get("values", [])
                else:
                    values = getattr(getattr(resp, "embedding", None), "values", [])
                vectors.append(list(values))
                break
            except Exception as e:
                last_err = e
                if attempt < max_retries:
                    time.sleep(sleep_sec * attempt)
                else:
                    vectors.append([])
        if last_err and vectors[-1] == []:
            pass

    return vectors if isinstance(texts, list) else vectors[0]

def generate_page_summary(page_text: str) -> str:
    system_prompt = "You are an AI assistant. Summarize the following text in a concise paragraph."
    max_chars = 3000
    truncated_text = page_text[:max_chars] if len(page_text) > max_chars else page_text
    model = genai.GenerativeModel("gemini-1.5-flash")
    response = model.generate_content(
        f"{system_prompt}\n\n{truncated_text}",
        generation_config={"temperature": 0.3}
    )
    return response.text.strip()

def process_document_hierarchically(pdf_path: str, chunk_size=1000, chunk_overlap=200):
    pages = extract_text_pages(pdf_path)
    if not pages:
        return SimpleVectorStore(), SimpleVectorStore()
    
    detailed_store = SimpleVectorStore()
    detailed_text_items = []
    for page_idx, page_text in enumerate(pages):
        page_text = page_text.strip()
        if not page_text:
            continue
        detailed_text_items.append((page_idx, page_text))
        for chunk_text_str, start_idx in chunk_text(page_text, chunk_size, chunk_overlap):
            emb = create_embeddings(chunk_text_str)
            if emb:
                detailed_store.add_item(
                    text=chunk_text_str,
                    embedding=emb,
                    metadata={"page": page_idx, "char_start": start_idx}
                )

    summary_store = SimpleVectorStore()
    for page_idx, full_text in detailed_text_items:
        summary_text = full_text[:350] if len(full_text) > 350 else full_text
        emb = create_embeddings(summary_text)
        if emb:
            summary_store.add_item(
                text=summary_text,
                embedding=emb,
                metadata={"page": page_idx}
            )
    return summary_store, detailed_store

def retrieve_hierarchically(query: str, summary_store: SimpleVectorStore, detailed_store: SimpleVectorStore, k_summaries=5, k_chunks=10):
    query_emb = create_embeddings(query)
    if not query_emb:
        return []
    
    summary_docs = summary_store.similarity_search(query_emb, k=k_summaries, return_scores=False)
    relevant_pages = {p for p in (_safe_get_page(d) for d in summary_docs) if p is not None and p != -1}

    if not relevant_pages:
        return detailed_store.similarity_search(query_emb, k=max(1, k_chunks), return_scores=False)

    def page_filter(doc):
        return _safe_get_page(doc) in relevant_pages

    detailed_docs = detailed_store.similarity_search(
        query_emb,
        k=max(1, k_chunks * max(1, len(relevant_pages))),
        filter_func=page_filter,
        return_scores=False
    )
    return detailed_docs

def generate_response(query: str, retrieved_docs: List[Dict[str, Any]], max_chars=1200):
    context = ""
    for d in retrieved_docs:
        t = d.get("text", "")
        if not t: continue
        space_left = max_chars - len(context)
        if space_left <= 0: break
        context += (t[:space_left] + "\n")
    answer = f"Based on the document, key points include:\n{context.strip()}"
    return answer

def hierarchical_rag(query, pdf_path, chunk_size=1000, chunk_overlap=200, k_summaries=5, k_chunks=10, regenerate=False):
    base = os.path.splitext(os.path.basename(pdf_path))[0]
    summary_pkl = f"{base}__summary_store.pkl"
    detailed_pkl = f"{base}__detailed_store.pkl"

    if regenerate or not (os.path.exists(summary_pkl) and os.path.exists(detailed_pkl)):
        print("Processing document and creating vector stores...")
        summary_store, detailed_store = process_document_hierarchically(pdf_path, chunk_size, chunk_overlap)
        with open(summary_pkl, "wb") as f: pickle.dump(summary_store, f)
        with open(detailed_pkl, "wb") as f: pickle.dump(detailed_store, f)
    else:
        print("Loading existing vector stores...")
        with open(summary_pkl, "rb") as f: summary_store = pickle.load(f)
        with open(detailed_pkl, "rb") as f: detailed_store = pickle.load(f)

    retrieved = retrieve_hierarchically(query, summary_store, detailed_store, k_summaries, k_chunks)
    response = generate_response(query, retrieved)

    return {"response": response, "chunks_used": retrieved}

def standard_rag(query: str, pdf_path: str, k=15):
    base = os.path.splitext(os.path.basename(pdf_path))[0]
    detailed_pkl = f"{base}__detailed_store.pkl"
    if not os.path.exists(detailed_pkl):
        _ = hierarchical_rag("warmup", pdf_path, regenerate=True)
    with open(detailed_pkl, "rb") as f:
        detailed_store = pickle.load(f)

    query_emb = create_embeddings(query)
    docs = detailed_store.similarity_search(query_emb, k=k, return_scores=False)
    return {"response": generate_response(query, docs), "chunks_used": docs}

def _cosine(a, b):
    a = np.asarray(a, dtype=float).reshape(-1)
    b = np.asarray(b, dtype=float).reshape(-1)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    if na == 0 or nb == 0: return 0.0
    return float(np.dot(a, b) / (na * nb))

def run_evaluation(pdf_path: str, test_queries: List[str], reference_answers: List[str]):
    results = []
    for q in test_queries:
        hier = hierarchical_rag(q, pdf_path)
        base = standard_rag(q, pdf_path)

        hier_emb = create_embeddings(hier["response"])
        base_emb = create_embeddings(base["response"])

        ref_scores_h = []
        ref_scores_b = []
        for ref in reference_answers:
            ref_emb = create_embeddings(ref)
            ref_scores_h.append(_cosine(hier_emb, ref_emb))
            ref_scores_b.append(_cosine(base_emb, ref_emb))

        results.append({
            "query": q,
            "hierarchical_score": max(ref_scores_h) if ref_scores_h else 0.0,
            "standard_score": max(ref_scores_b) if ref_scores_b else 0.0,
            "hierarchical_response": hier["response"],
            "standard_response": base["response"],
        })

    avg_h = sum(r["hierarchical_score"] for r in results) / max(1, len(results))
    avg_b = sum(r["standard_score"] for r in results) / max(1, len(results))
    verdict = "Hierarchical > Standard" if avg_h > avg_b else ("Standard > Hierarchical" if avg_b > avg_h else "Tie")

    return {
        "per_query": results,
        "overall": {"avg_hierarchical": avg_h, "avg_standard": avg_b, "verdict": verdict},
        "overall_analysis": f"Avg(Hierarchical)={avg_h:.3f}, Avg(Standard)={avg_b:.3f} → {verdict}"
    }

# Path to the PDF document containing AI information
pdf_path = "/Users/kekunkoya/Desktop/ISEM 770 Class Project/Homelessness.pdf"

# Example query about AI for testing the hierarchical RAG approach
query = "What have been done to prevent homelessness?"
result = hierarchical_rag(query, pdf_path)

print("\n=== Response ===")
print(result["response"])

# Test query for formal evaluation (using only one query as requested)
test_queries = [
     "What are the strategies to prevent homelessness?"
]

# Reference answer for the test query to enable comparison
reference_answers = [
    "Prevent new incidences of homelessness through early intervention and support services.",  "Address and mitigate the underlying causes of homelessness, such as poverty, unemployment, and lack of affordable housing.", "Reduce the overall number of people experiencing homelessness via rapid rehousing and housing-first models.", "Minimize the negative social, health, and economic impacts on individuals and families currently experiencing homelessness.", "Ensure formerly homeless people maintain permanent, independent housing through ongoing support and follow-up services."
]


# Run the evaluation comparing hierarchical and standard RAG approaches
evaluation_results = run_evaluation(
    pdf_path=pdf_path,
    test_queries=test_queries,
    reference_answers=reference_answers
)

# Print the overall analysis of the comparison
print("\n=== OVERALL ANALYSIS ===")
print(evaluation_results["overall_analysis"])

Loading existing vector stores...

=== Response ===
Based on the document, key points include:

Loading existing vector stores...

=== OVERALL ANALYSIS ===
Avg(Hierarchical)=0.000, Avg(Standard)=0.000 → Tie
