# 🧠 Self-RAG Implementation

## Self-Reflective Retrieval-Augmented Generation

This notebook demonstrates a **Self-RAG system** with:
- 🤔 **Automatic Retrieval Decisions** - Model decides when to retrieve
- 🔍 **Self-Reflection Tokens** - Special tokens for internal reasoning
- 📊 **Multi-Stage Evaluation** - Relevance, support, and utility assessment
- 🎯 **Adaptive Processing** - Dynamic response generation based on self-assessment
- 🧠 **Meta-Cognitive Awareness** - Understanding of its own knowledge limitations

### Key Innovation: Self-Reflection Tokens
- **[Retrieve]**: Decision to retrieve additional information
- **[ISREL]**: Relevance assessment of retrieved documents
- **[ISSUP]**: Support evaluation - does retrieval support the answer?
- **[ISUSE]**: Utility judgment - is the generated response useful?

### Benefits of Self-RAG
- **Intelligent Retrieval**: Only retrieves when necessary
- **Quality Control**: Self-evaluates response quality
- **Transparency**: Reasoning process is visible
- **Adaptability**: Adjusts strategy based on confidence

In [None]:
# Install required packages
!pip install sentence-transformers faiss-cpu google-generativeai rank-bm25 transformers scikit-learn numpy python-dotenv torch

In [None]:
# Import libraries
import numpy as np
import json
import re
import os
import time
import uuid
import random
from typing import List, Dict, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from enum import Enum
from datetime import datetime, timedelta

from sentence_transformers import SentenceTransformer
import faiss
import google.generativeai as genai
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

load_dotenv()
print("📚 Libraries imported successfully!")

## 🎯 Self-Reflection Token System

Define the core self-reflection tokens and decision framework:

In [None]:
# Self-Reflection Tokens
class ReflectionToken(Enum):
    RETRIEVE = "[Retrieve]"
    NO_RETRIEVE = "[No Retrieve]"
    ISREL_YES = "[ISREL: Yes]"
    ISREL_PARTIAL = "[ISREL: Partial]"
    ISREL_NO = "[ISREL: No]"
    ISSUP_FULL = "[ISSUP: Fully Supported]"
    ISSUP_PARTIAL = "[ISSUP: Partially Supported]"
    ISSUP_NO = "[ISSUP: Not Supported]"
    ISUSE_HIGH = "[ISUSE: High Utility]"
    ISUSE_MEDIUM = "[ISUSE: Medium Utility]"
    ISUSE_LOW = "[ISUSE: Low Utility]"

# Query complexity and confidence levels
class ConfidenceLevel(Enum):
    HIGH = "high"
    MEDIUM = "medium"
    LOW = "low"

class QueryComplexity(Enum):
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"

class Domain(Enum):
    TECHNOLOGY = "technology"
    SCIENCE = "science"
    MEDICINE = "medicine"
    BUSINESS = "business"
    GENERAL = "general"

# Core data structures
@dataclass
class SelfRAGQuery:
    id: str
    text: str
    domain: Domain
    complexity: QueryComplexity
    user_id: Optional[str] = None
    timestamp: datetime = field(default_factory=datetime.now)

@dataclass
class RetrievalDecision:
    should_retrieve: bool
    confidence: float
    reasoning: str
    token: ReflectionToken

@dataclass
class RelevanceAssessment:
    relevance_score: float
    token: ReflectionToken
    reasoning: str

@dataclass
class SupportEvaluation:
    support_level: float
    token: ReflectionToken
    evidence: List[str]
    reasoning: str

@dataclass
class UtilityJudgment:
    utility_score: float
    token: ReflectionToken
    reasoning: str
    improvement_suggestions: List[str]

@dataclass
class Document:
    id: str
    title: str
    content: str
    domain: Domain
    embedding: Optional[np.ndarray] = None
    keywords: List[str] = field(default_factory=list)

@dataclass
class SelfRAGResponse:
    query: SelfRAGQuery
    retrieval_decision: RetrievalDecision
    retrieved_documents: List[Document]
    relevance_assessment: Optional[RelevanceAssessment]
    generated_answer: str
    support_evaluation: SupportEvaluation
    utility_judgment: UtilityJudgment
    processing_time: float
    reflection_chain: List[str]

print("🎯 Self-Reflection token system defined!")

## 📚 Enhanced Knowledge Base

Create a comprehensive knowledge base for testing Self-RAG:

