In [4]:
import os
import pickle
import json
from pathlib import Path
from typing import Any, Dict, Optional

from dotenv import load_dotenv
from pydantic import BaseModel

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory import ConversationBufferMemory
from langchain_community.document_loaders import JSONLoader
from langchain.document_loaders import PyPDFLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_astradb import AstraDBVectorStore
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain, RetrievalQA
from langchain.prompts import PromptTemplate
from langgraph.graph import StateGraph, START, END

# ----------------------------
# Load environment variables
# ----------------------------
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
ASTRA_API_KEY = os.getenv("ASTRA_API_KEY")
DB_ENDPOINT = os.getenv("DB_ENDPOINT")

llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    api_key=GEMINI_API_KEY,
    temperature=0.2,
)

# ----------------------------
# Disk Memory
# ----------------------------
class DiskConversationMemory:
    def __init__(self, filename: str = "chat_memory.pkl"):
        self.filename = Path(filename)
        self.memory = ConversationBufferMemory(return_messages=True)
        self._load()

    def _load(self):
        if self.filename.exists():
            try:
                with open(self.filename, "rb") as f:
                    self.memory = pickle.load(f)
                print(f"Loaded memory from {self.filename}")
            except Exception as e:
                print("Failed to load memory, starting fresh:", e)

    def persist(self):
        try:
            with open(self.filename, "wb") as f:
                pickle.dump(self.memory, f)
                print(f"Persisted memory to {self.filename}")
        except Exception as e:
            print("Failed to persist memory:", e)

# ----------------------------
# State Model
# ----------------------------
class HealthGraphState(BaseModel):
    twilio_payload: Optional[Dict[str, Any]] = None
    user_message: Optional[str] = None
    user_meta: Optional[Dict[str, Any]] = None
    vaccination_docs: Optional[Any] = None
    outbreak_docs: Optional[Any] = None
    local_vectorstore: Optional[Any] = None
    disk_memory: Optional[Any] = None
    route_decision: Optional[Dict[str, str]] = None
    response: Optional[str] = None
    vaccination_json_path: Optional[str] = (
        r"/Users/aashutoshkumar/Documents/Projects/healthgraph-assistant/data/vaccination_schedule.json"
    )
    outbreak_pdf_path: Optional[str] = (
        r"/Users/aashutoshkumar/Documents/Projects/healthgraph-assistant/latest_weekly_outbreak/31st_weekly_outbreak.pdf"
    )
    index_dir: Optional[str] = (
        r"/Users/aashutoshkumar/Documents/Projects/healthgraph-assistant/exp/faiss_index/index.faiss"
    )

# ----------------------------
# Node Functions
# ----------------------------
def node_twilio_ingress(state: HealthGraphState) -> HealthGraphState:
    payload = state.twilio_payload
    if not payload:
        return state

    text = payload.get("Body") or payload.get("Message") or payload.get("text")
    sender = payload.get("From") or payload.get("from")

    state.user_message = text
    state.user_meta = {"sender": sender, "raw_payload": payload}
    return state

def node_load_vaccination_json(state: HealthGraphState) -> HealthGraphState:
    loader = JSONLoader(file_path=state.vaccination_json_path)
    docs = loader.load()
    state.vaccination_docs = docs
    return state

def node_load_outbreak_pdf(state: HealthGraphState) -> HealthGraphState:
    loader = PyPDFLoader(state.outbreak_pdf_path)
    docs = loader.load_and_split()
    state.outbreak_docs = docs
    return state

def node_build_faiss_index(state: HealthGraphState) -> HealthGraphState:
    hf_embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    docs = state.outbreak_docs or []
    if not docs:
        print("No outbreak documents found to index.")
        return state

    index_dir = Path(state.index_dir).parent
    index_dir.mkdir(parents=True, exist_ok=True)

    vectorstore = FAISS.from_documents(docs, embedding=hf_embedding)
    vectorstore.save_local(str(index_dir))

    state.local_vectorstore = vectorstore
    print(f"FAISS index built with {len(docs)} docs and saved to {index_dir}")
    return state

def node_router(state: HealthGraphState) -> HealthGraphState:
    message = state.user_message or ""
    if not message:
        state.route_decision = {"route": "general_query", "reason": "empty_message"}
        return state

    lower = message.lower()
    if any(w in lower for w in ["urgent", "emergency", "outbreak"]):
        state.route_decision = {"route": "emergency_outbreak", "reason": "keyword_match"}
        return state
    if any(w in lower for w in ["symptom", "fever", "cough"]):
        state.route_decision = {"route": "symptom", "reason": "keyword_match"}
        return state
    if any(w in lower for w in ["vaccine", "vaccination", "schedule"]):
        state.route_decision = {"route": "vaccination_schedule", "reason": "keyword_match"}
        return state

    state.route_decision = {"route": "general_query", "reason": "default"}
    return state

