# Summarizer Optimization for Entity Resolution

Benchmarking from `entity_resolution_fine_tuning.ipynb` revealed the summarizer is the bottleneck:

| Format | Tokens (mean) | Accuracy | F1 |
|--------|--------------|----------|-----|
| Condensed (no dates, names only) | ~200 | 0.620 | 0.558 |
| Full `summarize_patient_records()` | ~969 | 0.900 | 0.889 |
| Raw records (semantic tags) | ~6069 | 0.980 | 0.980 |

The full summarizer loses 9 F1 points vs raw records. This notebook tests 5 summarizer strategies to close that gap while staying within ~500–2000 tokens per summary.

**Strategies:**
- **A: Temporal-Enhanced** (~1200 tok) — preserve more dates, time-series, encounter details
- **B: Identity-Focused** (~600 tok) — only highly discriminative features
- **C: Compact Raw** (~2000 tok) — XML-tagged raw records, pruned
- **D: Structured Diff-Friendly** (~800 tok) — year-grouped for pairwise comparison
- **E: Haiku LLM Summarizer** (~500 tok) — Claude Haiku generates clinical summaries optimized for entity resolution

In [1]:
import os
import sys
import random
import numpy as np
import pandas as pd

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Add llm-entity-resolution for summarizer imports
_llm_er_root = os.path.join(PROJECT_ROOT, "llm-entity-resolution")
if _llm_er_root not in sys.path:
    sys.path.insert(0, _llm_er_root)

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

def count_tokens(text):
    """Count tokens using Gemma tokenizer."""
    return len(tokenizer.encode(text))

print(f"Project root: {PROJECT_ROOT}")

Project root: /Users/alex/repos/Kaggle/SyntheticMass


## 1. Data Loading & Pair Generation

Reproduce the exact same 50 benchmark pairs (25 match + 25 non-match) from the fine-tuning notebook using the same random seed and pair generation logic.

In [2]:
from shared.data_loader import load_facility_patients
from shared.ground_truth import (
    load_ground_truth,
    add_record_ids_to_ground_truth,
    generate_true_pairs_from_ground_truth,
)
from shared.medical_records import load_medical_records, get_patient_records
from src.summarize import (
    summarize_patient_records,
    _summarize_conditions, _summarize_allergies,
    _summarize_imaging, _summarize_devices, _summarize_careplans,
)

RUN_DIR = os.path.join(PROJECT_ROOT, "output", "augmented", "run_20260203_071928")

# Load patients and create record_ids
patients_df = load_facility_patients(RUN_DIR)
patients_df['record_id'] = patients_df['facility_id'] + '_' + patients_df['id'].astype(str)

# Load ground truth and add record_ids
ground_truth_df = load_ground_truth(RUN_DIR)
ground_truth_df = add_record_ids_to_ground_truth(ground_truth_df, patients_df)

# Generate true match pairs
true_pairs = generate_true_pairs_from_ground_truth(ground_truth_df)

# Build record_id -> (patient_uuid, facility_id) mapping
record_map = {}
for _, row in patients_df.iterrows():
    record_map[row['record_id']] = (row['id'], row['facility_id'])

# Load medical records
medical_records = load_medical_records(RUN_DIR)

# Generate non-match pairs (same code + seed as fine-tuning notebook)
rid_to_true_id = (ground_truth_df.dropna(subset=['record_id'])
                  .set_index('record_id')['true_patient_id'].to_dict())
all_record_ids = list(rid_to_true_id.keys())

random.seed(42)
non_match_pairs = set()
target = len(true_pairs)
attempts = 0
while len(non_match_pairs) < target and attempts < target * 20:
    r1, r2 = random.sample(all_record_ids, 2)
    if rid_to_true_id.get(r1) != rid_to_true_id.get(r2):
        non_match_pairs.add(tuple(sorted([r1, r2])))
    attempts += 1

# Reproduce train/eval/test split (same as fine-tuning notebook)
all_pairs = ([(r1, r2, True) for r1, r2 in true_pairs] +
             [(r1, r2, False) for r1, r2 in non_match_pairs])
random.shuffle(all_pairs)

# Filter to pairs with records (same as fine-tuning notebook filtering by summary_cache)
all_pairs = [(r1, r2, l) for r1, r2, l in all_pairs
             if r1 in record_map and r2 in record_map]

matches = [p for p in all_pairs if p[2]]
non_matches = [p for p in all_pairs if not p[2]]

n_total = min(len(matches), len(non_matches))
n_train = min(500, int(n_total * 0.70))
n_eval = min(100, int(n_total * 0.15))
n_test = min(50, n_total - n_train - n_eval)

test_pairs = (matches[n_train+n_eval:n_train+n_eval+n_test] +
              non_matches[n_train+n_eval:n_train+n_eval+n_test])
random.shuffle(test_pairs)

# Same 50-pair selection as fine-tuning notebook (Sonnet/Opus benchmarks)
benchmark_matches = [p for p in test_pairs if p[2]][:25]
benchmark_non_matches = [p for p in test_pairs if not p[2]][:25]
benchmark_pairs = benchmark_matches + benchmark_non_matches
random.shuffle(benchmark_pairs)

print(f"True pairs: {len(true_pairs)}, Non-match pairs: {len(non_match_pairs)}")
print(f"Test pairs: {len(test_pairs)} ({sum(1 for _,_,l in test_pairs if l)} match + "
      f"{sum(1 for _,_,l in test_pairs if not l)} non-match)")
print(f"Benchmark pairs: {len(benchmark_pairs)} "
      f"({len(benchmark_matches)} match + {len(benchmark_non_matches)} non-match)")

True pairs: 1121, Non-match pairs: 1121
Test pairs: 100 (50 match + 50 non-match)
Benchmark pairs: 50 (25 match + 25 non-match)


## 2. Shared Benchmark Helpers

Reusable `call_opus()` and `run_opus_benchmark()` functions used by all strategies.

In [3]:
import asyncio
from claude_agent_sdk import query, ClaudeAgentOptions
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

