# Gemma 1B Base Model Benchmark with Strategy D Summaries

Benchmark the **unfine-tuned Gemma 3 1B** on entity resolution using Strategy D (Structured Diff-Friendly) summaries.

**Why?** The summarizer optimization notebook showed Strategy D achieves F1=0.936 with Opus as the judge at ~760 tok/summary. Before fine-tuning Gemma 1B on this task, we need a baseline: how well does the base model perform?

This establishes the "before" number that fine-tuning will improve upon. The existing `entity_resolution_fine_tuning.ipynb` only benchmarked base Gemma on condensed summaries (~220 tok), which scored poorly (Opus got F1=0.558 on those). Strategy D summaries are much richer and should give a stronger baseline.

In [1]:
import os
import sys
import torch

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

# Add project root so shared/ imports work
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Add llm-entity-resolution for summarizer imports
_llm_er_root = os.path.join(PROJECT_ROOT, "llm-entity-resolution")
if _llm_er_root not in sys.path:
    sys.path.insert(0, _llm_er_root)

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Project root: {PROJECT_ROOT}")

PyTorch version: 2.10.0
MPS available: True
Using device: mps
Project root: /Users/alex/repos/Kaggle/SyntheticMass


In [2]:
from dotenv import load_dotenv
from huggingface_hub import login, whoami

load_dotenv()

token = os.environ.get("HF_TOKEN")
if not token:
    raise ValueError("HF_TOKEN not found. Create a .env file with: HF_TOKEN=hf_your_token_here")

login(token=token)
print(f"Logged in as: {whoami()['name']}")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in as: abicyclerider


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "google/gemma-3-1b-it"

print(f"Loading tokenizer and model from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float32,
    device_map="mps",
)

print(f"Model loaded on: {model.device}")
print(f"Parameters: {model.num_parameters():,}")

Loading tokenizer and model from google/gemma-3-1b-it...


Loading weights:   0%|          | 0/340 [00:00<?, ?it/s]

Model loaded on: mps:0
Parameters: 999,885,952


## Data Loading & Benchmark Pair Selection

Reproduce the exact same 50 benchmark pairs (25 match + 25 non-match) from `summarizer_optimization.ipynb` using the same random seed and pair generation logic. This ensures results are directly comparable.

In [4]:
import random
import numpy as np
import pandas as pd
from shared.data_loader import load_facility_patients
from shared.ground_truth import (
    load_ground_truth,
    add_record_ids_to_ground_truth,
    generate_true_pairs_from_ground_truth,
)
from shared.medical_records import load_medical_records, get_patient_records

RUN_DIR = os.path.join(PROJECT_ROOT, "output", "augmented", "run_20260203_071928")

# Load patients and create record_ids
patients_df = load_facility_patients(RUN_DIR)
patients_df['record_id'] = patients_df['facility_id'] + '_' + patients_df['id'].astype(str)

# Load ground truth and add record_ids
ground_truth_df = load_ground_truth(RUN_DIR)
ground_truth_df = add_record_ids_to_ground_truth(ground_truth_df, patients_df)

# Generate true match pairs
true_pairs = generate_true_pairs_from_ground_truth(ground_truth_df)

# Build record_id -> (patient_uuid, facility_id) mapping
record_map = {}
for _, row in patients_df.iterrows():
    record_map[row['record_id']] = (row['id'], row['facility_id'])

# Load medical records
medical_records = load_medical_records(RUN_DIR)

# Generate non-match pairs (same code + seed as fine-tuning notebook)
rid_to_true_id = (ground_truth_df.dropna(subset=['record_id'])
                  .set_index('record_id')['true_patient_id'].to_dict())
all_record_ids = list(rid_to_true_id.keys())

random.seed(42)
non_match_pairs = set()
target = len(true_pairs)
attempts = 0
while len(non_match_pairs) < target and attempts < target * 20:
    r1, r2 = random.sample(all_record_ids, 2)
    if rid_to_true_id.get(r1) != rid_to_true_id.get(r2):
        non_match_pairs.add(tuple(sorted([r1, r2])))
    attempts += 1

# Reproduce train/eval/test split (same as fine-tuning notebook)
all_pairs = ([(r1, r2, True) for r1, r2 in true_pairs] +
             [(r1, r2, False) for r1, r2 in non_match_pairs])
