## Setup

In [1]:
import sys
sys.path.insert(0, "../src/")

In [2]:
from load_dotenv import load_dotenv

load_dotenv()

True

In [None]:
import os
import json
from pprint import pprint
from IPython.display import Markdown
from loguru import logger

import numpy as np
import pandas as pd

import torch
from datasets import Dataset, load_dataset

from mediqa_oe.data import MedicalOrderDataLoader
from mediqa_oe.lm import OrderExtractionLM
from mediqa_oe.lm.base import BaseAgent

  from .autonotebook import tqdm as notebook_tqdm


## Load Data and LM

In [None]:
@dataclass
class TranscriptTurn:
    turn_id: int
    speaker: str
    transcript: str

@dataclass
class MedicalOrder:
    order_type: str
    description: str
    reason: str
    provenance: List[int]


In [None]:
input_json_path = '<input_json_path for test>'

In [4]:
data_loader = MedicalOrderDataLoader(trs_json_path="../data/orders_data_transcript.json")

ds, ds_val = data_loader.ds, data_loader.ds_val

In [5]:
lm = OrderExtractionLM(
    backend="openai",
    model_name_or_path="",
    api_base=os.getenv("OPENAI_API_BASE"),
    api_key=os.getenv("OPENAI_API_KEY"),
)

## Methods Demo

In [6]:
print([model.id for model in lm.impl.client.models.list().data])

['google/medgemma-27b-text-it']


In [7]:
lm.get_device_info()

'Remote: https://e307wui0v6xrqf-8000.proxy.runpod.net/v1/'

In [8]:
test_msg = [
    {
        "role": "system",
        "content": "You are a medical AI assistant how answers in one sentence.",
    },
    {
        "role": "user",
        "content": "Hi, what kind of assistant are you?",
    },
]

out = lm.infer(
    messages=test_msg
)

Markdown(out)

I am a medical AI assistant designed to provide concise, one-sentence answers to health-related questions.


In [9]:
for chunk in lm.infer_stream(messages=test_msg):
    print(chunk, end="", flush=True)

I am a medical AI assistant designed to provide information and answer questions related to health and medicine in a single sentence.


## Agentic Approach