In [None]:
# Enhanced knowledge base with medical and scientific content
knowledge_base = [
    {
        "id": "tech_001",
        "title": "Large Language Models and Transformers",
        "domain": "technology",
        "content": "Large Language Models (LLMs) like GPT, BERT, and T5 are based on the Transformer architecture introduced in 'Attention Is All You Need'. These models use self-attention mechanisms to process sequences and have revolutionized natural language processing. Key components include multi-head attention, position encodings, and feed-forward networks. Training involves massive datasets and requires significant computational resources.",
        "keywords": ["LLM", "transformer", "attention", "BERT", "GPT", "NLP"]
    },
    {
        "id": "med_001",
        "title": "COVID-19 Symptoms and Treatment",
        "domain": "medicine",
        "content": "COVID-19 symptoms include fever, cough, shortness of breath, fatigue, muscle aches, headache, loss of taste or smell, sore throat, and congestion. Severe cases may develop pneumonia, acute respiratory distress syndrome (ARDS), or multi-organ failure. Treatment approaches include supportive care, antiviral medications like Paxlovid, monoclonal antibodies, and in severe cases, corticosteroids and mechanical ventilation. Vaccination remains the primary prevention strategy.",
        "keywords": ["COVID-19", "symptoms", "treatment", "vaccine", "pneumonia", "antiviral"]
    },
    {
        "id": "sci_001",
        "title": "Quantum Entanglement and Bell's Theorem",
        "domain": "science",
        "content": "Quantum entanglement is a phenomenon where particles become correlated in such a way that the quantum state of each particle cannot be described independently. Bell's theorem demonstrates that no physical theory based on local hidden variables can reproduce all the predictions of quantum mechanics. Bell test experiments have consistently violated Bell inequalities, supporting quantum mechanics over local realism. This has implications for quantum computing, cryptography, and our understanding of reality.",
        "keywords": ["quantum", "entanglement", "Bell theorem", "locality", "hidden variables"]
    },
    {
        "id": "med_002",
        "title": "Diabetes Management and Blood Sugar Control",
        "domain": "medicine",
        "content": "Diabetes management focuses on maintaining blood glucose levels within target ranges. Type 1 diabetes requires insulin therapy, while Type 2 may be managed with lifestyle modifications, oral medications (metformin, sulfonylureas), or insulin. Continuous glucose monitoring (CGM) and insulin pumps have improved management. Complications include diabetic retinopathy, nephropathy, neuropathy, and cardiovascular disease. Regular HbA1c testing monitors long-term glucose control.",
        "keywords": ["diabetes", "insulin", "glucose", "HbA1c", "metformin", "CGM"]
    },
    {
        "id": "tech_002",
        "title": "Retrieval-Augmented Generation (RAG)",
        "domain": "technology",
        "content": "Retrieval-Augmented Generation (RAG) combines parametric knowledge from pre-trained language models with non-parametric knowledge from external sources. The system retrieves relevant documents from a knowledge base and uses them to augment the generation process. RAG variants include Self-RAG, which adds self-reflection capabilities, and GraphRAG, which uses graph structures. Benefits include factual accuracy, reduced hallucinations, and ability to incorporate up-to-date information.",
        "keywords": ["RAG", "retrieval", "generation", "knowledge base", "hallucination"]
    },
    {
        "id": "bus_001",
        "title": "Artificial Intelligence in Healthcare Business",
        "domain": "business",
        "content": "AI in healthcare represents a $45 billion market with applications in diagnostics, drug discovery, personalized medicine, and operational efficiency. Key technologies include machine learning for medical imaging, natural language processing for clinical documentation, and predictive analytics for patient outcomes. Challenges include regulatory approval, data privacy (HIPAA compliance), integration with existing systems, and physician adoption. Success factors include clinical validation, workflow integration, and demonstrable ROI.",
        "keywords": ["AI healthcare", "medical imaging", "drug discovery", "HIPAA", "clinical validation"]
    },
    {
        "id": "sci_002",
        "title": "CRISPR Gene Editing Technology",
        "domain": "science",
        "content": "CRISPR-Cas9 is a revolutionary gene editing tool that allows precise modification of DNA sequences. The system uses guide RNAs to direct the Cas9 nuclease to specific genomic locations where it creates double-strand breaks. This enables gene knockout, knock-in, or correction of mutations. Applications include treating genetic diseases, improving crops, and basic research. Ethical considerations include germline editing, off-target effects, and equitable access to therapies.",
        "keywords": ["CRISPR", "gene editing", "Cas9", "guide RNA", "genetic disease"]
    },
    {
        "id": "med_003",
        "title": "Immunotherapy and Cancer Treatment",
        "domain": "medicine",
        "content": "Immunotherapy harnesses the immune system to fight cancer. Key approaches include checkpoint inhibitors (PD-1, PD-L1, CTLA-4 antibodies), CAR-T cell therapy, and cancer vaccines. Checkpoint inhibitors remove brakes on T-cells, allowing them to attack tumors. CAR-T involves modifying patient T-cells to recognize cancer antigens. Success rates vary by cancer type, with remarkable responses in melanoma, lung cancer, and certain blood cancers. Side effects include immune-related adverse events.",
        "keywords": ["immunotherapy", "checkpoint inhibitors", "CAR-T", "PD-1", "cancer vaccine"]
    }
]

print(f"📚 Enhanced knowledge base created with {len(knowledge_base)} documents")
print(f"🌍 Domains: {set(doc['domain'] for doc in knowledge_base)}")
print(f"📄 Document types: Medical, Scientific, Technology, Business")

## 🧠 Self-Reflection Decision Engine

Core component that makes intelligent retrieval decisions:

