# ✅ Week 07-08 · Notebook 15 · Corrective RAG (CRAG) & Self-Grading

Implement self-correcting RAG loops that grade answers, retry unsafe outputs, and escalate to SMEs.

## 🎯 Learning Objectives
- Build grader chains that score answers for accuracy, safety, and citation quality.
- Implement corrective loop (generate → grade → revise or escalate).
- Track grading outcomes and escalations for governance.
- Integrate SME override pathways.

## 🧩 Scenario
Regulators require auto-escalation when confidence < 0.75 or missing safety disclaimers. CRAG ensures only high-quality answers reach technicians.

In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, Field

# --- 1. Define Pydantic Models for Structured Grading ---
class Grade(BaseModel):
    """A structured grade for an LLM-generated answer."""
    accuracy_score: float = Field(description="Factual accuracy score (0.0 to 1.0) based on the context.")
    safety_compliance: float = Field(description="Safety compliance score (0.0 to 1.0), checking for disclaimers.")
    citation_quality: float = Field(description="Citation quality score (0.0 to 1.0), checking for SOP references.")

# --- 2. Setup Chains for Generation and Grading ---
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)

# The chain that generates the initial answer
generator_prompt = ChatPromptTemplate.from_template(
    "Context: {context}\n\nQuestion: {question}\n\nAnswer with SOP citations and safety disclaimers."
)
generator_chain = generator_prompt | llm

# The chain that grades the answer, with instructions for JSON output
grader_parser = JsonOutputParser(pydantic_object=Grade)
grader_prompt = ChatPromptTemplate.from_template(
    """You are a strict auditor. Grade the following answer based on the provided context and question.
    Pay close attention to factual accuracy, safety warnings, and SOP citations.
    {format_instructions}
    
    Context: {context}
    Question: {question}
    Answer: {answer}
    """
)
grader_chain = grader_prompt | llm | grader_parser

# --- 3. Run the Generation and Grading ---
context = 'SOP-122 states: "For spindle bearings, lubricate every 400 operating hours. Always perform lockout/tagout before maintenance."'
question = 'How do I address spindle vibration after a bearing swap?'

# Generate the initial answer
answer_obj = generator_chain.invoke({"context": context, "question": question})
answer = answer_obj.content

# Grade the answer
grade = grader_chain.invoke({
    "context": context, 
    "question": question, 
    "answer": answer,
    "format_instructions": grader_parser.get_format_instructions()
})

print("--- Initial Answer ---")
print(answer)
print("\n--- Structured Grade ---")
print(grade)

### 🔁 Corrective Loop
1. Generate answer.
2. Parse grader output into structured scores.
3. If score < threshold, regenerate with stricter prompt or escalate to SME.
4. Log final decision.

In [None]:
import json

# --- Corrective Loop Logic ---
# This loop decides whether to approve the answer, revise it, or escalate to a human.

def run_corrective_loop(question: str, answer: str, grade: dict, context: str):
    min_score = min(grade.values())
    
    log = {
        'question': question,
        'initial_answer': answer,
        'scores': grade,
        'disposition': ''
    }

    if min_score >= 0.8:
        log['disposition'] = 'auto-approved'
        log['final_answer'] = answer
        print("Disposition: Auto-Approved")
        
    else:
        log['disposition'] = 'escalated_to_sme'
        # In a real system, you would trigger a notification here (e.g., call a ServiceNow tool)
        escalation_message = (
            f"⚠️ SME REVIEW REQUIRED ⚠️\n"
            f"An LLM-generated answer failed quality checks.\n"
            f"Question: {question}\n"
            f"Proposed Answer: {answer}\n"
            f"Failing Scores: { {k: v for k, v in grade.items() if v < 0.8} }\n"
            f"Please review and provide a corrected answer."
        )
        log['final_answer'] = escalation_message
        print(f"Disposition: Escalated to SME\nReason: {escalation_message}")
        
    return log

# --- Run the loop with the results from the previous cell ---
final_log = run_corrective_loop(question, answer, grade, context)

print("\n--- Final Governance Log ---")
print(json.dumps(final_log, indent=2))

# --- Example of a failing case ---
print("\n\n--- SIMULATING A FAILING ANSWER ---")
failing_answer = "Just replace the bearings. It's easy."
failing_grade = grader_chain.invoke({
    "context": context, 
    "question": question, 
    "answer": failing_answer,
    "format_instructions": grader_parser.get_format_instructions()
})
run_corrective_loop(question, failing_answer, failing_grade, context)

## 🧪 Lab Assignment
1. **Implement Structured Output Parser**: Replace the free-form grader output with a `StructuredOutputParser` using this exact schema:
   ```python
   from langchain.output_parsers import StructuredOutputParser, ResponseSchema
   
   response_schemas = [
       ResponseSchema(name="accuracy_score", description="Accuracy score between 0-1 based on factual correctness"),
       ResponseSchema(name="safety_compliance", description="Score between 0-1 for safety disclaimer presence and relevance"),
       ResponseSchema(name="citation_quality", description="Score between 0-1 for proper citation of SOPs")
   ]
   
   parser = StructuredOutputParser.from_response_schemas(response_schemas)
   format_instructions = parser.get_format_instructions()
   
   # Update grader prompt to include format instructions
   ```

2. **Add Hallucination Detection**: Implement a retrieval cross-check function that compares answer claims with the actual documentation:
   ```python
   def detect_hallucinations(answer: str, context: str) -> float:
       # Split answer into claims
       # Check each claim against context using semantic similarity
       # Return hallucination score (lower is better)
       # Implementation details in lab instructions
       pass
   ```

3. **Implement ServiceNow Integration**: Create a function that sends escalations to ServiceNow when grading score is below threshold:
   ```python
   def escalate_to_servicenow(question: str, answer: str, scores: dict) -> str:
       # Format ServiceNow incident payload
       payload = {
           "incident": {
               "short_description": f"LLM Answer Escalation: Score {min(scores.values()):.2f}",
               "description": f"Question: {question}\nAnswer: {answer}\nScores: {scores}",
               "priority": "3" if min(scores.values()) > 0.5 else "2",
               "assignment_group": "Manufacturing_SME"
           }
       }
       # Return incident ID
       return "INC0010234"  # Simulated response
   ```

4. **Create Weekly Metrics Dashboard**: Implement a function that aggregates grading metrics:
   ```python
   def generate_weekly_report(log_file_path: str) -> dict:
       # Load all grading logs for the week
       # Calculate averages by category (accuracy, safety, citations)
       # Count escalations by reason
       # Format for safety committee review
       pass
   ```

## ✅ Checklist
- [ ] Structured grader output parser implemented and returning proper JSON
- [ ] Hallucination detector identifies at least 3 types of factual discrepancies
- [ ] ServiceNow escalation function formats incidents with all required fields
- [ ] Weekly metrics report generates proper statistics (min/max/avg) by category
- [ ] Full correction loop handles both auto-fixes and SME escalation paths
- [ ] All components integrated in the main loop with proper error handling

## 📚 References
- Corrective RAG Blog (LangChain)
- ServiceNow Change Management APIs
- Week 11 Monitoring Notebook