In [None]:
class Agent1_Extractor(BaseAgent):
    """Agent 1: Extract descriptions and reasons from clinical transcripts"""

    def process(self, transcript: List[TranscriptTurn]) -> Tuple[Dict[int, List[str]], Dict[int, List[str]]]:
        """Extract medical descriptions and reasons from transcript"""

        system_prompt = """
                    You are a clinically trained assistant tasked with extracting only *precise, actionable medical orders* and their justifying reasons from a clinical conversation.

                    Your goal is to extract:

                    1. DESCRIPTIONS: Only *clear, clinically actionable medical orders* related to:
                    - Medications (e.g., "start metformin", "continue omeprazole 20 milligrams daily")
                    - Lab tests (e.g., "order a hemoglobin a1c", "check white blood cell count")
                    - Imaging studies (e.g., "schedule a chest x-ray", "get an MRI of the brain")
                    - Follow-ups or Referrals (e.g., "see endocrinologist", "come back in 2 weeks")

                    2. REASONS: Only *clinically meaningful problems, diagnoses, or symptoms* that clearly explain why the above order is needed (e.g., "for diabetes", "due to shortness of breath")

                    ---

                    🔒 STRICT RULES (DO NOT VIOLATE):
                    - ❌ DO NOT extract vague suggestions ("weigh yourself", "call me", "go to the ER")
                    - ❌ DO NOT extract non-actionable observations or physical findings ("vital signs", "crackles", "edema", "murmur", "exam looks good")
                    - ❌ DO NOT include commands to medical staff ("show me labs", "pull up x-ray")
                    - ✅ ONLY extract direct clinical actions involving patient care
                    - ✅ EXTRACT ONLY 10 DESCRIPTION AND REASON PER TURN

                    ---

                    🧠 SELF-CHECK:
                    - ✅ Is the phrase a *clear medical order* related to drugs, labs, imaging, or follow-up?
                    - ✅ Is the reason a *disease, symptom, or clinical condition* directly tied to the action?
                    - ❌ If unsure or borderline — DO NOT extract it.

                    ---

                    🔍 FORMAT:
                    Return a single *valid JSON object*:
                    {
                    "descriptions": [list of exact phrases, each ≤30 tokens],
                    "reasons": [list of exact phrases, each ≤30 tokens]
                    }

                    ---

                    ✅ GOOD EXAMPLES (KEEP):
                    "order a hemoglobin a1c"
                    "put you on some lasix 40 milligrams a day"
                    "continue you on the omeprazole 20 milligrams a day"
                    "refer you to psychiatry"
                    "follow up with your endocrinologist"
                    "for your type i diabetes"
                    "due to acute heart failure exacerbation"

                    ❌ BAD EXAMPLES (REMOVE):
                    "show me the vital signs"
                    "labs look okay"
                    "oxygenation level"
                    "reviewed the results"
                    "weigh yourself every day"
                    "call 911"
                    "neck exam"
                    "it looks good"
                    "pulse ox"

                    ---

                    📌 Final reminder:
                    Only return *fully valid, concise, medically actionable JSON. If there is **nothing valid*, return:
                    {"descriptions": [], "reasons": []}
             """


        description_dict = {}
        reason_dict = {}

        for turn in transcript:
            # Focus on medically relevant speakers
            if turn.speaker in ["DOCTOR"]:
                user_input = f"""Turn {turn.turn_id}: "{turn.transcript}"

                # Extract medical descriptions and reasons from this turn. Return only JSON:
                # {{"descriptions": ["exact phrase 1", "exact phrase 2"], "reasons": ["reason 1", "reason 2"]}}"""

                messages = self._create_messages(system_prompt, user_input)
                response = self.llm.infer(messages, max_new_tokens=500)

                if response:
                    try:
                        # Clean response to extract JSON
                        json_match = re.search(r'\{.*\}', response, re.DOTALL)
                        if json_match:
                            parsed = json.loads(json_match.group())
                            if parsed.get("descriptions"):
                                description_dict[turn.turn_id] = parsed["descriptions"]
                            if parsed.get("reasons"):
                                reason_dict[turn.turn_id] = parsed["reasons"]
                    except json.JSONDecodeError:
                        # Fallback parsing if JSON fails
                        descriptions = self._extract_with_regex(turn.transcript, "description")
                        reasons = self._extract_with_regex(turn.transcript, "reason")
                        if descriptions:
                            description_dict[turn.turn_id] = descriptions
                        if reasons:
                            reason_dict[turn.turn_id] = reasons

        return description_dict, reason_dict

    def _extract_with_regex(self, text: str, extract_type: str) -> List[str]:
        """Fallback regex extraction if LLM fails"""
        if extract_type == "description":
            patterns = [
                r'(?:order|prescribe|schedule|book|perform|do|give|take)\s+([^.!?]{10,100})',
                r'(?:need|want|should|will)\s+(?:to\s+)?([^.!?]{10,100})',
                r'(?:let\'s|we\'ll)\s+([^.!?]{10,100})'
            ]
        else:  # reason
            patterns = [
                r'(?:for|due to|because of|to treat|regarding)\s+([^.!?]{10,100})',
                r'(?:pain|symptoms?|condition|problem|difficulty|trouble)\s+([^.!?]{10,100})'
            ]

        results = []
        for pattern in patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            results.extend([match.strip() for match in matches if len(match.strip()) > 5])

        return results[:3]  # Limit to 3 per turn


