## Setup

In [1]:
import sys
sys.path.insert(0, "../src/")

In [2]:
from load_dotenv import load_dotenv

load_dotenv()

True

In [None]:
import os
import json
from pprint import pprint
from IPython.display import Markdown

import numpy as np
import pandas as pd

import torch
from datasets import Dataset, load_dataset

from mediqa_oe.data import MedicalOrderDataLoader
from mediqa_oe.lm import OrderExtractionLM

  from .autonotebook import tqdm as notebook_tqdm


## Load Data and LM

In [None]:
input_json_path = '<input_json_path for test>'

In [None]:
data_loader = MedicalOrderDataLoader(trs_json_path=input_json_path)

ds, ds_val = data_loader.ds, data_loader.ds_val

In [5]:
lm = OrderExtractionLM(
    backend="openai",
    model_name_or_path="",
    api_base=os.getenv("OPENAI_API_BASE"),
    api_key=os.getenv("OPENAI_API_KEY"),
)

## Methods Demo

In [6]:
print([model.id for model in lm.impl.client.models.list().data])

['google/medgemma-27b-text-it']


In [None]:
lm.get_device_info()

In [8]:
test_msg = [
    {
        "role": "system",
        "content": "You are a medical AI assistant how answers in one sentence.",
    },
    {
        "role": "user",
        "content": "Hi, what kind of assistant are you?",
    },
]

out = lm.infer(
    messages=test_msg
)

Markdown(out)

I am a medical AI assistant designed to provide concise, one-sentence answers to health-related questions.


In [9]:
for chunk in lm.infer_stream(messages=test_msg):
    print(chunk, end="", flush=True)

I am a medical AI assistant designed to provide information and answer questions related to health and medicine in a single sentence.


## Prompts

In [None]:
SYSTEM_PROMPT = """
You are a medical AI assistant specialized in extracting EXPLICIT medical orders from doctor-patient conversations.

CRITICAL RULES:
1. Extract ONLY orders explicitly stated by the doctor
2. Do NOT infer or assume orders that aren't clearly mentioned
3. Provenance must be EXACT turn numbers where orders appear
4. Be balanced - i.e precision and recall on level terms
5. If the doctor orders multiple DISTINCT items (e.g., 'get a covid test and blood test'), create separate order objects for each item - never merge them into one combined description.

Order Types:
- medication: Prescriptions, dosage instructions, medication changes
- lab: Blood tests, urine tests, specific diagnostic tests
- imaging: X-rays, MRI, CT scans, ultrasounds
- followup: Scheduled return visits, check-ups (these must be explicitly stated by the doctor)

For each order extract:
- order_type: One of the 4 types above
- description: EXACT medical terminology used by doctor
- reason: Specific condition/symptom mentioned by doctor
- provenance: ONLY turn numbers where this exact order is mentioned"""


INSTRUCTION_TEMPLATE = """Please extract all medical orders from the following doctor-patient conversation:

CONVERSATION:
{conversation}

Extract all medical orders and return them as a JSON list with the following format:
[
  {{
    "order_type": "medication|lab|imaging|followup|referral",
    "description": "specific description of the order",
    "reason": "medical condition or reason for the order", 
    "provenance": [list of turn numbers where this order appears]
  }}
]

Focus on explicit orders given by the doctor. Be precise with medical terminology."""

In [11]:
def _format_conv(turns, max_turns=-1, only_last_n=False):
    formatted = []

    if max_turns > 0:
        turns = turns[-max_turns:] if only_last_n else turns[:max_turns]

    for turn in turns:
        speaker = turn['speaker']
        text = turn['transcript']
        turn_id = turn['turn_id']
        formatted.append(f"Turn {turn_id} - {speaker}: {text}")
    
    return "\n".join(formatted)


def format_messages(conv):
    instruction = INSTRUCTION_TEMPLATE.format(
        conversation=conv,
    )
    instruction = f"""EXAMPLE CONVERSATION:
Turn 126 - DOCTOR: so, for your first problem of your shortness of breath i think that you are in an acute heart failure exacerbation.
Turn 127 - DOCTOR: i want to go ahead and, uh, put you on some lasix, 40 milligrams a day.
Turn 138 - DOCTOR: for your second problem of your type i diabetes, um, let's go ahead... i wanna order a hemoglobin a1c for, um, uh, just in a, like a month or so.

EXPECTED OUTPUT:
[
  {{
    "order_type": "medication",
    "description": "lasix 40 milligrams a day",
    "reason": "shortness of breath acute heart failure exacerbation",
    "provenance": [126, 127]
  }},
  {{
    "order_type": "lab", 
    "description": "hemoglobin a1c",
    "reason": "type i diabetes",
    "provenance": [138]
  }}
]

NOW EXTRACT FROM THIS CONVERSATION:

---

{instruction}
"""

    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT,
        },
        {
            "role": "user",
            "content": instruction,
        }
    ]

    return messages


