# RAG Generation Pipeline - LangGraph Workflow Orchestration

**Notebook ID:** `06_generation_v3`  
**Description:** Optimized RAG pipeline for diabetes knowledge management

---

## Overview

This notebook implements the **RAG (Retrieval-Augmented Generation) pipeline** that powers the chat interface. The system provides accurate, cited answers about diabetes management based on Kenya National Clinical Guidelines.

### Key Features

- **Query Classification**: Automatically categorizes queries and handles greetings, system questions, and safety checks
- **Semantic Retrieval**: Finds relevant information from the knowledge base
- **Cited Responses**: Generates answers with numbered citations that link to source documents
- **Safety Guardrails**: Refuses to answer irrelevant questions or provide personalized medical advice
- **Streaming Responses**: Provides real-time status updates and token-by-token answer generation

### Workflow

The pipeline follows a streamlined process:

1. **Classification**: Analyzes the user's query to determine intent and safety
2. **Retrieval**: Searches the knowledge base for relevant clinical information
3. **Generation**: Creates a comprehensive answer with inline citations
4. **Validation**: Ensures only cited sources are returned to the user

### Citation System

The system uses numbered citations that reference specific chunks from the knowledge base. Only sources that are actually cited in the response are shown to users, ensuring accuracy and relevance.

---


In [None]:
# CELL_ID: 06_generation_v3_imports
# ============================================================================
# IMPORT DEPENDENCIES
# ============================================================================

# %pip install langchain langchain-ollama langgraph pydantic chromadb gradio --quiet

import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal, TypedDict, Union, Annotated, Sequence
from enum import Enum

# LangChain imports
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, AnyMessage, BaseMessage
from langchain.tools import tool, ToolRuntime

# LangGraph imports
from langgraph.graph import StateGraph, END, START, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import create_react_agent
from langgraph.types import RetryPolicy
from langgraph.config import get_stream_writer
from langchain.agents.middleware import wrap_tool_call

# LangChain agents (v1 - replaces create_react_agent)
from langchain.agents import create_agent

# Pydantic for structured outputs
from pydantic import BaseModel, Field

# ChromaDB
import chromadb
from chromadb.config import Settings

# Gradio for interface
import gradio as gr

print("✓ Imports loaded successfully")


In [None]:
# CELL_ID: 06_generation_v3_chromadb_reader
# ============================================================================
# CHROMADB READER CLASS (REUSED FROM 06_generation_v1)
# ============================================================================

class ChromaDBReader:
    """
    Handles reading/searching from Chroma DB with Jina embedding function.
  
    """
    
    def __init__(
        self,
        chroma_db_path: str = "./chroma_db",
        collection_name: str = "diabetes_guidelines_v1",
        embedding_function = None
    ):
        self.chroma_db_path = Path(chroma_db_path)
        self.collection_name = collection_name
        self.embedding_function = embedding_function
        self.client = None
        self.collection = None
    
    def initialize(self):
        """Initialize ChromaDB client and collection."""
        if self.client is None:
            self.client = chromadb.PersistentClient(
                path=str(self.chroma_db_path),
                settings=Settings(
                    anonymized_telemetry=False,
                    allow_reset=True
                )
            )
            print(f"✓ ChromaDB client initialized: {self.chroma_db_path}")
        
        try:
            self.collection = self.client.get_collection(name=self.collection_name)
            print(f"✓ Loaded collection: {self.collection_name}")
            print(f"  • Total chunks: {self.collection.count()}")
        except Exception as e:
            if self.embedding_function:
                self.collection = self.client.get_collection(
                    name=self.collection_name,
                    embedding_function=self.embedding_function
                )
                print(f"✓ Loaded collection: {self.collection_name}")
                print(f"  • Total chunks: {self.collection.count()}")
            else:
                raise Exception(f"Collection '{self.collection_name}' not found. Make sure you've run 04_vector_store_v1.ipynb first.")
    
    def _unflatten_metadata(self, flat_metadata: Dict) -> Dict:
        """Unflatten metadata (parse JSON strings back to objects)."""
        unflattened = {}
        for key, value in flat_metadata.items():
            try:
                if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
                    unflattened[key] = json.loads(value)
                else:
                    unflattened[key] = value
            except:
                unflattened[key] = value
        return unflattened
    
    def search(self, query: str, n_results: int = 5, where: Dict = None, min_similarity: float = 0.4) -> List[Dict]:
        """
        Search the collection with semantic search.
        
        Args:
            query: Search query text
            n_results: Number of results to return
            where: Optional metadata filter
            min_similarity: Minimum relevance score (0-1), default 0.4
            
        Returns:
            List of result dictionaries with content, metadata, and relevance score
            Only chunks with relevance_score >= min_similarity are returned
        """
        if not self.collection:
            self.initialize()
        
        results = self.collection.query(
            query_texts=[query],
            n_results=n_results,
            where=where,
            include=['documents', 'metadatas', 'distances']
        )
        
        # Format results and filter by similarity
        formatted_results = []
        seen_chunk_ids = set()
        
        for i in range(len(results['ids'][0])):
            chunk_id = results['ids'][0][i]
            relevance_score = 1 - results['distances'][0][i]
            
            # Filter by minimum similarity
            if relevance_score < min_similarity:
                continue
            
            # Deduplicate
            if chunk_id in seen_chunk_ids:
                continue
            
            chunk_data = {
                'chunk_id': chunk_id,
                'content': results['documents'][0][i],
                'metadata': self._unflatten_metadata(results['metadatas'][0][i]),
                'relevance_score': relevance_score,
                'distance': results['distances'][0][i]
            }
            formatted_results.append(chunk_data)
            seen_chunk_ids.add(chunk_id)
        
        return formatted_results

