

# Med SARATHI - Medical Diagnosis Assistant: An AI Agent Workflow

## Overview

Welcome to this Kaggle notebook implementing a **Socially Impactful Medical Diagnosis Assistant**. This project explores the potential of using advanced AI techniques – specifically a multi-agent system orchestrated by LangGraph and powered by Google's Gemini LLM – to provide preliminary diagnostic insights based on patient symptoms. It integrates Retrieval-Augmented Generation (RAG) using LlamaIndex and FAISS to ground responses in medical knowledge and incorporates explicit checks for potential biases in the AI's reasoning. The goal is to create a system that is not only intelligent but also equitable and accessible.

This notebook contains the complete implementation, from data handling and agent definition to workflow orchestration and a user-friendly interface built with Gradio.

## Why This Matters: The Motivation

Access to timely, reliable, and unbiased medical information remains a critical challenge worldwide. Patients often struggle to understand their symptoms or navigate complex healthcare systems, while clinicians face increasing workloads. Furthermore, unconscious biases can sometimes influence diagnostic pathways.

This project aims to address these challenges by:

1.  **Improving Accessibility:** Offering users initial guidance and educational material based on their symptoms.
2.  **Promoting Equity:** Actively identifying and flagging potential demographic, socioeconomic, or cultural biases in the diagnostic suggestions.
3.  **Exploring AI for Social Good:** Demonstrating how sophisticated AI workflows can be applied responsibly to complex, real-world problems with significant social impact.
4.  **Potential Support for Healthcare:** Investigating foundational technology that could, in the future, assist in optimizing patient flow (e.g., guiding patients to the correct OPD).

## How It Works: Technology & Approach

This assistant utilizes a **multi-agent workflow** where different AI agents collaborate to analyze symptoms and generate a comprehensive report:

1.  **Symptoms Input:** User provides symptoms via a simple UI (Gradio).
2.  **RAG Context:** The system retrieves relevant information from a knowledge base (initially dummy data, extensible to PubMed, etc.) using LlamaIndex/FAISS.
3.  **Agentic Analysis:** The symptoms and context flow through specialized agents built with LangGraph and Gemini:
    * **Diagnostician:** Generates initial diagnoses.
    * **Validator:** Critiques the initial diagnosis.
    * **Bias Checker:** Analyzes for potential biases.
    * **Educator:** Creates patient-friendly explanations.
4.  **Structured Report:** The final output is a formatted report consolidating insights from all agents.

**Core Technologies:** `Python`, `LangGraph`, `Google Gemini`, `LlamaIndex`, `FAISS`, `Gradio`.

## Key Features Implemented

* Multi-Agent Workflow (Diagnostician, Validator, Bias Checker, Educator)
* Retrieval-Augmented Generation (RAG) for Grounding
* Structured JSON Output Generation & Processing
* Explicit Bias Detection Analysis
* Patient Education Component
* Interactive Gradio Web UI (Text Input)
* Error Handling & Robust API Calls


## Future Scope

The underlying architecture holds significant potential for expansion. One promising direction is developing this into a **Pre-Admission OPD Agent**. Such an application could analyze patient symptoms to suggest the correct hospital department, potentially screen for urgency (serious vs. non-serious cases), and help optimize patient flow, reducing waiting times and administrative overhead in busy OPD centers.

## Disclaimer

**This notebook and the resulting application are a proof-of-concept demonstrator for exploring AI techniques. It is NOT a certified medical device and should NOT be used as a substitute for professional medical advice, diagnosis, or treatment.** Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition.

---

*This notebook contains the code implementation for the system described above.*

In [1]:
# Install necessary libraries
!pip install -q langgraph llama-index google-generativeai faiss-cpu gradio langdetect python-dotenv llama-index-vector-stores-faiss llama-index-embeddings-gemini llama-index-readers-file sentence-transformers torch


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m148.2/148.2 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m51.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.9/46.9 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.6/322.6 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [2]:
import os
import json
import time
import logging
from typing import TypedDict, List, Dict, Optional, Any
import warnings
import re

In [3]:
# Import core libraries
import google.generativeai as genai
from google.api_core import exceptions as google_exceptions
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext, StorageContext
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.gemini import GeminiEmbedding # Use Gemini for embeddings
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.indices.prompt_helper import PromptHelper
import faiss # Explicit import for FAISS index creation
import gradio as gr
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv # For API key management

In [4]:
# Suppress common warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


In [5]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


In [6]:
# --- Kaggle-Specific Setup & API Key ---
# Best practice: Use Kaggle Secrets for your API key.
# 1. Click "Add-ons" -> "Secrets" in the Kaggle editor.
# 2. Add a secret named "GOOGLE_API_KEY" with your actual Gemini API key.
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    GOOGLE_API_KEY = user_secrets.get_secret("GOOGLE_API_KEY")
    os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY
    logger.info("Successfully loaded GOOGLE_API_KEY from Kaggle Secrets.")
except ImportError:
    logger.warning("Kaggle Secrets not available. Attempting to load from .env file or environment variables.")
    # Fallback for local development (create a .env file with GOOGLE_API_KEY=YOUR_KEY)
    load_dotenv()
    GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
    if not GOOGLE_API_KEY:
        logger.error("GOOGLE_API_KEY not found in Kaggle Secrets, .env file, or environment variables. Please set it up.")
        # In a real notebook, you might raise an error or stop execution here.
        # For this example, we'll proceed but API calls will fail.
        # raise ValueError("API Key not configured.") # Uncomment to enforce key presence
    else:
        logger.info("Successfully loaded GOOGLE_API_KEY from .env file or environment.")


In [7]:
# Configure the Google Generative AI client
if GOOGLE_API_KEY:
    try:
        genai.configure(api_key=GOOGLE_API_KEY)
        logger.info("Google Generative AI configured successfully.")
    except Exception as e:
        logger.error(f"Error configuring Google Generative AI: {e}")
        # Handle configuration error (e.g., invalid key format)
else:
    logger.warning("Google Generative AI not configured due to missing API key.")



In [8]:

# --- GPU Memory Optimization Check (Kaggle T4) ---
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB
    logger.info(f"GPU detected: {gpu_name}")
    logger.info(f"Total GPU Memory: {total_memory:.2f} GB")
    # Add any T4-specific optimizations here if needed later (e.g., half-precision)
    # torch.set_float32_matmul_precision('high') # Example optimization
else:
    logger.info("No GPU detected. Running on CPU.")

In [9]:
# --- Constants ---
# Using gemini-1.5-flash as requested, ensure this model is available and suitable.
# Use "gemini-1.5-flash-latest" for the most recent version.
LLM_MODEL_NAME = "gemini-2.5-flash-preview-04-17"
EMBEDDING_MODEL_NAME = "models/text-embedding-004" # Recommended embedding model
# API Rate Limiting (Example: Max 5 requests per minute)
API_MAX_CALLS = 5
API_TIME_WINDOW = 60  # seconds
api_call_timestamps = []


In [10]:
# Placeholder for PDF data - In a real scenario, upload your PDF(s) to Kaggle Dataset
PDF_DIR = "./pubmed_data/"
PDF_FILENAME = "pubmed_papers.pdf" # The user specified this name
PDF_FILEPATH = os.path.join(PDF_DIR, PDF_FILENAME)


In [11]:
# Create dummy directory and PDF for demonstration purposes if it doesn't exist
if not os.path.exists(PDF_DIR):
    os.makedirs(PDF_DIR)
    logger.info(f"Created directory: {PDF_DIR}")