## Test with Sample Data

In [12]:
sample_data = ds[1]

sample_conv = _format_conv(sample_data["transcript"])
prompt = format_messages(conv=sample_conv)

In [13]:
lm.token_count(prompt[-1]['content'])

2081

In [14]:
response = ""

for chunk in lm.infer_stream(prompt):
    response += chunk
    print(chunk, end="", flush=True)

```json
[
  {
    "order_type": "lab",
    "description": "pulmonary function test (pft)",
    "reason": "check and baseline for lung function",
    "provenance": [27]
  },
  {
    "order_type": "imaging",
    "description": "pet ct",
    "reason": "determine if the lung nodule is metabolically active",
    "provenance": [27]
  },
  {
    "order_type": "followup",
    "description": "continue to follow up with your rheumatologist",
    "reason": "rheumatoid arthritis",
    "provenance": [27]
  },
  {
    "order_type": "medication",
    "description": "continue your medication therapy",
    "reason": "rheumatoid arthritis",
    "provenance": [27]
  }
]
```

## Run Inference

In [21]:
def infer_sample(sample, max_seqlen=8192):
    sample["pred"] = None
    if not sample["transcript"]:
        print(f"Transcript is None, skipping...")
        return sample
    
    sample_conv = _format_conv(sample["transcript"])
    prompt = format_messages(conv=sample_conv)

    token_count = lm.token_count(prompt[-1]['content'])
    if token_count > 0.9 * max_seqlen:
        print(f"Token length {token_count} exceeded max_seqlen {max_seqlen}, skipping...")
        return sample
    
    try:
        out = lm.infer(messages=prompt, max_new_tokens=2048)
        sample["pred"] = out
    except Exception as e:
        print(f"Error in LLM call -> {e}")
        return sample

    return sample

In [None]:
ds = ds.map(infer_sample, num_proc=4)

In [None]:
ds.filter(lambda x: x["transcript"] is None).num_rows

In [None]:
ds.filter(lambda x: x["pred"] is None).num_rows

In [None]:
ds.to_json("/medgemma_llm_out_train__init.jsonl")

In [22]:
ds_val = ds_val.map(infer_sample, num_proc=4)

Map (num_proc=4):   3%|▎         | 3/100 [00:20<07:56,  4.91s/ examples]

Transcript is None, skipping...


Map (num_proc=4):   7%|▋         | 7/100 [00:35<05:19,  3.44s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  13%|█▎        | 13/100 [01:18<09:05,  6.27s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  15%|█▌        | 15/100 [01:23<06:03,  4.28s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  22%|██▏       | 22/100 [01:46<04:37,  3.55s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  23%|██▎       | 23/100 [01:48<03:53,  3.03s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  32%|███▏      | 32/100 [02:37<04:32,  4.01s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  36%|███▌      | 36/100 [03:14<08:41,  8.15s/ examples]

Transcript is None, skipping...
Transcript is None, skipping...


Map (num_proc=4):  39%|███▉      | 39/100 [03:23<05:33,  5.46s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  43%|████▎     | 43/100 [03:43<04:55,  5.19s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  48%|████▊     | 48/100 [03:48<02:10,  2.50s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  53%|█████▎    | 53/100 [04:20<03:33,  4.55s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  54%|█████▍    | 54/100 [04:22<03:10,  4.13s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  56%|█████▌    | 56/100 [04:28<02:39,  3.63s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  61%|██████    | 61/100 [04:57<03:27,  5.31s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  68%|██████▊   | 68/100 [05:24<02:20,  4.39s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  75%|███████▌  | 75/100 [06:02<01:56,  4.68s/ examples]

Transcript is None, skipping...
Transcript is None, skipping...


Map (num_proc=4):  77%|███████▋  | 77/100 [06:11<01:41,  4.43s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  86%|████████▌ | 86/100 [06:44<00:37,  2.67s/ examples]

Transcript is None, skipping...
Transcript is None, skipping...


Map (num_proc=4):  98%|█████████▊| 98/100 [08:31<00:33, 16.97s/ examples]

Transcript is None, skipping...


Map (num_proc=4): 100%|██████████| 100/100 [08:52<00:00,  5.33s/ examples]


In [None]:
ds_val.filter(lambda x: x["transcript"] is None).num_rows, ds_val.filter(lambda x: x["pred"] is None).num_rows

In [None]:
ds_val.to_json("/outputs/medgemma_llm_out_val__init.jsonl")

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 41.97ba/s]


944276