print("✓ ChromaDBReader class loaded")


In [None]:
# CELL_ID: 03_generation_v3_jina_embedding
# ============================================================================
# JINA EMBEDDING FUNCTION
# ============================================================================

import os
import requests
import time
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

class JinaEmbeddingFunction:
    """
    Custom embedding function for ChromaDB using Jina API.
    """
    
    def __init__(
        self,
        api_key: str = None,
        model: str = "jina-embeddings-v4",
        task: str = "text-matching",
        api_url: str = "https://api.jina.ai/v1/embeddings",
        batch_size: int = 10,
        max_retries: int = 3
    ):
        self.api_key = api_key or os.getenv("JINA_API_KEY")
        if not self.api_key:
            raise ValueError(
                "JINA_API_KEY environment variable is required. "
                "Set it in your .env file or environment."
            )
        self.model = model
        self.task = task
        self.api_url = api_url
        self.batch_size = batch_size
        self.max_retries = max_retries
        self.headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {self.api_key}'
        }
    
    def name(self) -> str:
        return "jina-embeddings-v4"
    
    def __call__(self, input):
        """Generate embeddings for input text(s)."""
        if isinstance(input, str):
            texts = [input]
        else:
            texts = input
        
        if not texts:
            return []
        
        all_embeddings = []
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i + self.batch_size]
            batch_embeddings = self._embed_batch(batch)
            all_embeddings.extend(batch_embeddings)
        
        return all_embeddings
    
    def _embed_batch(self, texts: List[str]) -> List[List[float]]:
        """Embed a batch of texts using Jina API."""
        data = {
            "model": self.model,
            "task": self.task,
            "input": [{"text": text} for text in texts]
        }
        
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    self.api_url,
                    headers=self.headers,
                    json=data,
                    timeout=60
                )
                response.raise_for_status()
                
                result = response.json()
                embeddings = []
                if 'data' in result:
                    for item in result['data']:
                        if 'embedding' in item:
                            embeddings.append(item['embedding'])
                    return embeddings
                else:
                    raise ValueError(f"Unexpected API response format: {result}")
                    
            except requests.exceptions.RequestException as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(f"⚠ API request failed (attempt {attempt + 1}/{self.max_retries}), retrying in {wait_time}s...")
                    time.sleep(wait_time)
                else:
                    raise Exception(f"Failed to get embeddings after {self.max_retries} attempts: {e}")
        
        return []

print("✓ JinaEmbeddingFunction class loaded")


In [None]:
# CELL_ID: 06_generation_v3_llm_setup
# ============================================================================
# LLM CONFIGURATION (OLLAMA)
# ============================================================================

# Ollama configuration
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_MODEL = "kimi-k2-thinking:cloud"

print("=" * 60)
print("LLM CONFIGURATION")
print("=" * 60)
print(f"Ollama Base URL: {OLLAMA_BASE_URL}")
print(f"Model: {OLLAMA_MODEL}")
print("=" * 60)

# Initialize Ollama LLM
llm = ChatOllama(
    model=OLLAMA_MODEL,
    base_url=OLLAMA_BASE_URL,
    temperature=0.1  # Low temperature for consistent structured outputs
    # num_ctx=4096  # Context window
)

# Test LLM connection
try:
    test_response = llm.invoke("Say 'OK' if you can read this.")
    print(f"✓ LLM connection successful: {test_response.content[:50]}")
except Exception as e:
    print(f"⚠ LLM connection failed: {e}")
    print("  Make sure Ollama is running and model is installed")


In [None]:
# CELL_ID: 06_generation_v4_pydantic_models
# ============================================================================
# OPTIMIZED PYDANTIC MODELS FOR STRUCTURED OUTPUTS
# ============================================================================

class Source(BaseModel):
    """Source citation for generated response."""
    title: str = Field(description="Title of the source section")
    url: str = Field(description="URL path to the source")
    chunk_id: str = Field(description="Chunk ID from ChromaDB")

class ClassifierOutput(BaseModel):
    """Single LLM call output for all classification logic"""
    # Query understanding
    intent: Optional[str] = Field(None, description="Contextually-aware rephrased query for retrieval (only for substantive queries)")
    
    # Classification
    query_type: Literal["greeting", "about_system", "substantive", "irrelevant", "unsafe"] = Field(
        description="Type of user query"
    )
    is_relevant: bool = Field(description="Is query about diabetes management/care")
    is_safe: bool = Field(description="Is safe to answer without personalized medical advice")
    
    # Direct response for non-substantive queries
    direct_response: Optional[str] = Field(None, description="Complete response for greetings/about_system/irrelevant/unsafe queries")
    
    # Routing
    should_generate: bool = Field(description="Whether to proceed to generator node")
    
    # User feedback (for streaming)
    status_message: str = Field(description="Status update for user (e.g., 'Understanding your query...')")

