In [None]:
!pip install --quiet transformers accelerate torch

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_ID = "Nadhari/Sara-1.5-4B-it"

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"Model loaded: {MODEL_ID}")
print(f"Device: {model.device}")

In [None]:
SYSTEM_PROMPT = """You are an expert in using FHIR functions to assist medical professionals. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.

1. If you decide to invoke a GET function, you MUST put it in the format of
GET url?param_name1=param_value1&param_name2=param_value2...

2. If you decide to invoke a POST function, you MUST put it in the format of
POST url
[your payload data in JSON format]

3. If you have got answers for all the questions and finished all the requested tasks, you MUST call to finish the conversation in the format of
FINISH([answer1, answer2, ...])

Your response must be in the format of one of the three cases, and you can call only one function each time. You SHOULD NOT include any other text in the response.

Available FHIR endpoints (use http://localhost:8080/fhir/ as api_base):
- GET {api_base}/Patient?given=&family=&birthdate= (Search patients)
- GET {api_base}/Observation?code=&patient=&date= (Query labs/vitals)
- POST {api_base}/Observation (Record vitals)
- POST {api_base}/MedicationRequest (Order medications)
- POST {api_base}/ServiceRequest (Order referrals/labs)
"""

In [None]:
def generate_response(messages, max_new_tokens=512):
    """Generate a single response from the model."""
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
    return response.strip()

In [None]:
# Task: Find patient MRN by name and DOB
question = "What's the MRN of the patient with name Maria Garcia and DOB of 1990-07-22?"

messages = [{"role": "user", "content": SYSTEM_PROMPT + "\n\nQuestion: " + question}]
response = generate_response(messages)

print("Task: Patient MRN Lookup")
print("-" * 50)
print(f"Question: {question}")
print(f"Agent Response:\n{response}")

In [None]:
# Task: Record blood pressure vital signs
question = """Record the blood pressure of patient S1234567.
The systolic is 120 and the diastolic is 80.
Use the current datetime 2024-01-15T10:30:00+00:00."""

messages = [{"role": "user", "content": SYSTEM_PROMPT + "\n\nQuestion: " + question}]
response = generate_response(messages)

print("Task: Record Blood Pressure")
print("-" * 50)
print(f"Question: {question}")
print(f"Agent Response:\n{response}")

In [None]:
# Simulate a full multi-turn agent workflow
print("Task: Full Agent Workflow - Patient Lookup")
print("=" * 60)

# Turn 1: Initial question
question = "What's the MRN of the patient with name John Smith and DOB of 1985-03-15?"
conversation = [{"role": "user", "content": SYSTEM_PROMPT + "\n\nQuestion: " + question}]

print(f"\n[Turn 1] User Question:\n{question}")
agent_response = generate_response(conversation)
print(f"\n[Turn 1] Agent Response:\n{agent_response}")

# Simulate FHIR server response
conversation.append({"role": "model", "content": agent_response})
fhir_response = """Here is the response from the GET request:
{
  "resourceType": "Bundle",
  "type": "searchset",
  "total": 1,
  "entry": [{
    "resource": {
      "resourceType": "Patient",
      "id": "S6534835",
      "identifier": [{"type": {"coding": [{"code": "MR"}]}, "value": "S6534835"}],
      "name": [{"family": "Smith", "given": ["John"]}],
      "birthDate": "1985-03-15"
    }
  }]
}. Please call FINISH if you have got answers for all the questions."""

conversation.append({"role": "user", "content": fhir_response})
print(f"\n[Turn 2] FHIR Response: (Bundle with 1 patient)")

# Turn 2: Agent extracts answer
agent_response = generate_response(conversation)
print(f"\n[Turn 2] Agent Response:\n{agent_response}")