## Setup

In [1]:
import sys
sys.path.insert(0, "../src/")

In [2]:
from load_dotenv import load_dotenv

load_dotenv()

True

In [3]:
import os
import json
from pprint import pprint
from IPython.display import Markdown

import numpy as np
import pandas as pd

import torch
from datasets import Dataset, load_dataset

from mediqa_oe.data import MedicalOrderDataLoader
from mediqa_oe.lm import OrderExtractionLM

  from .autonotebook import tqdm as notebook_tqdm


## Load Data and LM

In [4]:
data_loader = MedicalOrderDataLoader(trs_json_path="../data/orders_data_transcript.json")

ds, ds_val = data_loader.ds, data_loader.ds_val

In [5]:
lm = OrderExtractionLM(
    backend="openai",
    model_name_or_path="",
    api_base=os.getenv("OPENAI_API_BASE"),
    api_key=os.getenv("OPENAI_API_KEY"),
)

## Methods Demo

In [25]:
print([model.id for model in lm.impl.client.models.list().data])

['google/medgemma-27b-text-it']


In [7]:
lm.get_device_info()

'Remote: https://e307wui0v6xrqf-8000.proxy.runpod.net/v1/'

In [8]:
test_msg = [
    {
        "role": "system",
        "content": "You are a medical AI assistant how answers in one sentence.",
    },
    {
        "role": "user",
        "content": "Hi, what kind of assistant are you?",
    },
]

out = lm.infer(
    messages=test_msg
)

Markdown(out)

I am a medical AI assistant designed to provide information and answer questions related to health and medicine in a single sentence.


In [None]:
for chunk in lm.infer_stream(messages=test_msg):
    print(chunk, end="", flush=True)

I am a medical AI assistant designed to provide information and answer questions related to health and medicine in a single sentence.


## Prompts

In [10]:
SYSTEM_PROMPT = """You are a medical AI assistant specialized in extracting medical orders from doctor-patient conversations.

Your task is to identify and extract all medical orders mentioned by the doctor, including:
1. Medications (prescriptions, dosage changes)
2. Laboratory tests 
3. Imaging studies
4. Follow-up appointments
5. Referrals

For each order, extract:
- order_type: "medication", "lab", "imaging", "followup", or "referral"
- description: Clear description of what is being ordered
- reason: Medical condition or symptom being addressed
- provenance: Turn numbers where this order is mentioned

Return the results as a JSON list of objects."""


INSTRUCTION_TEMPLATE = """Please extract all medical orders from the following doctor-patient conversation:

CONVERSATION:
{conversation}

Extract all medical orders and return them as a JSON list with the following format:
[
  {{
    "order_type": "medication|lab|imaging|followup|referral",
    "description": "specific description of the order",
    "reason": "medical condition or reason for the order", 
    "provenance": [list of turn numbers where this order appears]
  }}
]

Focus on explicit orders given by the doctor. Be precise with medical terminology."""

In [11]:
def _format_conv(turns, max_turns=-1, only_last_n=False):
    formatted = []

    if max_turns > 0:
        turns = turns[-max_turns:] if only_last_n else turns[:max_turns]

    for turn in turns:
        speaker = turn['speaker']
        text = turn['transcript']
        turn_id = turn['turn_id']
        formatted.append(f"Turn {turn_id} - {speaker}: {text}")
    
    return "\n".join(formatted)


def format_messages(conv):
    instruction = INSTRUCTION_TEMPLATE.format(
        conversation=conv,
    )
    instruction = f"""EXAMPLE CONVERSATION:
Turn 126 - DOCTOR: so, for your first problem of your shortness of breath i think that you are in an acute heart failure exacerbation.
Turn 127 - DOCTOR: i want to go ahead and, uh, put you on some lasix, 40 milligrams a day.
Turn 138 - DOCTOR: for your second problem of your type i diabetes, um, let's go ahead... i wanna order a hemoglobin a1c for, um, uh, just in a, like a month or so.

EXPECTED OUTPUT:
[
  {{
    "order_type": "medication",
    "description": "lasix 40 milligrams a day",
    "reason": "shortness of breath acute heart failure exacerbation",
    "provenance": [126, 127]
  }},
  {{
    "order_type": "lab", 
    "description": "hemoglobin a1c",
    "reason": "type i diabetes",
    "provenance": [138]
  }}
]

NOW EXTRACT FROM THIS CONVERSATION:

---

{instruction}
"""

    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT,
        },
        {
            "role": "user",
            "content": instruction,
        }
    ]

    return messages