if not os.path.exists(PDF_FILEPATH):
    try:
        # Create a simple dummy PDF content
        from reportlab.pdfgen import canvas
        from reportlab.lib.pagesizes import letter
        c = canvas.Canvas(PDF_FILEPATH, pagesize=letter)
        c.drawString(100, 750, "Sample PubMed Abstract: Malaria Diagnosis")
        c.drawString(100, 735, "Background: Early diagnosis of Plasmodium falciparum malaria is crucial.")
        c.drawString(100, 720, "Methods: Rapid diagnostic tests (RDTs) and microscopy were compared.")
        c.drawString(100, 705, "Results: RDTs showed high sensitivity (95%) and specificity (90%).")
        c.drawString(100, 690, "Conclusion: RDTs are effective tools in resource-limited settings.")
        c.save()
        logger.info(f"Created dummy PDF: {PDF_FILEPATH}")
    except ImportError:
        logger.warning("ReportLab not found. Cannot create dummy PDF. Please upload 'pubmed_papers.pdf' manually.")
        # Create a dummy text file instead if reportlab fails
        with open(PDF_FILEPATH.replace('.pdf', '.txt'), 'w') as f:
            f.write("Sample PubMed Abstract: Malaria Diagnosis\n")
            f.write("Background: Early diagnosis of Plasmodium falciparum malaria is crucial.\n")
            f.write("Methods: Rapid diagnostic tests (RDTs) and microscopy were compared.\n")
            f.write("Results: RDTs showed high sensitivity (95%) and specificity (90%).\n")
            f.write("Conclusion: RDTs are effective tools in resource-limited settings.")
        logger.info(f"Created dummy text file as fallback: {PDF_FILEPATH.replace('.pdf', '.txt')}")
        # Update PDF_FILEPATH if we created a txt file
        if os.path.exists(PDF_FILEPATH.replace('.pdf', '.txt')):
             # SimpleDirectoryReader can often handle .txt files even if expecting PDF
             logger.info("Using .txt fallback for RAG data source.")


In [12]:
# --- Helper Function for API Rate Limiting ---
def check_api_rate_limit():
    """Checks and enforces a simple time-window based rate limit."""
    global api_call_timestamps
    now = time.time()
    # Remove timestamps older than the time window
    api_call_timestamps = [t for t in api_call_timestamps if now - t < API_TIME_WINDOW]
    if len(api_call_timestamps) >= API_MAX_CALLS:
        wait_time = API_TIME_WINDOW - (now - api_call_timestamps[0])
        logger.warning(f"Rate limit reached ({API_MAX_CALLS} req/{API_TIME_WINDOW}s). Waiting for {wait_time:.2f} seconds.")
        time.sleep(wait_time + 0.1) # Add a small buffer
    api_call_timestamps.append(time.time())

In [13]:
# --- Robust LLM Call Function with Error Handling & Retries ---
def generate_gemini_content_with_retry(model_name: str, prompt: str, max_retries=3, initial_delay=2):
    """
    Calls the Gemini API with exponential backoff retry mechanism.
    Handles common API errors.
    """
    if not GOOGLE_API_KEY:
        logger.error("Cannot call Gemini API: API Key not configured.")
        # Fallback: Return a predefined error message or raise an exception
        return "Error: Gemini API key not configured."
        # Or potentially switch to a local model if implemented:
        # return call_local_model(prompt)

    check_api_rate_limit() # Enforce rate limiting before each call
    llm = genai.GenerativeModel(model_name)
    delay = initial_delay
    for attempt in range(max_retries):
        try:
            response = llm.generate_content(prompt)
            # Basic check for blocked content - refine as needed
            if not response.parts:
                 if response.prompt_feedback.block_reason:
                      logger.warning(f"Gemini API call blocked. Reason: {response.prompt_feedback.block_reason}")
                      return f"Error: Content generation blocked by API. Reason: {response.prompt_feedback.block_reason}"
                 else:
                      logger.warning("Gemini API call returned no parts, unknown reason.")
                      return "Error: Content generation failed (No parts received)."

            # Check if the response contains valid text
            if hasattr(response, 'text') and response.text:
                return response.text
            else:
                 # Handle cases where response structure might be different or empty unexpectedly
                 logger.warning(f"Gemini API response format unexpected or empty: {response}")
                 # Attempt to extract text if possible from parts
                 try:
                     return " ".join(part.text for part in response.parts if hasattr(part, 'text'))
                 except Exception:
                     return "Error: Failed to extract text from Gemini response."

        except (google_exceptions.ResourceExhausted,
                google_exceptions.ServiceUnavailable,
                google_exceptions.DeadlineExceeded,
                google_exceptions.InternalServerError) as e:
            logger.warning(f"Gemini API error (Attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay} seconds...")
            time.sleep(delay)
            delay *= 2 # Exponential backoff
        except google_exceptions.InvalidArgument as e:
             logger.error(f"Gemini API Invalid Argument error: {e}. Prompt was likely malformed. Aborting retries.")
             return f"Error: Invalid argument passed to Gemini API. {e}"
        except google_exceptions.PermissionDenied as e:
             logger.error(f"Gemini API Permission Denied error: {e}. Check API key and permissions. Aborting retries.")
             return f"Error: API Permission Denied. {e}"
        except Exception as e:
            # Catch any other unexpected errors
            logger.error(f"An unexpected error occurred during Gemini API call (Attempt {attempt + 1}/{max_retries}): {e}")
            time.sleep(delay)
            delay *= 2
    logger.error(f"Gemini API call failed after {max_retries} attempts.")
    # Fallback strategy: Could return a canned error, or try a local model
    # return call_local_model(prompt) # If local fallback exists
    return "Error: Failed to get response from Gemini API after multiple retries."


print("\n--- Cell 1: Setup Complete ---")


--- Cell 1: Setup Complete ---


In [14]:
 # NOTEBOOK CELL 2: RAG System (LlamaIndex + FAISS + Gemini Embeddings)
# =====================================================================
# This cell sets up the Retrieval Augmented Generation system.
# It loads PubMed data (dummy PDF for now), creates embeddings using
# Gemini, stores them in a FAISS vector store, and prepares a query engine.

from llama_index.core import Settings # Use Settings for global configuration

logger.info("Starting RAG System Setup...")

# --- 1. Configure Embedding Model ---
try:
    # Using the specified Gemini embedding model
    embed_model = GeminiEmbedding(model_name=EMBEDDING_MODEL_NAME, api_key=GOOGLE_API_KEY)
    Settings.embed_model = embed_model # Set globally via Settings
    logger.info(f"Gemini Embedding Model ({EMBEDDING_MODEL_NAME}) initialized.")
except Exception as e:
    logger.error(f"Failed to initialize Gemini Embedding Model: {e}")
    logger.warning("Falling back to default HuggingFace embeddings (requires sentence-transformers).")
    # Fallback embedding model if Gemini Embedding fails
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
    try:
        # Make sure 'sentence-transformers' is installed for this fallback
        embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
        Settings.embed_model = embed_model
        logger.info("Using fallback HuggingFace Embedding Model.")
    except Exception as hf_error:
        logger.critical(f"Failed to initialize fallback HuggingFace Embedding Model: {hf_error}")
        # If embeddings fail completely, RAG cannot function.
        embed_model = None # Indicate failure

In [15]:
# --- 2. Load Data ---
documents = None
if os.path.exists(PDF_FILEPATH) or os.path.exists(PDF_FILEPATH.replace('.pdf', '.txt')):
    try:
        # Use SimpleDirectoryReader, more robust for different file types in a directory
        # Use the .txt fallback if the PDF couldn't be created
        actual_file_path = PDF_FILEPATH if os.path.exists(PDF_FILEPATH) else PDF_FILEPATH.replace('.pdf', '.txt')
        reader = SimpleDirectoryReader(input_files=[actual_file_path])
        documents = reader.load_data()
        if documents:
            logger.info(f"Successfully loaded {len(documents)} document(s) from {actual_file_path}.")
            # Log sample content
            logger.info(f"Sample document content (first 100 chars): {documents[0].text[:100]}...")
        else:
            logger.error(f"No documents were loaded from {actual_file_path}, though the file exists.")
    except Exception as e:
        logger.error(f"Error loading data from {PDF_FILEPATH} (or fallback): {e}")
else:
    logger.error(f"Data file not found: {PDF_FILEPATH}. Cannot proceed with RAG setup.")