In [None]:
class Agent2_Mapper(BaseAgent):
    """Agent 2: Map descriptions to reasons using transcript context"""

    def process(self, description_dict: Dict[int, List[str]], reason_dict: Dict[int, List[str]],
                transcript: List[TranscriptTurn]) -> List[Dict[str, Any]]:
        """Map descriptions to reasons using clinical context and transcript"""

        system_prompt = """

                You are a clinical expert validating the correctness of mappings between medical orders ("descriptions") and clinical justifications ("reasons") using the actual doctor-patient transcript.

                You are given:
                - A list of mappings: [{"description": ..., "reason": ..., "provenance": [turn_ids]}]
                - A full transcript: [{"turn_id": ..., "speaker": ..., "transcript": ...}, ...]

                Your job is to strictly verify each mapping.

                For each mapping:
                1. ✅ The *description* must be a real, actionable clinical order clearly stated in one of the provenance turns. Examples: prescribing a medication, ordering a test, requesting follow-up, imaging, or referral.
                2. ✅ The *reason* must be a clearly stated clinical problem, symptom, or diagnosis — also found within the provenance turns or in immediately adjacent context.
                3. ✅ The *provenance* must include turn IDs where both description and reason are mentioned or inferred from context.
                4. ✅ The mapping must be *clinically logical* — would a real doctor give that order for that reason?

                Strict Rejection Rules:
                - ❌ Reject mappings if description is vague or general (e.g., "labs look okay", "exam", "oxygenation level")
                - ❌ Reject mappings if the reason is unclear, non-clinical, or just an observation (e.g., "pulse ox", "reviewed x-ray")
                - ❌ Reject if either description or reason does NOT appear in the provenance
                - ❌ Reject hallucinated or inferred phrases that are not clearly spoken in transcript

                Format of Your Output:
                Return a single valid JSON array of corrected and verified mappings only:
                [
                {
                    "description": "verified description from transcript",
                    "reason": "verified clinical reason from transcript or empty string",
                    "provenance": [verified_turn_ids]
                },
                ...
                ]

                If a mapping is not fully valid based on the transcript, DO NOT include it in the output.

                ✅ Examples of accepted output:
                [
                {
                    "description": "put you on some lasix 40 milligrams a day",
                    "reason": "acute heart failure exacerbation",
                    "provenance": [125, 126, 127]
                },
                {
                    "description": "order a hemoglobin a1c",
                    "reason": "type i diabetes",
                    "provenance": [138, 139]
                }
                ]

                Final Reminders:
                - Be strict. Do NOT let vague, non-actionable, or unrelated mappings through.
                - Do NOT invent or hallucinate content.
                - ONLY include verified mappings grounded in the transcript and supported clinically.
                """

        # Build context from transcript
        transcript_context = []
        for turn in transcript:
            transcript_context.append(f"Turn {turn.turn_id} ({turn.speaker}): {turn.transcript}")

        # Prepare structured input
        all_descriptions = []
        all_reasons = []

        for turn_id, descriptions in description_dict.items():
            for desc in descriptions:
                all_descriptions.append({"turn_id": turn_id, "text": desc})

        for turn_id, reasons in reason_dict.items():
            for reason in reasons:
                all_reasons.append({"turn_id": turn_id, "text": reason})

        user_input = f"""TRANSCRIPT CONTEXT:
                    {chr(10).join(transcript_context)}

                    EXTRACTED DESCRIPTIONS:
                    {json.dumps(all_descriptions, indent=2)}

                    EXTRACTED REASONS:
                    {json.dumps(all_reasons, indent=2)}

                    Map each description to the most clinically relevant reason using the full conversation context.
                    Look for patterns where patient mentions symptoms/conditions and doctor responds with orders.
                    Return JSON array only:
                    [{{"description": "exact text", "reason": "exact text or empty", "provenance": [turn_ids]}}]
                """

        messages = self._create_messages(system_prompt, user_input)
        response = self.llm.infer(messages, max_new_tokens=2000)

        if response:
            try:
                # Extract JSON from response
                json_match = re.search(r'\[.*\]', response, re.DOTALL)
                if json_match:
                    return json.loads(json_match.group())
            except json.JSONDecodeError:
                pass

        # Fallback: context-aware mapping
        return self._fallback_mapping(description_dict, reason_dict, transcript)

    def _fallback_mapping(self, description_dict: Dict[int, List[str]],
                         reason_dict: Dict[int, List[str]],
                         transcript: List[TranscriptTurn]) -> List[Dict[str, Any]]:
        """Context-aware fallback mapping using transcript flow"""
        mappings = []

        # Create turn lookup
        turn_lookup = {turn.turn_id: turn for turn in transcript}

        for desc_turn, descriptions in description_dict.items():
            for desc in descriptions:
                best_reason = ""
                best_turn = None
                best_score = 0

                # Look for reasons in nearby turns, prioritizing patient statements
                for reason_turn, reasons in reason_dict.items():
                    for reason in reasons:
                        score = self._calculate_mapping_score(
                            desc_turn, reason_turn, desc, reason, turn_lookup
                        )
                        if score > best_score:
                            best_score = score
                            best_reason = reason
                            best_turn = reason_turn

                provenance = [desc_turn]
                if best_turn and best_turn != desc_turn:
                    provenance.append(best_turn)

                mappings.append({
                    "description": desc,
                    "reason": best_reason,
                    "provenance": sorted(provenance)
                })

        return mappings

    def _calculate_mapping_score(self, desc_turn: int, reason_turn: int,
                               desc: str, reason: str, turn_lookup: Dict[int, TranscriptTurn]) -> float:
        """Calculate mapping score based on context"""
        score = 0

        # Distance penalty (closer turns get higher score)
        distance = abs(desc_turn - reason_turn)
        distance_score = max(0, 10 - distance)
        score += distance_score

        # Patient-to-doctor flow bonus
        if (reason_turn < desc_turn and
            turn_lookup.get(reason_turn, {}).speaker == "PATIENT" and
            turn_lookup.get(desc_turn, {}).speaker in ["DOCTOR", "PHYSICIAN"]):
            score += 5

        # Semantic similarity bonus (basic keyword matching)
        desc_words = set(desc.lower().split())
        reason_words = set(reason.lower().split())
        common_words = desc_words.intersection(reason_words)
        score += len(common_words) * 2

        return score


