In [None]:
# Community Health Agent - Kaggle-ready single-file implementation
# Paste this entire file into a Kaggle notebook cell and run.

from __future__ import annotations
import os
import time
import json
import uuid
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple

try:
    import google.generativeai as genai  # type: ignore
    GEMINI_SDK_AVAILABLE = True
except Exception:
    genai = None  # type: ignore
    GEMINI_SDK_AVAILABLE = False

try:
    from kaggle_secrets import UserSecretsClient  # type: ignore
    KAGGLE_SECRETS_AVAILABLE = True
except Exception:
    UserSecretsClient = None  # type: ignore
    KAGGLE_SECRETS_AVAILABLE = False

def safe_print(*args, **kwargs):
    print(*args, **kwargs)

def _load_api_key() -> Optional[str]:
    names = ["AI_STUDIO_API_KEY", "CAPSTONE_PROJECT", "Capstone_Project"]
    for n in names:
        v = os.environ.get(n)
        if v:
            return v
    if KAGGLE_SECRETS_AVAILABLE:
        try:
            client = UserSecretsClient()
            for key_name in ["Capstone_Project", "CAPSTONE_PROJECT", "AI_STUDIO_API_KEY"]:
                try:
                    v = client.get_secret(key_name)
                    if v:
                        return v
                except Exception:
                    continue
        except Exception:
            pass
    return None

API_KEY = _load_api_key()

if GEMINI_SDK_AVAILABLE and API_KEY:
    try:
        genai.configure(api_key=API_KEY)
        safe_print("[LLM] Gemini SDK configured.")
    except Exception as e:
        safe_print("[LLM] Failed to configure Gemini SDK, falling back to mock mode:", e)
        GEMINI_SDK_AVAILABLE = False
else:
    if not GEMINI_SDK_AVAILABLE:
        safe_print("[LLM] Gemini SDK not installed; running in mock mode.")
    else:
        safe_print("[LLM] API key not found; running in mock mode.")
    GEMINI_SDK_AVAILABLE = False

ADVISOR_MODEL = "gemini-2.5-pro"
VALIDATOR_MODEL = "gemini-2.5-pro"
EMBEDDING_MODEL = "text-embedding-004"
DEFAULT_TEMPERATURE = 0.2
DEFAULT_MAX_OUTPUT_TOKENS = 512

class LLMClient:
    def __init__(self, model: str, temperature: float = DEFAULT_TEMPERATURE, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS):
        self.model = model
        self.temperature = temperature
        self.max_output_tokens = max_output_tokens
        self.online = GEMINI_SDK_AVAILABLE

    def generate(self, prompt: str) -> Dict[str, Any]:
        if self.online:
            try:
                if hasattr(genai, "responses") and hasattr(genai.responses, "create"):
                    r = genai.responses.create(model=self.model, input=prompt, temperature=self.temperature, max_output_tokens=self.max_output_tokens)
                    text = None
                    try:
                        output = getattr(r, "output", None)
                        if output and len(output) > 0:
                            first = output[0]
                            if isinstance(first, dict):
                                content = first.get("content")
                                if content and len(content) > 0 and isinstance(content[0], dict) and "text" in content[0]:
                                    text = content[0]["text"]
                            else:
                                try:
                                    text = first.content[0].text  # type: ignore
                                except Exception:
                                    pass
                    except Exception:
                        pass
                    text = text or getattr(r, "output_text", None) or str(r)
                    return {"text": str(text), "structured": None}
                if hasattr(genai, "generate_text"):
                    r = genai.generate_text(model=self.model, prompt=prompt, temperature=self.temperature, max_output_tokens=self.max_output_tokens)
                    text = getattr(r, "text", None) or str(r)
                    return {"text": str(text), "structured": None}
                if hasattr(genai, "generate"):
                    r = genai.generate(model=self.model, prompt=prompt)
                    text = getattr(r, "content", None) or getattr(r, "text", None) or str(r)
                    return {"text": str(text), "structured": None}
            except Exception as e:
                safe_print("[LLM] SDK call failed, using mock response:", e)
                return self._mock_response(prompt)
        return self._mock_response(prompt)

    def _mock_response(self, prompt: str) -> Dict[str, Any]:
        lower = prompt.lower()
        if "chest pain" in lower or "difficulty breathing" in lower:
            return {"text": "I detect potentially serious symptoms. Please seek urgent medical attention.", "structured": {"severity": "high", "action": "escalate"}}
        if "fever" in lower and ("3 days" in lower or "72 hours" in lower or "72hrs" in lower):
            return {"text": "Persistent fever (>72 hours) should be evaluated by a clinician.", "structured": {"severity": "medium", "action": "see_clinic"}}
        return {"text": "Based on your description: rest, hydrate, monitor symptoms. Seek care if worsening.", "structured": {"severity": "low", "action": "self_care"}}