class GeneratorOutput(BaseModel):
    """Generator node structured output"""
    response: str = Field(description="Final answer with inline citations")
    has_sufficient_info: bool = Field(description="Whether sufficient chunks were found")
    sources_used: List[str] = Field(default_factory=list, description="List of source URLs used")

print("✓ Optimized Pydantic models defined")
print("  • ClassifierOutput: Unified classification with intent rephrasing")
print("  • GeneratorOutput: Structured generation with citations")
print("  • Source: Citation metadata")



In [None]:
# CELL_ID: 06_generation_v4_state_schema
# ============================================================================
# OPTIMIZED STATE SCHEMA
# ============================================================================

class ChatState(MessagesState):
    """Optimized state schema with structured outputs"""
    # Classifier outputs
    classifier_output: Optional[ClassifierOutput]
    
    # Retrieval (programmatic)
    retrieved_chunks: List[Dict]
    
    # Generator outputs
    generator_output: Optional[GeneratorOutput]
    sources: List[Source]
    
    # Final response
    final_response: Optional[str]

print("✓ Optimized ChatState schema defined")



In [None]:
# CELL_ID: 06_generation_v4_unified_classifier
# ============================================================================
# UNIFIED CLASSIFIER NODE (LLM CALL #1)
# ============================================================================

def classify_query_unified(state: ChatState) -> ChatState:
    """
    Single LLM call handles all classification logic:
    - Greetings
    - Questions about system
    - Intent understanding/rephrasing
    - Relevance check
    - Safety check
    - Routing decision
    """
    
    classifier_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are the classification system for a Diabetes Knowledge Management Assistant based on Kenya National Clinical Guidelines.

Your job is to analyze user queries and determine the appropriate response path.

## Query Types & Responses:

1. **GREETING** (e.g., "Hi", "Hello", "Hey there")
   - query_type: "greeting"
   - direct_response: "Hello! I'm a diabetes knowledge assistant based on Kenya National Clinical Guidelines for Diabetes Management. I can help healthcare providers with questions about diabetes diagnosis, treatment, management, and prevention. How can I assist you today?"
   - is_relevant: False
   - is_safe: True
   - should_generate: False
   - status_message: "Processing greeting..."

2. **ABOUT_SYSTEM** (e.g., "What can you do?", "How do you work?", "What are you?")
   - query_type: "about_system"
   - direct_response: "I'm a specialized AI assistant for healthcare providers focused on diabetes management. I provide information based on the Kenya National Clinical Guidelines for the Management of Diabetes. I can answer questions about:\\n\\n- Diabetes diagnosis and screening\\n- Treatment options and medications\\n- Management strategies\\n- Complications and prevention\\n- Clinical protocols\\n\\nI cannot provide patient-specific medical advice. How can I help with your diabetes-related question?"
   - is_relevant: False
   - is_safe: True
   - should_generate: False
   - status_message: "Explaining system capabilities..."

3. **IRRELEVANT** (Not about diabetes)
   - query_type: "irrelevant"
   - is_relevant: False
   - is_safe: True
   - direct_response: "I'm sorry, but I'm specifically designed to answer questions about diabetes management based on the Kenya National Clinical Guidelines. Your query doesn't appear to be related to diabetes. Please ask me about diabetes diagnosis, treatment, management, or prevention."
   - should_generate: False
   - status_message: "Analyzing query relevance..."

4. **UNSAFE** (Patient-specific medical advice, diagnoses, prognoses)
   - query_type: "unsafe"
   - is_relevant: True
   - is_safe: False
   - direct_response: "I cannot provide patient-specific medical advice, diagnoses, or treatment recommendations. This type of question requires a healthcare provider who can evaluate the full clinical context and provide personalized guidance.\\n\\nI can help with general questions about diabetes management guidelines and protocols. Would you like to rephrase your question in a more general way?"
   - should_generate: False
   - status_message: "Evaluating query safety..."

5. **SUBSTANTIVE** (Safe diabetes questions)
   - query_type: "substantive"
   - is_relevant: True
   - is_safe: True
   - intent: <Rephrase query with full context for retrieval>
   - direct_response: None
   - should_generate: True
   - status_message: "Understanding your query and searching knowledge base..."

## Intent Rephrasing (for SUBSTANTIVE queries only):
- If follow-up question, incorporate context from conversation history
- If standalone question, ensure it's clear and complete
- Make it suitable for semantic search - use natural medical/clinical language
- DO NOT add phrases like "guidelines", "Kenya National Clinical Guidelines", "according to guidelines" - the knowledge base contains clinical content, not meta-references
- Focus on the actual medical/clinical concepts and terms
- Example: User says "What about that?" after discussing Type 2 diabetes → intent: "What are the management approaches for Type 2 diabetes?"
- Example: User says "How is diabetes diagnosed?" → intent: "What are the diagnostic criteria and screening procedures for diabetes?"
- Example: User says "What about treatment?" → intent: "What are the treatment options for Type 2 diabetes?"

## Safety Examples:
**UNSAFE:**
- "My patient has diabetes and blood pressure, will they die?"
- "Should I stop taking my insulin?"
- "What dose of metformin should I give this patient?"