INSTRUCTION = (
    "You are a medical record matching expert. Compare these two patient "
    "medical records and determine if they belong to the same patient based "
    "only on their clinical history.\n\n"
    "Record A:\n{summary_a}\n\n"
    "Record B:\n{summary_b}\n\n"
    "Are these the same patient? Answer only 'True' or 'False'."
)

MAX_PAIR_TOKENS = 50_000  # skip pairs exceeding this to avoid context window errors


async def call_opus(prompt: str, retries: int = 3) -> str:
    """Call Claude Opus via claude-agent-sdk with retry on failure."""
    for attempt in range(retries):
        try:
            result_parts = []
            async for message in query(
                prompt=prompt,
                options=ClaudeAgentOptions(
                    model="claude-opus-4-6",
                    max_turns=1,
                    allowed_tools=[],
                    system_prompt="You are a medical record matching expert. Answer only 'True' or 'False'.",
                ),
            ):
                if hasattr(message, 'content'):
                    if isinstance(message.content, list):
                        for block in message.content:
                            if hasattr(block, 'text'):
                                result_parts.append(block.text)
                    elif isinstance(message.content, str):
                        result_parts.append(message.content)
            return "\n".join(result_parts)
        except Exception as e:
            if attempt < retries - 1:
                wait = 2 ** (attempt + 1)
                print(f"    Retry {attempt+1}/{retries-1} after error: {e!r} (waiting {wait}s)")
                await asyncio.sleep(wait)
            else:
                raise


# Store all results for final comparison
all_results = {}


async def run_opus_benchmark(name, summarizer_fn, benchmark_pairs, record_map, medical_records):
    """Run Opus benchmark on 50 pairs with a given summarizer. Returns (metrics, cache)."""
    # Build summary cache
    cache = {}
    for r1, r2, _ in benchmark_pairs:
        for rid in [r1, r2]:
            if rid not in cache and rid in record_map:
                pid, fid = record_map[rid]
                cache[rid] = summarizer_fn(pid, fid, medical_records)

    # Token stats
    token_lengths = [count_tokens(s) for s in cache.values()]
    pair_lengths = [count_tokens(cache[r1]) + count_tokens(cache[r2])
                    for r1, r2, _ in benchmark_pairs]

    print(f"\n{'='*60}")
    print(f"  {name}")
    print(f"{'='*60}")
    print(f"Tokens per summary: mean={np.mean(token_lengths):.0f}, "
          f"median={np.median(token_lengths):.0f}, max={max(token_lengths)}")
    print(f"Tokens per pair:    mean={np.mean(pair_lengths):.0f}, max={max(pair_lengths)}")

    # Show example
    example_rid = list(cache.keys())[0]
    print(f"\nExample summary ({example_rid}):")
    print("-" * 40)
    text = cache[example_rid]
    print(text[:1500])
    if len(text) > 1500:
        print("... (truncated)")

    # Run Opus
    print(f"\nRunning {len(benchmark_pairs)} pairs through Opus...")
    preds, labels = [], []
    skipped, errors = 0, 0
    for i, (r1, r2, label) in enumerate(benchmark_pairs):
        pair_tok = count_tokens(cache[r1]) + count_tokens(cache[r2])
        if pair_tok > MAX_PAIR_TOKENS:
            skipped += 1
            print(f"  Pair {i+1}: SKIPPED ({pair_tok:,} tokens exceeds {MAX_PAIR_TOKENS:,} limit)")
            continue

        try:
            prompt = INSTRUCTION.format(summary_a=cache[r1], summary_b=cache[r2])
            answer = await call_opus(prompt)
        except Exception as e:
            errors += 1
            print(f"  Pair {i+1}: ERROR after retries: {e!r}")
            continue

        answer_lower = answer.strip().lower()

        if "true" in answer_lower:
            pred = True
        elif "false" in answer_lower:
            pred = False
        else:
            print(f"  Pair {i+1}: unparseable: {answer!r}")
            continue

        preds.append(pred)
        labels.append(label)

        if (i + 1) % 10 == 0:
            correct = sum(p == l for p, l in zip(preds, labels))
            print(f"  {i+1}/{len(benchmark_pairs)}... ({correct}/{len(preds)} correct)")

    metrics = {
        'accuracy': accuracy_score(labels, preds),
        'precision': precision_score(labels, preds, zero_division=0),
        'recall': recall_score(labels, preds, zero_division=0),
        'f1': f1_score(labels, preds, zero_division=0),
        'tokens_mean': np.mean(token_lengths),
        'tokens_pair_mean': np.mean(pair_lengths),
    }

    print(f"\nResults ({len(preds)} scored / {skipped} skipped / {errors} errors "
          f"/ {len(benchmark_pairs)} total):")
    for m in ['accuracy', 'precision', 'recall', 'f1']:
        print(f"  {m:>10s}: {metrics[m]:.3f}")
    print(f"\nConfusion Matrix (rows=actual, cols=predicted):")
    print(confusion_matrix(labels, preds))

    all_results[name] = metrics
    return metrics, cache


print("Helpers loaded. call_opus() and run_opus_benchmark() ready.")

Helpers loaded. call_opus() and run_opus_benchmark() ready.


## 3. Baseline: Full Summary (`summarize_patient_records`)

Re-run the production summarizer as a live baseline on the same 50 pairs.
Known result from fine-tuning notebook: accuracy=0.900, F1=0.889.

In [4]:
baseline_metrics, baseline_cache = await run_opus_benchmark(
    "Baseline (full summary)",
    summarize_patient_records,
    benchmark_pairs, record_map, medical_records
)


  Baseline (full summary)
Tokens per summary: mean=1019, median=992, max=2624
Tokens per pair:    mean=2048, max=3845

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
=== MEDICAL HISTORY ===

CONDITIONS (active/historical):
- Received certificate of high school equivalency (finding) (onset: 1978-11-15, ongoing)
- Stress (finding) (x9) (onset: 1978-11-15, resolved 2018-05-30)
- Part-time employment (finding) (x8) (onset: 1982-11-24, resolved 2016-05-18)
- Essential hypertension (disorder) (onset: 1985-11-27, ongoing)
- Miscarriage in first trimester (onset: 1989-05-10, ongoing)
- Body mass index 30+ - obesity (finding) (onset: 1993-01-06, ongoing)
- Unhealthy alcohol drinking behavior (finding) (onset: 1993-01-06, resolved 1999-02-10)
- History of tubal ligation (situation) (onset: 1996-06-21, ongoing)
- Victim of intimate partner abuse (finding) (onset: 2011-04-20, resolved 2016-05-18)

