# Benchmark MedGemma 4B on Entity Resolution

Evaluate **MedGemma 4B** (medical-domain Gemma) on the same 338 entity resolution test pairs
used for Gemma 1B benchmarks. Goal: does domain-specific pretraining help without any fine-tuning?

**Setup:**
- Model: `medgemma:1.5-4b-fast` via Ollama (F16, thinking suppressed)
- Ollama's OpenAI-compatible API at `http://localhost:11434/v1/`
- Test set: 338 pairs (169 match + 169 non-match) from HF Hub

**Note:** We use the `-fast` variant which suppresses thinking tokens, so the model
responds directly with True/False. The Gemma 1B baseline used transformers directly,
so this isn't perfectly apples-to-apples — but MedGemma 4B in float16 doesn't work
on MPS (produces only pad tokens), so Ollama is the practical option.

## 1. Setup & Imports

In [1]:
# !pip install openai datasets scikit-learn

import os
import re
import time

import numpy as np
from openai import OpenAI
from datasets import load_dataset
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
)

OLLAMA_MODEL = "medgemma:1.5-4b"
client = OpenAI(base_url="http://localhost:11434/v1/", api_key="ollama")

# Quick smoke test
resp = client.chat.completions.create(
    model=OLLAMA_MODEL,
    messages=[{"role": "user", "content": "Is 2+2=4? Answer True or False only."}],
    max_tokens=1024,
)
print(f"Model: {OLLAMA_MODEL}")
print(f"Smoke test response ({len(resp.choices[0].message.content)} chars):")
print(resp.choices[0].message.content)

Model: medgemma:1.5-4b
Smoke test response (508 chars):
<unused94>thought
1.  **Identify the core question:** The user is asking if the mathematical statement "2 + 2 = 4" is true or false.
2.  **Recall basic arithmetic:** Access knowledge about the rules of addition.
3.  **Evaluate the statement:** 2 plus 2 equals 4. This is a fundamental, universally accepted mathematical fact.
4.  **Determine the truth value:** The statement is correct.
5.  **Format the answer:** The user requested "True or False only".
6.  **Provide the final answer:** True.<unused95>True


## 2. Load Test Set from HF Hub

In [2]:
DATASET_REPO = "abicyclerider/entity-resolution-pairs"

print(f"Loading dataset from {DATASET_REPO}...")
dataset = load_dataset(DATASET_REPO)

# Extract test prompts and labels
test_prompts = []
test_labels = []
for example in dataset["test"]:
    test_prompts.append(example["messages"][0]["content"])
    test_labels.append(example["messages"][1]["content"] == "True")

n_match = sum(test_labels)
n_non = len(test_labels) - n_match
print(f"\nDataset splits:")
print(f"  Train: {len(dataset['train'])}")
print(f"  Eval:  {len(dataset['eval'])}")
print(f"  Test:  {len(dataset['test'])} ({n_match} match + {n_non} non-match)")

Loading dataset from abicyclerider/entity-resolution-pairs...

Dataset splits:
  Train: 1568
  Eval:  336
  Test:  338 (169 match + 169 non-match)


## 3. Helper Functions

In [3]:
from tqdm.notebook import tqdm


def extract_answer(text):
    """Strip MedGemma thinking tokens from response.

    MedGemma outputs `<unused94>thought...` then `<unused95>` before the actual answer.
    If the closing tag is missing (truncated), strip everything from <unused94> onward.
    """
    # First try to strip a complete thinking block
    cleaned = re.sub(r'<unused94>.*?<unused95>', '', text, flags=re.DOTALL)
    # If <unused94> is still present, the block was truncated — strip from there to end
    cleaned = re.sub(r'<unused94>.*', '', cleaned, flags=re.DOTALL)
    return cleaned.strip()


def predict_match(client, model, prompt):
    """Predict whether two medical records match via Ollama API."""
    resp = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=2048,
        temperature=0,
    )
    raw = resp.choices[0].message.content
    response = extract_answer(raw).lower()

    if "true" in response:
        return True, raw
    elif "false" in response:
        return False, raw
    return None, raw