**SAFE:**
- "What are the general treatment options for Type 2 diabetes?"
- "What are the guidelines for insulin therapy initiation?"
- "What are the recommended HbA1c targets in diabetes management?"

## Important:
- status_message should be user-friendly and informative
- direct_response should be complete, helpful, and professional
- intent should be self-contained and context-aware for retrieval
- Always be polite and helpful even when declining

## Output Format:
You MUST respond with a valid JSON object with the following structure:
{{
  "query_type": "greeting|about_system|substantive|irrelevant|unsafe",
  "is_relevant": true/false,
  "is_safe": true/false,
  "should_generate": true/false,
  "status_message": "Status message for user",
  "intent": "Rephrased query (only for substantive queries, null otherwise)",
  "direct_response": "Complete response (only for non-substantive queries, null otherwise)"
}}

Return ONLY valid JSON, no markdown formatting or additional text."""),
        MessagesPlaceholder(variable_name="messages"),
    ])
    
    # Helper function to parse output (JSON or markdown)
    def parse_classifier_output(text: str) -> ClassifierOutput:
        """Parse classifier output - handles both JSON and markdown formats"""
        import json
        import re
        
        # First, try to extract JSON from the text
        json_match = re.search(r'\{[^{}]*"query_type"[^{}]*\}', text, re.DOTALL)
        if json_match:
            try:
                json_str = json_match.group(0)
                data = json.loads(json_str)
                return ClassifierOutput(**data)
            except:
                pass
        
        # Try to parse entire text as JSON
        try:
            data = json.loads(text.strip())
            return ClassifierOutput(**data)
        except:
            pass
        
        # Fallback: parse markdown-formatted output
        query_type_match = re.search(r'\*\*query_type\*\*:\s*(\w+)', text)
        is_relevant_match = re.search(r'\*\*is_relevant\*\*:\s*(True|False)', text)
        is_safe_match = re.search(r'\*\*is_safe\*\*:\s*(True|False)', text)
        should_generate_match = re.search(r'\*\*should_generate\*\*:\s*(True|False)', text)
        status_message_match = re.search(r'\*\*status_message\*\*:\s*(.+?)(?=\n\*\*|\Z)', text, re.DOTALL)
        intent_match = re.search(r'\*\*intent\*\*:\s*(.+?)(?=\n\*\*|\Z)', text, re.DOTALL)
        direct_response_match = re.search(r'\*\*direct_response\*\*:\s*(.+?)(?=\n\*\*|\Z)', text, re.DOTALL)
        
        # Parse values
        query_type = query_type_match.group(1) if query_type_match else "substantive"
        is_relevant = is_relevant_match.group(1) == "True" if is_relevant_match else True
        is_safe = is_safe_match.group(1) == "True" if is_safe_match else True
        should_generate = should_generate_match.group(1) == "True" if should_generate_match else True
        status_message = status_message_match.group(1).strip() if status_message_match else "Processing query..."
        intent = intent_match.group(1).strip() if intent_match else None
        direct_response = direct_response_match.group(1).strip() if direct_response_match else None
        
        return ClassifierOutput(
            query_type=query_type,
            is_relevant=is_relevant,
            is_safe=is_safe,
            should_generate=should_generate,
            status_message=status_message,
            intent=intent,
            direct_response=direct_response
        )
    
    # Use direct LLM call with robust parsing (Ollama works better this way)
    # The parser handles both JSON and markdown formats
    chain = classifier_prompt | llm | StrOutputParser()
    raw_output = chain.invoke({"messages": state["messages"]})
    result = parse_classifier_output(raw_output)
    
    # Store in state
    state["classifier_output"] = result
    
    # Set final response for non-substantive queries
    if not result.should_generate:
        state["final_response"] = result.direct_response
        # Add response to messages for consistency
        if result.direct_response:
            state["messages"] = state.get("messages", []) + [AIMessage(content=result.direct_response)]
    
    # Stream status update with intent for substantive queries
    writer = get_stream_writer()
    if writer:
        if result.should_generate and result.intent:
            # For substantive queries, include intent in status
            status_msg = f"I am getting the relevant resources to answer: {result.intent}"
        else:
            status_msg = result.status_message
        writer({"type": "classifier_status", "message": status_msg})
    
    print(f"✓ Classified as: {result.query_type}")
    print(f"  Relevant={result.is_relevant}, Safe={result.is_safe}, Generate={result.should_generate}")
    if result.intent:
        print(f"  Intent: {result.intent[:80]}...")
    
    return state

print("✓ Unified classifier node defined")



In [None]:
# CELL_ID: 06_generation_v3_search_tool
# ============================================================================
# SEMANTIC RETRIEVAL TOOL FOR AGENT
# ============================================================================

@tool
def search_semantic_only(
    query: str,
    n_results: int = 5,
    min_similarity: float = 0.4,
    runtime: ToolRuntime = None
) -> List[Dict[str, Any]]:
    """
    Search using semantic similarity only. Use when you need to retrieve information from the knowledge base.
    Only chunks with relevance_score >= min_similarity (0.4) are returned.
    
    Args:
        query: Search query text
        n_results: Number of results to return (default: 5)
        min_similarity: Minimum relevance score threshold (0-1, default: 0.4)
    
    Returns:
        List of chunk dictionaries with content, metadata, and relevance_score
    """
    if runtime and runtime.stream_writer:
        runtime.stream_writer({"type": "tool_progress", "message": f"Semantic search: {query[:50]}..."})
    
    try:
        chunks = chroma_reader.search(
            query=query,
            n_results=n_results,
            min_similarity=min_similarity,
            where=None  # No metadata filtering
        )
        
        if runtime and runtime.stream_writer:
            runtime.stream_writer({"type": "tool_progress", "message": f"Found {len(chunks)} chunks via semantic search"})
        
        return chunks
    except Exception as e:
        error_msg = f"Semantic search failed: {str(e)}"
        if runtime and runtime.stream_writer:
            runtime.stream_writer({"type": "tool_error", "message": error_msg})
        return [{"error": error_msg, "chunk_id": None, "content": "", "metadata": {}, "relevance_score": 0.0}]

print("✓ Semantic retrieval tool created: search_semantic_only")


In [None]:
# CELL_ID: 06_generation_v4_retrieval_node
# ============================================================================
# PROGRAMMATIC RETRIEVAL NODE (NO LLM)
# ============================================================================

def retrieval_node(state: ChatState) -> ChatState:
    """
    Programmatic retrieval based on classifier intent.
    No LLM calls - pure Python logic.
    """
    classifier_output = state["classifier_output"]
    writer = get_stream_writer()
    
    if not classifier_output.should_generate:
        # Skip retrieval for non-substantive queries
        return state
    
    intent = classifier_output.intent
    
    if not intent:
        print("⚠ No intent available for retrieval")
        return state
    
    # Always do fresh retrieval with rephrased intent
    try:
        chunks = chroma_reader.search(
            query=intent,
            n_results=5,
            min_similarity=0.4
        )
        
        # Update state
        state["retrieved_chunks"] = chunks
        
        # Stream retrieval status
        if writer:
            if chunks:
                writer({"type": "retrieval_status", "message": f"Found {len(chunks)} relevant sources. Generating answer..."})
            else:
                writer({"type": "retrieval_status", "message": "No sources found with sufficient relevance. Responding..."})
        
        print(f"✓ Retrieved {len(chunks)} chunks (similarity >= 0.4)")
        if chunks:
            print(f"  Top relevance: {chunks[0]['relevance_score']:.3f}")
    except Exception as e:
        print(f"⚠ Retrieval error: {e}")
        import traceback
        traceback.print_exc()
        state["retrieved_chunks"] = []
        if writer:
            writer({"type": "retrieval_error", "message": "Error during retrieval. Continuing..."})
    
    return state

print("✓ Retrieval node defined")



In [None]:
# CELL_ID: 06_generation_v4_generator_node
# ============================================================================
# DIRECT GENERATOR NODE (LLM CALL #2)
# ============================================================================

def generator_node(state: ChatState) -> ChatState:
    """
    Single LLM call for generation with conversation history.
    Handles both cases: with chunks and without chunks.
    """
    try:
        chunks = state.get("retrieved_chunks", [])
        classifier_output = state.get("classifier_output")
        
        if not classifier_output:
            state["final_response"] = "Error: No classifier output available."
            return state
        
        intent = classifier_output.intent
        if not intent:
            state["final_response"] = "Error: No intent available for generation."
            return state
        
        # Build context from chunks
        # IMPORTANT: Number chunks sequentially (1, 2, 3...) and track chunk-to-source mapping
        if chunks:
            context_parts = []
            chunk_to_source_map = {}  # Maps chunk index (1-based) to Source object
            seen_urls = {}  # Maps URL to first Source object with that URL
            
            for i, chunk in enumerate(chunks, 1):  # Start numbering from 1
                metadata = chunk.get("metadata", {})
                title = metadata.get("title", "Unknown")
                url = metadata.get("url", "")
                content = chunk.get("content", "")
                relevance = chunk.get("relevance_score", 0)
                
                # Format context clearly with source info
                context_parts.append(f"--- Source {i}: {title} (Relevance: {relevance:.2f}) ---\nURL: {url}\n\n{content}")
                
                # Create source for this chunk
                source = Source(
                    title=title,
                    url=url,
                    chunk_id=chunk.get("chunk_id", "")
                )
                
                # Map this chunk index to its source
                chunk_to_source_map[i] = source
                
                # Track first occurrence of each URL for deduplication
                if url and url not in seen_urls:
                    seen_urls[url] = source
            
            context = "\n\n".join(context_parts)
            has_context = True
        else:
            context = "No relevant information found in knowledge base."
            chunk_to_source_map = {}
            seen_urls = {}
            has_context = False
        
        # Build generator prompt with conversation history
        generator_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a diabetes specialist assistant for healthcare providers, based on Kenya National Clinical Guidelines for Diabetes Management.

## Your Task:
Answer the user's query using the provided context from the knowledge base.

## CRITICAL GUARDRAILS:
1. **USE the provided context** - The context contains relevant clinical information. If context is provided, you MUST use it to answer the question
2. **Be factual and truthful** - Base your answer on the context provided
3. **No personalized medical advice** - Provide general clinical information only
4. **Use numbered citations** - Format: [1], [2], [3] etc. when referencing specific information from the context
5. **DO NOT add a Sources section** - The frontend will handle displaying sources automatically
6. **Be comprehensive** - Extract and use ALL relevant information from the context to provide a thorough answer

## Critical Instructions:
- **If context is provided above, you MUST answer the question** - Do NOT say "insufficient information"
- The context contains relevant clinical information that was retrieved specifically for this query
- Extract and synthesize information from ALL provided context chunks
- Only mention "insufficient information" if the context section explicitly says "No relevant information found"
- When context is provided, your job is to answer comprehensively using that information
- **Use numbered citations [1], [2], [3] etc.** - Reference sources by their number from the available sources list
- **Only cite sources you actually use** - If you mention information from Source 1, use [1]. If from Source 2, use [2], etc.
- **DO NOT include a "## Sources" section at the end** - The system will automatically extract and display only the sources you actually cited
- **DO NOT use [Title](url) format** - Use only numbered references like [1], [2], [3]

## Response Format:
- Clear, clinical language for healthcare providers
- Numbered citations throughout: [1], [2], [3] when referencing specific information
- Structured with headers if appropriate
- NO Sources section - just use numbered citations
- Be comprehensive - use all relevant information from the context"""),
            MessagesPlaceholder(variable_name="messages"),
            ("human", """User Query: {intent}

Relevant Information from Knowledge Base:
{context}

IMPORTANT CITATION INSTRUCTIONS:
- Each chunk in the context above is labeled as "Source 1", "Source 2", "Source 3", etc.
- When you reference information from a chunk, cite it using its Source number: [1], [2], [3], etc.
- For example, if you use information from "Source 1", cite it as [1]
- If you use information from "Source 2", cite it as [2]
- Only cite sources (chunks) that you actually use in your answer
- Do NOT cite sources you don't use

CRITICAL: 
- The context above contains relevant clinical information. Use this information to provide a comprehensive answer to the user's query.
- Include specific details, recommendations, and numbered citations [1], [2], [3] etc. when referencing information from specific chunks.
- Only cite chunks you actually use in your answer.

Provide your answer following all guardrails above.""")
        ])
        
        # Use direct LLM call (Ollama works well with structured prompts)
        # Determine has_sufficient_info programmatically based on chunks
        has_sufficient_info = len(chunks) > 0 and any(chunk.get('relevance_score', 0) >= 0.4 for chunk in chunks)
        
        # Build chain
        chain = generator_prompt | llm | StrOutputParser()
        
        # Invoke with full conversation history
        response = chain.invoke({
            "messages": state["messages"],
            "intent": intent,
            "context": context
        })
        
        final_response = response if isinstance(response, str) else response.content
        
        # Remove any "## Sources" section that the LLM might have added
        # Split by "## Sources" and take only the content before it
        if "## Sources" in final_response:
            final_response = final_response.split("## Sources")[0].strip()
        
        # Extract only sources that are actually referenced in the response
        # Citations refer to chunk numbers (1, 2, 3...) from the context
        import re
        referenced_chunk_numbers = set()  # Chunk numbers cited (1-based)
        
        # Validate chunk numbers are within valid range
        max_chunk_num = len(chunks) if chunks else 0
        
        # Pattern to match numbered citations like [1], [2], [10] but not [Title](url)
        # Match [number] where number is digits, but not followed by (
        citation_pattern = r'\[(\d+)\](?!\()'
        matches = re.findall(citation_pattern, final_response)
        for num_str in matches:
            try:
                chunk_num = int(num_str)  # This is 1-based chunk number from context
                # Validate: chunk number must be within valid range (1 to max_chunk_num)
                if 1 <= chunk_num <= max_chunk_num and chunk_num in chunk_to_source_map:
                    referenced_chunk_numbers.add(chunk_num)
                else:
                    # Log invalid citation for debugging
                    print(f"  ⚠ Invalid citation [{chunk_num}] - out of range (valid: 1-{max_chunk_num})")
            except ValueError:
                pass
        
        # Also check for [Title](url) format as fallback (in case LLM uses old format)
        referenced_urls = set()
        url_citation_pattern = r'\[([^\]]+)\]\(([^\)]+)\)'
        url_matches = re.findall(url_citation_pattern, final_response)
        for title, url in url_matches:
            referenced_urls.add(url)
        
        # Extract sources ONLY from chunks that were actually cited
        # This ensures we only return sources for chunks that were used
        # NO FALLBACK - if no citations found, return empty list
        cited_sources = []
        cited_urls = set()  # Track URLs to avoid duplicates
        
        # First, get sources from cited chunk numbers (in order of citation)
        for chunk_num in sorted(referenced_chunk_numbers):
            if chunk_num in chunk_to_source_map:
                source = chunk_to_source_map[chunk_num]
                # Only add if we haven't seen this URL yet (preserve first occurrence order)
                if source.url not in cited_urls:
                    cited_sources.append(source)
                    cited_urls.add(source.url)
        
        # Also include sources referenced by URL (fallback for markdown link format)
        # But only if they were actually cited in the response
        for url in referenced_urls:
            if url not in cited_urls:
                # Find source with this URL from seen_urls
                if url in seen_urls:
                    source = seen_urls[url]
                    cited_sources.append(source)
                    cited_urls.add(url)
        
        # CRITICAL: If no citations found at all, log warning but return empty sources
        # We should NOT return all chunks if none were cited
        if not referenced_chunk_numbers and not referenced_urls:
            print(f"  ⚠ WARNING: No citations found in response!")
            print(f"     Response length: {len(final_response)} chars")
            print(f"     Chunks provided: {len(chunks)}")
            print(f"     This means the LLM used information without citing it properly")
            # Return empty sources - we can't cite what wasn't cited
            cited_sources = []
        
        # Log for debugging
        print(f"  Citations found: {sorted(referenced_chunk_numbers)}")
        print(f"  URLs cited: {list(referenced_urls)}")
        print(f"  Chunks provided: {len(chunks)}")
        print(f"  Sources returned: {len(cited_sources)}")
        if len(cited_sources) != len(referenced_chunk_numbers) + len(referenced_urls):
            print(f"  ⚠ Note: Some citations may reference the same source (deduplicated)")
        
        # CRITICAL: Validate and filter cited_sources to ensure only actually cited sources are included
        # Verify each source in cited_sources was actually cited
        cited_chunk_nums = {chunk_num for chunk_num in referenced_chunk_numbers if chunk_num in chunk_to_source_map}
        cited_source_urls = {chunk_to_source_map[cn].url for cn in cited_chunk_nums}
        cited_source_urls.update(referenced_urls)
        
        # Remove any sources that weren't actually cited
        final_cited_sources = [s for s in cited_sources if s.url in cited_source_urls]
        
        if len(final_cited_sources) != len(cited_sources):
            print(f"  ⚠ WARNING: Removed {len(cited_sources) - len(final_cited_sources)} uncited sources!")
            print(f"     Expected URLs: {cited_source_urls}")
            print(f"     Found URLs: {[s.url for s in cited_sources]}")
            cited_sources = final_cited_sources
        
        # Final validation: ensure all sources in cited_sources were actually cited
        for source in cited_sources:
            if source.url not in cited_source_urls:
                print(f"  ⚠ ERROR: Source with URL {source.url} was not cited but included in results!")
                cited_sources = [s for s in cited_sources if s.url in cited_source_urls]
                break
        
        # Extract source URLs from validated cited sources
        sources_used = [source.url for source in cited_sources]
        
        # Create GeneratorOutput object
        result = GeneratorOutput(
            response=final_response,
            has_sufficient_info=has_sufficient_info,
            sources_used=sources_used
        )
        
        # If no chunks found, add insufficient info message
        if not has_sufficient_info and not chunks:
            if "don't have sufficient information" not in final_response.lower():
                final_response = "I don't have sufficient information in my knowledge base to answer this question accurately. You may want to:\n- Rephrase your question with more specific terms\n- Ask about a different aspect of diabetes management\n- Consult the full clinical guidelines directly"
        
        # Update result with correct has_sufficient_info
        result.has_sufficient_info = has_sufficient_info
        result.response = final_response
        
        # Update state - only store sources that were actually cited
        state["generator_output"] = result
        state["sources"] = cited_sources  # Only cited sources, not all retrieved
        state["final_response"] = final_response
        
        print(f"✓ Generated response: {len(final_response)} chars")
        print(f"  Sufficient info: {result.has_sufficient_info}")
        print(f"  Retrieved chunks: {len(chunks)}")
        print(f"  Cited sources: {len(cited_sources)}")
        
        return state
        
    except Exception as e:
        error_msg = f"Error in generator node: {str(e)[:200]}"
        print(f"❌ {error_msg}")
        import traceback
        traceback.print_exc()
        state["final_response"] = f"I encountered an error while generating the response: {str(e)[:200]}. Please try rephrasing your question."
        return state

