In [1]:
from configs import models, env
from backend.llm.api import load_llm_langchain

config_loaded = {"model_config": models, "env": env}
llm = load_llm_langchain(source='groq', model_name='LLaMA-3', config=config_loaded)

[LLM Loader] Successfully initialized model 'llama-3.1-8b-instant' from 'groq'.


In [4]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional, List
from backend.agents.nodes import (
    symptom_node,
    ehr_node,
    literature_node,
    drug_node,
    treatment_node,
)


# Step 1: Define the schema
class AgentState(TypedDict):
    symptoms: Optional[str]
    ehr_text: Optional[str]
    medications: Optional[List[str]]
    question: Optional[str]
    diagnosis: Optional[str]
    summary: Optional[str]
    answer: Optional[str]
    plan: Optional[str]
    drug_warnings: Optional[str]
    patient_profile: Optional[dict]


graph = StateGraph(state_schema=AgentState)

# # Add nodes
# graph.add_node("symptom_checker", symptom_node)
# graph.add_node("ehr_summarizer", ehr_node)
# graph.add_node("literature_qa", literature_node)
# graph.add_node("drug_checker", drug_node)
# graph.add_node("treatment_planner", treatment_node)

# Bind graph nodes using lambdas to pass llm
graph.add_node("symptom_checker", lambda state: symptom_node(state, llm))
graph.add_node("ehr_summarizer", lambda state: ehr_node(state, llm))
graph.add_node("literature_qa", lambda state: literature_node(state, llm))
graph.add_node("drug_checker", lambda state: drug_node(state, llm))
graph.add_node("treatment_planner", lambda state: treatment_node(state, llm))



# Set flow
graph.set_entry_point("symptom_checker")
graph.add_edge("symptom_checker", "ehr_summarizer")
graph.add_edge("ehr_summarizer", "literature_qa")
graph.add_edge("literature_qa", "drug_checker")
graph.add_edge("drug_checker", "treatment_planner")
graph.add_edge("treatment_planner", END)

# ## 🚀 Run Graph with Enriched Inputs

app_graph = graph.compile()

In [5]:
initial_state = {
    "symptoms": "shortness of breath, chest pain, fatigue",
    "ehr_text": "Patient has a history of hypertension and presents with elevated troponins...",
    "question": "What is the current recommendation for NSTEMI management in elderly patients?",
    "medications": ["Aspirin", "Warfarin"],
    "patient_profile": {
        "diagnosis": "NSTEMI",
        "age": 72,
        "sex": "Male",
        "comorbidities": ["Hypertension", "Atrial Fibrillation"]
    }
}

final_state = app_graph.invoke(initial_state)

for key, value in final_state.items():
    print(f"\n📌 {key.upper()}:")
    if isinstance(value, str):
        print(value)
    elif hasattr(value, "content"):
        print(value.content)
    elif isinstance(value, list):
        # Print each item or summarize the list
        for i, item in enumerate(value, 1):
            print(f"  {i}. {item}")
    else:
        # fallback, print repr
        print(repr(value))
    print()


🔑 Keys in state at symptom_node: ['symptoms', 'ehr_text', 'medications', 'question', 'patient_profile']
🔑 Keys in state at ehr_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'patient_profile']
🔑 Keys in state at literature_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']
🔑 Keys in state at drug_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']
🔑 Keys in state at treatment_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']

📌 SYMPTOMS:
shortness of breath, chest pain, fatigue


📌 EHR_TEXT:
Patient has a history of hypertension and presents with elevated troponins...


📌 MEDICATIONS:
  1. Aspirin
  2. Warfarin


📌 QUESTION:
What is the current recommendation for NSTEMI management in elderly patients?


📌 DIAGNOSIS:
**Ranked List of Likely Diagnoses:**

1. **Acute Coronary Syndrome (ACS)**: High likelihood due to c