In [None]:
class SelfReflectionEngine:
    def __init__(self):
        self.confidence_threshold = 0.7
        self.complexity_weights = {
            QueryComplexity.SIMPLE: 0.3,
            QueryComplexity.MODERATE: 0.6,
            QueryComplexity.COMPLEX: 0.9
        }
        
        # Domain-specific retrieval patterns
        self.domain_retrieval_bias = {
            Domain.MEDICINE: 0.8,  # High retrieval bias for medical queries
            Domain.SCIENCE: 0.7,   # High for scientific queries
            Domain.TECHNOLOGY: 0.6,
            Domain.BUSINESS: 0.5,
            Domain.GENERAL: 0.4
        }
        
        print("🧠 Self-Reflection Engine initialized")
    
    def should_retrieve(self, query: SelfRAGQuery, initial_confidence: float) -> RetrievalDecision:
        """
        Decide whether to retrieve additional information based on:
        - Initial confidence in answering
        - Query complexity
        - Domain-specific requirements
        - Potential risk of incorrect information
        """
        
        # Calculate retrieval probability
        complexity_factor = self.complexity_weights[query.complexity]
        domain_factor = self.domain_retrieval_bias[query.domain]
        confidence_factor = 1.0 - initial_confidence
        
        # Risk assessment for medical/scientific queries
        risk_factor = 1.0
        if query.domain in [Domain.MEDICINE, Domain.SCIENCE]:
            risk_factor = 1.2  # Higher weight for high-stakes domains
        
        retrieval_score = (
            complexity_factor * 0.3 +
            domain_factor * 0.3 +
            confidence_factor * 0.4
        ) * risk_factor
        
        should_retrieve = retrieval_score > 0.5
        
        # Generate reasoning
        if should_retrieve:
            reasoning = self._generate_retrieval_reasoning(query, initial_confidence, retrieval_score)
            token = ReflectionToken.RETRIEVE
        else:
            reasoning = f"High confidence ({initial_confidence:.2f}) in existing knowledge for {query.complexity.value} {query.domain.value} query"
            token = ReflectionToken.NO_RETRIEVE
        
        return RetrievalDecision(
            should_retrieve=should_retrieve,
            confidence=retrieval_score,
            reasoning=reasoning,
            token=token
        )
    
    def _generate_retrieval_reasoning(self, query: SelfRAGQuery, confidence: float, score: float) -> str:
        reasons = []
        
        if confidence < 0.6:
            reasons.append(f"Low initial confidence ({confidence:.2f})")
        
        if query.complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]:
            reasons.append(f"Complex query requiring detailed information")
        
        if query.domain in [Domain.MEDICINE, Domain.SCIENCE]:
            reasons.append(f"High-stakes {query.domain.value} domain requiring accuracy")
        
        if "latest" in query.text.lower() or "recent" in query.text.lower():
            reasons.append("Query requests current information")
        
        return "; ".join(reasons)
    
    def assess_relevance(self, query: SelfRAGQuery, documents: List[Document]) -> RelevanceAssessment:
        """
        Assess how relevant retrieved documents are to the query
        """
        if not documents:
            return RelevanceAssessment(
                relevance_score=0.0,
                token=ReflectionToken.ISREL_NO,
                reasoning="No documents retrieved"
            )
        
        # Simple relevance scoring based on keyword overlap and domain matching
        relevance_scores = []
        query_words = set(query.text.lower().split())
        
        for doc in documents:
            doc_words = set(doc.content.lower().split())
            keyword_overlap = len(query_words.intersection(doc_words)) / len(query_words)
            
            domain_match = 1.0 if doc.domain == query.domain else 0.5
            doc_score = keyword_overlap * 0.7 + domain_match * 0.3
            relevance_scores.append(doc_score)
        
        avg_relevance = np.mean(relevance_scores)
        
        if avg_relevance >= 0.7:
            token = ReflectionToken.ISREL_YES
            reasoning = f"High relevance ({avg_relevance:.2f}) - documents closely match query"
        elif avg_relevance >= 0.4:
            token = ReflectionToken.ISREL_PARTIAL
            reasoning = f"Partial relevance ({avg_relevance:.2f}) - some useful information found"
        else:
            token = ReflectionToken.ISREL_NO
            reasoning = f"Low relevance ({avg_relevance:.2f}) - documents don't match query well"
        
        return RelevanceAssessment(
            relevance_score=avg_relevance,
            token=token,
            reasoning=reasoning
        )
    
    def evaluate_support(self, answer: str, documents: List[Document]) -> SupportEvaluation:
        """
        Evaluate how well the retrieved documents support the generated answer
        """
        if not documents:
            return SupportEvaluation(
                support_level=0.0,
                token=ReflectionToken.ISSUP_NO,
                evidence=[],
                reasoning="No retrieved documents to provide support"
            )
        
        # Extract key facts from answer
        answer_sentences = [s.strip() for s in answer.split('.') if s.strip()]
        evidence = []
        support_scores = []
        
        for sentence in answer_sentences[:3]:  # Check first 3 sentences
            sentence_words = set(sentence.lower().split())
            max_support = 0.0
            supporting_doc = None
            
            for doc in documents:
                doc_words = set(doc.content.lower().split())
                overlap = len(sentence_words.intersection(doc_words)) / max(len(sentence_words), 1)
                if overlap > max_support:
                    max_support = overlap
                    supporting_doc = doc
            
            if max_support > 0.3 and supporting_doc:
                evidence.append(f"'{sentence}' supported by {supporting_doc.title}")
                support_scores.append(max_support)
        
        avg_support = np.mean(support_scores) if support_scores else 0.0
        
        if avg_support >= 0.7:
            token = ReflectionToken.ISSUP_FULL
            reasoning = f"Answer is fully supported by retrieved documents ({avg_support:.2f})"
        elif avg_support >= 0.4:
            token = ReflectionToken.ISSUP_PARTIAL
            reasoning = f"Answer is partially supported by retrieved documents ({avg_support:.2f})"
        else:
            token = ReflectionToken.ISSUP_NO
            reasoning = f"Answer lacks strong support from retrieved documents ({avg_support:.2f})"
        
        return SupportEvaluation(
            support_level=avg_support,
            token=token,
            evidence=evidence,
            reasoning=reasoning
        )
    
    def judge_utility(self, query: SelfRAGQuery, answer: str, support_eval: SupportEvaluation) -> UtilityJudgment:
        """
        Judge the overall utility and quality of the generated response
        """
        utility_factors = []
        
        # Answer length and completeness
        length_score = min(len(answer.split()) / 50, 1.0)  # Normalize to 50 words
        utility_factors.append(length_score * 0.2)
        
        # Support from retrieved documents
        utility_factors.append(support_eval.support_level * 0.4)
        
        # Domain appropriateness
        domain_keywords = {
            Domain.MEDICINE: ['treatment', 'symptoms', 'diagnosis', 'therapy', 'medical'],
            Domain.SCIENCE: ['research', 'study', 'experiment', 'theory', 'scientific'],
            Domain.TECHNOLOGY: ['system', 'algorithm', 'software', 'technology', 'technical'],
            Domain.BUSINESS: ['market', 'business', 'revenue', 'strategy', 'commercial']
        }
        
        if query.domain in domain_keywords:
            domain_match = sum(1 for word in domain_keywords[query.domain] 
                             if word in answer.lower()) / len(domain_keywords[query.domain])
            utility_factors.append(domain_match * 0.2)
        else:
            utility_factors.append(0.1)  # Default for general domain
        
        # Specificity and detail
        specific_indicators = ['specifically', 'exactly', 'precisely', 'including', 'such as']
        specificity_score = sum(1 for indicator in specific_indicators 
                              if indicator in answer.lower()) / len(specific_indicators)
        utility_factors.append(specificity_score * 0.2)
        
        utility_score = sum(utility_factors)
        
        # Generate improvement suggestions
        suggestions = []
        if length_score < 0.5:
            suggestions.append("Provide more detailed explanation")
        if support_eval.support_level < 0.5:
            suggestions.append("Better utilize retrieved information")
        if specificity_score < 0.3:
            suggestions.append("Include more specific examples and details")
        
        # Determine utility level
        if utility_score >= 0.7:
            token = ReflectionToken.ISUSE_HIGH
            reasoning = f"High utility response ({utility_score:.2f}) - comprehensive and well-supported"
        elif utility_score >= 0.4:
            token = ReflectionToken.ISUSE_MEDIUM
            reasoning = f"Medium utility response ({utility_score:.2f}) - adequate but could be improved"
        else:
            token = ReflectionToken.ISUSE_LOW
            reasoning = f"Low utility response ({utility_score:.2f}) - needs significant improvement"
        
        return UtilityJudgment(
            utility_score=utility_score,
            token=token,
            reasoning=reasoning,
            improvement_suggestions=suggestions
        )

