In [None]:
from fastapi import FastAPI, Request
from pydantic import BaseModel
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain_community.tools.tavily_search import TavilySearchResults
from typing import Optional

# Define FastAPI app
app = FastAPI()

# Define input model
class PatientRequest(BaseModel):
    topic: Optional[str] = None
    ready_for_quiz: Optional[bool] = False
    patient_answer: Optional[str] = None
    continue_session: Optional[bool] = False

# Initialize state
initial_state = {
    "topic": None,
    "search_results": None,
    "summary": None,
    "quiz_question": None,
    "patient_answer": None,
    "evaluation": None,
}

# Tools
search_tool = TavilySearchResults()

# Nodes
def ask_topic(state):
    return state

def search_medical_info(state):
    results = search_tool.run(state["topic"])
    state["search_results"] = results
    return state

def summarize_info(state):
    summary = f"Summary of {state['topic']}: {state['search_results'][0]['content'][:300]}..."
    state["summary"] = summary
    return state

def present_info(state):
    return state

def generate_quiz(state):
    question = f"What is one key point about {state['topic']}?"
    state["quiz_question"] = question
    return state

def present_quiz(state):
    return state

def evaluate_answer(state):
    correct = "Yes" if "key" in state["patient_answer"].lower() else "No"
    explanation = f"Your answer is {correct}. Refer to: {state['summary'][:150]}..."
    state["evaluation"] = explanation
    return state

def present_results(state):
    return state

def ask_continue(state):
    return state

# Build the graph
graph = StateGraph()
graph.add_node("ask_topic", RunnableLambda(ask_topic))
graph.add_node("search_info", RunnableLambda(search_medical_info))
graph.add_node("summarize", RunnableLambda(summarize_info))
graph.add_node("present_info", RunnableLambda(present_info))
graph.add_node("generate_quiz", RunnableLambda(generate_quiz))
graph.add_node("present_quiz", RunnableLambda(present_quiz))
graph.add_node("evaluate", RunnableLambda(evaluate_answer))
graph.add_node("present_results", RunnableLambda(present_results))
graph.add_node("ask_continue", RunnableLambda(ask_continue))

# Define edges
graph.set_entry_point("ask_topic")
graph.add_edge("ask_topic", "search_info")
graph.add_edge("search_info", "summarize")
graph.add_edge("summarize", "present_info")
graph.add_edge("present_info", "generate_quiz")
graph.add_edge("generate_quiz", "present_quiz")
graph.add_edge("present_quiz", "evaluate")
graph.add_edge("evaluate", "present_results")
graph.add_edge("present_results", "ask_continue")
graph.add_edge("ask_continue", "ask_topic")

# Compile the graph
healthbot_app = graph.compile()

# API endpoint
@app.post("/healthbot")
async def healthbot_interaction(request: PatientRequest):
    state = initial_state.copy()

    if request.topic:
        state["topic"] = request.topic
        state = healthbot_app.invoke(state)
        return {
            "summary": state["summary"],
            "quiz_question": state["quiz_question"]
        }

    if request.ready_for_quiz:
        state["topic"] = request.topic
        state["search_results"] = search_tool.run(request.topic)
        state = summarize_info(state)
        state = generate_quiz(state)
        return {
            "quiz_question": state["quiz_question"]
        }

    if request.patient_answer:
        state["topic"] = request.topic
        state["search_results"] = search_tool.run(request.topic)
        state = summarize_info(state)
        state = generate_quiz(state)
        state["patient_answer"] = request.patient_answer
        state = evaluate_answer(state)
        return {
            "evaluation": state["evaluation"]
        }

    if request.continue_session:
        return {"message": "Session restarted. Please provide a new topic."}

    return {"message": "Please provide a topic to begin."}