random.shuffle(all_pairs)

# Filter to pairs with records
all_pairs = [(r1, r2, l) for r1, r2, l in all_pairs
             if r1 in record_map and r2 in record_map]

matches = [p for p in all_pairs if p[2]]
non_matches = [p for p in all_pairs if not p[2]]

n_total = min(len(matches), len(non_matches))
n_train = min(500, int(n_total * 0.70))
n_eval = min(100, int(n_total * 0.15))
n_test = min(50, n_total - n_train - n_eval)

test_pairs = (matches[n_train+n_eval:n_train+n_eval+n_test] +
              non_matches[n_train+n_eval:n_train+n_eval+n_test])
random.shuffle(test_pairs)

# Same 50-pair selection as summarizer_optimization notebook
benchmark_matches = [p for p in test_pairs if p[2]][:25]
benchmark_non_matches = [p for p in test_pairs if not p[2]][:25]
benchmark_pairs = benchmark_matches + benchmark_non_matches
random.shuffle(benchmark_pairs)

print(f"True pairs: {len(true_pairs)}, Non-match pairs: {len(non_match_pairs)}")
print(f"Test pairs: {len(test_pairs)} ({sum(1 for _,_,l in test_pairs if l)} match + "
      f"{sum(1 for _,_,l in test_pairs if not l)} non-match)")
print(f"Benchmark pairs: {len(benchmark_pairs)} "
      f"({len(benchmark_matches)} match + {len(benchmark_non_matches)} non-match)")

True pairs: 1121, Non-match pairs: 1121
Test pairs: 100 (50 match + 50 non-match)
Benchmark pairs: 50 (25 match + 25 non-match)


## Strategy D: Structured Diff-Friendly Summarizer

From `summarizer_optimization.ipynb` — organizes clinical data by category with year grouping:
- **Conditions** grouped by onset year with ongoing status
- **Medications** as drug (start–end or ongoing)
- **Allergies** as flat list
- **Key observations**: latest 2 values per metric
- **Procedures** with years, sorted chronologically

Achieves F1=0.936 with Opus at ~760 tok/summary (~1508 tok/pair).