def evaluate_predictions(labels, preds):
    """Compute classification metrics."""
    return {
        'accuracy': accuracy_score(labels, preds),
        'precision': precision_score(labels, preds, zero_division=0),
        'recall': recall_score(labels, preds, zero_division=0),
        'f1': f1_score(labels, preds, zero_division=0),
    }


def run_evaluation(client, model, test_prompts, test_labels, label="Model"):
    """Run model on test prompts and return metrics + predictions."""
    preds, labels, indices, raw_responses = [], [], [], []
    unparseable = 0

    for i, (prompt, true_label) in enumerate(tqdm(
        zip(test_prompts, test_labels), total=len(test_prompts), desc=label
    )):
        pred, raw = predict_match(client, model, prompt)
        raw_responses.append(raw)
        if pred is not None:
            preds.append(pred)
            labels.append(true_label)
            indices.append(i)
        else:
            unparseable += 1

    metrics = evaluate_predictions(labels, preds)

    print(f"\n{label} ({len(preds)} parseable / {len(test_prompts)} total, {unparseable} unparseable):")
    for m, v in metrics.items():
        print(f"  {m:>10s}: {v:.3f}")
    print(f"\nConfusion matrix (rows=actual, cols=predicted):")
    print(confusion_matrix(labels, preds))

    return metrics, preds, labels, indices, raw_responses

## 4. Benchmark

In [4]:
# Full run on all 338 test pairs
N = len(test_prompts)
print(f"Benchmarking {OLLAMA_MODEL} on all {N} test pairs...")
medgemma_metrics, medgemma_preds, medgemma_labels, medgemma_indices, medgemma_raw = run_evaluation(
    client, OLLAMA_MODEL, test_prompts[:N], test_labels[:N], "MedGemma 4B"
)

Benchmarking medgemma:1.5-4b on all 338 test pairs...


MedGemma 4B:   0%|          | 0/338 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [5]:
# Diagnostic: inspect raw responses to understand what's happening
for i, (raw, label) in enumerate(zip(medgemma_raw, test_labels[:N])):
    parsed = extract_answer(raw)
    truncated = "<unused95>" not in raw
    print(f"\n{'='*70}")
    print(f"Pair {i+1} | True label: {label} | Truncated: {truncated} | len={len(raw)}")
    print(f"{'='*70}")
    print(f"Raw (last 200 chars): ...{raw[-200:]}")
    print(f"After extract_answer: [{parsed[:100]}]")

NameError: name 'medgemma_raw' is not defined

## 5. Results Comparison

In [None]:
import matplotlib.pyplot as plt

# Hardcoded Gemma 1B results from fine-tuning notebook
gemma_base = {'accuracy': 0.527, 'precision': 0.523, 'recall': 0.675, 'f1': 0.589}
gemma_ft = {'accuracy': 0.571, 'precision': 0.544, 'recall': 0.917, 'f1': 0.683}

# Comparison table
print(f"{'Model':<25s}  {'Params':>6s}  {'Acc':>6s}  {'Prec':>6s}  {'Rec':>6s}  {'F1':>6s}  {'Notes'}")
print("-" * 85)
print(f"{'Gemma 1B (base)':<25s}  {'1B':>6s}  {gemma_base['accuracy']:>6.3f}  {gemma_base['precision']:>6.3f}  {gemma_base['recall']:>6.3f}  {gemma_base['f1']:>6.3f}  float32, transformers")
print(f"{'Gemma 1B (fine-tuned)':<25s}  {'1B':>6s}  {gemma_ft['accuracy']:>6.3f}  {gemma_ft['precision']:>6.3f}  {gemma_ft['recall']:>6.3f}  {gemma_ft['f1']:>6.3f}  LoRA, 3 epochs")
print(f"{'MedGemma 4B':<25s}  {'4B':>6s}  {medgemma_metrics['accuracy']:>6.3f}  {medgemma_metrics['precision']:>6.3f}  {medgemma_metrics['recall']:>6.3f}  {medgemma_metrics['f1']:>6.3f}  F16, Ollama")