print("✓ Generator node defined")



In [None]:
# CELL_ID: 06_generation_v4_build_graph
# ============================================================================
# BUILD OPTIMIZED LANGGRAPH WORKFLOW
# ============================================================================

def route_after_classifier(state: ChatState) -> str:
    """Route based on classifier decision"""
    classifier_output = state.get("classifier_output")
    if classifier_output and classifier_output.should_generate:
        return "retrieval"
    else:
        return END

# Build graph
workflow = StateGraph(ChatState)

# Add nodes
workflow.add_node("classifier", classify_query_unified)
workflow.add_node("retrieval", retrieval_node)
workflow.add_node("generator", generator_node)

# Set entry point
workflow.set_entry_point("classifier")

# Add conditional routing
workflow.add_conditional_edges(
    "classifier",
    route_after_classifier,
    {
        "retrieval": "retrieval",
        END: END
    }
)

# Linear path for substantive queries
workflow.add_edge("retrieval", "generator")
workflow.add_edge("generator", END)

# Compile
graph = workflow.compile()

print("✓ Graph built and compiled")
print("\\nOptimized workflow:")
print("  START → classifier → [retrieval → generator | END]")
print("\\nExpected LLM calls:")
print("  • Greetings/About/Irrelevant/Unsafe: 1 LLM call")
print("  • Substantive queries: 2 LLM calls")