In [None]:
class Agent3_Classifier(BaseAgent):
    """Agent 3: Classify orders into types (medication, lab, imaging, followup)"""

    def process(self, mapped_pairs: List[Dict[str, Any]]) -> List[MedicalOrder]:
        """Classify each description-reason pair into order types"""

        system_prompt = """You are a medical expert classifying orders into types.
                        Categories:
                        - medication: Drugs, prescriptions, dosages, pharmacy instructions
                        - lab: Blood tests, A1C, screenings, examinations, diagnostic tests
                        - imaging: X-rays, CT, MRI, ultrasound scans, radiological studies
                        - followup: Referrals, appointments, return visits, scheduling

                        Rules:
                        - Use clinical intent, not just keywords
                        - Each order gets exactly one type
                        - Maintain original description and reason text
                        - Return valid JSON array only: [{"order_type": "...", "description": "...", "reason": "...", "provenance": [...]}]"""


        user_input = f"""Classify these medical orders:
                        {json.dumps(mapped_pairs, indent=2)}
                        Return JSON array with order_type added to each item:
                        [{{"order_type": "medication|lab|imaging|followup", "description": "...", "reason": "...", "provenance": [...]}}]"""

        messages = self._create_messages(system_prompt, user_input)
        response = self.llm.infer(messages, max_new_tokens=2000)

        if response:
            try:
                # Extract JSON from response
                json_match = re.search(r'\[.*\]', response, re.DOTALL)
                if json_match:
                    classified = json.loads(json_match.group())
                    return [MedicalOrder(**order) for order in classified]
            except (json.JSONDecodeError, TypeError):
                pass

        # Fallback classification
        return self._fallback_classification(mapped_pairs)

    def _fallback_classification(self, mapped_pairs: List[Dict[str, Any]]) -> List[MedicalOrder]:
        """Rule-based fallback classification"""
        orders = []

        classification_patterns = {
            'medication': [
                'medication', 'prescription', 'dose', 'mg', 'pills', 'tablet', 'capsule',
                'drug', 'pharmacy', 'take', 'prescribe', 'antibiotic', 'pain relief'
            ],
            'lab': [
                'blood', 'test', 'a1c', 'screening', 'examination', 'lab', 'sample',
                'otoscopy', 'urine', 'stool', 'culture', 'panel', 'workup'
            ],
            'imaging': [
                'x-ray', 'ct', 'mri', 'ultrasound', 'scan', 'imaging', 'radiology',
                'mammogram', 'echo', 'nuclear', 'pet'
            ],
            'followup': [
                'follow-up', 'appointment', 'visit', 'come back', 'schedule', 'return',
                'referral', 'see', 'consult', 'specialist'
            ]
        }

        for pair in mapped_pairs:
            text = f"{pair['description']} {pair['reason']}".lower()

            # Score each category
            scores = {}
            for category, patterns in classification_patterns.items():
                scores[category] = sum(1 for pattern in patterns if pattern in text)

            # Choose category with highest score
            order_type = max(scores, key=scores.get) if max(scores.values()) > 0 else 'followup'

            orders.append(MedicalOrder(
                order_type=order_type,
                description=pair['description'],
                reason=pair['reason'],
                provenance=pair['provenance']
            ))

        return orders