MEDICATIONS (current/past):
- Hydroch

## 4. Strategy A: Temporal-Enhanced (~1200 tokens)

Preserve more temporal detail than the production summarizer:
- **Conditions**: Same as production (already has onset dates + status)
- **Medications**: Individual fill start dates per drug (not just "first to last" range)
- **Observations**: Last 3 values per key metric (not just latest) to reveal trends
- **Procedures**: All procedures with all dates (no top-15 cutoff)
- **Encounters**: One-liner per encounter with date + class + reason
- **Immunizations**: Each with individual date (not just counts)
- **Allergies/Imaging/Devices/Careplans**: Same as production

In [5]:
def summarize_temporal_enhanced(patient_id, facility_id, medical_records):
    """Enhanced summary preserving temporal detail. Target ~1200 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = ["=== MEDICAL HISTORY ==="]

    # CONDITIONS - reuse production (already includes onset dates + status)
    conditions_df = records.get('conditions')
    sections.append(_summarize_conditions(conditions_df))

    # MEDICATIONS - individual fill start dates per drug
    meds_df = records.get('medications')
    if meds_df is not None:
        meds_df = meds_df.copy().sort_values('START')
        lines = ["MEDICATIONS:"]
        for desc, grp in meds_df.groupby('DESCRIPTION', sort=False):
            dates = sorted(grp['START'].astype(str).str[:10].unique())
            reasons = grp['REASONDESCRIPTION'].dropna()
            reasons = reasons[reasons.astype(str).str.strip() != '']
            reason = reasons.mode().iloc[0] if not reasons.empty else ""
            reason_str = f" for {reason}" if reason else ""
            is_current = grp['STOP'].isna().any() | (grp['STOP'].astype(str).str.strip() == '').any()
            status = " (ongoing)" if is_current else ""
            lines.append(f"- {desc}: {', '.join(dates)}{reason_str}{status}")
        sections.append("\n".join(lines))
    else:
        sections.append("MEDICATIONS: none")

    # ENCOUNTERS - one-liner per encounter
    enc_df = records.get('encounters')
    if enc_df is not None:
        enc_df = enc_df.copy().sort_values('START')
        lines = ["ENCOUNTERS:"]
        for _, row in enc_df.iterrows():
            date = str(row.get('START', ''))[:10]
            enc_class = row.get('ENCOUNTERCLASS', '')
            reason = row.get('REASONDESCRIPTION', '')
            reason_str = f" \u2014 {reason}" if pd.notna(reason) and str(reason).strip() else ""
            lines.append(f"- {date} {enc_class}{reason_str}")
        sections.append("\n".join(lines))
    else:
        sections.append("ENCOUNTERS: none")

    # OBSERVATIONS - last 3 values per key metric
    obs_df = records.get('observations')
    if obs_df is not None:
        obs_df = obs_df.copy()
        obs_df['date_dt'] = pd.to_datetime(obs_df['DATE'], errors='coerce')
        key_obs = [
            'Body Height', 'Body Weight', 'Body Mass Index',
            'Systolic Blood Pressure', 'Diastolic Blood Pressure',
            'Hemoglobin A1c/Hemoglobin.total in Blood',
            'Glucose', 'Total Cholesterol',
            'Heart rate', 'Respiratory rate',
            'Estimated Glomerular Filtration Rate',
        ]
        lines = ["KEY OBSERVATIONS (last 3 values):"]
        for obs_name in key_obs:
            match = obs_df[obs_df['DESCRIPTION'].str.contains(obs_name, case=False, na=False)]
            if match.empty:
                continue
            recent = match.sort_values('date_dt').tail(3)
            vals = []
            for _, row in recent.iterrows():
                v = row.get('VALUE', '')
                u = row.get('UNITS', '')
                d = str(row.get('DATE', ''))[:10]
                if pd.notna(v) and str(v).strip():
                    u_str = f" {u}" if pd.notna(u) and u else ""
                    vals.append(f"{v}{u_str} ({d})")
            if vals:
                lines.append(f"- {obs_name}: {' \u2192 '.join(vals)}")
        lines.append(f"- Total observations on file: {len(obs_df)}")
        sections.append("\n".join(lines))
    else:
        sections.append("KEY OBSERVATIONS: none")

    # PROCEDURES - all with dates (no top-15 cutoff)
    proc_df = records.get('procedures')
    if proc_df is not None:
        proc_df = proc_df.copy().sort_values('START')
        lines = ["PROCEDURES:"]
        for desc, grp in proc_df.groupby('DESCRIPTION', sort=False):
            dates = sorted(grp['START'].astype(str).str[:10].unique())
            lines.append(f"- {desc}: {', '.join(dates)}")
        sections.append("\n".join(lines))
    else:
        sections.append("PROCEDURES: none")

    # IMMUNIZATIONS - each with date
    imm_df = records.get('immunizations')
    if imm_df is not None:
        imm_df = imm_df.copy().sort_values('DATE')
        lines = ["IMMUNIZATIONS:"]
        for _, row in imm_df.iterrows():
            date = str(row.get('DATE', ''))[:10]
            desc = row.get('DESCRIPTION', '')
            lines.append(f"- {date}: {desc}")
        sections.append("\n".join(lines))
    else:
        sections.append("IMMUNIZATIONS: none")

    # ALLERGIES, IMAGING, DEVICES, CAREPLANS - reuse production
    sections.append(_summarize_allergies(records.get('allergies')))
    sections.append(_summarize_imaging(records.get('imaging_studies')))
    sections.append(_summarize_devices(records.get('devices')))
    sections.append(_summarize_careplans(records.get('careplans')))

    return "\n\n".join(s for s in sections if s)


print("summarize_temporal_enhanced() defined.")

summarize_temporal_enhanced() defined.


In [6]:
strategy_a_metrics, strategy_a_cache = await run_opus_benchmark(
    "A: Temporal-Enhanced",
    summarize_temporal_enhanced,
    benchmark_pairs, record_map, medical_records
)


  A: Temporal-Enhanced
Tokens per summary: mean=2315, median=1601, max=12148
Tokens per pair:    mean=4564, max=14439

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
=== MEDICAL HISTORY ===

CONDITIONS (active/historical):
- Received certificate of high school equivalency (finding) (onset: 1978-11-15, ongoing)
- Stress (finding) (x9) (onset: 1978-11-15, resolved 2018-05-30)
- Part-time employment (finding) (x8) (onset: 1982-11-24, resolved 2016-05-18)
- Essential hypertension (disorder) (onset: 1985-11-27, ongoing)
- Miscarriage in first trimester (onset: 1989-05-10, ongoing)
- Body mass index 30+ - obesity (finding) (onset: 1993-01-06, ongoing)
- Unhealthy alcohol drinking behavior (finding) (onset: 1993-01-06, resolved 1999-02-10)
- History of tubal ligation (situation) (onset: 1996-06-21, ongoing)
- Victim of intimate partner abuse (finding) (onset: 2011-04-20, resolved 2016-05-18)

MEDICATIONS:
- Hydrochlorothiazide 25

## 5. Strategy B: Identity-Focused (~600 tokens)

Only the most discriminative features for patient matching:
- **Chronic conditions only** — filter out resolved conditions and generic social/employment findings
- **All medications** with dosages (discriminative)
- **All allergies** (rare and highly patient-specific)
- **Key stable vitals**: height, weight, BMI (stable identifiers)
- **Rare procedures** only — skip routine screenings (depression, anxiety, substance use assessments)
- **Skip**: encounters, immunizations, care plans, devices, imaging (low signal-to-noise for matching)

In [7]:
# Generic conditions that appear on many patients (low discriminative value)
GENERIC_CONDITIONS = {
    'Medication review due (situation)',
    'Full-time employment (finding)',
    'Part-time employment (finding)',
    'Unemployed (finding)',
    'Not in labor force (finding)',
    'Received higher education (finding)',
    'Only received primary school education (finding)',
    'Received certificate of high school equivalency (finding)',
    'Social isolation (finding)',
    'Limited social contact (finding)',
    'Reports of violence in the environment (finding)',
    'Stress (finding)',
    'Victim of intimate partner abuse (finding)',
}

# Common screening procedures (appear on nearly every patient)
COMMON_SCREENING_PROCEDURES = {
    'Assessment of health and social care needs (procedure)',
    'Assessment of anxiety (procedure)',
    'Assessment of substance use (procedure)',
    'Depression screening (procedure)',
    'Depression screening using Patient Health Questionnaire Nine Item score (procedure)',
    'Depression screening using Patient Health Questionnaire Two-Item score (procedure)',
    'Assessment using Alcohol Use Disorders Identification Test - Consumption (procedure)',
    'Medication Reconciliation (procedure)',
    'Screening for domestic abuse (procedure)',
    'Screening for drug abuse (procedure)',
}


def summarize_identity_focused(patient_id, facility_id, medical_records):
    """Identity-focused summary with only discriminative features. Target ~600 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    # CHRONIC CONDITIONS (ongoing only, skip generic)
    cond_df = records.get('conditions')
    if cond_df is not None:
        cond_df = cond_df.copy()
        cond_df['is_ongoing'] = cond_df['STOP'].isna() | (cond_df['STOP'].astype(str).str.strip() == '')
        chronic = cond_df[cond_df['is_ongoing']]
        chronic = chronic[~chronic['DESCRIPTION'].isin(GENERIC_CONDITIONS)]
        if not chronic.empty:
            names = sorted(chronic['DESCRIPTION'].unique())
            sections.append("CHRONIC CONDITIONS: " + "; ".join(names))

    # ALL MEDICATIONS (with dosages — discriminative)
    meds_df = records.get('medications')
    if meds_df is not None:
        names = sorted(meds_df['DESCRIPTION'].unique())
        sections.append("MEDICATIONS: " + "; ".join(names))

    # ALLERGIES (rare and patient-specific)
    allg_df = records.get('allergies')
    if allg_df is not None:
        lines = []
        for _, row in allg_df.iterrows():
            desc = row.get('DESCRIPTION', '')
            reaction = row.get('DESCRIPTION1', '')
            extras = []
            if pd.notna(reaction) and str(reaction).strip():
                extras.append(str(reaction))
            extra_str = f" ({', '.join(extras)})" if extras else ""
            lines.append(f"{desc}{extra_str}")
        sections.append("ALLERGIES: " + "; ".join(lines))

    # KEY STABLE VITALS (height, weight, BMI)
    obs_df = records.get('observations')
    if obs_df is not None:
        obs_df = obs_df.copy()
        obs_df['date_dt'] = pd.to_datetime(obs_df['DATE'], errors='coerce')
        stable_metrics = ['Body Height', 'Body Weight', 'Body Mass Index']
        vitals = []
        for metric in stable_metrics:
            match = obs_df[obs_df['DESCRIPTION'].str.contains(metric, case=False, na=False)]
            if not match.empty:
                latest = match.sort_values('date_dt').iloc[-1]
                v = latest.get('VALUE', '')
                u = latest.get('UNITS', '')
                if pd.notna(v) and str(v).strip():
                    u_str = f" {u}" if pd.notna(u) and u else ""
                    vitals.append(f"{metric}: {v}{u_str}")
        if vitals:
            sections.append("VITALS: " + "; ".join(vitals))

    # RARE PROCEDURES (not common screenings)
    proc_df = records.get('procedures')
    if proc_df is not None:
        rare = proc_df[~proc_df['DESCRIPTION'].isin(COMMON_SCREENING_PROCEDURES)]
        if not rare.empty:
            names = sorted(rare['DESCRIPTION'].unique())[:20]
            sections.append("PROCEDURES: " + "; ".join(names))

    return "\n".join(sections) if sections else "No clinical records available."