class EmbeddingClient:
    def __init__(self, model: str = EMBEDDING_MODEL):
        self.model = model
        self.online = GEMINI_SDK_AVAILABLE

    def embed(self, texts: List[str]) -> List[List[float]]:
        if self.online:
            try:
                if hasattr(genai, "embeddings") and hasattr(genai.embeddings, "create"):
                    r = genai.embeddings.create(model=self.model, input=texts)
                    data = getattr(r, "data", None) or (r.get("data") if isinstance(r, dict) else None)
                    if data:
                        emb = []
                        for d in data:
                            vec = d.get("embedding") if isinstance(d, dict) else getattr(d, "embedding", None)
                            emb.append(list(vec))
                        return emb
                if hasattr(genai, "generate_embeddings"):
                    r = genai.generate_embeddings(model=self.model, input=texts)
                    data = getattr(r, "data", None) or (r.get("data") if isinstance(r, dict) else None)
                    if data:
                        emb = []
                        for d in data:
                            vec = d.get("embedding") if isinstance(d, dict) else getattr(d, "embedding", None)
                            emb.append(list(vec))
                        return emb
            except Exception as e:
                safe_print("[Embed] SDK call failed, using mock vectors:", e)
                return [self._mock_vector(t) for t in texts]
        return [self._mock_vector(t) for t in texts]

    def _mock_vector(self, text: str) -> List[float]:
        vec = [float((ord(c) % 97) / 97.0) for c in text[:64]]
        vec = (vec + [0.0] * 32)[:32]
        return vec