In [None]:
class Agent4_Validator(BaseAgent):
    """Agent 4: Validate and format final output"""

    def process(self, classified_orders: List[MedicalOrder]) -> List[Dict[str, Any]]:
        """Validate medical orders and format final output"""

        system_prompt = """You are a clinical validation expert tasked with reviewing a list of extracted medical orders from a doctor-patient conversation.

                    Each order includes:
                    - "order_type": The general clinical category of the order (e.g., medication, lab, imaging, follow-up)
                    - "description": A specific, actionable medical instruction (e.g., "start you on metformin", "order a hemoglobin a1c")
                    - "reason": A clinical justification for the order (e.g., "type i diabetes", "shortness of breath")
                    - "provenance": A list of transcript turn IDs where the description and/or reason were spoken

                    ---

                    Your job is to *strictly validate* each entry based on the following criteria:

                    ### ✅ VALIDATION CRITERIA

                    1. *Description Quality*
                    - Must be a *clear, specific, and actionable medical order*
                    - ❌ Reject vague or generic statements like "labs look good", "check your vitals", "follow up"
                    - ✅ Must reflect real medical decision-making (e.g., drug initiation, lab/imaging request, follow-up instructions)

                    2. *Clinical Logic*
                    - The *order_type, **description, and **reason* must be *clinically consistent and meaningful together*
                    - Example:
                        - ✅ "order_type": "lab", "description": "order a hemoglobin a1c", "reason": "type i diabetes"
                        - ❌ "order_type": "imaging", "description": "get an x-ray", "reason": "depression"

                    3. *No Duplication*
                    - Remove duplicate or near-duplicate orders (e.g., “order hemoglobin a1c” and “check a1c” in same context)
                    - Keep only the most precise or complete phrasing

                    4. *Safety and Appropriateness*
                    - Reject orders that are medically unsafe, contradictory, or do not make sense for the stated reason
                    - ❌ Examples: "give insulin" for "depression", or "refer to psychiatry" for "abdominal pain"

                    ---

                    ### ⚠ STRICT RULES

                    - ❌ Do NOT invent or change any values
                    - ❌ Do NOT create new orders
                    - ✅ ONLY remove invalid or weak entries
                    - ✅ Preserve the exact wording of description, reason, and order_type for valid items

                    ---

                    ### ✅ OUTPUT FORMAT

                    Return a *valid JSON array* of cleaned, verified orders:

                    ```json
                    [
                    {
                        "order_type": "medication",
                        "description": "put you on some lasix 40 milligrams a day",
                        "reason": "acute heart failure exacerbation",
                        "provenance": [125, 126, 127]
                    },
                    {
                        "order_type": "lab",
                        "description": "order a hemoglobin a1c",
                        "reason": "type i diabetes",
                        "provenance": [138, 139]
                    }
                    ]
"""

        orders_json = [
            {
                "order_type": order.order_type,
                "description": order.description,
                "reason": order.reason,
                "provenance": order.provenance
            }
            for order in classified_orders
        ]

        user_input = f"""Validate these medical orders:
                {json.dumps(orders_json, indent=2)}
                Remove any invalid orders and return the cleaned JSON array:
                [{{"order_type": "...", "description": "...", "reason": "...", "provenance": [...]}}]"""

        messages = self._create_messages(system_prompt, user_input)
        response = self.llm.infer(messages, max_new_tokens=2000)

        if response:
            try:
                # Extract JSON from response
                json_match = re.search(r'\[.*\]', response, re.DOTALL)
                if json_match:
                    return json.loads(json_match.group())
            except json.JSONDecodeError:
                pass

        # Fallback validation
        return self._fallback_validation(orders_json)

    def _fallback_validation(self, orders: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Basic validation rules"""
        valid_orders = []
        seen_orders = set()

        for order in orders:
            # Skip vague or too short descriptions
            if len(order['description'].strip()) < 10:
                continue

            # Skip if no meaningful provenance
            if not order['provenance']:
                continue

            # Remove duplicates
            order_key = (order['description'].lower(), order['reason'].lower())
            if order_key in seen_orders:
                continue
            seen_orders.add(order_key)

            # Basic medical logic checks
            if self._is_medically_logical(order):
                valid_orders.append(order)

        return valid_orders

    def _is_medically_logical(self, order: Dict[str, Any]) -> bool:
        """Basic medical logic validation"""
        desc = order['description'].lower()
        reason = order['reason'].lower()

        # Basic illogical combinations
        illogical_pairs = [
            ('x-ray', 'headache'),  # unless trauma mentioned
            ('blood test', 'broken bone'),
            ('antibiotic', 'viral infection')
        ]

        for desc_pattern, reason_pattern in illogical_pairs:
            if desc_pattern in desc and reason_pattern in reason:
                if desc_pattern == 'x-ray' and 'trauma' in reason:
                    continue  # This is actually logical
                return False

        return True


In [None]:
class MedicalOrderExtractor:
    """Main pipeline orchestrator"""

    def __init__(self, llm: OrderExtractionLM):
        self.llm = llm
        self.agent1 = Agent1_Extractor(llm)
        self.agent2 = Agent2_Mapper(llm)
        self.agent3 = Agent3_Classifier(llm)
        self.agent4 = Agent4_Validator(llm)

    def extract_orders(self, transcript_data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
        """Main extraction pipeline"""

        # Parse input
        conversation_id = transcript_data["id"]
        transcript = [TranscriptTurn(**turn) for turn in transcript_data["transcript"]]

        logger.info(f"Processing conversation: {conversation_id}")

        # Agent 1: Extract descriptions and reasons
        logger.info("Agent 1: Extracting descriptions and reasons...")
        description_dict, reason_dict = self.agent1.process(transcript)
        logger.info(f"Found {len(description_dict)} description turns, {len(reason_dict)} reason turns")
        logger.info(f"description_dict-->{description_dict}")
        logger.info(f"reason_dict-->{reason_dict}")


        # Agent 2: Map descriptions to reasons using transcript context
        logger.info("Agent 2: Mapping descriptions to reasons with transcript context...")
        mapped_pairs = self.agent2.process(description_dict, reason_dict, transcript)
        logger.info(f"Created {len(mapped_pairs)} description-reason pairs")
        logger.info(f"mapped_pairs-->{mapped_pairs}")

        # Agent 3: Classify order types
        logger.info("Agent 3: Classifying order types...")
        classified_orders = self.agent3.process(mapped_pairs)
        logger.info(f"Classified {len(classified_orders)} orders")
        logger.info(f"classified_orders-->{classified_orders}")


        # Agent 4: Validate and format
        logger.info("Agent 4: Validating and formatting...")
        final_orders = self.agent4.process(classified_orders)
        logger.info(f"Final output: {len(final_orders)} valid orders")

        return {conversation_id: final_orders}


In [None]:
# Example usage
if __name__ == "__main__":
    # Initialize data loader
    data_loader = MedicalOrderDataLoader(input_json_path)
    ds, ds_val = data_loader.ds, data_loader.ds_val

    dic={}
    for i in ds_val:
        if i['id']=="primock57_3_3":
            dic["id"] = i['id']
            dic['expected_orders'] = i['expected_orders']
            dic['transcript'] = i['transcript']

    
    # Initialize extractor
    extractor = MedicalOrderExtractor(lm)



    final_results = []
    for sample in ds_val:
        results = extractor.extract_orders(sample)
        final_results.append(results)


    print("\nFinal Results:")
    print(json.dumps(final_results, indent=2))