In [None]:
# CELL_ID: 06_generation_v3_visualize_graph
# ============================================================================
# VISUALIZE GRAPH STRUCTURE
# ============================================================================

try:
    from IPython.display import Image, display
    graph_image = graph.get_graph().draw_mermaid_png()
    display(Image(graph_image))
    print("✓ Graph visualization displayed")
except Exception as e:
    print(f"⚠ Visualization error: {e}")
    print("Graph structure:")
    print("  START → classify → [not_relevant | unsafe | generator] → END")


In [None]:
# CELL_ID: 06_generation_v4_gradio_interface
# ============================================================================
# GRADIO INTERFACE WITH STREAMING
# ============================================================================

def chat_interface_streaming(message, history):
    """
    Streaming chat interface with status updates.
    Provides feedback at each stage for better UX.
    """
    # Convert history to messages
    messages = []
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append(HumanMessage(content=user_msg))
        if assistant_msg:
            messages.append(AIMessage(content=assistant_msg))
    
    messages.append(HumanMessage(content=message))
    
    # Initialize state
    initial_state = {
        "messages": messages,
        "classifier_output": None,
        "retrieved_chunks": [],
        "generator_output": None,
        "sources": [],
        "final_response": None
    }
    
    # Stream with updates mode
    current_response = ""
    
    try:
        for chunk in graph.stream(initial_state, stream_mode="updates"):
            for node_name, state_update in chunk.items():
                if node_name == "classifier":
                    classifier_output = state_update.get("classifier_output")
                    if classifier_output:
                        # Stream status message
                        current_response = classifier_output.status_message
                        yield current_response
                        
                        # If direct response, yield it
                        if not classifier_output.should_generate:
                            current_response = classifier_output.direct_response
                            yield current_response
                
                elif node_name == "retrieval":
                    # Stream retrieval status
                    chunks = state_update.get("retrieved_chunks", [])
                    if chunks:
                        current_response = f"✓ Found {len(chunks)} relevant sources. Generating answer..."
                    else:
                        current_response = "⚠ No sources found with sufficient relevance. Responding..."
                    yield current_response
                
                elif node_name == "generator":
                    # Final response
                    final_response = state_update.get("final_response")
                    if final_response:
                        current_response = final_response
                        yield current_response
        
        # Ensure we have a response
        if not current_response:
            yield "No response generated."
    
    except Exception as e:
        error_msg = f"Error: {str(e)[:200]}"
        yield error_msg