print("summarize_identity_focused() defined.")

summarize_identity_focused() defined.


In [8]:
strategy_b_metrics, strategy_b_cache = await run_opus_benchmark(
    "B: Identity-Focused",
    summarize_identity_focused,
    benchmark_pairs, record_map, medical_records
)


  B: Identity-Focused
Tokens per summary: mean=145, median=125, max=568
Tokens per pair:    mean=288, max=703

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
CHRONIC CONDITIONS: Body mass index 30+ - obesity (finding); Essential hypertension (disorder); History of tubal ligation (situation); Miscarriage in first trimester
MEDICATIONS: Hydrochlorothiazide 25 MG Oral Tablet; ferrous sulfate 325 MG Oral Tablet; lisinopril 10 MG Oral Tablet

Running 50 pairs through Opus...
  10/50... (10/10 correct)
  20/50... (19/20 correct)
  30/50... (27/30 correct)
  40/50... (37/40 correct)
  50/50... (46/50 correct)

Results (50 scored / 0 skipped / 0 errors / 50 total):
    accuracy: 0.920
   precision: 0.957
      recall: 0.880
          f1: 0.917

Confusion Matrix (rows=actual, cols=predicted):
[[24  1]
 [ 3 22]]


## 6. Strategy C: Compact Raw (~2000 tokens)