# Initialize self-reflection engine
reflection_engine = SelfReflectionEngine()
print("✅ Self-Reflection Engine ready!")

## 🔍 Enhanced Retrieval System

Retrieval system optimized for Self-RAG:

In [None]:
class SelfRAGRetriever:
    def __init__(self):
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.documents = []
        self.semantic_index = None
        self.bm25_index = None
        
        print("🔍 Self-RAG Retriever initialized")
    
    def index_documents(self, documents: List[Dict]):
        print(f"📚 Indexing {len(documents)} documents for Self-RAG...")
        
        # Convert to Document objects
        self.documents = [
            Document(
                id=doc['id'],
                title=doc['title'],
                content=doc['content'],
                domain=Domain(doc['domain']),
                keywords=doc.get('keywords', [])
            )
            for doc in documents
        ]
        
        # Build semantic index
        self._build_semantic_index()
        
        # Build keyword index
        self._build_keyword_index()
        
        print("✅ Documents indexed successfully!")
    
    def _build_semantic_index(self):
        doc_texts = [f"{doc.title} {doc.content}" for doc in self.documents]
        embeddings = self.embedding_model.encode(doc_texts)
        
        # Store embeddings
        for doc, embedding in zip(self.documents, embeddings):
            doc.embedding = embedding
        
        # Create FAISS index
        dimension = embeddings.shape[1]
        self.semantic_index = faiss.IndexFlatIP(dimension)
        faiss.normalize_L2(embeddings)
        self.semantic_index.add(embeddings.astype('float32'))
    
    def _build_keyword_index(self):
        doc_texts = [f"{doc.title} {doc.content}" for doc in self.documents]
        tokenized_docs = [text.lower().split() for text in doc_texts]
        self.bm25_index = BM25Okapi(tokenized_docs)
    
    def retrieve(self, query: SelfRAGQuery, top_k: int = 5) -> List[Document]:
        """
        Retrieve relevant documents using hybrid approach
        """
        # Semantic retrieval
        query_embedding = self.embedding_model.encode([query.text])
        faiss.normalize_L2(query_embedding)
        
        semantic_scores, semantic_indices = self.semantic_index.search(
            query_embedding.astype('float32'), min(top_k * 2, len(self.documents))
        )
        
        # Keyword retrieval
        query_tokens = query.text.lower().split()
        keyword_scores = self.bm25_index.get_scores(query_tokens)
        keyword_indices = np.argsort(keyword_scores)[::-1][:top_k * 2]
        
        # Combine and rank
        combined_scores = {}
        
        # Add semantic scores
        for score, idx in zip(semantic_scores[0], semantic_indices[0]):
            if idx < len(self.documents):  # Valid index check
                combined_scores[idx] = {'semantic': float(score), 'keyword': 0.0}
        
        # Add keyword scores (normalized)
        max_keyword_score = max(keyword_scores) if max(keyword_scores) > 0 else 1.0
        for idx in keyword_indices:
            if idx < len(self.documents):  # Valid index check
                normalized_score = keyword_scores[idx] / max_keyword_score
                if idx in combined_scores:
                    combined_scores[idx]['keyword'] = normalized_score
                else:
                    combined_scores[idx] = {'semantic': 0.0, 'keyword': normalized_score}
        
        # Calculate final scores with domain boost
        final_scores = []
        for idx, scores in combined_scores.items():
            doc = self.documents[idx]
            
            # Domain matching boost
            domain_boost = 1.2 if doc.domain == query.domain else 1.0
            
            final_score = (scores['semantic'] * 0.7 + scores['keyword'] * 0.3) * domain_boost
            final_scores.append((final_score, idx))
        
        # Sort and return top documents
        final_scores.sort(reverse=True)
        retrieved_docs = [self.documents[idx] for _, idx in final_scores[:top_k]]
        
        return retrieved_docs

# Initialize retriever
retriever = SelfRAGRetriever()
retriever.index_documents(knowledge_base)
print("✅ Self-RAG Retriever ready!")

## 🤖 Self-Aware Generator

Generation module with self-awareness capabilities:

