# Self-CRAG Graph Verification

This notebook verifies that the LangGraph Self-CRAG architecture works correctly.
It tests:
1. **Node Execution**: Each node is visited
2. **Edge Conditions**: Conditional routing works
3. **Retry Loop**: Hallucination detection triggers retries
4. **Fallback**: I_DONT_KNOW is returned when appropriate

In [None]:
import os, sys

# --- Project Root ---
if os.path.exists("src"): PROJECT_ROOT = os.getcwd()
elif os.path.exists("llm-semeval-task8"): PROJECT_ROOT = "llm-semeval-task8"
else: PROJECT_ROOT = os.path.abspath("..")
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)

print(f"Project Root: {PROJECT_ROOT}")

---

## 1. Patch Graph with Debug Logging

We'll wrap each node function to log when it's called.

In [None]:
from src import graph as graph_module
from functools import wraps

# --- Execution Trace ---
execution_trace = []

def trace_node(name, original_fn):
    """Wrapper that logs node execution."""
    @wraps(original_fn)
    def wrapped(state):
        print(f"  [NODE] {name} called")
        execution_trace.append(name)
        result = original_fn(state)
        print(f"  [NODE] {name} -> {list(result.keys())}")
        return result
    return wrapped

# --- Patch All Nodes ---
original_rewrite = graph_module.rewrite_node
original_retrieve = graph_module.retrieve_node
original_grade = graph_module.grade_documents_node
original_generate = graph_module.generate_node
original_hallucination = graph_module.hallucination_check_node
original_increment = graph_module.increment_retry_node
original_fallback = graph_module.fallback_node

graph_module.rewrite_node = trace_node("rewrite", original_rewrite)
graph_module.retrieve_node = trace_node("retrieve", original_retrieve)
graph_module.grade_documents_node = trace_node("grade_docs", original_grade)
graph_module.generate_node = trace_node("generate", original_generate)
graph_module.hallucination_check_node = trace_node("hallucination_check", original_hallucination)
graph_module.increment_retry_node = trace_node("increment_retry", original_increment)
graph_module.fallback_node = trace_node("fallback", original_fallback)

print("Nodes patched with debug logging.")

---

## 2. Initialize Graph

In [None]:
print("Initializing Graph...")
app = graph_module.initialize_graph()
print("Graph Ready.")

---

## 3. Test Case 1: Normal Flow (Happy Path)

A question that should retrieve relevant documents and generate a grounded answer.

In [None]:
execution_trace.clear()
print("="*60)
print("TEST 1: Happy Path (Normal Flow)")
print("="*60)

result = app.invoke({"question": "What is climate change?", "domain": "govt"})

print(f"\nExecution Trace: {execution_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')[:200]}...")

# --- Assertions ---
assert "rewrite" in execution_trace, "Rewrite node not called!"
assert "retrieve" in execution_trace, "Retrieve node not called!"
assert "grade_docs" in execution_trace, "Grade docs node not called!"
print("\n✅ TEST 1 PASSED: Core nodes executed.")

---

## 4. Test Case 2: Fallback (No Relevant Documents)

A nonsensical question that should trigger the fallback.

In [None]:
execution_trace.clear()
print("="*60)
print("TEST 2: Fallback (Nonsense Query)")
print("="*60)

result = app.invoke({"question": "asdfghjkl zxcvbnm qwertyuiop 12345", "domain": "govt"})

print(f"\nExecution Trace: {execution_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')}")

# --- Check for I_DONT_KNOW ---
gen = result.get('generation', '')
if "I_DONT_KNOW" in gen or "fallback" in execution_trace:
    print("\n✅ TEST 2 PASSED: Fallback triggered correctly.")
else:
    print("\n⚠️ TEST 2 INFO: Model attempted to answer. Check if grading is too lenient.")

---

## 5. Test Case 3: Retry Loop Verification

We manually force hallucination to verify the retry mechanism.

In [None]:
execution_trace.clear()
print("="*60)
print("TEST 3: Retry Loop (Force Hallucination)")
print("="*60)

# This test checks if increment_retry is ever called during normal execution
# If the model hallucinates, it should retry

# We use a tricky question
result = app.invoke({"question": "Tell me something completely made up that is not in any document.", "domain": "govt"})

print(f"\nExecution Trace: {execution_trace}")
print(f"Final Generation: {result.get('generation', 'N/A')[:200]}...")

if "increment_retry" in execution_trace:
    print(f"\n✅ TEST 3 PASSED: Retry loop was triggered ({execution_trace.count('increment_retry')} retries).")
else:
    print("\n⚠️ TEST 3 INFO: No retry triggered. Model may have been grounded or skipped hallucination check.")

---

## 6. Summary

In [None]:
print("="*60)
print("GRAPH VERIFICATION SUMMARY")
print("="*60)
print("\nNodes that were exercised across all tests:")

all_nodes = ["rewrite", "retrieve", "grade_docs", "generate", "hallucination_check", "increment_retry", "fallback"]
visited = set(execution_trace)

for node in all_nodes:
    status = "✅" if node in visited else "❌"
    print(f"  {status} {node}")

print("\n" + "="*60)