In [1]:
import json
from typing import Callable
from pydantic import BaseModel, Field
from langchain_ollama.chat_models import ChatOllama

from langchain_core.messages import SystemMessage, HumanMessage, AIMessage

from langgraph.graph import START, MessagesState, StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.message import add_messages, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.errors import GraphRecursionError

from utils import display_graph, fprint


In [2]:
class DoctorResponse(BaseModel):
    question: str = Field(description="The question asked by the doctor")
    diagnoses: list[str] = Field(description="The diagnoses the doctor is considering (max: 3 items)")

In [71]:
MODEL_NAME = "aya-expanse:32b"

PATIENT_PROMPT = """
You are a patient with a certain psychiatric diagnosis.
The description of your case and the diagnoses you might have are as follows:
DIAGNOSIS:
{diagnosis}

CASE DESCRIPTION:
{case}

You are being interviewed by a psychiatrist for a detailed diagnostic evaluation.

Your task is to answer the last question from the psychiatrist based on your life experiences and current symptoms, which have to be derived from the case description and the diagnosis.
The tone and language of your answer should be compatible with the case description and the diagnosis.
Do not describe your behavior or attitude, just give the answer in less than 100 words.
"""


DOCTOR_PROMPT = """
You are a psychiatrist interviewing a patient.
Your task is to systematically ask the patient questions to gather sufficient information for a DSM-5 diagnosis. The first question should be nonspecific, like 'How can I help you today?'.
But, as the interview progresses, formulate more focused questions that help narrow down the range of possible diagnoses.

Be gentle, empathetic and responsive to the patient.
If the conversation drifts away too far, bring the conversation back to reach the diagnosis.

Formulate a question in response to the patient's last answer and provide the candidate diagnoses you can consider.
The response should be a JSON object with the following keys:
- 'question': The question you formulated in less than 100 words.
- 'diagnoses': The candidate diagnoses you are considering (list of strings, max: 5 items)
If you do not consider any candidate diagnoses yet, just respond with an empty list.

Example of a response: 
('question': 'Can you describe your symptoms in more detail?', 
 'diagnoses': ['Depressive Disorder', 'Anxiety Disorder'])

If you have reached a diagnosis, the question should be 'TERMINATE' like this example:
('question': 'TERMINATE', 'diagnoses': ['Depressive Disorder'])
"""


In [72]:
llm_patient = ChatOllama(model=MODEL_NAME)
llm_doctor = ChatOllama(model=MODEL_NAME).with_structured_output(DoctorResponse)


In [73]:
cases = json.load(open("../Data/DSM-5-TR Clinical Cases.json", "r"))

In [74]:
class DxMessageState(MessagesState):
    messages: Annotated[list[str], add_messages]
    diagnoses: Annotated[list[str], add_messages]



In [84]:
def extract_messages(app: CompiledStateGraph, thread_id: str):
    config = RunnableConfig(configurable={"thread_id": thread_id})
    history = next(app.get_state_history(config=config))
    return history.values["messages"]


def print_message(output: list[str]):
    for i, message in enumerate(output[1:]):
        if i % 2 == 0:
            fprint(f"Doctor:\n {message.content}")
        else:
            fprint(f"Patient:\n {message.content}")
        print("-" * 80)

def prepare_patient_node(case: dict[str, str], llm: ChatOllama) -> Callable:
    diagnosis = case.get("diagnosis", "")
    case_text = case.get("text", "")

    if not diagnosis or not case_text:
        raise ValueError("Diagnosis and case description are required")

    patient_prompt = PATIENT_PROMPT.format(diagnosis=diagnosis, case=case_text)

    # Define the function that calls the model
    def patient_node(state: MessagesState):
        print("Patient thinking")
        # last message is from the doctor
        last_message = state["messages"][-1].copy()
        question = json.loads(last_message.content).get("question", "")
        last_message = HumanMessage(content=question)
        
        
        # messages = [SystemMessage(content=patient_prompt)] + state["messages"][:-1] + [last_message]
        messages = [SystemMessage(content=patient_prompt)] + [last_message]
        print(f"MESSAGES: {messages}")
        response = llm.invoke(messages)
        print(f"PATIENTRESPONSE: {response}")
        return {"messages": [response]}

    return patient_node


def prepare_doctor_node(llm: ChatOllama) -> Callable:
    
    doctor_prompt = DOCTOR_PROMPT

    def doctor_node(state: MessagesState):
        print("Doctor thinking")
        last_message = state["messages"][-1]
        if isinstance(last_message, AIMessage):
            last_message = HumanMessage(content=last_message.content)

        messages = [SystemMessage(content=doctor_prompt)] + state["messages"]
        response = llm.invoke(messages)
        if not response:
            response = {'question': "Can you elaborate on that?", 'diagnoses': []}
            return {"messages": [AIMessage(content=json.dumps(response))]}
        
        print(f"DOCTOR RESPONSE: {response}\n\n")
            
        return {"messages": [AIMessage(content=json.dumps(response.model_dump()))]}
        
    return doctor_node