In [16]:
# --- 3. Setup FAISS Vector Store ---
index = None
query_engine = None
if documents and embed_model: # Proceed only if documents and embeddings are ready
    try:
        logger.info("Setting up FAISS Vector Store...")
        # Get embedding dimension dynamically from the model if possible
        # Otherwise, use a common default (e.g., 768 for many models)
        try:
             d = Settings.embed_model.embed_dim
             logger.info(f"Embedding dimension detected: {d}")
        except AttributeError:
             d = 768 # Fallback dimension - Adjust if needed for your specific model
             logger.warning(f"Could not automatically determine embedding dimension. Using default: {d}")

        faiss_index = faiss.IndexFlatL2(d) # Use L2 distance for similarity

        # Create the FaissVectorStore instance
        vector_store = FaissVectorStore(faiss_index=faiss_index)
        logger.info("FAISS index created.")

        # Define storage context (where the vector store lives)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        logger.info("Storage context created.")

        # Define service context (includes LLM, embeddings, chunking strategy)
        # We don't need LLM here for indexing, but set embed_model
        # Use SentenceSplitter for basic chunking
        Settings.chunk_size = 512 # Adjust chunk size based on model context window and content
        Settings.chunk_overlap = 50
        # Settings.llm = None # No LLM needed for indexing part of ServiceContext

        logger.info(f"Using chunk size: {Settings.chunk_size}, overlap: {Settings.chunk_overlap}")

# --- 4. Create Index ---
        logger.info("Creating vector store index from documents...")
        index = VectorStoreIndex.from_documents(documents,storage_context=storage_context,show_progress=True)
        
        logger.info("Vector Store Index created successfully.")

# --- 5. Create Query Engine ---
        # similarity_top_k: How many relevant chunks to retrieve
        query_engine = index.as_query_engine(similarity_top_k=3)
        logger.info("RAG Query Engine created.")

    except Exception as e:
        logger.error(f"Failed during FAISS/Indexing setup: {e}", exc_info=True) # Log traceback
        if index is None:
            logger.error("Index creation failed. RAG will not be available.")
        # Graceful degradation: The system might continue without RAG,
        # but agents needing it should handle its absence.

else:
    logger.error("Cannot proceed with RAG setup: Documents or Embedding Model not available.")


print("\n--- Cell 2: RAG System Setup Complete ---")
# You can test the query engine here (optional):
# if query_engine:
#     try:
#         test_response = query_engine.query("What are the symptoms of Malaria?")
#         logger.info(f"RAG Test Query Response: {test_response}")
#     except Exception as e:
#         logger.error(f"Error testing RAG query engine: {e}")

Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]

Generating embeddings:   0%|          | 0/1 [00:00<?, ?it/s]


--- Cell 2: RAG System Setup Complete ---


In [17]:
# NOTEBOOK CELL 3: LangGraph Agents & Workflow
# ============================================
# This cell defines the state, agents (starting with Diagnostician),
# and the graph structure for the multi-agent workflow.



logger.info("Setting up LangGraph Agents and Workflow...")

# --- 1. Define Agent State ---
# This TypedDict defines the data structure that flows through the graph.
class AgentState(TypedDict):
    original_input: Any # Could be text or audio data path
    input_language: Optional[str] # Detected language
    symptoms_text: str # Processed symptoms text
    rag_context: Optional[List[str]] # Retrieved context from PubMed
    initial_diagnosis: Optional[Dict[str, Any]] # Output from Diagnostician
    validation_results: Optional[Dict[str, Any]] # Output from Validator
    final_diagnosis_report: Optional[Dict[str, Any]] # Final structured output
    patient_education: Optional[Dict[str, Any]] # Output from Educator
    bias_analysis: Optional[Dict[str, Any]] # Output from Bias Checker
    error_message: Optional[str] # To capture errors during the flow

# --- 2. Initialize LLM for Agents ---
# We'll use the robust function defined earlier for API calls
agent_llm = None
if GOOGLE_API_KEY:
    try:
        # Test basic connectivity to the model specified
        agent_llm = genai.GenerativeModel(LLM_MODEL_NAME)
        # Perform a small test generation to check API key validity etc.
        agent_llm.generate_content("test", generation_config=genai.types.GenerationConfig(candidate_count=1))
        logger.info(f"Agent LLM ({LLM_MODEL_NAME}) initialized and tested successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize or test Agent LLM ({LLM_MODEL_NAME}): {e}")
        agent_llm = None # Ensure it's None if initialization fails
else:
    logger.error("Agent LLM cannot be initialized: API Key not configured.")




# --- 3. Define Agent Nodes ---