def create_gradio_interface():
    """Create and return Gradio chat interface."""
    with gr.Blocks(title="Diabetes Knowledge Management Assistant (Optimized)", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # Diabetes Knowledge Management Assistant 
        
        
        Ask questions about diabetes management, treatment, diagnosis, and related topics based on the Kenya National Clinical Guidelines.
        """)
        
        chatbot = gr.Chatbot(
            label="Conversation",
            height=600,
            show_copy_button=True
        )
        
        msg = gr.Textbox(
            label="Your Question",
            placeholder="Type your question here...",
            lines=2
        )
        
        with gr.Row():
            submit_btn = gr.Button("Submit", variant="primary")
            clear_btn = gr.Button("Clear Conversation")
        
        gr.Markdown("""
        ### Instructions
        - Ask questions about diabetes management, treatment, diagnosis, prevention
        - The assistant retrieves information from the knowledge base
        - Responses include inline citations and sources section
        - Follow-up questions are automatically contextualized
        
        """)
        
        # Event handlers
        def respond(message, history):
            response = ""
            for chunk in chat_interface_streaming(message, history):
                response = chunk
                yield history + [[message, chunk]]
        
        submit_btn.click(
            respond,
            inputs=[msg, chatbot],
            outputs=[chatbot]
        ).then(
            lambda: "",
            outputs=[msg]
        )
        
        msg.submit(
            respond,
            inputs=[msg, chatbot],
            outputs=[chatbot]
        ).then(
            lambda: "",
            outputs=[msg]
        )
        
        clear_btn.click(
            lambda: ([], ""),
            outputs=[chatbot, msg]
        )
    
    return demo

# Create interface
demo = create_gradio_interface()

print("✓ Gradio interface created (optimized)")
print("\\nTo launch the interface, run:")
print("  demo.launch(share=True)  # For public link")
print("  demo.launch()  # For local only")



In [None]:
# CELL_ID: 07_generation_v2_initialize_chromadb
# ============================================================================
# INITIALIZE CHROMADB READER
# ============================================================================

print("=" * 60)
print("INITIALIZING CHROMADB READER")
print("=" * 60)

# Initialize Jina embedding function
jina_embedding_fn = JinaEmbeddingFunction()
print("✓ Jina embedding function ready")

# Initialize ChromaDB reader
chroma_reader = ChromaDBReader(
    chroma_db_path="./chroma_db",
    collection_name="diabetes_guidelines_v1",
    embedding_function=jina_embedding_fn
)
chroma_reader.initialize()

print("=" * 60)


In [None]:
# CELL_ID: 06_generation_v3_launch_gradio
# ============================================================================
# LAUNCH GRADIO INTERFACE
# ============================================================================

# Launch the interface
# Uncomment the line below to launch
demo.launch(share=True)  # Creates a public link
# demo.close()

print("✓ Ready to launch Gradio interface")
print("Uncomment demo.launch() in the cell above to start the interface")