In [None]:
class SelfAwareGenerator:
    def __init__(self):
        # Try to initialize Gemini
        api_key = os.getenv('GEMINI_API_KEY')
        if api_key:
            try:
                genai.configure(api_key=api_key)
                self.model = genai.GenerativeModel('gemini-1.5-flash')
                self.has_llm = True
                print("🤖 Gemini API configured for Self-RAG")
            except Exception as e:
                print(f"⚠️ Gemini error: {e}. Using fallback generation.")
                self.has_llm = False
        else:
            print("⚠️ No Gemini API key. Using template-based generation.")
            self.has_llm = False
    
    def estimate_confidence(self, query: SelfRAGQuery) -> float:
        """
        Estimate initial confidence for answering query without retrieval
        """
        # Base confidence on complexity and domain
        complexity_confidence = {
            QueryComplexity.SIMPLE: 0.8,
            QueryComplexity.MODERATE: 0.6,
            QueryComplexity.COMPLEX: 0.4
        }
        
        domain_confidence = {
            Domain.GENERAL: 0.7,
            Domain.TECHNOLOGY: 0.6,
            Domain.BUSINESS: 0.6,
            Domain.SCIENCE: 0.4,
            Domain.MEDICINE: 0.3  # Lower confidence for medical queries
        }
        
        base_confidence = complexity_confidence[query.complexity] * domain_confidence[query.domain]
        
        # Adjust for specific query patterns
        uncertainty_indicators = ['latest', 'recent', 'current', 'new', 'specific', 'exact']
        if any(indicator in query.text.lower() for indicator in uncertainty_indicators):
            base_confidence *= 0.7
        
        return min(base_confidence, 0.95)  # Cap at 95%
    
    def generate_answer(self, query: SelfRAGQuery, documents: List[Document] = None, 
                       retrieval_decision: RetrievalDecision = None) -> str:
        """
        Generate answer with self-awareness of limitations
        """
        if self.has_llm:
            return self._generate_with_llm(query, documents, retrieval_decision)
        else:
            return self._generate_with_template(query, documents, retrieval_decision)
    
    def _generate_with_llm(self, query: SelfRAGQuery, documents: List[Document], 
                          retrieval_decision: RetrievalDecision) -> str:
        """
        Generate answer using Gemini with self-reflection context
        """
        # Prepare context
        context_parts = []
        if documents:
            for doc in documents:
                context_parts.append(f"**{doc.title}**\n{doc.content}")
        
        context = "\n\n".join(context_parts) if context_parts else "No additional context provided."
        
        # Create self-aware prompt
        prompt = f"""You are an AI assistant with self-reflection capabilities. 

**Query Analysis:**
- Question: {query.text}
- Domain: {query.domain.value}
- Complexity: {query.complexity.value}
- Retrieval Decision: {retrieval_decision.token.value if retrieval_decision else 'Not specified'}
- Reasoning: {retrieval_decision.reasoning if retrieval_decision else 'Not provided'}

**Retrieved Context:**
{context}

**Instructions:**
1. Provide a comprehensive answer based on the available information
2. If using retrieved context, clearly integrate the information
3. Be explicit about any limitations or uncertainties
4. For medical/scientific queries, emphasize the need for professional consultation
5. Maintain appropriate confidence level based on available evidence

**Answer:**"""
        
        try:
            response = self.model.generate_content(prompt)
            return response.text
        except Exception as e:
            return f"Error generating response: {str(e)}"
    
    def _generate_with_template(self, query: SelfRAGQuery, documents: List[Document], 
                               retrieval_decision: RetrievalDecision) -> str:
        """
        Generate answer using template-based approach
        """
        if not documents:
            answer = f"Based on my general knowledge about {query.domain.value}, "
            
            if query.domain == Domain.MEDICINE:
                answer += "I can provide general information, but please consult healthcare professionals for medical advice. "
            elif query.domain == Domain.SCIENCE:
                answer += "I can share scientific concepts, but specific research may require current literature. "
            
            answer += f"Regarding '{query.text}': This is a {query.complexity.value} question that "
            
            if retrieval_decision and not retrieval_decision.should_retrieve:
                answer += "I feel confident addressing with existing knowledge."
            else:
                answer += "would benefit from additional specialized resources."
            
            return answer
        
        # Use retrieved documents
        answer = f"Based on the retrieved information about {query.domain.value}, here's what I found regarding '{query.text}':\n\n"
        
        # Incorporate information from top documents
        for i, doc in enumerate(documents[:2], 1):
            answer += f"**{doc.title}**: {doc.content[:300]}..\n\n"
        
        # Add appropriate disclaimers
        if query.domain == Domain.MEDICINE:
            answer += "\n**Important**: This information is for educational purposes only. Please consult qualified healthcare professionals for medical advice."
        elif query.domain == Domain.SCIENCE:
            answer += "\n**Note**: Scientific understanding evolves. For the latest research, consult current peer-reviewed literature."
        
        return answer

# Initialize generator
generator = SelfAwareGenerator()
print("✅ Self-Aware Generator ready!")

## 📝 Query Processing and Classification

Intelligent query understanding for Self-RAG:

In [None]:
class SelfRAGQueryProcessor:
    def __init__(self):
        # Complexity indicators
        self.complexity_patterns = {
            QueryComplexity.SIMPLE: [
                r"what is", r"define", r"explain", r"tell me about"
            ],
            QueryComplexity.MODERATE: [
                r"how does", r"why", r"compare", r"difference between", 
                r"advantages", r"disadvantages"
            ],
            QueryComplexity.COMPLEX: [
                r"analyze", r"evaluate", r"assess", r"relationship between",
                r"impact of", r"implications", r"comprehensive"
            ]
        }
        
        # Domain keywords
        self.domain_keywords = {
            Domain.MEDICINE: [
                "disease", "treatment", "symptoms", "diagnosis", "therapy", 
                "medication", "clinical", "patient", "medical", "health",
                "cancer", "diabetes", "COVID", "vaccine", "drug"
            ],
            Domain.SCIENCE: [
                "research", "study", "experiment", "theory", "hypothesis",
                "quantum", "physics", "chemistry", "biology", "genetics",
                "CRISPR", "evolution", "molecular", "scientific"
            ],
            Domain.TECHNOLOGY: [
                "AI", "algorithm", "software", "computing", "system",
                "machine learning", "neural network", "programming",
                "database", "cloud", "blockchain", "cybersecurity"
            ],
            Domain.BUSINESS: [
                "market", "revenue", "profit", "strategy", "business",
                "investment", "economy", "finance", "management",
                "startup", "company", "industry", "commercial"
            ]
        }
        
        print("📝 Self-RAG Query Processor initialized")
    
    def process_query(self, query_text: str, user_id: str = None) -> SelfRAGQuery:
        """
        Process and classify query for Self-RAG
        """
        query_id = str(uuid.uuid4())[:8]
        
        # Detect complexity
        complexity = self._detect_complexity(query_text)
        
        # Detect domain
        domain = self._detect_domain(query_text)
        
        return SelfRAGQuery(
            id=query_id,
            text=query_text,
            domain=domain,
            complexity=complexity,
            user_id=user_id
        )
    
    def _detect_complexity(self, text: str) -> QueryComplexity:
        text_lower = text.lower()
        
        # Check for complex patterns first
        for complexity, patterns in reversed(list(self.complexity_patterns.items())):
            for pattern in patterns:
                if re.search(pattern, text_lower):
                    return complexity
        
        # Fallback based on length and structure
        word_count = len(text.split())
        if word_count > 15:
            return QueryComplexity.COMPLEX
        elif word_count > 8:
            return QueryComplexity.MODERATE
        else:
            return QueryComplexity.SIMPLE
    
    def _detect_domain(self, text: str) -> Domain:
        text_lower = text.lower()
        domain_scores = {}
        
        for domain, keywords in self.domain_keywords.items():
            score = sum(1 for keyword in keywords if keyword.lower() in text_lower)
            domain_scores[domain] = score
        
        if max(domain_scores.values()) > 0:
            return max(domain_scores, key=domain_scores.get)
        return Domain.GENERAL

# Initialize query processor
query_processor = SelfRAGQueryProcessor()