# Bar chart comparing F1 scores
models = ['Gemma 1B\n(base)', 'Gemma 1B\n(fine-tuned)', 'MedGemma 4B']
f1_scores = [gemma_base['f1'], gemma_ft['f1'], medgemma_metrics['f1']]
colors = ['#4285f4', '#34a853', '#ea4335']

fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(models, f1_scores, color=colors, width=0.5)
ax.set_ylabel('F1 Score')
ax.set_title('Entity Resolution F1 Score Comparison')
ax.set_ylim(0, 1)
ax.axhline(y=0.936, color='gray', linestyle='--', alpha=0.5, label='Opus ceiling (0.936)')
ax.legend()

for bar, score in zip(bars, f1_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

## 6. Error Analysis

In [None]:
# Identify false positives and false negatives
false_positives = []
false_negatives = []

for pred, label, idx in zip(medgemma_preds, medgemma_labels, medgemma_indices):
    if pred and not label:
        false_positives.append(idx)
    elif not pred and label:
        false_negatives.append(idx)

print(f"False positives (predicted match, actually non-match): {len(false_positives)}")
print(f"False negatives (predicted non-match, actually match): {len(false_negatives)}")

# Show example false positives with raw model response
print(f"\n{'='*70}")
print("EXAMPLE FALSE POSITIVES (predicted True, actually False)")
print(f"{'='*70}")
for idx in false_positives[:3]:
    print(f"\n--- Test pair #{idx} ---")
    print(f"Raw response: {medgemma_raw[idx][:200]}")

# Show example false negatives
print(f"\n{'='*70}")
print("EXAMPLE FALSE NEGATIVES (predicted False, actually True)")
print(f"{'='*70}")
for idx in false_negatives[:3]:
    print(f"\n--- Test pair #{idx} ---")
    print(f"Raw response: {medgemma_raw[idx][:200]}")

# Prediction distribution
pred_true = sum(medgemma_preds)
pred_false = len(medgemma_preds) - pred_true
print(f"\n{'='*70}")
print("PREDICTION DISTRIBUTION")
print(f"{'='*70}")
print(f"MedGemma 4B: {pred_true} True, {pred_false} False (of {len(medgemma_preds)} parseable)")
print(f"Gemma 1B base: 218 True, 118 False (of 336 parseable)")
print(f"Gemma 1B FT:   285 True, 51 False (of 336 parseable)")
print(f"Actual:        169 True, 169 False")

## 7. MedGemma 4B Classifier (Fine-Tuned)

Evaluate the fine-tuned MedGemma 4B classifier adapter. This model was trained with QLoRA
on RunPod (A4000 16GB) using `Gemma3ForSequenceClassification` — a 2-class classification
head replacing the 262K-vocab LM head. Single forward pass per pair, no text generation needed.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from tqdm.notebook import tqdm

CLS_MODEL_ID = "google/medgemma-4b-it"
CLS_ADAPTER_REPO = "abicyclerider/medgemma-4b-entity-resolution-classifier"

# Load base model as classifier + LoRA adapter
print(f"Loading {CLS_MODEL_ID} as classifier...")
cls_tokenizer = AutoTokenizer.from_pretrained(CLS_MODEL_ID)
if cls_tokenizer.pad_token is None:
    cls_tokenizer.pad_token = cls_tokenizer.eos_token

cls_base = AutoModelForSequenceClassification.from_pretrained(
    CLS_MODEL_ID,
    num_labels=2,
    torch_dtype=torch.float16,
    device_map="mps",
)
cls_base.config.pad_token_id = cls_tokenizer.pad_token_id

print(f"Loading adapter from {CLS_ADAPTER_REPO}...")
cls_model = PeftModel.from_pretrained(cls_base, CLS_ADAPTER_REPO)
cls_model.eval()
print(f"Classifier loaded on {cls_model.device}")

In [None]:
# Run classifier on all 338 test pairs
cls_preds = []
cls_labels_list = []

print(f"Evaluating classifier on {len(test_prompts)} test pairs...")
for prompt, true_label in tqdm(zip(test_prompts, test_labels), total=len(test_prompts), desc="Classifier"):
    inputs = cls_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(cls_model.device) for k, v in inputs.items()}
    with torch.no_grad():
        pred = cls_model(**inputs).logits.argmax(dim=-1).item()
    cls_preds.append(bool(pred))
    cls_labels_list.append(true_label)