def router(state: MessagesState):
    print("Router")
    print(state["messages"][-1])
    
    last_message = json.loads(state["messages"][-1].content).get("question", "")
    if "TERMINATE" in last_message:
        return END
    return "continue"


def setup_workflow(case: dict[str, str], llm_patient: ChatOllama, llm_doctor: ChatOllama) -> CompiledStateGraph:
    
    workflow = StateGraph(MessagesState)
    
    # Define the nodes
    patient_node = prepare_patient_node(case, llm_patient)
    doctor_node = prepare_doctor_node(llm_doctor)
    
    # Add the nodes
    workflow.add_node("patient", patient_node)
    workflow.add_node("doctor", doctor_node)
            
    # Add the edges
    workflow.add_edge(START, "doctor")
    workflow.add_conditional_edges(
        "doctor",
        router,
        {
            "continue": "patient",
            END: END,
        },
    )
    workflow.add_edge("patient", "doctor")

    app = workflow.compile()

    return app


def invoke_app(app: CompiledStateGraph, recursion_limit: int = 20):
    config = RunnableConfig(recursion_limit=recursion_limit)
    try:
        events = app.stream(HumanMessage(content="start the interview"), config=config)

        for event in events:
            for k, v in event.items():
                print(v["messages"].content)
            print("-------")
    except GraphRecursionError as e:
        print(e)


In [85]:
# llm = ChatOllama(model="llama3.1")
# cases = json.load(open("../Data/DSM-5-TR Clinical Cases.json", "r"))


def main(case: dict[str, str], llm: ChatOllama, recursion_limit: int = 6):
    app = setup_workflow(case, llm)
    invoke_app(app, recursion_limit)
    
    return extract_messages(app)


In [86]:
app = setup_workflow(cases[0], llm_patient, llm_doctor)
app.invoke({"messages": [HumanMessage(content="patient is a 16 year female student")]})


Doctor thinking
DOCTOR RESPONSE: question='How can I help you today?' diagnoses=[]


Router
content='{"question": "How can I help you today?", "diagnoses": []}' additional_kwargs={} response_metadata={} id='53aa7b1a-80bc-4fb0-8fc9-a937fbdc3895'
Patient thinking
MESSAGES: [SystemMessage(content="\nYou are a patient with a certain psychiatric diagnosis.\nThe description of your case and the diagnoses you might have are as follows:\nDIAGNOSIS:\n['Intellectual developmental disorder (intellectual disability), severe', 'Autism spectrum disorder, with accompanying intellectual and language']\n\nCASE DESCRIPTION:\n Ashley, age 17, was referred for a diagnostic reevaluation after having carried diagnoses\n\n of autism and intellectual disability for almost all of her life. She was recently found to\n\n have Kleefstra syndrome, and the family wanted to reconfirm the earlier diagnoses and\n\n assess the genetic risk to the future children of her older sisters.\n\n At the time of the reevaluation

{'messages': [HumanMessage(content='patient is a 16 year female student', additional_kwargs={}, response_metadata={}, id='7026677c-4e38-4967-9598-e674e3685d30'),
  AIMessage(content='{"question": "How can I help you today?", "diagnoses": []}', additional_kwargs={}, response_metadata={}, id='53aa7b1a-80bc-4fb0-8fc9-a937fbdc3895'),
  AIMessage(content='"Help me understand why I struggle with things that others seem to find easy, like understanding what I read or remembering how to do new tasks. Also, sometimes I feel really upset and don\'t know how to calm down without hurting myself. I want to learn better ways to express my feelings."', additional_kwargs={}, response_metadata={'model': 'aya-expanse:32b', 'created_at': '2024-11-17T09:12:07.982369Z', 'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 14152272167, 'load_duration': 16125875, 'prompt_eval_count': 1026, 'prompt_eval_duration': 9569000000, 'eval_count': 60, 'eval_duration'

In [87]:
doctor = prepare_doctor_node(llm_doctor)
doctor(state)



Doctor thinking


{'messages': [AIMessage(content='{"question": "", "diagnoses": []}', additional_kwargs={}, response_metadata={})]}

In [47]:
response = Out[45]

In [51]:
question = json.loads(response['messages'][0].content).get("question", "")
question


'How can I help you today?'

In [43]:
doctor = prepare_doctor_node(llm_doctor)

In [None]:
state

In [None]:
response = doctor(state)

In [None]:
response