## Test with Sample Data

In [21]:
sample_data = ds[1]

sample_conv = _format_conv(sample_data["transcript"])
prompt = format_messages(conv=sample_conv)

In [22]:
lm.token_count(prompt[-1]['content'])

2081

In [23]:
response = ""

for chunk in lm.infer_stream(prompt):
    response += chunk
    print(chunk, end="", flush=True)

```json
[
  {
    "order_type": "lab",
    "description": "pulmonary function test (pft)",
    "reason": "check and baseline for lung function",
    "provenance": [27]
  },
  {
    "order_type": "imaging",
    "description": "pet ct",
    "reason": "determine if the lung nodule is metabolically active",
    "provenance": [27]
  },
  {
    "order_type": "followup",
    "description": "continue to follow up with your rheumatologist",
    "reason": "rheumatoid arthritis",
    "provenance": [27]
  },
  {
    "order_type": "medication",
    "description": "continue your medication therapy",
    "reason": "rheumatoid arthritis",
    "provenance": [27]
  }
]
```

## Run Inference

In [33]:
def infer_sample(sample, max_seqlen=8192):
    sample["pred"] = None
    if not sample["transcript"]:
        print(f"Transcript is None, skipping...")
        return sample
    
    sample_conv = _format_conv(sample["transcript"])
    prompt = format_messages(conv=sample_conv)

    token_count = lm.token_count(prompt[-1]['content'])
    if token_count > 0.9 * max_seqlen:
        print(f"Token length {token_count} exceeded max_seqlen {max_seqlen}, skipping...")
        return sample
    
    try:
        out = lm.infer(messages=prompt, max_new_tokens=4096)
        sample["pred"] = out
    except Exception as e:
        print(f"Error in LLM call -> {e}")
        return sample

    return sample

In [34]:
ds = ds.map(infer_sample, num_proc=4)

Map (num_proc=4):   3%|▎         | 2/63 [00:17<08:23,  8.25s/ examples]

Transcript is None, skipping...
Transcript is None, skipping...


Map (num_proc=4):  14%|█▍        | 9/63 [00:32<02:09,  2.41s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  22%|██▏       | 14/63 [01:01<03:56,  4.83s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  25%|██▌       | 16/63 [01:13<04:09,  5.30s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  33%|███▎      | 21/63 [01:36<03:46,  5.39s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  40%|███▉      | 25/63 [01:40<01:53,  2.99s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  43%|████▎     | 27/63 [01:56<02:53,  4.82s/ examples]

Error in LLM call -> Error code: 400 - {'object': 'error', 'message': "This model's maximum context length is 8192 tokens. However, you requested 9341 tokens (5245 in the messages, 4096 in the completion). Please reduce the length of the messages or completion. None", 'type': 'BadRequestError', 'param': None, 'code': 400}


Map (num_proc=4):  46%|████▌     | 29/63 [02:03<02:25,  4.29s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  54%|█████▍    | 34/63 [02:21<01:51,  3.86s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  73%|███████▎  | 46/63 [03:15<01:16,  4.49s/ examples]

Transcript is None, skipping...
Error in LLM call -> Error code: 400 - {'object': 'error', 'message': "This model's maximum context length is 8192 tokens. However, you requested 8434 tokens (4338 in the messages, 4096 in the completion). Please reduce the length of the messages or completion. None", 'type': 'BadRequestError', 'param': None, 'code': 400}


Map (num_proc=4):  75%|███████▍  | 47/63 [03:16<00:55,  3.47s/ examples]

Transcript is None, skipping...


Map (num_proc=4):  78%|███████▊  | 49/63 [03:31<01:08,  4.92s/ examples]

Transcript is None, skipping...
Transcript is None, skipping...


Map (num_proc=4):  89%|████████▉ | 56/63 [03:53<00:27,  3.88s/ examples]

Transcript is None, skipping...


Map (num_proc=4): 100%|██████████| 63/63 [05:36<00:00,  5.34s/ examples]


In [43]:
ds.filter(lambda x: x["transcript"] is None).num_rows

14

In [44]:
ds.filter(lambda x: x["pred"] is None).num_rows

16

In [None]:
ds.to_json("../data/outputs/starter_outputs.jsonl", ind)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 86.80ba/s]


545220