@dataclass
class MemoryRecord:
    user_id: str
    content: str
    created_at: str = field(default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
    metadata: Dict[str, Any] = field(default_factory=dict)
    id: str = field(default_factory=lambda: str(uuid.uuid4()))

class MemoryAgent:
    def __init__(self, use_embeddings: bool = True):
        self.store: List[MemoryRecord] = []
        self.use_embeddings = use_embeddings and GEMINI_SDK_AVAILABLE
        self.embedding_client = EmbeddingClient() if self.use_embeddings else None
        self.embeddings_index: List[List[float]] = []

    def add(self, user_id: str, content: str, metadata: Optional[Dict[str, Any]] = None) -> str:
        rec = MemoryRecord(user_id=user_id, content=content, metadata=metadata or {})
        self.store.append(rec)
        if self.use_embeddings and self.embedding_client:
            vecs = self.embedding_client.embed([content])
            if vecs:
                self.embeddings_index.append(vecs[0])
            else:
                self.embeddings_index.append([0.0] * 32)
        safe_print(f"[MemoryAgent] Stored record {rec.id} for user {user_id}")
        return rec.id

    def query(self, user_id: str, keywords: List[str], k: int = 5) -> List[MemoryRecord]:
        if self.use_embeddings and keywords and self.embedding_client and self.embeddings_index:
            q = " ".join(keywords)
            qv = self.embedding_client.embed([q])[0]
            scores: List[Tuple[float, MemoryRecord]] = []
            for rec, vec in zip(self.store, self.embeddings_index):
                if rec.user_id != user_id:
                    continue
                score = self._cosine_similarity(qv, vec)
                scores.append((score, rec))
            scores.sort(key=lambda x: x[0], reverse=True)
            return [r for _, r in scores[:k]]
        scores: List[Tuple[float, MemoryRecord]] = []
        klower = [w.lower() for w in keywords]
        for rec in self.store:
            if rec.user_id != user_id:
                continue
            score = 0
            txt = rec.content.lower()
            for kw in klower:
                if kw in txt:
                    score += 1
            if score > 0:
                scores.append((score, rec))
        scores.sort(key=lambda x: x[0], reverse=True)
        res = [r for _, r in scores[:k]]
        safe_print(f"[MemoryAgent] Query for {keywords} returned {len(res)} records")
        return res

    def get_recent(self, user_id: str, k: int = 5) -> List[MemoryRecord]:
        recs = [r for r in self.store if r.user_id == user_id]
        recs.sort(key=lambda r: r.created_at, reverse=True)
        return recs[:k]

    def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
        import math
        num = sum(x * y for x, y in zip(a, b))
        denom_a = math.sqrt(sum(x * x for x in a))
        denom_b = math.sqrt(sum(y * y for y in b))
        if denom_a == 0 or denom_b == 0:
            return 0.0
        return num / (denom_a * denom_b)

class SymptomCheckerTool:
    def __init__(self):
        self.knowledge = [
            {"conditions": ["fever", "cough"], "diagnosis": "Viral respiratory infection", "advice": "rest, hydrate, paracetamol if needed", "severity": "low"},
            {"conditions": ["fever", "rash"], "diagnosis": "Possible measles or viral exanthem", "advice": "seek clinic evaluation", "severity": "medium"},
            {"conditions": ["chest pain", "sweating"], "diagnosis": "Possible cardiac event", "advice": "seek emergency care", "severity": "high"},
            {"conditions": ["shortness of breath", "wheezing"], "diagnosis": "Possible asthma exacerbation", "advice": "seek urgent care", "severity": "high"},
            {"conditions": ["diarrhea", "dehydration"], "diagnosis": "Gastroenteritis", "advice": "oral rehydration", "severity": "medium"},
        ]

    def check(self, symptoms: List[str]) -> Dict[str, Any]:
        if not symptoms:
            return {"diagnosis": "Unknown", "advice": "Provide more details or seek clinician", "severity": "unknown"}
        symset = set(s.lower() for s in symptoms)
        results = []
        for rule in self.knowledge:
            rule_conditions = set(rule["conditions"])
            if symset & rule_conditions:
                overlap = len(symset & rule_conditions)
                results.append((overlap, rule))
        if not results:
            return {"diagnosis": "Unknown", "advice": "Provide more details or seek clinician", "severity": "unknown"}
        results.sort(key=lambda x: x[0], reverse=True)
        top = results[0][1]
        return {"diagnosis": top["diagnosis"], "advice": top["advice"], "severity": top["severity"]}

class ValidatorAgent:
    def __init__(self, symptom_tool: SymptomCheckerTool, model: str = VALIDATOR_MODEL):
        self.tool = symptom_tool
        self.client = LLMClient(model)

    def validate(self, user_text: str, proposed_action: Dict[str, Any]) -> Dict[str, Any]:
        lower = user_text.lower()
        if any(w in lower for w in ["chest pain", "severe chest pain", "crushing pain", "difficulty breathing"]):
            return {"ok": False, "reason": "High-risk symptom detected (heuristic)", "recommendation": "escalate"}
        known = set(sum([r["conditions"] for r in self.tool.knowledge], []))
        tokens = [t.strip(".,!?") for t in user_text.lower().split()]
        found = [t for t in tokens if t in known]
        if found:
            chk = self.tool.check(found)
            if chk.get("severity") == "high":
                return {"ok": False, "reason": "Tool suggests high severity", "recommendation": "escalate"}
        if GEMINI_SDK_AVAILABLE:
            prompt = "Classify the following patient text into severity: low, medium, high.\nText: " + user_text + "\nReturn a JSON object like: {\"severity\": \"low|medium|high\", \"note\": \"...\" }"
            resp = self.client.generate(prompt)
            try:
                txt = resp.get("text", "")
                low = "low" in txt.lower()
                med = "medium" in txt.lower()
                high = "high" in txt.lower()
                if high and not med:
                    return {"ok": False, "reason": "Validator model suggests high severity", "recommendation": "escalate"}
                if med and not high:
                    return {"ok": True, "reason": "Validator model suggests medium severity", "recommendation": "see_clinic"}
            except Exception:
                pass
        return {"ok": True, "reason": "No high-risk indicators", "recommendation": "safe"}

class SchedulerAgent:
    def __init__(self):
        self.tasks: Dict[str, Dict[str, Any]] = {}

    def schedule(self, user_id: str, message: str, delay_seconds: int) -> str:
        task_id = str(uuid.uuid4())
        run_at = time.time() + max(0, int(delay_seconds))
        self.tasks[task_id] = {"user_id": user_id, "message": message, "run_at": run_at}
        safe_print(f"[Scheduler] Scheduled reminder {task_id} in {delay_seconds}s")
        return task_id

    def run_pending(self, current_time: Optional[float] = None) -> int:
        """Run all pending tasks that are due. Returns count of executed tasks."""
        now = current_time if current_time is not None else time.time()
        to_run = [tid for tid, t in list(self.tasks.items()) if t["run_at"] <= now]
        for tid in to_run:
            task = self.tasks.pop(tid, None)
            if task:
                safe_print(f"[Reminder for {task['user_id']}] {task['message']}")
        return len(to_run)

    def next_run_in(self) -> Optional[float]:
        if not self.tasks:
            return None
        now = time.time()
        next_run = min(t["run_at"] for t in self.tasks.values())
        return max(0.0, next_run - now)

class AdvisorAgent:
    def __init__(self, memory: MemoryAgent, validator: ValidatorAgent, scheduler: SchedulerAgent, symptom_tool: SymptomCheckerTool, model: str = ADVISOR_MODEL):
        self.memory = memory
        self.validator = validator
        self.scheduler = scheduler
        self.symptom_tool = symptom_tool
        self.client = LLMClient(model)

    def _extract_keywords(self, text: str) -> List[str]:
        tokens = [t.strip(".,!?()") for t in text.lower().split()]
        return [t for t in tokens if len(t) > 2][:10]

    def _ask_clarifying(self, user_text: str) -> Optional[str]:
        lower = user_text.lower()
        if any(w in lower for w in ["chest pain", "difficulty breathing", "severe chest pain", "crushing pain", "sweating"]):
            return None
        keywords = self._extract_keywords(user_text)
        if "pain" in keywords and "where" not in user_text.lower():
            return "Can you tell me where the pain is located and how severe it is (1-10)?"
        if "fever" in keywords and not any(w in user_text.lower() for w in ["days", "hours", "duration"]):
            return "How long have you had the fever? (e.g., 2 days, 12 hours)"
        return None

    def handle(self, user_id: str, user_text: str) -> Dict[str, Any]:
        safe_print(f"[Advisor] Received from {user_id}: {user_text}")
        self.memory.add(user_id=user_id, content=user_text, metadata={"role": "user_input"})
        clar = self._ask_clarifying(user_text)
        if clar:
            self.memory.add(user_id=user_id, content=clar, metadata={"role": "clarifying_question"})
            return {"type": "clarify", "message": clar}
        known_symptoms = set(sum([r["conditions"] for r in self.symptom_tool.knowledge], []))
        tokens = [t.strip(".,!?()") for t in user_text.lower().split()]
        symptoms = [t for t in tokens if t in known_symptoms]
        symptom_result = self.symptom_tool.check(symptoms) if symptoms else {"diagnosis": "Unknown", "advice": "Please provide more details", "severity": "unknown"}
        safe_print(f"[Advisor] Symptom tool result: {symptom_result}")
        prompt = (
            "User reported: " + user_text + "\n"
            + "Tool says: " + json.dumps(symptom_result) + "\n"
            + "Provide concise, empathetic advice and next steps. Return a short JSON with fields: message, severity, action."
        )
        llm_resp = self.client.generate(prompt)
        validation = self.validator.validate(user_text, llm_resp.get("structured", {}))
        safe_print(f"[Advisor] Validation: {validation}")
        if not validation.get("ok"):
            message = "I detect potentially serious symptoms. Please seek urgent medical attention or call emergency services."
            self.memory.add(user_id=user_id, content=message, metadata={"role": "escalation"})
            return {"type": "escalate", "message": message, "recommendation": validation.get("recommendation")}
        text = llm_resp.get("text", "")
        parsed = None
        try:
            parsed = json.loads(text)
        except Exception:
            parsed = None
        final_message = parsed.get("message") if isinstance(parsed, dict) and parsed.get("message") else text
        severity = parsed.get("severity") if isinstance(parsed, dict) and parsed.get("severity") else symptom_result.get("severity")
        final = {
            "type": "advice",
            "message": final_message,
            "diagnosis": symptom_result.get("diagnosis"),
            "severity": severity,
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
        }
        self.memory.add(user_id=user_id, content=json.dumps(final), metadata={"role": "agent_response"})
        return final

    def set_reminder(self, user_id: str, message: str, delay_seconds: int) -> str:
        return self.scheduler.schedule(user_id=user_id, message=message, delay_seconds=delay_seconds)

def demo_conversation():
    safe_print("Starting demo...")

    memory = MemoryAgent(use_embeddings=True)
    symptom_tool = SymptomCheckerTool()
    validator = ValidatorAgent(symptom_tool)
    scheduler = SchedulerAgent()
    advisor = AdvisorAgent(memory, validator, scheduler, symptom_tool)

    user_id = "user_001"

    user_msgs = [
        "I have a fever and cough.",
        "It's been about 2 days and I also feel a bit tired.",
        "I also have chest pain sometimes while breathing.",
    ]

    for msg in user_msgs:
        safe_print(f"\n> User: {msg}")
        resp = advisor.handle(user_id, msg)

        if resp["type"] == "clarify":
            safe_print(f"Agent (clarify): {resp['message']}")
            followup = "Pain is in my chest and rating 5/10"
            safe_print(f"User replies to clarifying Q: {followup}")
            resp2 = advisor.handle(user_id, followup)
            safe_print(f"Agent: {resp2['message']}")

        elif resp["type"] == "escalate":
            safe_print(f"Agent (ESCALATE): {resp['message']}")

        else:
            safe_print(f"Agent: {resp['message']}")

    safe_print("\n\nScheduling follow-up reminders...")

    r1 = advisor.set_reminder(
        user_id,
        "Please check your temperature again in 2 hours.",
        delay_seconds=2
    )

    r2 = advisor.set_reminder(
        user_id,
        "Monitor your breathing and report any worsening symptoms in 6 hours.",
        delay_seconds=6
    )

    r3 = advisor.set_reminder(
        user_id,
        "24-hour follow-up: How are your symptoms today?",
        delay_seconds=24
    )

    safe_print("Reminders scheduled:", r1, r2, r3)

    safe_print("\n\nRunning scheduler (simulated time)...\n")

    future_time = time.time() + 999999
    executed_count = scheduler.run_pending(current_time=future_time)
    safe_print(f"\nExecuted {executed_count} reminders")

    safe_print("\nDemo finished.")

def test_scheduler_behavior():
    safe_print("\n\nRunning scheduler tests...")
    sched = SchedulerAgent()
    id1 = sched.schedule("u1", "Immediate reminder", delay_seconds=0)
    id2 = sched.schedule("u1", "Delayed reminder", delay_seconds=1)
    executed1 = sched.run_pending()
    assert executed1 >= 1, "Expected at least 1 executed task (immediate)"
    safe_print("[Test] First run_pending executed:", executed1)
    time.sleep(1.2)
    executed2 = sched.run_pending()
    assert executed2 >= 1, "Expected delayed task to execute"
    safe_print("[Test] Second run_pending executed:", executed2)
    safe_print("Scheduler tests passed.")

def test_advisor_flow():
    safe_print("\n\nRunning advisor flow test...")
    mem = MemoryAgent(use_embeddings=False)
    st = SymptomCheckerTool()
    val = ValidatorAgent(st)
    sched = SchedulerAgent()
    adv = AdvisorAgent(mem, val, sched, st)
    r1 = adv.handle("tester", "I have fever and cough for 1 day")
    assert r1["type"] in ("advice", "clarify"), "Advisor should return advice or ask to clarify"
    safe_print("[Test] Advisor returned type:", r1["type"])
    r2 = adv.handle("tester", "I have chest pain and sweating")
    assert r2["type"] == "escalate", "Advisor should escalate on chest pain"
    safe_print("[Test] Advisor escalated as expected")
    safe_print("Advisor tests passed.")

if __name__ == "__main__":
    demo_conversation()
    test_scheduler_behavior()
    test_advisor_flow()

[LLM] Gemini SDK configured.
Starting demo...

> User: I have a fever and cough.
[Advisor] Received from user_001: I have a fever and cough.
[MemoryAgent] Stored record 371541ce-c0bf-4fe8-802b-95fa878ab42d for user user_001
[MemoryAgent] Stored record 7a8cad3b-03e6-4c0b-b022-f8010a952701 for user user_001
Agent (clarify): How long have you had the fever? (e.g., 2 days, 12 hours)
User replies to clarifying Q: Pain is in my chest and rating 5/10
[Advisor] Received from user_001: Pain is in my chest and rating 5/10
[MemoryAgent] Stored record d795287e-b3e9-4cbe-97c2-354ac96ea0a8 for user user_001
[MemoryAgent] Stored record 9588cda7-ab2b-4a24-bf5a-33bed9fbdabf for user user_001
Agent: Can you tell me where the pain is located and how severe it is (1-10)?

> User: It's been about 2 days and I also feel a bit tired.
[Advisor] Received from user_001: It's been about 2 days and I also feel a bit tired.
[MemoryAgent] Stored record 3eaf507b-032e-4e6b-9227-1f9a5fad0a8c for user user_001
[Advisor