# Test the processor
test_query = query_processor.process_query("What are the latest treatments for diabetes?")
print(f"\n🔍 Test query processed:")
print(f"   Text: {test_query.text}")
print(f"   Domain: {test_query.domain.value}")
print(f"   Complexity: {test_query.complexity.value}")
print("✅ Query Processor ready!")

## 🧩 Complete Self-RAG System

Integration of all components into the complete Self-RAG system:

In [None]:
class SelfRAGSystem:
    def __init__(self):
        print("🧩 Initializing Self-RAG System...")
        
        # Initialize components
        self.query_processor = query_processor
        self.reflection_engine = reflection_engine
        self.retriever = retriever
        self.generator = generator
        
        # System metrics
        self.total_queries = 0
        self.retrieval_decisions = {'retrieve': 0, 'no_retrieve': 0}
        self.avg_processing_time = 0.0
        self.quality_scores = []
        
        print("✅ Self-RAG System initialized!")
        print("🎯 Ready for self-reflective question answering")
    
    def process_query(self, query_text: str, user_id: str = None) -> SelfRAGResponse:
        """
        Process query through complete Self-RAG pipeline
        """
        start_time = time.time()
        reflection_chain = []
        
        try:
            print(f"\n🧠 Self-RAG Processing: '{query_text}'")
            print("=" * 60)
            
            # Step 1: Query Processing
            print("📝 Step 1: Query analysis and classification...")
            query = self.query_processor.process_query(query_text, user_id)
            reflection_chain.append(f"Query classified as {query.complexity.value} {query.domain.value} question")
            
            print(f"   Domain: {query.domain.value}")
            print(f"   Complexity: {query.complexity.value}")
            
            # Step 2: Initial Confidence Assessment
            print("🎯 Step 2: Self-confidence assessment...")
            initial_confidence = self.generator.estimate_confidence(query)
            reflection_chain.append(f"Initial confidence: {initial_confidence:.2f}")
            
            print(f"   Initial confidence: {initial_confidence:.2f}")
            
            # Step 3: Retrieval Decision (Self-Reflection)
            print("🤔 Step 3: Retrieval decision with self-reflection...")
            retrieval_decision = self.reflection_engine.should_retrieve(query, initial_confidence)
            reflection_chain.append(f"{retrieval_decision.token.value}: {retrieval_decision.reasoning}")
            
            print(f"   Decision: {retrieval_decision.token.value}")
            print(f"   Reasoning: {retrieval_decision.reasoning}")
            
            # Update decision metrics
            if retrieval_decision.should_retrieve:
                self.retrieval_decisions['retrieve'] += 1
            else:
                self.retrieval_decisions['no_retrieve'] += 1
            
            # Step 4: Conditional Retrieval
            retrieved_documents = []
            relevance_assessment = None
            
            if retrieval_decision.should_retrieve:
                print("🔍 Step 4: Document retrieval...")
                retrieved_documents = self.retriever.retrieve(query, top_k=5)
                reflection_chain.append(f"Retrieved {len(retrieved_documents)} documents")
                
                print(f"   Retrieved {len(retrieved_documents)} documents")
                for i, doc in enumerate(retrieved_documents[:3], 1):
                    print(f"   {i}. {doc.title} ({doc.domain.value})")
                
                # Step 5: Relevance Assessment
                print("📊 Step 5: Relevance assessment...")
                relevance_assessment = self.reflection_engine.assess_relevance(query, retrieved_documents)
                reflection_chain.append(f"{relevance_assessment.token.value}: {relevance_assessment.reasoning}")
                
                print(f"   Assessment: {relevance_assessment.token.value}")
                print(f"   Score: {relevance_assessment.relevance_score:.2f}")
            else:
                print("🚫 Step 4: Skipping retrieval based on self-assessment")
                reflection_chain.append("Retrieval skipped - high confidence in existing knowledge")
            
            # Step 6: Answer Generation
            print("🤖 Step 6: Self-aware answer generation...")
            generated_answer = self.generator.generate_answer(query, retrieved_documents, retrieval_decision)
            reflection_chain.append(f"Generated answer ({len(generated_answer.split())} words)")
            
            print(f"   Generated answer ({len(generated_answer.split())} words)")
            
            # Step 7: Support Evaluation
            print("🔬 Step 7: Support evaluation...")
            support_evaluation = self.reflection_engine.evaluate_support(generated_answer, retrieved_documents)
            reflection_chain.append(f"{support_evaluation.token.value}: {support_evaluation.reasoning}")
            
            print(f"   Support: {support_evaluation.token.value}")
            print(f"   Evidence count: {len(support_evaluation.evidence)}")
            
            # Step 8: Utility Judgment
            print("⚖️ Step 8: Utility judgment...")
            utility_judgment = self.reflection_engine.judge_utility(query, generated_answer, support_evaluation)
            reflection_chain.append(f"{utility_judgment.token.value}: {utility_judgment.reasoning}")
            
            print(f"   Utility: {utility_judgment.token.value}")
            print(f"   Score: {utility_judgment.utility_score:.2f}")
            
            # Create response
            processing_time = time.time() - start_time
            
            response = SelfRAGResponse(
                query=query,
                retrieval_decision=retrieval_decision,
                retrieved_documents=retrieved_documents,
                relevance_assessment=relevance_assessment,
                generated_answer=generated_answer,
                support_evaluation=support_evaluation,
                utility_judgment=utility_judgment,
                processing_time=processing_time,
                reflection_chain=reflection_chain
            )
            
            # Update system metrics
            self._update_metrics(response)
            
            print(f"\n✅ Self-RAG processing completed in {processing_time:.2f}s")
            return response
            
        except Exception as e:
            print(f"❌ Error in Self-RAG processing: {str(e)}")
            # Return error response
            error_query = SelfRAGQuery("error", query_text, Domain.GENERAL, QueryComplexity.SIMPLE, user_id)
            error_decision = RetrievalDecision(False, 0.0, "Error occurred", ReflectionToken.NO_RETRIEVE)
            error_support = SupportEvaluation(0.0, ReflectionToken.ISSUP_NO, [], "Error occurred")
            error_utility = UtilityJudgment(0.0, ReflectionToken.ISUSE_LOW, "Error occurred", [])
            
            return SelfRAGResponse(
                query=error_query,
                retrieval_decision=error_decision,
                retrieved_documents=[],
                relevance_assessment=None,
                generated_answer=f"I encountered an error while processing your query: {str(e)}",
                support_evaluation=error_support,
                utility_judgment=error_utility,
                processing_time=time.time() - start_time,
                reflection_chain=[f"Error: {str(e)}"]
            )
    
    def _update_metrics(self, response: SelfRAGResponse):
        self.total_queries += 1
        
        # Update running average processing time
        self.avg_processing_time = ((self.avg_processing_time * (self.total_queries - 1)) + 
                                   response.processing_time) / self.total_queries
        
        # Track quality scores
        self.quality_scores.append(response.utility_judgment.utility_score)
    
    def get_system_statistics(self) -> Dict:
        return {
            'total_queries': self.total_queries,
            'retrieval_rate': self.retrieval_decisions['retrieve'] / max(self.total_queries, 1),
            'avg_processing_time': self.avg_processing_time,
            'avg_quality_score': np.mean(self.quality_scores) if self.quality_scores else 0.0,
            'retrieval_decisions': self.retrieval_decisions
        }