def node_emergency_outbreak(state: HealthGraphState) -> HealthGraphState:
    if not state.user_message:
        state.response = "No message provided."
        return state
    if not state.local_vectorstore:
        state.response = "Outbreak data not indexed."
        return state

    retriever = state.local_vectorstore.as_retriever(search_kwargs={"k": 5})
    conv = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=state.disk_memory.memory if state.disk_memory else None,
        return_source_documents=False,
    )
    result = conv.run(question=state.user_message)
    state.response = result
    if state.disk_memory:
        state.disk_memory.persist()
    return state

def node_symptom(state: HealthGraphState) -> HealthGraphState:
    hf_embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = AstraDBVectorStore(
        embedding=hf_embedding,
        api_endpoint=DB_ENDPOINT,
        namespace="default_keyspace",
        token=ASTRA_API_KEY,
        collection_name="medical_v2",
    )
    retriever = vector_store.as_retriever()
    prompt = PromptTemplate(
        input_variables=["context", "question"],
        template="You are a medical assistant.\nContext:\n{context}\n\nQuestion:\n{question}\nAnswer clearly."
    )
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        chain_type="stuff",
        chain_type_kwargs={"prompt": prompt},
    )
    state.response = qa_chain.run(state.user_message)
    return state

def node_vaccination_schedule(state):
    message = state.user_message
    if not message:
        return state

    vaccination_json_path = state.vaccination_json_path

    try:
        with open(vaccination_json_path, "r", encoding="utf-8") as f:
            schedule_data = json.load(f)
        docs_json_str = json.dumps(schedule_data, ensure_ascii=False, indent=2)
    except Exception as e:
        state.response = f"(unable to load vaccination schedule JSON: {e})"
        return state

    prompt = f"""
        You are an assistant that knows how to infer vaccination due-dates from a vaccination schedule JSON.
        Use the provided schedule to answer the question as precisely as possible and, when appropriate, return a short checklist.

        SCHEDULE_JSON:
        {docs_json_str}

        QUESTION:
        {message}
    """

    answer = llm.invoke(prompt)
    state.response = answer.content  # store only the human-readable text
    return state


def node_general_query(state: HealthGraphState) -> HealthGraphState:
    resp = llm.invoke([{"role": "user", "content": state.user_message}]).content
    state.response = resp
    return state

def node_route_dispatcher(state: HealthGraphState) -> HealthGraphState:
    route = state.route_decision["route"] if state.route_decision else "general_query"
    if route == "emergency_outbreak":
        return node_emergency_outbreak(state)
    if route == "symptom":
        return node_symptom(state)
    if route == "vaccination_schedule":
        return node_vaccination_schedule(state)
    return node_general_query(state)

# ----------------------------
# Build Workflow
# ----------------------------
workflow = StateGraph(HealthGraphState)
workflow.add_node("twilio_ingress", node_twilio_ingress)
workflow.add_node("load_vaccination_json", node_load_vaccination_json)
workflow.add_node("load_outbreak_pdf", node_load_outbreak_pdf)
workflow.add_node("build_faiss_index", node_build_faiss_index)
workflow.add_node("router", node_router)
workflow.add_node("route_dispatcher", node_route_dispatcher)
workflow.add_node("emergency_outbreak", node_emergency_outbreak)
workflow.add_node("symptom", node_symptom)
workflow.add_node("vaccination_schedule", node_vaccination_schedule)
workflow.add_node("general_query", node_general_query)

workflow.add_edge(START, "twilio_ingress")
workflow.add_edge("twilio_ingress", "router")
workflow.add_edge("router", "route_dispatcher")
workflow.add_edge("load_vaccination_json", "build_faiss_index")
workflow.add_edge("load_outbreak_pdf", "build_faiss_index")
workflow.add_edge("build_faiss_index", "route_dispatcher")
workflow.add_edge("route_dispatcher", END)

app = workflow.compile()
print("✅ Workflow compiled successfully.")

# ----------------------------
# Local Test Runner
# ----------------------------



✅ Workflow compiled successfully.


E0000 00:00:1758361823.364028  725200 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


In [7]:
# Initialize test state
disk_mem = DiskConversationMemory()
test_state = HealthGraphState(
    twilio_payload={"Body": "I have fever what should i do?"}, 
    disk_memory=disk_mem
)

# Run the workflow
final_state_dict = app.invoke(test_state)

# Print only the human-readable response
print("Final Response:\n", final_state_dict["response"])


  state.response = qa_chain.run(state.user_message)


Final Response:
 Okay, I understand you have a fever. Based on the information provided, here's what I can suggest, keeping in mind I am a medical assistant and cannot provide medical advice:

1.  **Consider your symptoms:** Do you have any other symptoms besides fever, such as headache, muscle aches, loss of appetite, nausea, vomiting, abdominal pain, or jaundice? Have you traveled recently? These details are important for diagnosis.

2.  **If your fever persists:** The provided text mentions "Pyrexia >3wks with no identified cause after evaluation in hospital for 3d or ≥3 out-patient visits." If your fever has been present for an extended period (weeks) and you haven't found a cause, it's important to seek medical attention.

3.  **If you have been hospitalized:** The text also mentions "Nosocomial PUO: Patient hospitalized for >48h with no infection at admission." If you developed a fever after being in the hospital for more than 48 hours, this could be a specific type of fever that