In [1]:
from dotenv import load_dotenv
import os

In [2]:

# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key

In [15]:
from langgraph.graph import StateGraph, END
from typing import Literal
from typing import TypedDict, Optional, Literal, List, Dict
from langchain_community.document_loaders import Docx2txtLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import re

In [None]:
llm = ChatOpenAI(temperature=0.3, model="gpt-4o-mini")


In [17]:
# ==================
# 1. State Definition
# ==================
class ChatState(TypedDict):
    messages: List[Dict[str, str]]
    phase: Literal["symptoms", "report_qa", "verification", "reporting"]
    symptoms_summary: Optional[str]
    report_text: Optional[str]
    verification_questions: List[str]
    verification_answers: Dict[str, str]
    uploaded_files: List[str]


In [18]:
# ==================
# 2. Node Implementations
# ==================
def symptom_collector(state: ChatState):
    if state["phase"] != "symptoms":
        return state

    last_message = state["messages"][-1]["content"]
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You're a medical assistant conducting a symptom interview. 
         Ask ONE question at a time. When enough information is gathered, write:
         'SUMMARY: [structured summary]'"""),
        ("human", last_message)
    ])
    
    response = prompt | llm
    ai_message = response.invoke({}).content
    
    if "SUMMARY:" in ai_message:
        state["symptoms_summary"] = re.search(r"SUMMARY: (.*)", ai_message, re.DOTALL).group(1).strip()
        state["phase"] = "report_qa"
    else:
        state["messages"].append({"role": "ai", "content": ai_message})
    
    return state


In [19]:
def report_processor(state: ChatState):
    if state["phase"] != "report_qa":
        return state

    # Load document only once
    if not state.get("report_text") and state["uploaded_files"]:
        loader = Docx2txtLoader(state["uploaded_files"][0])
        docs = loader.load()
        state["report_text"] = "\n".join([doc.page_content for doc in docs])
    
    # Handle report Q&A
    last_message = state["messages"][-1]["content"]
    
    prompt = ChatPromptTemplate.from_template("""
    Medical Report:
    {report}
    
    Question: {question}
    Answer concisely in layman's terms:
    """)
    
    response = prompt | llm
    ai_message = response.invoke({
        "report": state["report_text"],
        "question": last_message
    }).content
    
    state["messages"].append({"role": "ai", "content": ai_message})
    return state

In [20]:
def verification_handler(state: ChatState):
    if state["phase"] != "verification":
        return state

    # Generate verification questions if not exists
    if not state.get("verification_questions"):
        prompt = ChatPromptTemplate.from_template("""
        Based on these test findings:
        {report}
        
        Generate 3 verification questions. Number each question.
        """)
        
        response = prompt | llm
        questions = response.invoke({"report": state["report_text"]}).content
        state["verification_questions"] = [
            q.split(" ", 1)[1].strip() 
            for q in questions.split("\n") 
            if q.strip().startswith(("1", "2", "3"))
        ][:3]
    
    # Process answers
    last_message = state["messages"][-1]["content"]
    current_q_index = len(state["verification_answers"])
    
    if current_q_index < len(state["verification_questions"]):
        state["verification_answers"][str(current_q_index+1)] = last_message
        state["messages"].append({
            "role": "ai",
            "content": f"Question {current_q_index+2}: {state['verification_questions'][current_q_index+1]}"
            if current_q_index+1 < len(state["verification_questions"]) 
            else "Thank you for the answers!"
        })
    
    if len(state["verification_answers"]) >= 3:
        state["phase"] = "reporting"
    
    return state


In [21]:
def report_generator(state: ChatState):
    prompt = ChatPromptTemplate.from_template("""
    Create a final doctor report with:
    - Patient symptoms: {symptoms}
    - Test findings: {report}
    - Verification answers: {answers}
    
    Include urgency level and recommendations.
    """)
    
    response = prompt | llm
    ai_message = response.invoke({
        "symptoms": state["symptoms_summary"],
        "report": state["report_text"],
        "answers": state["verification_answers"]
    }).content
    
    state["messages"].append({"role": "ai", "content": f"DOCTOR REPORT:\n{ai_message}"})
    state["phase"] = "reporting"
    return state


In [None]:
# ==================
# 3. Workflow Setup
# ==================
workflow = StateGraph(ChatState)

workflow.add_node("symptoms", symptom_collector)
workflow.add_node("report_processing", report_processor)
workflow.add_node("verification", verification_handler)
workflow.add_node("reporting", report_generator)

workflow.set_entry_point("symptoms")

workflow.add_conditional_edges(
    "symptoms",
    lambda s: "report_processing" if s.get("symptoms_summary") else "symptoms",
    {"symptoms": "symptoms", "report_processing": "report_processing"}
)

workflow.add_conditional_edges(
    "report_processing",
    lambda s: "verification" if s.get("report_text") else "reporting",
    {"verification": "verification", "reporting": "reporting"}
)

workflow.add_conditional_edges(
    "verification",
    lambda s: "reporting" if len(s.get("verification_answers", {})) >=3 else "verification",
    {"verification": "verification", "reporting": "reporting"}
)

workflow.add_edge("reporting", END)


In [None]:
from IPython.display import Image, display

try:
    display(Image(agent.get_graph().draw_mermaid_png()))
except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
# ==================
# 4. Execution Loop
# ==================
def run_conversation():
    agent = workflow.compile()
    
    state = {
        "messages": [{"role": "user", "content": "I'm having health issues"}],
        "phase": "symptoms",
        "symptoms_summary": None,
        "report_text": None,
        "verification_questions": [],
        "verification_answers": {},
        "uploaded_files": ["report.docx"]  # Replace with actual file path
    }
    
    while True:
        # Execute one graph step
        for step in agent.stream(state):
            node, new_state = next(iter(step.items()))
        
        # Exit condition
        if new_state["phase"] == "reporting" and "DOCTOR REPORT" in new_state["messages"][-1]["content"]:
            print("\nFinal Report:")
            print(new_state["messages"][-1]["content"])
            break
            
        # Show AI response
        last_ai_msg = next(m for m in reversed(new_state["messages"]) if m["role"] == "ai")
        print(f"\nAI: {last_ai_msg['content']}")
        
        # Get user input
        user_input = input("\nPatient: ")
        new_state["messages"].append({"role": "user", "content": user_input})
        
        # Update state for next iteration
        state = new_state

if __name__ == "__main__":
    run_conversation()