# Initialize complete Self-RAG system
self_rag = SelfRAGSystem()
print("\n🚀 Complete Self-RAG System ready!")

## 🧪 Comprehensive Self-RAG Testing

Test the Self-RAG system with various query types:

In [None]:
def run_self_rag_tests():
    print("\n" + "🧪" * 20 + " SELF-RAG TEST SUITE " + "🧪" * 20)
    
    test_cases = [
        {
            "name": "Simple Medical Query",
            "query": "What are the symptoms of diabetes?",
            "expected_retrieval": True,  # Medical queries should trigger retrieval
            "user_id": "medical_user_1"
        },
        {
            "name": "Complex Scientific Analysis",
            "query": "Analyze the implications of quantum entanglement for computing applications",
            "expected_retrieval": True,
            "user_id": "scientist_1"
        },
        {
            "name": "Simple Technology Question",
            "query": "What is machine learning?",
            "expected_retrieval": False,  # Might not need retrieval
            "user_id": "tech_user_1"
        },
        {
            "name": "Current Medical Treatment",
            "query": "What are the latest treatments for COVID-19?",
            "expected_retrieval": True,  # "Latest" should trigger retrieval
            "user_id": "medical_user_2"
        },
        {
            "name": "Business Strategy Query",
            "query": "How is AI transforming healthcare business models?",
            "expected_retrieval": True,
            "user_id": "business_user_1"
        }
    ]
    
    results = []
    total_start_time = time.time()
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"\n{'=' * 70}")
        print(f"🔬 TEST CASE {i}: {test_case['name']}")
        print(f"❓ Query: '{test_case['query']}'")
        print("=" * 70)
        
        try:
            response = self_rag.process_query(
                test_case['query'], 
                test_case['user_id']
            )
            
            # Analyze results
            print(f"\n🎯 **FINAL ANSWER:**")
            print(f"   {response.generated_answer[:300]}{'...' if len(response.generated_answer) > 300 else ''}")
            
            print(f"\n🧠 **SELF-REFLECTION ANALYSIS:**")
            print(f"   • Retrieval Decision: {response.retrieval_decision.token.value}")
            print(f"   • Retrieved Documents: {len(response.retrieved_documents)}")
            if response.relevance_assessment:
                print(f"   • Relevance: {response.relevance_assessment.token.value}")
            print(f"   • Support: {response.support_evaluation.token.value}")
            print(f"   • Utility: {response.utility_judgment.token.value}")
            
            print(f"\n📊 **METRICS:**")
            print(f"   • Processing Time: {response.processing_time:.2f}s")
            print(f"   • Quality Score: {response.utility_judgment.utility_score:.2f}")
            print(f"   • Support Score: {response.support_evaluation.support_level:.2f}")
            
            print(f"\n🔗 **REFLECTION CHAIN:**")
            for j, reflection in enumerate(response.reflection_chain, 1):
                print(f"   {j}. {reflection}")
            
            if response.utility_judgment.improvement_suggestions:
                print(f"\n💡 **IMPROVEMENT SUGGESTIONS:**")
                for suggestion in response.utility_judgment.improvement_suggestions:
                    print(f"   • {suggestion}")
            
            results.append({
                'test_case': test_case['name'],
                'success': True,
                'retrieval_used': response.retrieval_decision.should_retrieve,
                'quality_score': response.utility_judgment.utility_score,
                'processing_time': response.processing_time
            })
            
        except Exception as e:
            print(f"❌ ERROR: {str(e)}")
            results.append({
                'test_case': test_case['name'],
                'success': False,
                'error': str(e)
            })
        
        print(f"\n{'🔸' * 35} END TEST CASE {i} {'🔸' * 35}")
    
    total_time = time.time() - total_start_time
    
    # Generate comprehensive test report
    print(f"\n" + "=" * 70)
    print("📊 SELF-RAG TEST REPORT")
    print("=" * 70)
    
    successful_tests = [r for r in results if r['success']]
    
    print(f"\n🎯 **OVERALL PERFORMANCE:**")
    print(f"   • Tests Passed: {len(successful_tests)}/{len(results)} ({len(successful_tests)/len(results)*100:.1f}%)")
    print(f"   • Total Processing Time: {total_time:.2f}s")
    print(f"   • Avg Time per Query: {total_time/len(results):.2f}s")
    
    if successful_tests:
        retrieval_used = sum(1 for r in successful_tests if r.get('retrieval_used', False))
        avg_quality = sum(r['quality_score'] for r in successful_tests) / len(successful_tests)
        avg_time = sum(r['processing_time'] for r in successful_tests) / len(successful_tests)
        
        print(f"\n📈 **SELF-REFLECTION METRICS:**")
        print(f"   • Retrieval Rate: {retrieval_used}/{len(successful_tests)} ({retrieval_used/len(successful_tests)*100:.1f}%)")
        print(f"   • Average Quality Score: {avg_quality:.3f}")
        print(f"   • Average Processing Time: {avg_time:.3f}s")
    
    # System statistics
    stats = self_rag.get_system_statistics()
    print(f"\n🏆 **SYSTEM STATISTICS:**")
    print(f"   • Total System Queries: {stats['total_queries']}")
    print(f"   • Overall Retrieval Rate: {stats['retrieval_rate']:.1%}")
    print(f"   • System Avg Quality: {stats['avg_quality_score']:.3f}")
    print(f"   • System Avg Time: {stats['avg_processing_time']:.3f}s")
    
    print(f"\n🧠 **SELF-RAG CAPABILITIES DEMONSTRATED:**")
    capabilities = [
        "✅ Intelligent retrieval decision-making with [Retrieve] tokens",
        "✅ Relevance assessment with [ISREL] tokens",
        "✅ Support evaluation with [ISSUP] tokens", 
        "✅ Utility judgment with [ISUSE] tokens",
        "✅ Self-awareness of knowledge limitations",
        "✅ Adaptive processing based on query complexity",
        "✅ Domain-specific confidence assessment",
        "✅ Transparent reflection chain tracking"
    ]
    
    for capability in capabilities:
        print(f"   {capability}")
    
    return results