Raw records format with semantic XML tags but aggressively pruned:
- **Keep**: conditions, medications, allergies, procedures, key observations
- **Drop**: encounters (noisy), care plans, devices, supplies, imaging
- **Observations**: filtered to discriminative types (chronic disease markers, not routine vitals)
- **Each section** capped at 30 rows

In [9]:
# Observation types that are discriminative for patient matching
DISCRIMINATIVE_OBS = [
    'Body Height', 'Body Weight', 'Body Mass Index',
    'Hemoglobin A1c', 'Glucose', 'Total Cholesterol',
    'Triglycerides', 'Low Density', 'High Density',
    'Creatinine', 'Estimated Glomerular Filtration Rate',
    'Systolic Blood Pressure', 'Diastolic Blood Pressure',
    'Calcium', 'Sodium', 'Potassium', 'Chloride',
    'Carbon Dioxide', 'Urea Nitrogen',
]

SECTION_CAP = 30  # max rows per section to prevent runaway summaries


def summarize_compact_raw(patient_id, facility_id, medical_records):
    """Compact raw records with XML tags. Target ~2000 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    # Conditions (cap at SECTION_CAP)
    cond_df = records.get('conditions')
    if cond_df is not None:
        cond_df = cond_df.sort_values('START')
        lines = ["<conditions>"]
        for _, row in cond_df.head(SECTION_CAP).iterrows():
            start = str(row.get('START', ''))[:10]
            stop_val = row.get('STOP')
            stop = str(stop_val)[:10] if pd.notna(stop_val) and str(stop_val).strip() else 'ongoing'
            desc = row.get('DESCRIPTION', '')
            lines.append(f"  {start} to {stop}: {desc}")
        if len(cond_df) > SECTION_CAP:
            lines.append(f"  ... and {len(cond_df) - SECTION_CAP} more")
        lines.append("</conditions>")
        sections.append("\n".join(lines))

    # Medications (cap at SECTION_CAP)
    meds_df = records.get('medications')
    if meds_df is not None:
        meds_df = meds_df.sort_values('START')
        lines = ["<medications>"]
        for _, row in meds_df.head(SECTION_CAP).iterrows():
            start = str(row.get('START', ''))[:10]
            stop_val = row.get('STOP')
            stop = str(stop_val)[:10] if pd.notna(stop_val) and str(stop_val).strip() else 'ongoing'
            desc = row.get('DESCRIPTION', '')
            reason = row.get('REASONDESCRIPTION', '')
            reason_str = f" for {reason}" if pd.notna(reason) and str(reason).strip() else ""
            lines.append(f"  {start} to {stop}: {desc}{reason_str}")
        if len(meds_df) > SECTION_CAP:
            lines.append(f"  ... and {len(meds_df) - SECTION_CAP} more")
        lines.append("</medications>")
        sections.append("\n".join(lines))

    # Allergies (cap at SECTION_CAP)
    allg_df = records.get('allergies')
    if allg_df is not None:
        lines = ["<allergies>"]
        for _, row in allg_df.head(SECTION_CAP).iterrows():
            desc = row.get('DESCRIPTION', '')
            reaction = row.get('DESCRIPTION1', '')
            severity = row.get('SEVERITY1', '')
            extras = []
            if pd.notna(reaction) and str(reaction).strip():
                extras.append(str(reaction))
            if pd.notna(severity) and str(severity).strip():
                extras.append(str(severity))
            extra_str = f" ({', '.join(extras)})" if extras else ""
            lines.append(f"  {desc}{extra_str}")
        if len(allg_df) > SECTION_CAP:
            lines.append(f"  ... and {len(allg_df) - SECTION_CAP} more")
        lines.append("</allergies>")
        sections.append("\n".join(lines))

    # Procedures (cap at SECTION_CAP)
    proc_df = records.get('procedures')
    if proc_df is not None:
        proc_df = proc_df.sort_values('START')
        lines = ["<procedures>"]
        for _, row in proc_df.head(SECTION_CAP).iterrows():
            start = str(row.get('START', ''))[:10]
            desc = row.get('DESCRIPTION', '')
            lines.append(f"  {start}: {desc}")
        if len(proc_df) > SECTION_CAP:
            lines.append(f"  ... and {len(proc_df) - SECTION_CAP} more")
        lines.append("</procedures>")
        sections.append("\n".join(lines))

    # Observations - discriminative types only (cap at SECTION_CAP)
    obs_df = records.get('observations')
    if obs_df is not None:
        mask = obs_df['DESCRIPTION'].apply(
            lambda d: any(k in str(d) for k in DISCRIMINATIVE_OBS) if pd.notna(d) else False
        )
        filtered = obs_df[mask].sort_values('DATE')
        if not filtered.empty:
            lines = ["<observations>"]
            for _, row in filtered.tail(SECTION_CAP).iterrows():
                date = str(row.get('DATE', ''))[:10]
                desc = row.get('DESCRIPTION', '')
                val = row.get('VALUE', '')
                units = row.get('UNITS', '')
                u_str = f" {units}" if pd.notna(units) and str(units).strip() else ""
                lines.append(f"  {date}: {desc} = {val}{u_str}")
            if len(filtered) > SECTION_CAP:
                lines.append(f"  ... and {len(filtered) - SECTION_CAP} more")
            lines.append("</observations>")
            sections.append("\n".join(lines))

    return "\n\n".join(sections) if sections else "No clinical records available."


print("summarize_compact_raw() defined.")

summarize_compact_raw() defined.


In [10]:
strategy_c_metrics, strategy_c_cache = await run_opus_benchmark(
    "C: Compact Raw",
    summarize_compact_raw,
    benchmark_pairs, record_map, medical_records
)


  C: Compact Raw
Tokens per summary: mean=1846, median=1709, max=4322
Tokens per pair:    mean=3694, max=8561

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
<conditions>
  1978-11-15 to ongoing: Received certificate of high school equivalency (finding)
  1978-11-15 to 1979-11-21: Stress (finding)
  1982-11-24 to 1985-11-27: Part-time employment (finding)
  1982-11-24 to 1985-11-27: Stress (finding)
  1985-11-27 to ongoing: Essential hypertension (disorder)
  1988-12-14 to 1997-01-29: Stress (finding)
  1989-05-10 to ongoing: Miscarriage in first trimester
  1990-12-26 to 1991-07-24: Part-time employment (finding)
  1993-01-06 to ongoing: Body mass index 30+ - obesity (finding)
  1993-01-06 to 1999-02-10: Unhealthy alcohol drinking behavior (finding)
  1995-01-18 to 1996-01-24: Part-time employment (finding)
  1996-06-21 to ongoing: History of tubal ligation (situation)
  1997-01-29 to 1998-02-04: Part-time employment (fin

## 7. Strategy D: Structured Diff-Friendly (~800 tokens)

Designed specifically for pairwise comparison — organized by clinical category with year grouping:
- **Conditions** grouped by onset year (e.g., "2014: hypertension, prediabetes; 2019: obesity")
- **Medications** as "drug (start_year–end_year or ongoing)"
- **Allergies** as flat list
- **Key observations**: latest 2 values per metric
- **Procedures** with years, sorted chronologically

In [11]:
def summarize_diff_friendly(patient_id, facility_id, medical_records):
    """Structured for pairwise comparison, grouped by year. Target ~800 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    # CONDITIONS - grouped by onset year
    cond_df = records.get('conditions')
    if cond_df is not None:
        cond_df = cond_df.copy()
        cond_df['year'] = pd.to_datetime(cond_df['START'], errors='coerce').dt.year
        cond_df['is_ongoing'] = cond_df['STOP'].isna() | (cond_df['STOP'].astype(str).str.strip() == '')
        lines = ["CONDITIONS:"]
        for year, grp in sorted(cond_df.groupby('year')):
            if pd.isna(year):
                continue
            descs = []
            for _, row in grp.iterrows():
                status = " *" if row['is_ongoing'] else ""
                descs.append(f"{row['DESCRIPTION']}{status}")
            lines.append(f"  {int(year)}: {'; '.join(descs)}")
        sections.append("\n".join(lines))

    # MEDICATIONS - drug (start_year-end_year or ongoing)
    meds_df = records.get('medications')
    if meds_df is not None:
        meds_df = meds_df.copy()
        lines = ["MEDICATIONS:"]
        for desc, grp in meds_df.groupby('DESCRIPTION', sort=False):
            start_dt = pd.to_datetime(grp['START'], errors='coerce').min()
            is_current = grp['STOP'].isna().any() | (grp['STOP'].astype(str).str.strip() == '').any()
            if is_current:
                period = f"{start_dt.year}\u2013ongoing" if pd.notna(start_dt) else "ongoing"
            else:
                end_dt = pd.to_datetime(grp['STOP'], errors='coerce').max()
                if pd.notna(start_dt) and pd.notna(end_dt):
                    period = f"{start_dt.year}\u2013{end_dt.year}"
                else:
                    period = "unknown"
            lines.append(f"- {desc} ({period})")
        sections.append("\n".join(lines))

    # ALLERGIES - flat list
    allg_df = records.get('allergies')
    if allg_df is not None:
        names = sorted(allg_df['DESCRIPTION'].unique())
        sections.append("ALLERGIES: " + "; ".join(names))

    # KEY OBSERVATIONS - latest 2 values per metric
    obs_df = records.get('observations')
    if obs_df is not None:
        obs_df = obs_df.copy()
        obs_df['date_dt'] = pd.to_datetime(obs_df['DATE'], errors='coerce')
        key_obs = [
            'Body Height', 'Body Weight', 'Body Mass Index',
            'Systolic Blood Pressure', 'Diastolic Blood Pressure',
            'Hemoglobin A1c/Hemoglobin.total in Blood',
            'Glucose', 'Total Cholesterol',
        ]
        lines = ["OBSERVATIONS:"]
        for obs_name in key_obs:
            match = obs_df[obs_df['DESCRIPTION'].str.contains(obs_name, case=False, na=False)]
            if match.empty:
                continue
            recent = match.sort_values('date_dt').tail(2)
            vals = []
            for _, row in recent.iterrows():
                v = row.get('VALUE', '')
                u = row.get('UNITS', '')
                d = str(row.get('DATE', ''))[:10]
                if pd.notna(v) and str(v).strip():
                    u_str = f" {u}" if pd.notna(u) and u else ""
                    vals.append(f"{v}{u_str} ({d})")
            if vals:
                lines.append(f"- {obs_name}: {', '.join(vals)}")
        sections.append("\n".join(lines))

    # PROCEDURES - with years, chronological
    proc_df = records.get('procedures')
    if proc_df is not None:
        proc_df = proc_df.copy()
        proc_df['year'] = pd.to_datetime(proc_df['START'], errors='coerce').dt.year
        lines = ["PROCEDURES:"]
        for desc, grp in proc_df.groupby('DESCRIPTION', sort=False):
            years = sorted(grp['year'].dropna().unique())
            year_strs = [str(int(y)) for y in years]
            lines.append(f"- {desc} ({', '.join(year_strs)})")
        sections.append("\n".join(lines))

    return "\n\n".join(sections) if sections else "No clinical records available."