# == Diagnostician Agent Node ==
def diagnostician_node(state: AgentState) -> AgentState:
    """
    Generates initial differential diagnosis based on symptoms and RAG context.
    Handles potential errors during LLM calls and RAG queries.
    """
    logger.info("Entering Diagnostician Node...")
    symptoms = state.get("symptoms_text")
    if not symptoms:
        logger.error("Diagnostician Error: No symptoms provided in state.")
        return {**state, "error_message": "Diagnostician failed: Symptoms missing."}

    rag_context_str = ""
    # --- RAG Integration ---
    if query_engine: # Check if RAG system is available
        try:
            logger.info(f"Querying RAG with symptoms: {symptoms[:100]}...")
            rag_response = query_engine.query(symptoms)
            # Extract context from response nodes
            retrieved_docs = [node.get_content() for node in rag_response.source_nodes]
            state["rag_context"] = retrieved_docs # Store retrieved docs in state
            rag_context_str = "\n\nRelevant Medical Context:\n" + "\n---\n".join(retrieved_docs)
            logger.info(f"Retrieved {len(retrieved_docs)} context snippets from RAG.")
        except Exception as e:
            logger.error(f"RAG Query Error in Diagnostician: {e}", exc_info=True)
            rag_context_str = "\n\nRelevant Medical Context: [Error retrieving context]"
            state["rag_context"] = ["[Error retrieving context]"] # Update state with error
    else:
        logger.warning("Diagnostician Warning: RAG query engine not available. Proceeding without PubMed context.")
        rag_context_str = "\n\nRelevant Medical Context: [Not Available]"
        state["rag_context"] = ["[Not Available]"]

    # --- LLM Call for Diagnosis ---
    if not agent_llm:
         logger.error("Diagnostician Error: Agent LLM not available.")
         return {**state, "error_message": "Diagnostician failed: LLM not initialized."}

    prompt = f"""Act as a medical diagnosis assistant. Based ONLY on the provided symptoms and relevant medical context (if available), generate a differential diagnosis.

Patient Symptoms:
{symptoms}
{rag_context_str}

Instructions:
1. Analyze the symptoms and context.
2. Generate a list of possible diagnoses (differentials).
3. For each diagnosis, provide a confidence score (0.0 to 1.0) indicating your certainty based *only* on the provided information. Higher scores mean higher likelihood.
4. Identify the most likely primary diagnosis.
5. Structure your output as a JSON object with the following EXACT keys: "primary_diagnosis", "primary_confidence", "alternative_diagnoses" (which should be a list of strings).

Example Output Format:
{{
  "primary_diagnosis": "Example Condition A",
  "primary_confidence": 0.85,
  "alternative_diagnoses": ["Example Condition B", "Example Condition C"]
}}

Provide ONLY the JSON object in your response. Do not include any introductory text, explanations, or markdown formatting around the JSON.
"""
    logger.info("Generating initial diagnosis with Gemini...")
    llm_response_text = generate_gemini_content_with_retry(LLM_MODEL_NAME, prompt)
     # --- Process LLM Response ---
    if llm_response_text and llm_response_text.startswith("Error:"):
        logger.error(f"Diagnostician Error: LLM call failed. {llm_response_text}")
        return {**state, "error_message": f"Diagnostician LLM Error: {llm_response_text}"}

    # Robust JSON parsing
    diagnosis_json = None
    try:
        # Try to find JSON within potentially messy output
        json_match = re.search(r'\{.*\}', llm_response_text, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
            diagnosis_json = json.loads(json_str)
            logger.info(f"Successfully parsed JSON diagnosis: {diagnosis_json}")

            # Validate required keys
            required_keys = ["primary_diagnosis", "primary_confidence", "alternative_diagnoses"]
            if not all(key in diagnosis_json for key in required_keys):
                raise ValueError(f"Parsed JSON missing required keys: {required_keys}")
            if not isinstance(diagnosis_json["alternative_diagnoses"], list):
                raise ValueError("Parsed JSON 'alternative_diagnoses' is not a list.")
            if not isinstance(diagnosis_json["primary_confidence"], (float, int)):
                 raise ValueError("Parsed JSON 'primary_confidence' is not a number.")

        else:
            logger.error(f"Diagnostician Error: Could not find valid JSON in LLM response: {llm_response_text}")
            return {**state, "error_message": "Diagnostician failed: Invalid JSON response from LLM."}

    except json.JSONDecodeError as e:
        logger.error(f"Diagnostician JSON Parsing Error: {e}. Response was: {llm_response_text}")
        return {**state, "error_message": f"Diagnostician failed: JSON Decode Error - {e}"}
    except ValueError as e:
         logger.error(f"Diagnostician JSON Validation Error: {e}. Parsed JSON: {diagnosis_json}")
         return {**state, "error_message": f"Diagnostician failed: JSON Validation Error - {e}"}
    except Exception as e:
        logger.error(f"Diagnostician Error processing LLM response: {e}", exc_info=True)
        return {**state, "error_message": f"Diagnostician failed: Unexpected error processing response - {e}"}

    # Update state
    logger.info("Diagnostician Node completed.")
    return {**state, "initial_diagnosis": diagnosis_json, "error_message": None} # Clear previous errors if successful





# == Validator Agent Node (Implementation) ==
def validator_node(state: AgentState) -> AgentState:
    """
    Critiques the initial diagnosis based on symptoms and context,
    simulating guideline cross-checking using the LLM.
    """
    logger.info("Entering Validator Node...")
    initial_diagnosis = state.get("initial_diagnosis")
    symptoms = state.get("symptoms_text")
    rag_context = state.get("rag_context", [])

    if not initial_diagnosis or not symptoms:
         logger.warning("Validator skipping: Initial diagnosis or symptoms missing.")
         # Pass through state, maybe add a note to validation results
         return {**state, "validation_results": {"status": "Skipped", "reason": "Missing diagnosis or symptoms."}}

    if not agent_llm:
         logger.error("Validator Error: Agent LLM not available.")
         return {**state, "error_message": "Validator failed: LLM not initialized."}

    primary_diag = initial_diagnosis.get("primary_diagnosis", "N/A")
    confidence = initial_diagnosis.get("primary_confidence", "N/A")
    alternatives = initial_diagnosis.get("alternative_diagnoses", [])
    rag_context_str = "\n---\n".join(rag_context) if rag_context else "[Not Available]"

    prompt = f"""Act as a clinical reviewer simulating a check against established medical guidelines (like NICE, but using general medical knowledge).
You are given an initial diagnosis generated by another AI based on patient symptoms and some retrieved medical context.

Patient Symptoms:
{symptoms}

Retrieved Medical Context (from PubMed abstracts):
{rag_context_str}

Initial AI Diagnosis:
Primary: {primary_diag} (Confidence: {confidence})
Alternatives: {', '.join(alternatives) if alternatives else 'None'}

Your Task:
Critically evaluate the initial diagnosis based *only* on the provided symptoms and context.
1. Does the primary diagnosis seem reasonable given the symptoms and context?
2. Are there any obvious contradictions or inconsistencies?
3. Are there other highly probable diagnoses based on the provided info that were missed in the alternatives?
4. Based on your critique, would you tentatively 'Confirm', 'Flag for Review', or 'Suggest Revision' for the primary diagnosis?

Provide your output as a JSON object with the following keys:
- "validation_status": (string, one of "Confirmed", "Flagged for Review", "Revision Suggested")
- "critique": (string, your reasoning and evaluation based on the questions above)
- "missed_alternatives": (list of strings, other possible diagnoses you identified, if any)

Example Output Format:
{{
  "validation_status": "Flagged for Review",
  "critique": "The primary diagnosis of 'X' seems plausible, but the provided context mentions symptom 'Y' which strongly points towards 'Z'. Confidence score seems high given the limited context.",
  "missed_alternatives": ["Condition Z", "Condition W"]
}}

Provide ONLY the JSON object in your response.
"""
    logger.info(f"Performing validation critique for: {primary_diag}")
    llm_response_text = generate_gemini_content_with_retry(LLM_MODEL_NAME, prompt)

    # --- Process LLM Response ---
    if llm_response_text and llm_response_text.startswith("Error:"):
        logger.error(f"Validator Error: LLM call failed. {llm_response_text}")
        return {**state, "error_message": f"Validator LLM Error: {llm_response_text}"}

    validation_json = None
    try:
        json_match = re.search(r'\{.*\}', llm_response_text, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
            validation_json = json.loads(json_str)
            logger.info(f"Successfully parsed JSON validation: {validation_json}")
            # Basic validation of structure
            if not all(k in validation_json for k in ["validation_status", "critique", "missed_alternatives"]):
                raise ValueError("Validation JSON missing required keys.")
            if not isinstance(validation_json["missed_alternatives"], list):
                 raise ValueError("'missed_alternatives' must be a list.")
        else:
            logger.error(f"Validator Error: Could not find valid JSON in LLM response: {llm_response_text}")
            # Still proceed, but mark validation as failed in the state
            validation_json = {"status": "Failed", "reason": "Invalid JSON response from LLM critique.", "critique": "", "missed_alternatives": []}

    except json.JSONDecodeError as e:
        logger.error(f"Validator JSON Parsing Error: {e}. Response was: {llm_response_text}")
        validation_json = {"status": "Failed", "reason": f"JSON Decode Error - {e}", "critique": "", "missed_alternatives": []}
    except ValueError as e:
         logger.error(f"Validator JSON Validation Error: {e}. Parsed JSON: {validation_json}")
         validation_json = {"status": "Failed", "reason": f"JSON Validation Error - {e}", "critique": "", "missed_alternatives": []}
    except Exception as e:
        logger.error(f"Validator Error processing LLM response: {e}", exc_info=True)
        validation_json = {"status": "Failed", "reason": f"Unexpected error - {e}", "critique": "", "missed_alternatives": []}

    logger.info(f"Validator Node completed. Status: {validation_json.get('validation_status', 'Unknown')}")
    # Even if parsing failed, we put the failure reason into validation_results
    return {**state, "validation_results": validation_json, "error_message": None if 'status' in validation_json and validation_json['status'] != 'Failed' else state.get('error_message')}


# == Educator Agent Node (Implementation) ==
def educator_node(state: AgentState) -> AgentState:
    """
    Generates patient education materials (explanation, medication info from context,
    next steps) based on the diagnosis. Visuals remain placeholders.
    """
    logger.info("Entering Educator Node...")
    # Use validated diagnosis if available and confirmed, otherwise fall back to initial.
    # For simplicity now, just use initial diagnosis directly. A more complex logic
    # could check validation_results status.
    diagnosis_info = state.get("initial_diagnosis")
    rag_context = state.get("rag_context", [])

    if not diagnosis_info or not diagnosis_info.get("primary_diagnosis"):
         logger.warning("Educator skipping: No primary diagnosis found.")
         return {**state, "patient_education": {"status": "Skipped", "reason": "Missing diagnosis."}}

    if not agent_llm:
         logger.error("Educator Error: Agent LLM not available.")
         return {**state, "error_message": "Educator failed: LLM not initialized."}

    primary_diag = diagnosis_info.get("primary_diagnosis")
    rag_context_str = "\n---\n".join(rag_context) if rag_context else "[Not Available]"

    prompt = f"""Act as a patient educator AI. You are given a medical diagnosis and relevant context.

Diagnosis: {primary_diag}

Relevant Medical Context (from PubMed abstracts):
{rag_context_str}

Your Task: Generate patient education material based *only* on the provided diagnosis and context.
1.  **Explanation:** Provide a simple, patient-friendly explanation of what '{primary_diag}' is (approx. 2-3 sentences). Avoid jargon.
2.  **Medication Info:** Scan the 'Relevant Medical Context'. If specific medications for treating '{primary_diag}' are mentioned, list them. If not, state "Consult your physician for medication options." Do NOT invent medications.
3.  **Next Steps/Lifestyle:** Suggest 2-3 general, safe next steps or lifestyle considerations relevant to this type of condition (e.g., follow-up appointments, rest, hydration, seeking professional advice for specifics). Emphasize consulting a healthcare professional.
4.  **Visual Placeholder:** Generate a descriptive filename for a hypothetical explanatory visual (e.g., 'Animation_showing_{primary_diag.replace(' ','_')}.mp4').

Provide your output as a JSON object with the following keys:
- "explanation": (string) Patient-friendly explanation.
- "medication_info": (string) Mentioned medications or consultation advice.
- "next_steps": (list of strings) General advice points.
- "visual_placeholder_filename": (string) Generated filename for the visual.

Example Output Format:
{{
  "explanation": "Malaria is an illness caused by a parasite transmitted through mosquito bites. It often causes fever and flu-like symptoms.",
  "medication_info": "The context mentions Artemether-Lumefantrine as a possible treatment. Consult your physician for medication options.",
  "next_steps": ["Follow your doctor's treatment plan carefully.", "Rest well and stay hydrated.", "Prevent future mosquito bites."],
  "visual_placeholder_filename": "Animation_showing_Malaria_lifecycle.mp4"
}}

Provide ONLY the JSON object in your response.
"""
    logger.info(f"Generating patient education for: {primary_diag}")
    llm_response_text = generate_gemini_content_with_retry(LLM_MODEL_NAME, prompt)

    # --- Process LLM Response ---
    if llm_response_text and llm_response_text.startswith("Error:"):
        logger.error(f"Educator Error: LLM call failed. {llm_response_text}")
        return {**state, "error_message": f"Educator LLM Error: {llm_response_text}"}

    education_json = None
    try:
        json_match = re.search(r'\{.*\}', llm_response_text, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
            education_json = json.loads(json_str)
            logger.info(f"Successfully parsed JSON education material: {education_json}")
             # Basic validation
            if not all(k in education_json for k in ["explanation", "medication_info", "next_steps", "visual_placeholder_filename"]):
                raise ValueError("Education JSON missing required keys.")
            if not isinstance(education_json["next_steps"], list):
                 raise ValueError("'next_steps' must be a list.")
        else:
            logger.error(f"Educator Error: Could not find valid JSON in LLM response: {llm_response_text}")
            education_json = {"status": "Failed", "reason": "Invalid JSON response from educator LLM."}

    except json.JSONDecodeError as e:
        logger.error(f"Educator JSON Parsing Error: {e}. Response was: {llm_response_text}")
        education_json = {"status": "Failed", "reason": f"JSON Decode Error - {e}"}
    except ValueError as e:
        logger.error(f"Educator JSON Validation Error: {e}. Parsed JSON: {education_json}")
        education_json = {"status": "Failed", "reason": f"JSON Validation Error - {e}"}
    except Exception as e:
        logger.error(f"Educator Error processing LLM response: {e}", exc_info=True)
        education_json = {"status": "Failed", "reason": f"Unexpected error - {e}"}

    logger.info("Educator Node completed.")
    # Store the entire dictionary
    return {**state, "patient_education": education_json, "error_message": None if 'status' not in education_json or education_json['status'] != 'Failed' else state.get('error_message')}



# == Bias Check Node (Implementation) ==
def bias_check_node(state: AgentState) -> AgentState:
    """
    Analyzes the diagnosis process for potential demographic biases using the LLM.
    """
    logger.info("Entering Bias Check Node...")
    initial_diagnosis = state.get("initial_diagnosis")
    symptoms = state.get("symptoms_text")
    # Optionally include rag_context if relevant for bias check
    # rag_context = state.get("rag_context", [])

    # Basic check if there's anything to analyze
    if not initial_diagnosis or not symptoms:
        logger.warning("Bias Check skipping: Initial diagnosis or symptoms missing.")
        return {**state, "bias_analysis": {"status": "Skipped", "reason": "Missing diagnosis or symptoms."}}

    if not agent_llm:
         logger.error("Bias Check Error: Agent LLM not available.")
         return {**state, "error_message": "Bias Check failed: LLM not initialized."}

    # Prepare the input for the bias check prompt
    diagnosis_summary = f"Primary: {initial_diagnosis.get('primary_diagnosis', 'N/A')}, Confidence: {initial_diagnosis.get('primary_confidence', 'N/A')}, Alternatives: {initial_diagnosis.get('alternative_diagnoses', [])}"

    # Using the exact prompt structure requested by the user
    prompt = f"""Analyze the following diagnosis information for potential biases. Focus specifically on:
1.  **Gender/racial stereotypes:** Does the diagnosis or the way it might have been reached rely on assumptions about specific genders or races?
2.  **Socioeconomic assumptions:** Does the potential diagnosis path or suggested alternatives implicitly assume a certain socioeconomic status (e.g., access to specific tests, lifestyle factors)?
3.  **Cultural competency:** Could the symptoms presentation or interpretation be influenced by cultural factors not accounted for? Are there potential cultural adaptations needed for communication or treatment?

Patient Symptoms:
{symptoms}

AI-Generated Diagnosis Summary:
{diagnosis_summary}

Instructions:
- Critically evaluate based on the three points above.
- Provide a qualitative assessment. Note specific concerns if any.
- Suggest potential cultural adaptations if relevant (e.g., language considerations, culturally sensitive explanations).
- Assign a hypothetical bias risk score from 0.0 (very low risk) to 1.0 (high risk detected). This is subjective based on your analysis.
- Structure your output as a JSON object with keys: "bias_risk_score" (float), "potential_biases_identified" (list of strings describing concerns), "suggested_cultural_adaptations" (list of strings).

Example Output Format:
{{
  "bias_risk_score": 0.2,
  "potential_biases_identified": ["Symptom description might be interpreted differently across cultures.", "Consider if access to mentioned diagnostic tests is universal."],
  "suggested_cultural_adaptations": ["Provide explanation in local language if possible.", "Explore culturally relevant analogies for the condition."]
}}

Provide ONLY the JSON object in your response.
"""
    logger.info("Performing bias analysis...")
    llm_response_text = generate_gemini_content_with_retry(LLM_MODEL_NAME, prompt)

    # --- Process LLM Response ---
    if llm_response_text and llm_response_text.startswith("Error:"):
        logger.error(f"Bias Check Error: LLM call failed. {llm_response_text}")
        return {**state, "error_message": f"Bias Check LLM Error: {llm_response_text}"}

    bias_json = None
    try:
        json_match = re.search(r'\{.*\}', llm_response_text, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
            bias_json = json.loads(json_str)
            logger.info(f"Successfully parsed JSON bias analysis: {bias_json}")
            # Basic validation
            if not all(k in bias_json for k in ["bias_risk_score", "potential_biases_identified", "suggested_cultural_adaptations"]):
                raise ValueError("Bias analysis JSON missing required keys.")
            if not isinstance(bias_json["bias_risk_score"], (float, int)):
                 raise ValueError("'bias_risk_score' must be a number.")
            if not isinstance(bias_json["potential_biases_identified"], list):
                 raise ValueError("'potential_biases_identified' must be a list.")
            if not isinstance(bias_json["suggested_cultural_adaptations"], list):
                 raise ValueError("'suggested_cultural_adaptations' must be a list.")
        else:
            logger.error(f"Bias Check Error: Could not find valid JSON in LLM response: {llm_response_text}")
            bias_json = {"status": "Failed", "reason": "Invalid JSON response from bias check LLM."}

    except json.JSONDecodeError as e:
        logger.error(f"Bias Check JSON Parsing Error: {e}. Response was: {llm_response_text}")
        bias_json = {"status": "Failed", "reason": f"JSON Decode Error - {e}"}
    except ValueError as e:
        logger.error(f"Bias Check JSON Validation Error: {e}. Parsed JSON: {bias_json}")
        bias_json = {"status": "Failed", "reason": f"JSON Validation Error - {e}"}
    except Exception as e:
        logger.error(f"Bias Check Error processing LLM response: {e}", exc_info=True)
        bias_json = {"status": "Failed", "reason": f"Unexpected error - {e}"}

    logger.info("Bias Check Node completed.")
    # Store the entire dictionary (including potential failure info)
    return {**state, "bias_analysis": bias_json, "error_message": None if 'status' not in bias_json or bias_json['status'] != 'Failed' else state.get('error_message')}


# == Output Formatting Node (Updated) ==
def format_output_node(state: AgentState) -> AgentState:
    """
    Consolidates information from the implemented nodes into the final JSON report.
    """
    logger.info("Entering Output Formatting Node...")

    # Extract info, preferring results from implemented nodes over initial placeholders
    initial_diag = state.get("initial_diagnosis", {})
    validation = state.get("validation_results", {}) # Get validator output
    education = state.get("patient_education", {})   # Get educator output
    bias_info = state.get("bias_analysis", {})     # Get bias check output

    # Determine primary diagnosis and confidence (could adjust based on validation later)
    primary_diagnosis = initial_diag.get("primary_diagnosis", "N/A")
    confidence = initial_diag.get("primary_confidence", 0.0)
    alternatives = initial_diag.get("alternative_diagnoses", [])

    # Build Diagnosis part
    diagnosis_part = {
        "primary": primary_diagnosis,
        "confidence": float(confidence) if isinstance(confidence, (int, float)) else 0.0,
        "alternatives": alternatives if isinstance(alternatives, list) else [],
        # Add validation status if available
        "validation_status": validation.get("validation_status", "Pending/Skipped")
    }

    # Build Education part using educator output
    education_part = {
        "visual": education.get("visual_placeholder_filename", "visual_pending.mp4"),
        "explanation": education.get("explanation", "Explanation pending."),
        "medication": education.get("medication_info", "Medication info pending."),
        "next_steps": education.get("next_steps", ["Next steps pending."])
    }
     # Check if educator failed
    if education.get("status") == "Failed":
        education_part["status"] = "Failed: " + education.get("reason", "Unknown")


    # Build Equity Check part using bias check output
    equity_part = {
        "bias_score": bias_info.get("bias_risk_score", -1.0), # -1 indicates pending/failed
        "potential_biases": bias_info.get("potential_biases_identified", ["Pending analysis"]),
        "cultural_adaptations": bias_info.get("suggested_cultural_adaptations", ["Pending analysis"])
    }
    # Check if bias check failed
    if bias_info.get("status") == "Failed":
         equity_part["status"] = "Failed: " + bias_info.get("reason", "Unknown")


    # Assemble the final report
    final_report = {
        "patient_id": f"ANON-{int(time.time()) % 10000}",
        "diagnosis": diagnosis_part,
        "education": education_part,
        "equity_check": equity_part,
        "debug_info": {
             "rag_context_snippets_count": len(state.get("rag_context", [])),
             "validator_critique": validation.get("critique", "N/A") # Include critique
        }
    }

    # Handle workflow errors
    error = state.get("error_message")
    if error:
        logger.error(f"Workflow ended with error: {error}")
        final_report["workflow_status"] = "Error"
        final_report["error_details"] = error
    elif any(node_output.get("status") == "Failed" for node_output in [validation, education, bias_info]):
         final_report["workflow_status"] = "Completed with Errors in Nodes"
    else:
        final_report["workflow_status"] = "Success"


    logger.info("Final report generated.")
    # Ensure the final report itself is stored in the state key expected by LangGraph (if needed downstream)
    # but typically this is the final output returned by app.invoke()
    # We store it back into state for completeness, though END usually terminates.
    return {**state, "final_diagnosis_report": final_report}


# --- 4. Build LangGraph Workflow (Edges Unchanged) ---
logger.info("Building the LangGraph workflow with implemented nodes...")
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("diagnostician", diagnostician_node)
workflow.add_node("validator", validator_node)         
workflow.add_node("bias_checker", bias_check_node)     
workflow.add_node("educator", educator_node)           
workflow.add_node("output_formatter", format_output_node) 

# Define edges
workflow.set_entry_point("diagnostician")
workflow.add_edge("diagnostician", "validator")
workflow.add_edge("validator", "bias_checker")
workflow.add_edge("bias_checker", "educator")
workflow.add_edge("educator", "output_formatter")
workflow.add_edge("output_formatter", END)

# Compile the graph
try:
    app = workflow.compile()
    logger.info("LangGraph workflow compiled successfully with implemented nodes.")
    # Optional: Visualize
    # try:
    #     from IPython.display import Image, display
    #     display(Image(app.get_graph().draw_png()))
    # except Exception as viz_error:
    #     logger.warning(f"Could not visualize graph: {viz_error}.")
except Exception as e:
    logger.error(f"Error compiling LangGraph workflow: {e}", exc_info=True)
    app = None


print("\n--- Cell 3: LangGraph Agents & Workflow Implementation Complete ---")




--- Cell 3: LangGraph Agents & Workflow Implementation Complete ---


In [18]:
# --- Example Test Case ---
if app:
    logger.info("\n--- Running Example Test Case ---")
    # More specific symptoms likely to trigger context from the dummy PDF
    # test_symptoms = "Patient presents with high fever, chills, headache, and fatigue. Recently returned from a trip to a tropical region known for mosquito-borne illnesses. Rapid diagnostic tests were mentioned in literature."
    # test_symptoms = "65-year-old patient with cough for 5 days, producing yellow phlegm. Mild fever (38.1 C), shortness of breath on exertion, and chest tightness. History of smoking for 20 years."
    # test_symptoms = "Sudden onset of watery diarrhea (5 times in last 12 hours), nausea, and cramping abdominal pain. No fever. Ate questionable leftovers yesterday evening."
    # test_symptoms = "Persistent fatigue and feeling generally unwell for 3 weeks. Low-grade fever noted occasionally in the evenings. Generalized muscle aches and loss of appetite. Patient reports feeling 'run down'."
    test_symptoms = "Patient reports sudden onset of severe headache, worse than any previous headache. Describes it as 'thunderclap'. Also experiencing sensitivity to light and neck stiffness."
    # test_symptoms = "45-year-old male complains of lower back pain radiating down the left leg after lifting a heavy box two days ago. Pain is worse when sitting or bending forward. Reports some tingling in the foot."
    
    initial_state = AgentState(
        original_input=test_symptoms,
        input_language="en", # Assuming English for now
        symptoms_text=test_symptoms,
        rag_context=None,
        grounding_metadata=None,
        initial_diagnosis=None,
        validation_results=None,
        final_diagnosis_report=None,
        patient_education=None,
        bias_analysis=None,
        error_message=None
    )

    try:
        # Stream events to see the flow (optional)
        # for event in app.stream(initial_state):
        #     print("\n--- Workflow Event ---")
        #     print(event)
        #     print("----------------------\n")

        # Or just invoke and get the final state
        final_state = app.invoke(initial_state)

        logger.info("\n--- Final Workflow State ---")
        # Pretty print the final state, especially the report
        import pprint
        pprint.pprint(final_state)

        logger.info("\n--- Final Diagnosis Report ---")
        if final_state and final_state.get("final_diagnosis_report"):
            # Pretty print the final JSON report
            print(json.dumps(final_state["final_diagnosis_report"], indent=2))
        else:
            print("Final report not generated or workflow failed.")

    except Exception as e:
        logger.error(f"Error running workflow test case: {e}", exc_info=True)
else:
    logger.error("Workflow application not compiled. Cannot run test case.")


{'bias_analysis': {'bias_risk_score': 0.7,
                   'potential_biases_identified': ['The diagnostic pathway '
                                                   'suggested by the AI '
                                                   '(implying urgent imaging '
                                                   'like CT/CTA, potentially '
                                                   'LP) inherently assumes '
                                                   'access to advanced medical '
                                                   'resources and facilities, '
                                                   'reflecting a socioeconomic '
                                                   'bias in its applicability '
                                                   'across different '
                                                   'healthcare settings or '
                                                   'patient financial '
                             

In [19]:
    # --- Example Test Case (with Debug Prints) ---
    if 'app' in globals() and app: # Check if app exists and compiled
        logger.info("\n--- Running Example Test Case ---")
        print("DEBUG: Workflow 'app' seems to be compiled. Proceeding with test case.") # DEBUG PRINT

        # Select a test case
        test_symptoms = "Patient presents with high fever, chills, headache, and fatigue. Recently returned from a trip to a tropical region known for mosquito-borne illnesses. Rapid diagnostic tests were mentioned in literature."
        # test_symptoms = "45-year-old male complains of lower back pain radiating down the left leg after lifting a heavy box two days ago. Pain is worse when sitting or bending forward. Reports some tingling in the foot."
        logger.info(f"Using symptoms: {test_symptoms}")

        initial_state = AgentState(
            original_input=test_symptoms,
            input_language="en", # Assuming English for now
            symptoms_text=test_symptoms,
            # Initialize new/modified state fields to None or default
            rag_context=None,
            grounding_metadata=None,
            initial_diagnosis=None,
            validation_results=None,
            final_diagnosis_report=None,
            patient_education=None,
            bias_analysis=None,
            error_message=None
        )
        print(f"DEBUG: Initial state prepared: {str(initial_state)[:200]}...") # DEBUG PRINT

        try:
            print("DEBUG: About to call app.invoke()...") # DEBUG PRINT
            final_state = app.invoke(initial_state)
            print(f"DEBUG: app.invoke() finished. Type of final_state: {type(final_state)}") # DEBUG PRINT

            logger.info("\n--- Final Workflow State (Full) ---")
            import pprint
            print("DEBUG: About to pprint final_state...") # DEBUG PRINT
            pprint.pprint(final_state)
            print("DEBUG: Finished pprint final_state.") # DEBUG PRINT

            logger.info("\n--- Final Diagnosis Report (Formatted JSON) ---")
            if final_state and isinstance(final_state, dict) and final_state.get("final_diagnosis_report"):
                print("DEBUG: Found final_diagnosis_report, printing JSON...") # DEBUG PRINT
                print(json.dumps(final_state["final_diagnosis_report"], indent=2))
            else:
                print("DEBUG: Final report not found in final_state or final_state is not a dict.") # DEBUG PRINT
                print("Final report not generated or workflow failed.")

        except Exception as e:
            print(f"DEBUG: Exception occurred during app.invoke() or result processing: {e}") # DEBUG PRINT
            logger.error(f"Error running workflow test case: {e}", exc_info=True)
    else:
        print("DEBUG: Workflow application ('app') is None or not found. Cannot run test case.") # DEBUG PRINT
        logger.error("Workflow application ('app') not compiled successfully in the previous cell. Cannot run test case.")

    

DEBUG: Workflow 'app' seems to be compiled. Proceeding with test case.
DEBUG: Initial state prepared: {'original_input': 'Patient presents with high fever, chills, headache, and fatigue. Recently returned from a trip to a tropical region known for mosquito-borne illnesses. Rapid diagnostic tests were ...
DEBUG: About to call app.invoke()...
DEBUG: app.invoke() finished. Type of final_state: <class 'langgraph.pregel.io.AddableValuesDict'>
DEBUG: About to pprint final_state...
{'bias_analysis': {'bias_risk_score': 0.4,
                   'potential_biases_identified': ["The mention of 'rapid "
                                                   "diagnostic tests' "
                                                   'implicitly assumes access '
                                                   'to specific healthcare '
                                                   'technologies, which can be '
                                                   'tied to socioeconomic '
               

In [None]:
# NOTEBOOK CELL 4: Gradio User Interface (Simplified Text Input + Markdown Output)
# ===============================================================================
# This cell creates a simple web UI for interacting with the
# diagnosis workflow using only text input and displaying results
# in a user-friendly Markdown format.

import gradio as gr
import time
import json
import warnings
import logging # Ensure logger is available if running this cell independently

# If running cells out of order, ensure logger is configured
if 'logger' not in globals():
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)

logger.info("Setting up Simplified Gradio Interface with Markdown Output...")

# --- Helper Function to Format Report JSON to Markdown ---
def format_report_to_markdown(report_json):
    """Converts the final report JSON dictionary into a formatted Markdown string."""
    if not report_json:
        return "### Report Generation Failed\nNo report data available."

    # Check for overall workflow failure first
    status = report_json.get("workflow_status", "Unknown Status")
    error_details = report_json.get("error_details", "No details provided.")

    # Start building the Markdown string
    md_string = f"# Diagnosis Report\n"
    if report_json.get('patient_id'):
         md_string += f"*(Patient ID: {report_json.get('patient_id')})*\n\n"

    # If workflow failed entirely, show only that.
    if status not in ["Success", "Completed with Errors in Nodes"]:
        md_string += f"## Workflow Status: {status}\n"
        md_string += f"**Error Details:** {error_details}\n"
        return md_string

    # --- Diagnosis Section ---
    diag = report_json.get("diagnosis", {})
    md_string += "## 🩺 Diagnosis\n"
    md_string += f"- **Primary Diagnosis:** {diag.get('primary', 'N/A')}\n"
    try:
        # Format confidence as percentage if possible
        confidence = float(diag.get('confidence', 0.0))
        md_string += f"- **Confidence:** {confidence:.1%}\n"
    except (ValueError, TypeError):
        md_string += f"- **Confidence:** {diag.get('confidence', 'N/A')}\n" # Fallback if not a number

    alternatives = diag.get('alternatives', [])
    if alternatives:
        md_string += "- **Alternatives Considered:**\n"
        for alt in alternatives:
            md_string += f"  * {alt}\n" # Use Markdown list item
    md_string += f"- **Validation Status:** *{diag.get('validation_status', 'N/A')}*\n\n"

    # --- Patient Education Section ---
    edu = report_json.get("education", {})
    # Check if education step reported failure within its own structure
    edu_failed = edu.get("status") == "Failed" # Assumes 'status':'Failed' key if educator node fails internally
    md_string += "## 📖 Patient Education\n"
    if edu_failed:
         md_string += f"**Status:** *Education material generation failed: {edu.get('reason', 'Unknown')}*\n\n"
    elif not edu:
         md_string += "*No education information generated.*\n\n"
    else:
        md_string += f"**Explanation:**\n{edu.get('explanation', 'N/A')}\n\n"
        md_string += f"**Medication Info:** {edu.get('medication', 'N/A')}\n\n" # Key from format_output_node
        next_steps = edu.get('next_steps', [])
        if next_steps:
            md_string += "**Suggested Next Steps:**\n"
            for step in next_steps:
                md_string += f"* {step}\n" # Use Markdown list item
        # Add visual placeholder info if available
        visual = edu.get('visual') # Key from format_output_node
        if visual:
             md_string += f"\n*Visual Aid Placeholder:* `{visual}`\n"
        md_string += "\n"


    # --- Equity & Bias Check Section ---
    equity = report_json.get("equity_check", {})
    equity_failed = equity.get("status") == "Failed" # Assumes 'status':'Failed' key if bias node fails internally
    md_string += "## ⚖️ Equity & Bias Check\n"
    if equity_failed:
         md_string += f"**Status:** *Bias analysis failed: {equity.get('reason', 'Unknown')}*\n\n"
    elif not equity:
         md_string += "*No equity/bias analysis performed.*\n\n"
    else:
        bias_score = equity.get('bias_score', None) # Key from format_output_node
        # Check if score is valid (-1 might indicate pending/failed)
        if bias_score is not None and isinstance(bias_score, (float, int)) and bias_score >= 0.0:
             md_string += f"**Bias Risk Score (Subjective):** {bias_score:.2f}\n"
        else:
             md_string += f"**Bias Risk Score (Subjective):** N/A or Not Calculated\n"

        potential_biases = equity.get('potential_biases', []) # Key from format_output_node
        if potential_biases:
            md_string += "**Potential Biases Identified:**\n"
            for bias in potential_biases:
                md_string += f"* {bias}\n" # Use Markdown list item
        cultural_adaptations = equity.get('cultural_adaptations', []) # Key from format_output_node
        if cultural_adaptations:
            md_string += "**Suggested Cultural Adaptations:**\n"
            for adapt in cultural_adaptations:
                md_string += f"* {adapt}\n" # Use Markdown list item
        md_string += "\n"


    # --- Debug Info / Validation Critique (Optional but helpful) ---
    debug = report_json.get("debug_info", {})
    critique = debug.get('validator_critique')
    if critique and critique not in ["N/A", "", "Pending/Skipped"]: # Display if critique exists and is meaningful
        md_string += "---\n"
        md_string += "### Validation Critique\n"
        md_string += f"> {critique}\n\n" # Use Markdown blockquote

    # --- Add overall status note if workflow had issues ---
    if status == "Completed with Errors in Nodes":
        md_string += "---\n**Note:** The workflow completed, but errors occurred in some steps. Please review the report sections carefully."

    return md_string


# --- Gradio Interface Function (Simplified) ---
def run_diagnosis_workflow_simple(symptom_text_input):
    """
    Takes user text input, runs the compiled LangGraph app,
    and returns the results formatted as Markdown for the Gradio UI.
    """
    logger.info("Gradio function triggered (Text only).")
    symptoms_to_process = ""
    input_source = "text"

    # --- Handle Input ---
    if symptom_text_input:
        symptoms_to_process = symptom_text_input.strip()
        if not symptoms_to_process:
             logger.warning("Empty symptom text provided.")
             # Return empty Markdown and an error message for the status
             return "", "Error: Please enter symptoms."
    else:
        # No input provided
        logger.warning("No symptom text input provided.")
        return "", "Error: Please enter symptoms in the text box."

    logger.info(f"Processing text input: {symptoms_to_process[:100]}...")

    # --- Check if LangGraph App is Compiled ---
    # Ensure 'app' exists and is compiled (from Cell 3)
    if 'app' not in globals() or app is None:
        logger.error("LangGraph app not compiled. Cannot run workflow.")
        # Return empty Markdown and error message
        return "", "Error: Backend workflow is not ready. Please check server logs or run previous cells."

    # --- Run the LangGraph Workflow ---
    logger.info(f"Running workflow with symptoms...")
    # Prepare the initial state for the LangGraph app
    initial_state = AgentState(
        original_input=symptoms_to_process,
        input_language="en", # Assuming English for simplicity
        symptoms_text=symptoms_to_process,
        rag_context=None,
        initial_diagnosis=None,
        validation_results=None,
        final_diagnosis_report=None,
        patient_education=None,
        bias_analysis=None,
        error_message=None
    )

    final_state = None
    report_json = None # Initialize report_json
    status_message = "Starting..." # Initial status

    try:
        # Invoke the compiled LangGraph application
        final_state = app.invoke(initial_state)

        # --- Process Workflow Results ---
        if final_state:
            report_json = final_state.get("final_diagnosis_report", {})
            status = report_json.get("workflow_status", "Unknown Status")
            # Get error message from state first, then fallback to report
            error_details = final_state.get("error_message") or report_json.get("error_details")

            # Always format whatever report we got (even if partial/failed)
            formatted_markdown_report = format_report_to_markdown(report_json)

            if status == "Success":
                logger.info("Workflow completed successfully.")
                status_message = "Workflow Completed Successfully."
            elif status == "Completed with Errors in Nodes":
                 logger.warning(f"Workflow completed with errors in some nodes.")
                 status_message = f"Workflow completed with errors in some nodes. Check report details."
            else: # Handle explicit Error status or other unknown statuses
                logger.error(f"Workflow failed. Status: {status}. Error: {error_details}")
                status_message = f"Workflow Error: {error_details or status}"

            # Return formatted markdown and the determined status message
            return formatted_markdown_report, status_message

        else:
            # Handle case where app.invoke returns None unexpectedly
            logger.error("Workflow invocation returned None state.")
            # Return empty Markdown and error message
            return "", "Error: Workflow execution failed unexpectedly (returned None)."

    except Exception as e:
        # Catch any other exceptions during the workflow invocation
        logger.error(f"Unhandled exception during workflow invocation: {e}", exc_info=True)
        # Try to format any partial report gathered before the exception
        partial_report_md = format_report_to_markdown(final_state.get("final_diagnosis_report") if final_state else None)
        # Return partial/empty Markdown and the critical error message
        return partial_report_md, f"Critical Error: An unexpected error occurred: {e}"


# --- Define Simplified Gradio Interface ---
# Use gr.Blocks for better layout control
with gr.Blocks(theme=gr.themes.Soft()) as iface_simple:
    gr.Markdown("# 🩺 MEDISARATHI :-Medical Diagnosis Assistant")
    gr.Markdown(
        """
        Enter patient symptoms below using the text box.
        The AI agents will analyze the input, consult medical knowledge (via RAG),
        validate the findings, check for biases, and generate a preliminary report.
        *Disclaimer: This is a proof-of-concept demonstrator and NOT a substitute for professional medical advice.*
        """
    )
    # Arrange inputs and outputs
    with gr.Row():
        with gr.Column(scale=1):
            # Input Textbox
            symptom_text_input = gr.Textbox(
                lines=10, # Made slightly taller
                label="Enter Symptoms Here",
                placeholder="e.g., Patient presents with high fever, chills, headache, and fatigue. Recently returned from a trip to a tropical region..."
            )
            # Submit Button
            submit_btn = gr.Button("Get Diagnosis Report", variant="primary")
        with gr.Column(scale=2):
            # Output Components
            output_status = gr.Textbox(label="Workflow Status / Messages", interactive=False, lines=2)
            # Use Markdown component for formatted output
            output_report_markdown = gr.Markdown(label="Diagnosis Report")

    # Define the action when the button is clicked
    submit_btn.click(
        fn=run_diagnosis_workflow_simple, # Link to the simplified function
        inputs=[symptom_text_input],       # Pass the text input
        outputs=[output_report_markdown, output_status] # Display results in Markdown and status Textbox
    )

    # Add examples for easier testing
    gr.Markdown("---")
    gr.Markdown("### Example Inputs:")
    gr.Examples(
        examples=[
            ["Patient presents with high fever, chills, headache, and fatigue. Recently returned from a trip to a tropical region known for mosquito-borne illnesses. Rapid diagnostic tests were mentioned in literature."],
            ["65-year-old patient with cough for 5 days, producing yellow phlegm. Mild fever (38.1 C), shortness of breath on exertion, and chest tightness. History of smoking for 20 years."],
            ["Sudden onset of watery diarrhea (5 times in last 12 hours), nausea, and cramping abdominal pain. No fever. Ate questionable leftovers yesterday evening."],
            ["Persistent fatigue and feeling generally unwell for 3 weeks. Low-grade fever noted occasionally in the evenings. Generalized muscle aches and loss of appetite. Patient reports feeling 'run down'."],
        ],
        inputs=[symptom_text_input] # Map examples to the text input component
    )


# --- Launch the Simplified Interface ---
logger.info("Launching Simplified Gradio Interface with Markdown Output...")
# share=True creates a public link usable for ~72 hours (useful on Kaggle/Colab)
# debug=True provides more detailed logs if the UI itself has issues
iface_simple.launch(share=True, debug=True)

print("\n--- Cell 4: Simplified Gradio UI (Markdown Output) Setup and Launch Complete ---")
# Keep this cell running in Kaggle to use the interface.
# Click the public URL link (e.g., https://....gradio.live) generated above.