# Run the comprehensive test suite
test_results = run_self_rag_tests()

## 🎮 Interactive Self-RAG Demo

Experience Self-RAG with real-time self-reflection:

In [None]:
def interactive_self_rag_demo():
    print("\n" + "🎮" * 20 + " SELF-RAG INTERACTIVE DEMO " + "🎮" * 20)
    print("🧠 **SELF-REFLECTIVE AI ASSISTANT**")
    print("🎮" * 60)
    print("Experience AI that thinks about its own thinking! Watch how the system")
    print("decides when to retrieve information and evaluates its own responses.")
    print("\nType 'quit' to exit, 'stats' for system statistics, 'help' for examples")
    print("-" * 60)
    
    demo_user = "demo_user_selfrag"
    query_count = 0
    
    example_queries = {
        "Medical (High Retrieval)": [
            "What are the latest COVID-19 treatments?",
            "Explain diabetes management strategies",
            "What is immunotherapy for cancer?"
        ],
        "Scientific (Complex)": [
            "How does CRISPR gene editing work?",
            "Analyze quantum entanglement applications",
            "Explain the relationship between AI and scientific discovery"
        ],
        "Technology (Moderate)": [
            "What is retrieval-augmented generation?",
            "Compare different machine learning approaches",
            "How do large language models work?"
        ],
        "Simple (Low Retrieval)": [
            "What is AI?",
            "Define machine learning",
            "What is cloud computing?"
        ]
    }
    
    while True:
        try:
            user_input = input("\n🧠 Your question for Self-RAG: ").strip()
            
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("\n👋 Thank you for exploring Self-RAG!")
                print("🧠 Remember: The best AI systems know when they don't know!")
                break
            elif user_input.lower() == 'stats':
                stats = self_rag.get_system_statistics()
                print(f"\n📊 **SELF-RAG SYSTEM STATISTICS:**")
                print(f"   • Total Queries Processed: {stats['total_queries']}")
                print(f"   • Retrieval Decision Rate: {stats['retrieval_rate']:.1%}")
                print(f"   • Average Quality Score: {stats['avg_quality_score']:.3f}")
                print(f"   • Average Processing Time: {stats['avg_processing_time']:.3f}s")
                print(f"   • Retrieve Decisions: {stats['retrieval_decisions']['retrieve']}")
                print(f"   • No-Retrieve Decisions: {stats['retrieval_decisions']['no_retrieve']}")
                continue
            elif user_input.lower() == 'help':
                print(f"\n💡 **EXAMPLE QUERIES BY CATEGORY:**")
                for category, queries in example_queries.items():
                    print(f"\n🔸 **{category}:**")
                    for query in queries:
                        print(f"   • {query}")
                continue
            elif not user_input:
                print("Please enter a question or command.")
                continue
            
            query_count += 1
            print(f"\n🤔 Processing query #{query_count} with Self-RAG...")
            
            response = self_rag.process_query(user_input, demo_user)
            
            print(f"\n🤖 **SELF-RAG RESPONSE:**")
            print(f"   {response.generated_answer}")
            
            print(f"\n🧠 **SELF-REFLECTION TOKENS:**")
            print(f"   • {response.retrieval_decision.token.value}")
            if response.relevance_assessment:
                print(f"   • {response.relevance_assessment.token.value}")
            print(f"   • {response.support_evaluation.token.value}")
            print(f"   • {response.utility_judgment.token.value}")
            
            print(f"\n📊 **PERFORMANCE METRICS:**")
            print(f"   • Query Type: {response.query.complexity.value} {response.query.domain.value}")
            print(f"   • Retrieval Used: {'Yes' if response.retrieval_decision.should_retrieve else 'No'}")
            print(f"   • Documents Retrieved: {len(response.retrieved_documents)}")
            print(f"   • Quality Score: {response.utility_judgment.utility_score:.2f}")
            print(f"   • Processing Time: {response.processing_time:.2f}s")
            
            print(f"\n🔗 **REFLECTION REASONING:**")
            print(f"   • Retrieval: {response.retrieval_decision.reasoning}")
            if response.relevance_assessment:
                print(f"   • Relevance: {response.relevance_assessment.reasoning}")
            print(f"   • Support: {response.support_evaluation.reasoning}")
            print(f"   • Utility: {response.utility_judgment.reasoning}")
            
            if response.utility_judgment.improvement_suggestions:
                print(f"\n💡 **SELF-IMPROVEMENT SUGGESTIONS:**")
                for suggestion in response.utility_judgment.improvement_suggestions:
                    print(f"   • {suggestion}")
        
        except KeyboardInterrupt:
            print("\n\n👋 Demo interrupted. Goodbye!")
            break
        except Exception as e:
            print(f"\n❌ Error: {str(e)}")
    
    if query_count > 0:
        print(f"\n📈 **DEMO SESSION SUMMARY:**")
        print(f"   • Queries Processed: {query_count}")
        print(f"   • User Profile: {demo_user}")
        print(f"   • Self-Reflection Demonstrated: ✅")

print("\n💡 **To start interactive Self-RAG demo, uncomment the next cell**")
print("🧠 Experience AI that truly reflects on its own capabilities!")

In [None]:
# Uncomment the line below to start the interactive Self-RAG demo
# interactive_self_rag_demo()