cls_metrics = evaluate_predictions(cls_labels_list, cls_preds)
print(f"\nMedGemma 4B Classifier (338/338 pairs, 0 unparseable):")
for m, v in cls_metrics.items():
    print(f"  {m:>10s}: {v:.3f}")
print(f"\nConfusion matrix (rows=actual, cols=predicted):")
print(confusion_matrix(cls_labels_list, cls_preds))

## 8. Updated Results Comparison

In [None]:
import matplotlib.pyplot as plt

# Hardcoded baselines
gemma_base = {'accuracy': 0.527, 'precision': 0.523, 'recall': 0.675, 'f1': 0.589}
gemma_ft = {'accuracy': 0.571, 'precision': 0.544, 'recall': 0.917, 'f1': 0.683}

# Updated comparison table (4 models)
print(f"{'Model':<30s}  {'Params':>6s}  {'Acc':>6s}  {'Prec':>6s}  {'Rec':>6s}  {'F1':>6s}  {'Notes'}")
print("-" * 100)
print(f"{'Gemma 1B (base)':<30s}  {'1B':>6s}  {gemma_base['accuracy']:>6.3f}  {gemma_base['precision']:>6.3f}  {gemma_base['recall']:>6.3f}  {gemma_base['f1']:>6.3f}  float32, transformers")
print(f"{'Gemma 1B (fine-tuned)':<30s}  {'1B':>6s}  {gemma_ft['accuracy']:>6.3f}  {gemma_ft['precision']:>6.3f}  {gemma_ft['recall']:>6.3f}  {gemma_ft['f1']:>6.3f}  LoRA, 3 epochs")
print(f"{'MedGemma 4B (gen)':<30s}  {'4B':>6s}  {medgemma_metrics['accuracy']:>6.3f}  {medgemma_metrics['precision']:>6.3f}  {medgemma_metrics['recall']:>6.3f}  {medgemma_metrics['f1']:>6.3f}  F16, Ollama")
print(f"{'MedGemma 4B (classifier)':<30s}  {'4B':>6s}  {cls_metrics['accuracy']:>6.3f}  {cls_metrics['precision']:>6.3f}  {cls_metrics['recall']:>6.3f}  {cls_metrics['f1']:>6.3f}  QLoRA, SEQ_CLS")

# Bar chart comparing F1 scores
models = ['Gemma 1B\n(base)', 'Gemma 1B\n(fine-tuned)', 'MedGemma 4B\n(gen)', 'MedGemma 4B\n(classifier)']
f1_scores = [gemma_base['f1'], gemma_ft['f1'], medgemma_metrics['f1'], cls_metrics['f1']]
colors = ['#4285f4', '#34a853', '#ea4335', '#fbbc04']

fig, ax = plt.subplots(figsize=(10, 5))
bars = ax.bar(models, f1_scores, color=colors, width=0.5)
ax.set_ylabel('F1 Score')
ax.set_title('Entity Resolution F1 Score Comparison')
ax.set_ylim(0, 1)
ax.axhline(y=0.936, color='gray', linestyle='--', alpha=0.5, label='Opus ceiling (0.936)')
ax.legend()

for bar, score in zip(bars, f1_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Configuration - Point this at different augmentation runs
RUN_ID = "run_20260211_063607"
BASE_DIR = "/Users/alex/repos/Kaggle/SyntheticMass"
RUN_DIR = f"{BASE_DIR}/output/augmented/{RUN_ID}"

print(f"Analyzing run: {RUN_ID}")
print(f"Run directory: {RUN_DIR}")