In [None]:
import os
import json
import pickle
from pathlib import Path
from typing import Any, Dict

from fastapi import FastAPI, Request
from pydantic import BaseModel
from dotenv import load_dotenv

# LangChain + LangGraph imports
from langgraph.graph import StateGraph, START, END
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory import ConversationBufferMemory
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_community.document_loaders import JSONLoader, PyPDFLoader


load_dotenv()
GEMINI_KEY = os.getenv("GOOGLE_API_KEY")

chat_model = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    api_key=GEMINI_KEY,
    temperature=0.3
)

class FileMemory:
    def __init__(self, user: str = "default"):
        self.path = Path(f"{user}_session.pkl")
        self.buffer = ConversationBufferMemory(return_messages=True)
        if self.path.exists():
            try:
                self.buffer = pickle.load(open(self.path, "rb"))
            except:
                pass

    def save(self):
        pickle.dump(self.buffer, open(self.path, "wb"))


class SessionState(BaseModel):
    user_text: str = ""
    user_id: str = "anon"
    memory: Any = None
    reply: str | None = None

    vax_file: str = "data/vaccination_schedule.json"
    outbreak_file: str = "data/outbreak_latest.pdf"
    index_folder: str = "data/faiss_idx"

    vax_data: Any = None
    outbreak_chunks: Any = None
    faiss_store: Any = None


def intake(state: SessionState) -> SessionState:
    return state

def fetch_vaccines(state: SessionState) -> SessionState:
    try:
        state.vax_data = JSONLoader(file_path=state.vax_file).load()
    except:
        state.vax_data = None
    return state

def fetch_outbreak(state: SessionState) -> SessionState:
    try:
        state.outbreak_chunks = PyPDFLoader(state.outbreak_file).load_and_split()
    except:
        state.outbreak_chunks = None
    return state

def build_index(state: SessionState) -> SessionState:
    if not state.outbreak_chunks:
        return state
    embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    faiss_db = FAISS.from_documents(state.outbreak_chunks, embedder)
    faiss_db.save_local(state.index_folder)
    state.faiss_store = faiss_db
    return state

def classify_and_reply(state: SessionState) -> SessionState:
    msg = state.user_text.lower()

    if "outbreak" in msg:
        state.reply = outbreak_reply(state)
    elif "vaccine" in msg or "schedule" in msg:
        state.reply = vaccine_reply(state)
    else:
        state.reply = generic_reply(state)

    return state


def outbreak_reply(state: SessionState) -> str:
    if not state.faiss_store:
        return "Outbreak data not yet available."

    retriever = state.faiss_store.as_retriever(search_kwargs={"k": 4})
    convo = ConversationalRetrievalChain.from_llm(
        llm=chat_model,
        retriever=retriever,
        memory=state.memory.buffer if state.memory else None
    )
    answer = convo.run(state.user_text)
    if state.memory:
        state.memory.save()
    return answer

def vaccine_reply(state: SessionState) -> str:
    try:
        schedule = json.load(open(state.vax_file, "r", encoding="utf-8"))
        sched_str = json.dumps(schedule, indent=2)
    except:
        return "Could not load vaccination data."

    prompt = f"""
    You are a health assistant. 
    Use the vaccination schedule below to answer the question.

    VACCINATION DATA:
    {sched_str}

    QUESTION:
    {state.user_text}
    """
    return chat_model.invoke(prompt).content

def generic_reply(state: SessionState) -> str:
    return chat_model.invoke([{"role": "user", "content": state.user_text}]).content


graph = StateGraph(SessionState)
graph.add_node("intake", intake)
graph.add_node("vaccines", fetch_vaccines)
graph.add_node("outbreaks", fetch_outbreak)
graph.add_node("index", build_index)
graph.add_node("decide", classify_and_reply)

graph.add_edge(START, "intake")
graph.add_edge("intake", "decide")
graph.add_edge("vaccines", "index")
graph.add_edge("outbreaks", "index")
graph.add_edge("index", "decide")
graph.add_edge("decide", END)

app_graph = graph.compile()


api = FastAPI()

@api.post("/chat")
async def chat(req: Request):
    payload = await req.json()
    user_msg = payload.get("message", "")
    user_id = payload.get("user_id", "anon")

    mem = FileMemory(user_id)
    state = SessionState(user_text=user_msg, user_id=user_id, memory=mem)

    result = app_graph.invoke(state)
    return {"response": result.reply}