print("summarize_diff_friendly() defined.")

summarize_diff_friendly() defined.


In [12]:
strategy_d_metrics, strategy_d_cache = await run_opus_benchmark(
    "D: Structured Diff-Friendly",
    summarize_diff_friendly,
    benchmark_pairs, record_map, medical_records
)


  D: Structured Diff-Friendly
Tokens per summary: mean=760, median=627, max=2154
Tokens per pair:    mean=1508, max=3366

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
CONDITIONS:
  1978: Received certificate of high school equivalency (finding) *; Stress (finding)
  1982: Part-time employment (finding); Stress (finding)
  1985: Essential hypertension (disorder) *
  1988: Stress (finding)
  1989: Miscarriage in first trimester *
  1990: Part-time employment (finding)
  1993: Body mass index 30+ - obesity (finding) *; Unhealthy alcohol drinking behavior (finding)
  1995: Part-time employment (finding)
  1996: History of tubal ligation (situation) *
  1997: Part-time employment (finding)
  1998: Stress (finding)
  1999: Part-time employment (finding)
  2001: Stress (finding)
  2002: Part-time employment (finding)
  2003: Stress (finding)
  2005: Stress (finding)
  2006: Part-time employment (finding)
  2010: Stress (finding

## 8. Strategy E: Haiku LLM Summarizer (~500 tokens)

Instead of hand-crafted rules, use Claude Haiku as the summarizer. Feed it raw patient records and ask for a ~500-word clinical summary optimized for entity resolution.

**Approach:**
1. Format raw records into text input for Haiku
2. Pre-build a cache of Haiku summaries for all unique record_ids in benchmark pairs
3. Wrap the cache in a sync function for `run_opus_benchmark()`

This tests whether an LLM can produce better summaries than rule-based strategies by learning what matters for patient matching.

In [14]:
async def call_haiku(prompt: str, system_prompt: str = "", retries: int = 3) -> str:
    """Call Claude Haiku via claude-agent-sdk with retry on failure."""
    for attempt in range(retries):
        try:
            result_parts = []
            async for message in query(
                prompt=prompt,
                options=ClaudeAgentOptions(
                    model="claude-haiku-4-5-20251001",
                    max_turns=1,
                    allowed_tools=[],
                    system_prompt=system_prompt,
                ),
            ):
                if hasattr(message, 'content'):
                    if isinstance(message.content, list):
                        for block in message.content:
                            if hasattr(block, 'text'):
                                result_parts.append(block.text)
                    elif isinstance(message.content, str):
                        result_parts.append(message.content)
            return "\n".join(result_parts)
        except Exception as e:
            if attempt < retries - 1:
                wait = 2 ** (attempt + 1)
                print(f"    Retry {attempt+1}/{retries-1} after error: {e!r} (waiting {wait}s)")
                await asyncio.sleep(wait)
            else:
                raise


RAW_SECTION_CAP = 50  # max rows per section in raw records for Haiku input


def format_raw_records(patient_id, facility_id, medical_records):
    """Format raw patient records as text input for Haiku summarization."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    for key, label in [
        ('conditions', 'CONDITIONS'),
        ('medications', 'MEDICATIONS'),
        ('allergies', 'ALLERGIES'),
        ('observations', 'OBSERVATIONS'),
        ('procedures', 'PROCEDURES'),
        ('encounters', 'ENCOUNTERS'),
        ('immunizations', 'IMMUNIZATIONS'),
        ('careplans', 'CARE PLANS'),
        ('devices', 'DEVICES'),
        ('imaging_studies', 'IMAGING'),
    ]:
        df = records.get(key)
        if df is not None and not df.empty:
            lines = [f"=== {label} ({len(df)} rows) ==="]
            for _, row in df.head(RAW_SECTION_CAP).iterrows():
                parts = []
                for col in df.columns:
                    val = row[col]
                    if pd.notna(val) and str(val).strip():
                        parts.append(f"{col}={val}")
                lines.append("  " + " | ".join(parts))
            if len(df) > RAW_SECTION_CAP:
                lines.append(f"  ... and {len(df) - RAW_SECTION_CAP} more rows")
            sections.append("\n".join(lines))

    return "\n\n".join(sections) if sections else "No records available."


HAIKU_SUMMARIZER_PROMPT = """You are a medical record summarizer for an entity resolution system that matches patient records across facilities.

Given raw medical records for one patient at one facility, produce a concise clinical summary (~500 words) that preserves the features most useful for determining whether two records belong to the same person.

PRIORITIZE these highly discriminative features:
- Specific chronic conditions with onset dates (e.g., "Type 2 diabetes diagnosed 2018")
- Medications with exact dosages and time periods (e.g., "Metformin 500mg since 2018")
- Allergies with specific reactions and severity
- Procedures with exact dates (especially rare/distinctive ones)
- Key lab values and vitals with dates (HbA1c, cholesterol, BMI trends)
- Immunization dates

OMIT or minimize:
- Generic social findings (employment status, education level, stress)
- Routine screening procedures that appear on most patients
- Encounter metadata without clinical content
- Care plan boilerplate

Format as a structured clinical summary with clear section headers. Be specific — exact dates, dosages, and values matter more than general descriptions."""


# Build haiku_cache for all unique record_ids in benchmark pairs
unique_rids = set()
for r1, r2, _ in benchmark_pairs:
    unique_rids.add(r1)
    unique_rids.add(r2)
unique_rids = sorted(unique_rids)

print(f"Building Haiku cache for {len(unique_rids)} unique record_ids...")

haiku_cache = {}
errors = 0
for i, rid in enumerate(unique_rids):
    if rid not in record_map:
        continue
    pid, fid = record_map[rid]
    raw_text = format_raw_records(pid, fid, medical_records)
    prompt = f"Summarize the following patient records:\n\n{raw_text}"
    try:
        summary = await call_haiku(prompt, system_prompt=HAIKU_SUMMARIZER_PROMPT)
        haiku_cache[rid] = summary
    except Exception as e:
        errors += 1
        print(f"  FAILED {rid}: {e!r}")
        haiku_cache[rid] = "Summary unavailable due to error."

    if (i + 1) % 10 == 0 or (i + 1) == len(unique_rids):
        print(f"  {i+1}/{len(unique_rids)} done")

print(f"\nHaiku cache built: {len(haiku_cache)} summaries ({errors} errors)")
token_counts = [count_tokens(s) for s in haiku_cache.values()]
print(f"Tokens per summary: mean={np.mean(token_counts):.0f}, "
      f"median={np.median(token_counts):.0f}, max={max(token_counts)}")

Building Haiku cache for 95 unique record_ids...
  10/95 done
  20/95 done
  30/95 done
  40/95 done
  50/95 done
  60/95 done
  70/95 done
  80/95 done
  90/95 done
  95/95 done

Haiku cache built: 95 summaries (0 errors)
Tokens per summary: mean=1260, median=1296, max=2280


In [15]:
def summarize_haiku_cached(patient_id, facility_id, medical_records):
    """Look up pre-built Haiku summary from cache."""
    rid = facility_id + '_' + str(patient_id)
    return haiku_cache.get(rid, "No summary available.")

strategy_e_metrics, strategy_e_cache = await run_opus_benchmark(
    "E: Haiku LLM Summarizer",
    summarize_haiku_cached,
    benchmark_pairs, record_map, medical_records
)


  E: Haiku LLM Summarizer
Tokens per summary: mean=1260, median=1296, max=2280
Tokens per pair:    mean=2525, max=3658

Example summary (facility_004_5a9d98df-4309-437e-c91f-6d45126e6101):
----------------------------------------
# Clinical Summary

**Patient ID:** 5a9d98df-4309-437e-c91f-6d45126e6101  
**Facility:** facility_004  
**Record Period:** 1978–2015

---

## Chronic Conditions

**Essential Hypertension**
- Onset: November 27, 1985
- Status: Active/ongoing throughout record period
- Treatment initiated immediately upon diagnosis

**Obesity**
- BMI ≥30 documented: January 6, 1993
- Status: Chronic, documented at wellness visits

**Anemia**
- Referenced in ambulatory encounter: December 10, 1986
- Associated with iron supplementation therapy

---

## Medications

**Antihypertensive Regimen (since 1985)**
- **Hydrochlorothiazide 25 mg** oral tablet: November 27, 1985–present (continuing through 2009 documented)
  - Dosage stable throughout; 4 dispensations per year typical
- **

## 9. Comparison

Side-by-side results for all strategies. The goal is to find the best F1 score while keeping tokens reasonable (~500–2000 per summary).

In [16]:
print(f"\n{'Strategy':<30s}  {'Tok/Sum':>8s}  {'Tok/Pair':>9s}  "
      f"{'Acc':>6s}  {'Prec':>6s}  {'Recall':>6s}  {'F1':>6s}  {'F1/kTok':>8s}")
print("=" * 95)

for name, m in all_results.items():
    f1_per_ktok = m['f1'] / (m['tokens_pair_mean'] / 1000) if m['tokens_pair_mean'] > 0 else 0
    print(f"{name:<30s}  {m['tokens_mean']:>8.0f}  {m['tokens_pair_mean']:>9.0f}  "
          f"{m['accuracy']:>6.3f}  {m['precision']:>6.3f}  {m['recall']:>6.3f}  "
          f"{m['f1']:>6.3f}  {f1_per_ktok:>8.3f}")

# Identify best strategies
best_f1 = max(all_results.items(), key=lambda x: x[1]['f1'])
best_efficiency = max(all_results.items(),
                      key=lambda x: x[1]['f1'] / max(x[1]['tokens_pair_mean'], 1))

print(f"\nBest F1:         {best_f1[0]} (F1={best_f1[1]['f1']:.3f}, "
      f"{best_f1[1]['tokens_pair_mean']:.0f} tok/pair)")
print(f"Best F1/token:   {best_efficiency[0]} (F1={best_efficiency[1]['f1']:.3f}, "
      f"{best_efficiency[1]['tokens_pair_mean']:.0f} tok/pair)")

# Compare to raw records ceiling
print(f"\nReference: Raw records (Opus) = F1 0.980 at ~12,157 tok/pair")


Strategy                         Tok/Sum   Tok/Pair     Acc    Prec  Recall      F1   F1/kTok
Baseline (full summary)             1019       2048   0.940   1.000   0.880   0.936     0.457
A: Temporal-Enhanced                2315       4564   0.940   1.000   0.880   0.936     0.205
B: Identity-Focused                  145        288   0.920   0.957   0.880   0.917     3.185
C: Compact Raw                      1846       3694   0.920   1.000   0.840   0.913     0.247
D: Structured Diff-Friendly          760       1508   0.940   1.000   0.880   0.936     0.621
E: Haiku LLM Summarizer             1260       2525   0.980   1.000   0.960   0.980     0.388

Best F1:         E: Haiku LLM Summarizer (F1=0.980, 2525 tok/pair)
Best F1/token:   B: Identity-Focused (F1=0.917, 288 tok/pair)

Reference: Raw records (Opus) = F1 0.980 at ~12,157 tok/pair


## 10. Recommendations

### Summary of Findings

Review the comparison table above to identify:

1. **Best absolute F1**: Which strategy gets closest to the raw-records ceiling (F1=0.980)?
2. **Best F1-per-token efficiency**: Which strategy gives the best accuracy per token budget?
3. **Practical choice**: For fine-tuning Gemma 1B (context window ~2048), which strategy fits while maximizing signal?

### Next Steps

Based on results:
- If a strategy beats F1=0.889 (current production), update `llm-entity-resolution/src/summarize.py`
- Consider combining the best elements from multiple strategies
- Use the winning strategy to generate training data for Gemma 1B fine-tuning