In [5]:
def summarize_diff_friendly(patient_id, facility_id, medical_records):
    """Structured for pairwise comparison, grouped by year. Target ~800 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    # CONDITIONS - grouped by onset year
    cond_df = records.get('conditions')
    if cond_df is not None:
        cond_df = cond_df.copy()
        cond_df['year'] = pd.to_datetime(cond_df['START'], errors='coerce').dt.year
        cond_df['is_ongoing'] = cond_df['STOP'].isna() | (cond_df['STOP'].astype(str).str.strip() == '')
        lines = ["CONDITIONS:"]
        for year, grp in sorted(cond_df.groupby('year')):
            if pd.isna(year):
                continue
            descs = []
            for _, row in grp.iterrows():
                status = " *" if row['is_ongoing'] else ""
                descs.append(f"{row['DESCRIPTION']}{status}")
            lines.append(f"  {int(year)}: {'; '.join(descs)}")
        sections.append("\n".join(lines))

    # MEDICATIONS - drug (start_year-end_year or ongoing)
    meds_df = records.get('medications')
    if meds_df is not None:
        meds_df = meds_df.copy()
        lines = ["MEDICATIONS:"]
        for desc, grp in meds_df.groupby('DESCRIPTION', sort=False):
            start_dt = pd.to_datetime(grp['START'], errors='coerce').min()
            is_current = grp['STOP'].isna().any() | (grp['STOP'].astype(str).str.strip() == '').any()
            if is_current:
                period = f"{start_dt.year}\u2013ongoing" if pd.notna(start_dt) else "ongoing"
            else:
                end_dt = pd.to_datetime(grp['STOP'], errors='coerce').max()
                if pd.notna(start_dt) and pd.notna(end_dt):
                    period = f"{start_dt.year}\u2013{end_dt.year}"
                else:
                    period = "unknown"
            lines.append(f"- {desc} ({period})")
        sections.append("\n".join(lines))

    # ALLERGIES - flat list
    allg_df = records.get('allergies')
    if allg_df is not None:
        names = sorted(allg_df['DESCRIPTION'].unique())
        sections.append("ALLERGIES: " + "; ".join(names))

    # KEY OBSERVATIONS - latest 2 values per metric
    obs_df = records.get('observations')
    if obs_df is not None:
        obs_df = obs_df.copy()
        obs_df['date_dt'] = pd.to_datetime(obs_df['DATE'], errors='coerce')
        key_obs = [
            'Body Height', 'Body Weight', 'Body Mass Index',
            'Systolic Blood Pressure', 'Diastolic Blood Pressure',
            'Hemoglobin A1c/Hemoglobin.total in Blood',
            'Glucose', 'Total Cholesterol',
        ]
        lines = ["OBSERVATIONS:"]
        for obs_name in key_obs:
            match = obs_df[obs_df['DESCRIPTION'].str.contains(obs_name, case=False, na=False)]
            if match.empty:
                continue
            recent = match.sort_values('date_dt').tail(2)
            vals = []
            for _, row in recent.iterrows():
                v = row.get('VALUE', '')
                u = row.get('UNITS', '')
                d = str(row.get('DATE', ''))[:10]
                if pd.notna(v) and str(v).strip():
                    u_str = f" {u}" if pd.notna(u) and u else ""
                    vals.append(f"{v}{u_str} ({d})")
            if vals:
                lines.append(f"- {obs_name}: {', '.join(vals)}")
        sections.append("\n".join(lines))

    # PROCEDURES - with years, chronological
    proc_df = records.get('procedures')
    if proc_df is not None:
        proc_df = proc_df.copy()
        proc_df['year'] = pd.to_datetime(proc_df['START'], errors='coerce').dt.year
        lines = ["PROCEDURES:"]
        for desc, grp in proc_df.groupby('DESCRIPTION', sort=False):
            years = sorted(grp['year'].dropna().unique())
            year_strs = [str(int(y)) for y in years]
            lines.append(f"- {desc} ({', '.join(year_strs)})")
        sections.append("\n".join(lines))

    return "\n\n".join(sections) if sections else "No clinical records available."


# Build summary cache for all unique record_ids in benchmark_pairs
unique_rids = set()
for r1, r2, _ in benchmark_pairs:
    unique_rids.add(r1)
    unique_rids.add(r2)

print(f"Generating Strategy D summaries for {len(unique_rids)} unique records...")
summary_cache = {}
for rid in sorted(unique_rids):
    if rid in record_map:
        pid, fid = record_map[rid]
        summary_cache[rid] = summarize_diff_friendly(pid, fid, medical_records)

# Token length stats
token_lengths = [len(tokenizer.encode(s)) for s in summary_cache.values()]
pair_lengths = []
for r1, r2, _ in benchmark_pairs:
    if r1 in summary_cache and r2 in summary_cache:
        combined = summary_cache[r1] + "\n\n" + summary_cache[r2]
        pair_lengths.append(len(tokenizer.encode(combined)))

print(f"\nSingle summary token lengths:")
print(f"  Mean: {np.mean(token_lengths):.0f}, Median: {np.median(token_lengths):.0f}, Max: {max(token_lengths)}")
print(f"\nPair token lengths (summaries only):")
print(f"  Mean: {np.mean(pair_lengths):.0f}, Median: {np.median(pair_lengths):.0f}, Max: {max(pair_lengths)}")

# Show one example
example_rid = sorted(summary_cache.keys())[0]
print(f"\n{'='*60}")
print(f"Example summary ({example_rid}):")
print(f"{'='*60}")
print(summary_cache[example_rid][:1500])

Generating Strategy D summaries for 95 unique records...

Single summary token lengths:
  Mean: 722, Median: 554, Max: 2857

Pair token lengths (summaries only):
  Mean: 1424, Median: 1325, Max: 3447

Example summary (facility_001_2fc3ff4f-6a9b-b756-f92e-eb6f1f3a00a8):
CONDITIONS:
  2025: Medication review due (situation)

OBSERVATIONS:
- Body Height: 56.2 cm (2025-11-02)
- Body Weight: 4.6 kg (2025-11-02)
- Systolic Blood Pressure: 130.0 mm[Hg] (2025-11-02)
- Diastolic Blood Pressure: 78.0 mm[Hg] (2025-11-02)

PROCEDURES:
- Medication Reconciliation (procedure) (2025)


## Benchmark: Base Gemma 1B on Strategy D Summaries

Run the unfine-tuned model on all 50 benchmark pairs with greedy decoding. `max_length=4096` to accommodate Strategy D pairs (up to ~3400 tokens with chat template overhead).

In [6]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix

INSTRUCTION = (
    "You are a medical record matching expert. Compare these two patient "
    "medical records and determine if they belong to the same patient based "
    "only on their clinical history.\n\n"
    "Record A:\n{summary_a}\n\n"
    "Record B:\n{summary_b}\n\n"
    "Are these the same patient? Answer only 'True' or 'False'."
)


def predict_match(model, tokenizer, summary_a, summary_b):
    """Predict whether two medical records are from the same patient."""
    prompt = INSTRUCTION.format(summary_a=summary_a, summary_b=summary_b)
    messages = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(
        input_text, return_tensors="pt", truncation=True, max_length=4096
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=8, do_sample=False)

    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    ).strip().lower()

    if "true" in response:
        return True
    elif "false" in response:
        return False
    return None


print(f"Benchmarking base Gemma 1B on {len(benchmark_pairs)} pairs (Strategy D summaries)...")
print(f"  ({sum(1 for _,_,l in benchmark_pairs if l)} match + "
      f"{sum(1 for _,_,l in benchmark_pairs if not l)} non-match)\n")

preds, labels, unparseable = [], [], 0
for i, (r1, r2, label) in enumerate(benchmark_pairs):
    pred = predict_match(model, tokenizer, summary_cache[r1], summary_cache[r2])
    if pred is not None:
        preds.append(pred)
        labels.append(label)
    else:
        unparseable += 1
    if (i + 1) % 10 == 0:
        correct = sum(p == l for p, l in zip(preds, labels))
        print(f"  {i+1}/{len(benchmark_pairs)}... ({correct}/{len(preds)} correct)")

# Metrics
metrics = {
    'accuracy': accuracy_score(labels, preds),
    'precision': precision_score(labels, preds, zero_division=0),
    'recall': recall_score(labels, preds, zero_division=0),
    'f1': f1_score(labels, preds, zero_division=0),
}

print(f"\nBase Gemma 1B + Strategy D Results ({len(preds)} parseable / "
      f"{unparseable} unparseable / {len(benchmark_pairs)} total):")
for metric, value in metrics.items():
    print(f"  {metric:>10s}: {value:.3f}")
print(f"\nConfusion Matrix (rows=actual, cols=predicted):")
print(confusion_matrix(labels, preds))

# Prediction distribution
print(f"\nPrediction distribution: {sum(preds)} True / {len(preds) - sum(preds)} False")
print(f"Actual distribution:     {sum(labels)} True / {len(labels) - sum(labels)} False")

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Benchmarking base Gemma 1B on 50 pairs (Strategy D summaries)...
  (25 match + 25 non-match)

  10/50... (5/10 correct)
  20/50... (9/20 correct)
  30/50... (16/30 correct)
  40/50... (20/40 correct)
  50/50... (24/50 correct)

Base Gemma 1B + Strategy D Results (50 parseable / 0 unparseable / 50 total):
    accuracy: 0.480
   precision: 0.486
      recall: 0.680
          f1: 0.567

Confusion Matrix (rows=actual, cols=predicted):
[[ 7 18]
 [ 8 17]]

Prediction distribution: 35 True / 15 False
Actual distribution:     25 True / 25 False


## Results & Context

### Comparison with other configurations

| Model | Summary Format | Tok/Summary | F1 |
|-------|---------------|-------------|----|
| Opus | Condensed (~220 tok) | ~200 | 0.558 |
| Opus | Full summary (~1000 tok) | ~1019 | 0.889 |
| Opus | Strategy D (~760 tok) | ~760 | **0.936** |
| Opus | Raw records (~6000 tok) | ~6069 | 0.980 |
| **Base Gemma 1B** | **Strategy D (~760 tok)** | **~760** | **see above** |

This is the pre-fine-tuning baseline for Gemma 1B with Strategy D summaries. Fine-tuning should close the gap between this number and the Opus F1=0.936.