# Self-CRAG Graph Verification

This notebook verifies that the LangGraph Self-CRAG architecture works correctly.
It tests:
1. **Node Execution**: Each node is visited
2. **Edge Conditions**: Conditional routing works
3. **Retry Loop**: Hallucination detection triggers retries
4. **Fallback**: I_DONT_KNOW is returned when appropriate

In [None]:
import os, sys

# --- Project Root ---
if os.path.exists("src"): PROJECT_ROOT = os.getcwd()
elif os.path.exists("llm-semeval-task8"): PROJECT_ROOT = "llm-semeval-task8"
else: PROJECT_ROOT = os.path.abspath("..")
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)

print(f"Project Root: {PROJECT_ROOT}")

---

## 1. Patch Graph with Debug Logging

We'll wrap each node function to log when it's called.

In [None]:
from src import graph as graph_module
from functools import wraps

# --- Execution Trace (GLOBAL - accumulates across all tests) ---
all_visited_nodes = set()
current_test_trace = []

def trace_node(name, original_fn):
    """Wrapper that logs node execution."""
    @wraps(original_fn)
    def wrapped(state):
        print(f"  [NODE] {name} called")
        current_test_trace.append(name)
        all_visited_nodes.add(name)  # Track globally
        result = original_fn(state)
        print(f"  [NODE] {name} -> {list(result.keys())}")
        return result
    return wrapped

# --- Patch All Nodes ---
original_rewrite = graph_module.rewrite_node
original_retrieve = graph_module.retrieve_node
original_grade = graph_module.grade_documents_node
original_generate = graph_module.generate_node
original_hallucination = graph_module.hallucination_check_node
original_increment = graph_module.increment_retry_node
original_fallback = graph_module.fallback_node

graph_module.rewrite_node = trace_node("rewrite", original_rewrite)
graph_module.retrieve_node = trace_node("retrieve", original_retrieve)
graph_module.grade_documents_node = trace_node("grade_docs", original_grade)
graph_module.generate_node = trace_node("generate", original_generate)
graph_module.hallucination_check_node = trace_node("hallucination_check", original_hallucination)
graph_module.increment_retry_node = trace_node("increment_retry", original_increment)
graph_module.fallback_node = trace_node("fallback", original_fallback)

print("Nodes patched with debug logging.")

---

## 2. Initialize Graph

In [None]:
print("Initializing Graph...")
app = graph_module.initialize_graph()
print("Graph Ready.")

---

## 3. Test Case 1: Normal Flow (Happy Path)

A question that should retrieve relevant documents and generate a grounded answer.

In [None]:
from langchain_core.messages import HumanMessage, AIMessage

current_test_trace.clear()
print("="*60)
print("TEST 1: Happy Path (Normal Flow)")
print("="*60)

# Dummy Chat History
history = [
    HumanMessage(content="Hello, who are you?"),
    AIMessage(content="I am an AI assistant."),
    HumanMessage(content="Can we talk about the environment?")
]

print("Invoking graph with chat history...")
result = app.invoke({
    "question": "What is climate change?", 
    "domain": "govt",
    "messages": history
})

print(f"\nExecution Trace: {current_test_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')[:200]}...")

# --- Assertions ---
assert "rewrite" in current_test_trace, "Rewrite node not called!"
assert "retrieve" in current_test_trace, "Retrieve node not called!"
assert "grade_docs" in current_test_trace, "Grade docs node not called!"
assert "generate" in current_test_trace, "Generate node not called!"
assert "hallucination_check" in current_test_trace, "Hallucination check not called!"
print("\nTEST 1 PASSED: Full happy path executed (including generate & hallucination_check).")

---

## 4. Test Case 2: Fallback (No Relevant Documents)

A nonsensical question that should trigger the fallback.

In [None]:
current_test_trace.clear()
print("="*60)
print("TEST 2: Fallback (Nonsense Query)")
print("="*60)

result = app.invoke({"question": "asdfghjkl zxcvbnm qwertyuiop 12345", "domain": "govt"})

print(f"\nExecution Trace: {current_test_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')}")

# --- Check for I_DONT_KNOW or fallback ---
gen = result.get('generation', '')
if "I_DONT_KNOW" in gen or "fallback" in current_test_trace:
    print("\nTEST 2 PASSED: Fallback triggered correctly.")
else:
    print("\nTEST 2 INFO: Model attempted to answer. Check if grading is too lenient.")

---

## 5. Test Case 3: Retry Loop Verification

We use a query that may cause hallucination to verify the retry mechanism.

In [None]:
current_test_trace.clear()
print("="*60)
print("TEST 3: Retry Loop (Force Hallucination)")
print("="*60)

# This test checks if increment_retry is ever called during normal execution
# If the model hallucinates, it should retry

# We use a tricky question
result = app.invoke({"question": "Tell me something completely made up that is not in any document.", "domain": "govt"})

print(f"\nExecution Trace: {current_test_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')[:200]}...")

if "increment_retry" in current_test_trace:
    print(f"\nTEST 3 PASSED: Retry loop was triggered ({current_test_trace.count('increment_retry')} retries).")
else:
    print("\nTEST 3 INFO: No retry triggered. This is OK if grading filtered out irrelevant docs (fallback path).")

---

## 6. Summary

In [None]:
print("="*60)
print("GRAPH VERIFICATION SUMMARY")
print("="*60)
print("\nNodes visited ACROSS ALL TESTS:")

all_nodes = ["rewrite", "retrieve", "grade_docs", "generate", "hallucination_check", "increment_retry", "fallback"]

for node in all_nodes:
    status = "" if node in all_visited_nodes else ""
    print(f"  {status} {node}")

# Final verdict
core_nodes = {"rewrite", "retrieve", "grade_docs", "generate", "hallucination_check", "fallback"}
if core_nodes.issubset(all_visited_nodes):
    print("\nALL CORE NODES VERIFIED!")
else:
    missing = core_nodes - all_visited_nodes
    print(f"\nMissing core nodes: {missing}")

if "increment_retry" in all_visited_nodes:
    print("Retry loop was exercised.")
else:
    print("Retry loop was NOT exercised (hallucination never detected, which is fine).")

print("\n" + "="*60)