In [4]:
from langgraph.graph import StateGraph, END
from langchain_groq import ChatGroq

from typing import TypedDict, Annotated
from pydantic import BaseModel

In [5]:
from dotenv import load_dotenv
load_dotenv()

True

In [6]:
llm = ChatGroq(model="llama-3.3-70b-versatile")

In [7]:
# 1. Define State
class RxGuardState(TypedDict):
    raw_note: str

    patient_profile: dict
    proposed_medication: dict

    retrieved_guidelines: list[dict]

    risk_analysis: dict
    safety_flags: list[str]

    recommendation: dict
    confidence: str


In [8]:
# Context Extraction...part

In [9]:
from pydantic import BaseModel, Field
from typing import List, Optional


# Format for LLM to Extract Patient Profile
class PatientProfile(BaseModel):
    age: Optional[int]
    sex: Optional[str]
    conditions: List[str] = []
    risk_factors: List[str] = []

# Format for LLM to Extract Proposed Medication
class ProposedMedication(BaseModel):
    drug_name: Optional[str]
    dose_mg_per_unit: Optional[int]
    frequency_per_day: Optional[int]
    duration_days: Optional[int]
    total_daily_dose_mg: Optional[int]

# Format for LLM to Extract the final Extraction Result
class ExtractionResult(BaseModel):
    patient_profile: PatientProfile
    proposed_medication: ProposedMedication
    extraction_confidence: float = Field(
        description="0–1 confidence in correctness"
    )


In [10]:
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnableLambda

parser = PydanticOutputParser(pydantic_object=ExtractionResult)

prompt = ChatPromptTemplate.from_messages([
    ("system",
     """You are a clinical information extraction system.

TASK:
Extract structured medical facts from the input note.

RULES:
- Do NOT provide medical advice
- Do NOT infer unstated facts
- If information is missing, use null
- Output must strictly match the JSON schema
- No explanations, no prose

NORMALIZATION RULES:
- Normalize sex to: "male" or "female"
- Normalize CKD stages to: "Chronic Kidney Disease Stage X"
- If CKD is present, include "renal impairment" in risk_factors
- Chronic diseases → conditions
- Pain, discomfort → symptoms

NOTE:
This output will be used by downstream safety systems.
"""),
    ("human",
     "Clinical note:\n{note}\n\n"
     "Return JSON matching this schema:\n{format_instructions}")
])

extractor_chain = prompt | llm | parser


In [11]:
CONFIDENCE_THRESHOLD = 0.75

def run_extraction(note: str) -> ExtractionResult:
    result = extractor_chain.invoke({
        "note": note,
        "format_instructions": parser.get_format_instructions()
    })
    
    if result.extraction_confidence < CONFIDENCE_THRESHOLD:
        raise ValueError(
            f"Extraction confidence too low: {result.extraction_confidence}"
        )

    return result


In [None]:
def clinical_context_extraction_node(state: RxGuardState) -> RxGuardState:
    extraction = run_extraction(state["raw_note"])

    state["patient_profile"] = extraction.patient_profile.model_dump()
    state["proposed_medication"] = extraction.proposed_medication.model_dump()
    state["confidence"] = extraction.extraction_confidence

    return state


In [13]:
test_state = {
    "raw_note": "65M, Stage 3 CKD, severe back pain. Plan: Ibuprofen 800mg TID x5 days.",
    "patient_profile": {},
    "proposed_medication": {},
    "retrieved_guidelines": [],
    "risk_analysis": {},
    "safety_flags": [],
    "recommendation": {},
    "confidence": ""
}

output_state = clinical_context_extraction_node(test_state)

print(output_state)


{'raw_note': '65M, Stage 3 CKD, severe back pain. Plan: Ibuprofen 800mg TID x5 days.', 'patient_profile': {'age': 65, 'sex': 'male', 'conditions': ['Chronic Kidney Disease Stage 3'], 'risk_factors': ['renal impairment']}, 'proposed_medication': {'drug_name': 'Ibuprofen', 'dose_mg_per_unit': 800, 'frequency_per_day': 3, 'duration_days': 5, 'total_daily_dose_mg': 2400}, 'retrieved_guidelines': [], 'risk_analysis': {}, 'safety_flags': [], 'recommendation': {}, 'confidence': 1.0}


In [14]:
from IPython.display import JSON

JSON(output_state)

<IPython.core.display.JSON object>

In [15]:
output_state["patient_profile"]

{'age': 65,
 'sex': 'male',
 'conditions': ['Chronic Kidney Disease Stage 3'],
 'risk_factors': ['renal impairment']}

In [16]:

output_state["proposed_medication"]

{'drug_name': 'Ibuprofen',
 'dose_mg_per_unit': 800,
 'frequency_per_day': 3,
 'duration_days': 5,
 'total_daily_dose_mg': 2400}

In [19]:
output_state["retrieved_guidelines"]

[]

In [17]:
output_state["raw_note"]

'65M, Stage 3 CKD, severe back pain. Plan: Ibuprofen 800mg TID x5 days.'

In [18]:
output_state["confidence"]

1.0