# üè• MedGemma Function Calling with In-Context Learning

**No fine-tuning!** Just clever prompting with examples.

This tests if base MedGemma can do function calling when given good examples in the prompt.

In [None]:
!pip install -q transformers torch accelerate bitsandbytes huggingface_hub

In [None]:
from huggingface_hub import login
login()

In [None]:
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

try:
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", init_method="file:///tmp/icl_test", rank=0, world_size=1)
except: pass

MODEL_ID = "google/medgemma-4b-it"

print(f"Loading BASE MedGemma (no adapter)...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

print("‚úÖ Base MedGemma loaded (NO adapter)!")

In [None]:
def function_call_icl(user_input):
    """
    Use in-context learning with few-shot examples.
    """
    prompt = f"""<start_of_turn>user
You are a clinical documentation AI. Convert natural language clinical notes into structured function calls.

## Available Functions:
1. record_vitals(systolic, diastolic, heart_rate, temp_c) - Record patient vital signs
2. administer_medication(drug_name, dose, route) - Log medication administration
3. search_nmc_standards(query) - Search NMC nursing guidelines

## Examples:

Input: "BP is 120/80, pulse 72, temp 37.2"
Output: record_vitals(systolic=120, diastolic=80, heart_rate=72, temp_c=37.2)

Input: "Gave Paracetamol 1g orally"
Output: administer_medication(drug_name="Paracetamol", dose="1g", route="PO")

Input: "Blood pressure 145/95, heart rate 88"
Output: record_vitals(systolic=145, diastolic=95, heart_rate=88)

Input: "Administered Morphine 5mg IV"
Output: administer_medication(drug_name="Morphine", dose="5mg", route="IV")

Input: "What does NMC say about confidentiality?"
Output: search_nmc_standards(query="confidentiality")

## Your Task:
Convert the following input into a function call. Output ONLY the function call, nothing else.

Input: "{user_input}"
Output:<end_of_turn>
<start_of_turn>model
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=60,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode only new tokens
    input_length = inputs['input_ids'].shape[1]
    new_tokens = outputs[0][input_length:]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    
    # Clean up - take first line only
    if '\n' in response:
        response = response.split('\n')[0]
    
    return response

In [None]:
# Test cases
test_cases = [
    "BP is 110/70, pulse 68",
    "Patient's BP 130/85, heart rate 78, temperature 36.9",
    "Gave Paracetamol 500mg orally",
    "Administered Morphine 10mg IV",
    "What does NMC say about duty of candour?",
    "Heart rate 92, blood pressure 140 over 90",
    "IV Flucloxacillin 1g given",
    "Find NMC guidance on delegation"
]

print("üß™ Testing BASE MedGemma with In-Context Learning")
print("="*60)

for test in test_cases:
    result = function_call_icl(test)
    print(f"\nInput: {test}")
    print(f"Output: {result}")

## üìä Evaluation

In [None]:
import json
import re

# Load eval dataset if available
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    
    with open('/content/drive/MyDrive/nmc_brain/data/function_eval_dataset.json', 'r') as f:
        eval_data = json.load(f)
    
    print(f"\nüìä Running evaluation on {len(eval_data)} test cases...")
    print("="*60)
    
    correct = 0
    func_name_correct = 0
    has_values = 0
    
    for i, item in enumerate(eval_data[:20]):  # First 20 for speed
        predicted = function_call_icl(item['input'])
        expected = item['expected']
        
        # Check function name
        exp_func = expected.split('(')[0]
        if exp_func in predicted:
            func_name_correct += 1
        
        # Check for values
        if re.search(r'=\d+', predicted) or re.search(r'="[^"]+"', predicted):
            has_values += 1
        
        # Exact match (normalized)
        if re.sub(r'\s+', '', expected.lower()) == re.sub(r'\s+', '', predicted.lower()):
            correct += 1
        
        if i < 5:  # Show first 5
            status = "‚úÖ" if exp_func in predicted else "‚ùå"
            print(f"\n{status} Example {i+1}:")
            print(f"   Input: {item['input'][:50]}...")
            print(f"   Expected: {expected}")
            print(f"   Got: {predicted}")
    
    n = min(20, len(eval_data))
    print(f"\n" + "="*60)
    print(f"üìä RESULTS (first {n} examples):")
    print(f"   Function Name Accuracy: {func_name_correct/n*100:.1f}%")
    print(f"   Value Extraction Rate: {has_values/n*100:.1f}%")
    print(f"   Exact Match Accuracy: {correct/n*100:.1f}%")
    
except Exception as e:
    print(f"Could not run full evaluation: {e}")