In [None]:
pip install docx2txt python-docx

In [22]:
from dotenv import load_dotenv
import os

In [23]:

# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key

In [24]:
from typing import TypedDict, Optional, Literal, List, Dict
from langchain_community.document_loaders import Docx2txtLoader
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
import re


In [25]:
llm = ChatOpenAI(temperature=0.3, model="gpt-4o-mini")


In [26]:
# ==================
# 1. State Definition
# ==================
class MedicalState(TypedDict):
    messages: List[Dict[str, str]]
    phase: Literal[
        "symptom_interview", 
        "report_processing",
        "verification",
        "completed"
    ]
    symptoms_summary: Optional[str]
    report_text: Optional[str]
    verification_questions: List[str]
    verification_answers: Dict[int, str]


In [27]:
def symptom_interview(state: MedicalState):
    if state["phase"] != "symptom_interview":
        return state
    
    last_message = state["messages"][-1]["content"]
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", """Ask ONE medical question at a time. End with 'SUMMARY:...'"""),
        ("human", last_message)
    ])
    
    response = prompt | llm
    ai_message = response.invoke({}).content
    
    # Handle summary extraction safely
    if "SUMMARY:" in ai_message:
        match = re.search(r"SUMMARY:\s*(.*?)(?=\n\n|\Z)", ai_message, re.DOTALL)
        if match:
            state["symptoms_summary"] = match.group(1).strip()
            state["phase"] = "report_processing"
        else:
            # If SUMMARY exists but format is wrong, ask again
            ai_message = "Please provide a valid summary using 'SUMMARY:' format"
            state["messages"].append({"role": "assistant", "content": ai_message})
    else:
        state["messages"].append({"role": "assistant", "content": ai_message})
    
    return state

In [39]:
def process_report(state: MedicalState):
    if state["phase"] != "report_processing":
        return state
    
    if not state.get("report_text"):
        loader = Docx2txtLoader("report.docx")  # Use actual path
        docs = loader.load()
        state["report_text"] = "\n".join([d.page_content for d in docs])
    
    prompt = ChatPromptTemplate.from_template("""
    Analyze this medical report:
    {report}
    
    Generate 3 verification questions numbered 1-3:
    1. ...
    2. ...
    3. ...
    """)
    
    response = prompt | llm
    questions = response.invoke({"report": state["report_text"]}).content
    state["verification_questions"] = [
        q.split(" ", 1)[1].strip()
        for q in questions.split("\n") 
        if q.strip().startswith(tuple(str(i) for i in range(1,4)))
    ][:3]
    
    state["phase"] = "verification"
    state["messages"].append({
        "role": "assistant", 
        "content": "Please answer these verification questions:\n" + "\n".join(state["verification_questions"])
    })
    
    return state

In [40]:
def handle_verification(state: MedicalState):
    if state["phase"] != "verification":
        return state
    
    last_message = state["messages"][-1]["content"]
    current_q_index = len(state["verification_answers"])
    
    if current_q_index < len(state["verification_questions"]):
        state["verification_answers"][current_q_index] = last_message
        next_q = state["verification_questions"][current_q_index]
        state["messages"].append({
            "role": "assistant",
            "content": f"Question {current_q_index+1}: {next_q}"
        })
    
    if len(state["verification_answers"]) >= len(state["verification_questions"]):
        state["phase"] = "completed"
    
    return state


In [None]:
# ==================
# 3. Simplified Workflow
# ==================
workflow = StateGraph(MedicalState)

workflow.add_node("symptoms", symptom_interview)
workflow.add_node("report", process_report)
workflow.add_node("verify", handle_verification)
workflow.add_node("end", lambda state: state)

workflow.set_entry_point("symptoms")

workflow.add_conditional_edges(
    "symptoms",
    lambda s: "report" if s.get("symptoms_summary") else "symptoms",
    {"symptoms": "symptoms", "report": "report"}
)

workflow.add_conditional_edges(
    "verify",
    lambda s: "end" if len(s.get("verification_answers", {})) >=3 else "verify",
    {"verify": "verify", "end": "end"}
)

workflow.add_edge("report", "verify")
workflow.add_edge("end", END)



In [None]:
from IPython.display import Image, display

try:
    display(Image(agent.get_graph().draw_mermaid_png()))
except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
# ==================
# 4. Interactive Execution
# ==================
def run_medical_interview():
    agent = workflow.compile()
    
    state = {
        "messages": [{"role": "user", "content": "I need medical help"}],
        "phase": "symptom_interview",
        "symptoms_summary": None,
        "report_text": None,
        "verification_questions": [],
        "verification_answers": {}
    }
    
    while True:
        # Execute one graph step
        result = agent.invoke(state)
        state = result
        
        # Get last assistant message
        last_ai = next(m for m in reversed(state["messages"]) if m["role"] == "assistant")
        print(f"AI: {last_ai['content']}")
        
        # Check exit condition
        if state["phase"] == "completed":
            print("\nGenerating final report...")
            break
            
        # Get user input
        user_input = input("Patient: ")
        state["messages"].append({"role": "user", "content": user_input})

if __name__ == "__main__":
    run_medical_interview()
