# Fine-Tuning Gemma 1B for Medical Record Entity Resolution

Fine-tune **Gemma 3 1B** with LoRA to classify pairs of medical records as same-patient matches or non-matches. Uses only clinical data (conditions, medications, allergies, procedures, immunizations, observations) — no demographics.

**Approach:**
- Generate balanced training pairs from ground truth
- **Strategy D**: Structured diff-friendly summaries (~760 tokens each, Opus ceiling F1=0.936)
- Binary classification: "True" (match) or "False" (non-match)
- LoRA fine-tuning on RunPod GPU (RTX A4000, bfloat16) — see `train_on_gpu.py` and `RUNPOD_GUIDE.md`

**This notebook**: Explores the data pipeline (sections 3-6), then evaluates the GPU-trained adapter (sections 7-8).

## 1. Environment Setup

In [1]:
import os
import sys
import torch

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

# Add project root so shared/ imports work
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Project root: {PROJECT_ROOT}")

PyTorch version: 2.10.0
MPS available: True
Using device: mps
Project root: /Users/alex/repos/Kaggle/SyntheticMass


## 2. HuggingFace Login

In [2]:
from dotenv import load_dotenv
from huggingface_hub import login, whoami

load_dotenv()

token = os.environ.get("HF_TOKEN")
if not token:
    raise ValueError("HF_TOKEN not found. Create a .env file with: HF_TOKEN=hf_your_token_here")

login(token=token)
print(f"Logged in as: {whoami()['name']}")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in as: abicyclerider


## 3. Load Tokenizer

Load the **Gemma 3 1B Instruct** tokenizer for token length analysis. The full model is loaded later for evaluation.

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "google/gemma-3-1b-it"

print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Vocab size: {tokenizer.vocab_size:,}")

Loading tokenizer from google/gemma-3-1b-it...
Vocab size: 262,144


## 4. Load Data from Ground Truth

Generate balanced pairs entirely from ground truth + medical records. No dependency on demographic matching scores.

- **True match pairs**: Patients appearing at 2+ facilities yield cross-facility pairs
- **Non-match pairs**: Random pairs from different true patient identities

In [4]:
import random
import numpy as np
import pandas as pd
from shared.data_loader import load_facility_patients
from shared.ground_truth import (
    load_ground_truth,
    add_record_ids_to_ground_truth,
    generate_true_pairs_from_ground_truth,
)
from shared.medical_records import load_medical_records, get_patient_records

RUN_DIR = os.path.join(PROJECT_ROOT, "output", "augmented", "run_20260203_071928")

# Load patients and create record_ids
patients_df = load_facility_patients(RUN_DIR)
patients_df['record_id'] = patients_df['facility_id'] + '_' + patients_df['id'].astype(str)

# Load ground truth and add record_ids
ground_truth_df = load_ground_truth(RUN_DIR)
ground_truth_df = add_record_ids_to_ground_truth(ground_truth_df, patients_df)

# Generate true match pairs
true_pairs = generate_true_pairs_from_ground_truth(ground_truth_df)

# Build record_id -> (patient_uuid, facility_id) mapping
record_map = {}
for _, row in patients_df.iterrows():
    record_map[row['record_id']] = (row['id'], row['facility_id'])

# Generate non-match pairs (balanced with true pairs)
rid_to_true_id = (ground_truth_df.dropna(subset=['record_id'])
                  .set_index('record_id')['true_patient_id'].to_dict())
all_record_ids = list(rid_to_true_id.keys())

random.seed(42)
non_match_pairs = set()
target = len(true_pairs)
attempts = 0
while len(non_match_pairs) < target and attempts < target * 20:
    r1, r2 = random.sample(all_record_ids, 2)
    if rid_to_true_id.get(r1) != rid_to_true_id.get(r2):
        non_match_pairs.add(tuple(sorted([r1, r2])))
    attempts += 1

# Stats
multi_facility = ground_truth_df.groupby('true_patient_id')['facility_id'].nunique()
print(f"Total patient records: {len(patients_df)}")
print(f"Unique patients: {ground_truth_df['true_patient_id'].nunique()}")
print(f"Multi-facility patients: {(multi_facility >= 2).sum()}")
print(f"True match pairs: {len(true_pairs)}")
print(f"Non-match pairs: {len(non_match_pairs)}")

Total patient records: 1228
Unique patients: 571
Multi-facility patients: 354
True match pairs: 1121
Non-match pairs: 1121


## 5. Strategy D: Structured Diff-Friendly Summaries

**Clinical data only** — no demographics (name, DOB, address, SSN, gender). Structured for pairwise comparison with conditions grouped by year, medications with date ranges, key observations, and procedures with years. ~760 tokens per summary (Opus achieves F1=0.936 on these).

In [5]:
# Load all medical records
medical_records = load_medical_records(RUN_DIR)


def summarize_diff_friendly(patient_id, facility_id, medical_records):
    """Structured for pairwise comparison, grouped by year. Target ~800 tokens."""
    records = get_patient_records(patient_id, facility_id, medical_records)
    sections = []

    # CONDITIONS - grouped by onset year
    cond_df = records.get('conditions')
    if cond_df is not None:
        cond_df = cond_df.copy()
        cond_df['year'] = pd.to_datetime(cond_df['START'], errors='coerce').dt.year
        cond_df['is_ongoing'] = cond_df['STOP'].isna() | (cond_df['STOP'].astype(str).str.strip() == '')
        lines = ["CONDITIONS:"]
        for year, grp in sorted(cond_df.groupby('year')):
            if pd.isna(year):
                continue
            descs = []
            for _, row in grp.iterrows():
                status = " *" if row['is_ongoing'] else ""
                descs.append(f"{row['DESCRIPTION']}{status}")
            lines.append(f"  {int(year)}: {'; '.join(descs)}")
        sections.append("\n".join(lines))

    # MEDICATIONS - drug (start_year-end_year or ongoing)
    meds_df = records.get('medications')
    if meds_df is not None:
        meds_df = meds_df.copy()
        lines = ["MEDICATIONS:"]
        for desc, grp in meds_df.groupby('DESCRIPTION', sort=False):
            start_dt = pd.to_datetime(grp['START'], errors='coerce').min()
            is_current = grp['STOP'].isna().any() | (grp['STOP'].astype(str).str.strip() == '').any()
            if is_current:
                period = f"{start_dt.year}\u2013ongoing" if pd.notna(start_dt) else "ongoing"
            else:
                end_dt = pd.to_datetime(grp['STOP'], errors='coerce').max()
                if pd.notna(start_dt) and pd.notna(end_dt):
                    period = f"{start_dt.year}\u2013{end_dt.year}"
                else:
                    period = "unknown"
            lines.append(f"- {desc} ({period})")
        sections.append("\n".join(lines))

    # ALLERGIES - flat list
    allg_df = records.get('allergies')
    if allg_df is not None:
        names = sorted(allg_df['DESCRIPTION'].unique())
        sections.append("ALLERGIES: " + "; ".join(names))

    # KEY OBSERVATIONS - latest 2 values per metric
    obs_df = records.get('observations')
    if obs_df is not None:
        obs_df = obs_df.copy()
        obs_df['date_dt'] = pd.to_datetime(obs_df['DATE'], errors='coerce')
        key_obs = [
            'Body Height', 'Body Weight', 'Body Mass Index',
            'Systolic Blood Pressure', 'Diastolic Blood Pressure',
            'Hemoglobin A1c/Hemoglobin.total in Blood',
            'Glucose', 'Total Cholesterol',
        ]
        lines = ["OBSERVATIONS:"]
        for obs_name in key_obs:
            match = obs_df[obs_df['DESCRIPTION'].str.contains(obs_name, case=False, na=False)]
            if match.empty:
                continue
            recent = match.sort_values('date_dt').tail(2)
            vals = []
            for _, row in recent.iterrows():
                v = row.get('VALUE', '')
                u = row.get('UNITS', '')
                d = str(row.get('DATE', ''))[:10]
                if pd.notna(v) and str(v).strip():
                    u_str = f" {u}" if pd.notna(u) and u else ""
                    vals.append(f"{v}{u_str} ({d})")
            if vals:
                lines.append(f"- {obs_name}: {', '.join(vals)}")
        sections.append("\n".join(lines))

    # PROCEDURES - with years, chronological
    proc_df = records.get('procedures')
    if proc_df is not None:
        proc_df = proc_df.copy()
        proc_df['year'] = pd.to_datetime(proc_df['START'], errors='coerce').dt.year
        lines = ["PROCEDURES:"]
        for desc, grp in proc_df.groupby('DESCRIPTION', sort=False):
            years = sorted(grp['year'].dropna().unique())
            year_strs = [str(int(y)) for y in years]
            lines.append(f"- {desc} ({', '.join(year_strs)})")
        sections.append("\n".join(lines))

    return "\n\n".join(sections) if sections else "No clinical records available."


# Build summary cache for all needed records
all_needed = set()
for r1, r2 in true_pairs | non_match_pairs:
    all_needed.add(r1)
    all_needed.add(r2)

print(f"Generating Strategy D summaries for {len(all_needed)} unique records...")
summary_cache = {}
for rid in all_needed:
    if rid in record_map:
        pid, fid = record_map[rid]
        summary_cache[rid] = summarize_diff_friendly(pid, fid, medical_records)

# Token length stats
token_lengths = [len(tokenizer.encode(s)) for s in summary_cache.values()]
pair_lengths = []
for r1, r2 in list(true_pairs)[:200] + list(non_match_pairs)[:200]:
    if r1 in summary_cache and r2 in summary_cache:
        combined = summary_cache[r1] + "\n\n" + summary_cache[r2]
        pair_lengths.append(len(tokenizer.encode(combined)))

print(f"\nSingle summary token lengths:")
print(f"  Mean: {np.mean(token_lengths):.0f}, Median: {np.median(token_lengths):.0f}")
print(f"  Min: {min(token_lengths)}, Max: {max(token_lengths)}")
print(f"  95th pctl: {np.percentile(token_lengths, 95):.0f}")
print(f"\nPair token lengths (summaries only, sample of {len(pair_lengths)}):")
print(f"  Mean: {np.mean(pair_lengths):.0f}, Max: {max(pair_lengths)}")

Generating Strategy D summaries for 1190 unique records...

Single summary token lengths:
  Mean: 830, Median: 692
  Min: 6, Max: 5247
  95th pctl: 1978

Pair token lengths (summaries only, sample of 400):
  Mean: 1513, Max: 4673


### Example Pairs: Match vs Non-Match (Strategy D)

What does the model actually see? Here are a few true matches (same patient at different facilities) and non-matches (different patients) side by side.

In [6]:
# Show 2 true match pairs and 2 non-match pairs
sample_true = list(true_pairs)[:2]
sample_false = list(non_match_pairs)[:2]

for label_name, pairs in [("TRUE MATCH", sample_true), ("NON-MATCH", sample_false)]:
    for r1, r2 in pairs:
        if r1 not in summary_cache or r2 not in summary_cache:
            continue
        fac1 = r1.split('_')[0] + '_' + r1.split('_')[1]
        fac2 = r2.split('_')[0] + '_' + r2.split('_')[1]
        print(f"{'='*70}")
        print(f"  {label_name}  |  {fac1} vs {fac2}")
        print(f"{'='*70}")
        print(f"\n--- Record A ({fac1}) ---")
        print(summary_cache[r1])
        print(f"\n--- Record B ({fac2}) ---")
        print(summary_cache[r2])
        print()

  TRUE MATCH  |  facility_002 vs facility_003

--- Record A (facility_002) ---
CONDITIONS:
  2010: Medication review due (situation); Stress (finding)
  2015: Abnormal findings diagnostic imaging heart+coronary circulat (finding) *
  2018: Part-time employment (finding); Limited social contact (finding); Victim of intimate partner abuse (finding)
  2019: Acute viral pharyngitis (disorder)
  2022: Medication review due (situation); Full-time employment (finding)
  2024: Severe anxiety (panic) (finding)

MEDICATIONS:
- Hydrochlorothiazide 25 MG Oral Tablet (2014–2025)
- lisinopril 10 MG Oral Tablet (2014–2025)
- amLODIPine 2.5 MG Oral Tablet (2018–2025)

OBSERVATIONS:
- Body Height: 171.6 cm (2022-03-24), 171.6 cm (2024-04-04)
- Body Weight: 80.4 kg (2022-03-24), 80.4 kg (2024-04-04)
- Body Mass Index: 27.3 kg/m2 (2022-03-24), 27.3 kg/m2 (2024-04-04)
- Systolic Blood Pressure: 99.0 mm[Hg] (2022-03-24), 96.0 mm[Hg] (2024-04-04)
- Diastolic Blood Pressure: 66.0 mm[Hg] (2022-03-24), 72.0 mm

## 6. Load Test Set from Hub

Load the test split from the HF Hub dataset (built by `prepare_dataset.py`). This guarantees no overlap with the training data used on GPU.

In [7]:
from datasets import load_dataset

DATASET_REPO = "abicyclerider/entity-resolution-pairs"

print(f"Loading dataset from {DATASET_REPO}...")
dataset = load_dataset(DATASET_REPO)

# Extract test prompts and labels
test_prompts = []
test_labels = []
for example in dataset["test"]:
    test_prompts.append(example["messages"][0]["content"])
    test_labels.append(example["messages"][1]["content"] == "True")

n_match = sum(test_labels)
n_non = len(test_labels) - n_match
print(f"\nDataset splits:")
print(f"  Train: {len(dataset['train'])}")
print(f"  Eval:  {len(dataset['eval'])}")
print(f"  Test:  {len(dataset['test'])} ({n_match} match + {n_non} non-match)")

# Token length stats for test set
test_lengths = []
for example in dataset["test"]:
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    test_lengths.append(len(tokenizer.encode(text)))
print(f"\nTest sequence lengths:")
print(f"  Mean: {np.mean(test_lengths):.0f}, Max: {max(test_lengths)}")

Loading dataset from abicyclerider/entity-resolution-pairs...


README.md:   0%|          | 0.00/543 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/1.50M [00:00<?, ?B/s]

data/eval-00000-of-00001.parquet:   0%|          | 0.00/324k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/326k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1568 [00:00<?, ? examples/s]

Generating eval split:   0%|          | 0/336 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/338 [00:00<?, ? examples/s]


Dataset splits:
  Train: 1568
  Eval:  336
  Test:  338 (169 match + 169 non-match)

Test sequence lengths:
  Mean: 1566, Max: 4555


## 7. Benchmark Base Model (Before Fine-Tuning)

Run the base Gemma 1B model on the Hub test set with greedy decoding. This gives us a baseline to compare against after fine-tuning.

In [8]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix


def predict_match(model, tokenizer, prompt):
    """Predict whether two medical records match from a formatted prompt."""
    messages = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(
        input_text, return_tensors="pt", truncation=True, max_length=4096
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=8, do_sample=False)

    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    ).strip().lower()

    if "true" in response:
        return True
    elif "false" in response:
        return False
    return None


def evaluate_predictions(labels, preds):
    """Compute classification metrics."""
    return {
        'accuracy': accuracy_score(labels, preds),
        'precision': precision_score(labels, preds, zero_division=0),
        'recall': recall_score(labels, preds, zero_division=0),
        'f1': f1_score(labels, preds, zero_division=0),
    }


def run_evaluation(model, tokenizer, test_prompts, test_labels, label="Model"):
    """Run model on test prompts and return metrics + predictions."""
    preds, labels, indices = [], [], []
    for i, (prompt, true_label) in enumerate(zip(test_prompts, test_labels)):
        pred = predict_match(model, tokenizer, prompt)
        if pred is not None:
            preds.append(pred)
            labels.append(true_label)
            indices.append(i)
        if (i + 1) % 50 == 0:
            print(f"  {i+1}/{len(test_prompts)}...")

    metrics = evaluate_predictions(labels, preds)
    print(f"\n{label} ({len(preds)} parseable / {len(test_prompts)} total):")
    for m, v in metrics.items():
        print(f"  {m:>10s}: {v:.3f}")
    print(f"\nConfusion matrix (rows=actual, cols=predicted):")
    print(confusion_matrix(labels, preds))
    return metrics, preds, labels


# Load base model
print(f"Loading {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, dtype=torch.float32, device_map="mps",
)

print(f"\nBenchmarking base model on {len(test_prompts)} test pairs...")
base_metrics, base_preds, base_labels = run_evaluation(
    model, tokenizer, test_prompts, test_labels, "Base Model"
)

# Free memory before loading adapter
del model
torch.mps.empty_cache()

Loading google/gemma-3-1b-it...


Loading weights:   0%|          | 0/340 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



Benchmarking base model on 338 test pairs...
  50/338...
  100/338...
  150/338...
  200/338...
  250/338...
  300/338...

Base Model (336 parseable / 338 total):
    accuracy: 0.527
   precision: 0.523
      recall: 0.675
          f1: 0.589

Confusion matrix (rows=actual, cols=predicted):
[[ 63 104]
 [ 55 114]]


## 8. Evaluate GPU-Trained Adapter

Download the LoRA adapter trained on RunPod GPU (bfloat16, batch=1, grad_accum=16, 2048 max_length, full dataset) and evaluate on the test set.

In [9]:
from peft import PeftModel

GPU_ADAPTER_REPO = "abicyclerider/gemma-1b-entity-resolution-lora"

print(f"Loading {MODEL_ID}...")
gpu_base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, dtype=torch.float32, device_map="mps",
)

print(f"Loading GPU-trained adapter from {GPU_ADAPTER_REPO}...")
gpu_model = PeftModel.from_pretrained(gpu_base_model, GPU_ADAPTER_REPO)

print(f"\nEvaluating GPU-trained adapter on {len(test_prompts)} test pairs...")
gpu_metrics, gpu_preds, gpu_labels = run_evaluation(
    gpu_model, tokenizer, test_prompts, test_labels, "GPU Fine-Tuned"
)

# Side-by-side comparison
print(f"\n{'Metric':>10s}  {'Base':>8s}  {'GPU FT':>10s}  {'Delta':>8s}")
print("-" * 44)
for m in ['accuracy', 'precision', 'recall', 'f1']:
    b = base_metrics[m]
    gpu = gpu_metrics[m]
    print(f"{m:>10s}  {b:>8.3f}  {gpu:>10.3f}  {gpu-b:>+8.3f}")

# Cleanup
del gpu_base_model, gpu_model
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

Loading google/gemma-3-1b-it...


Loading weights:   0%|          | 0/340 [00:00<?, ?it/s]

Loading GPU-trained adapter from abicyclerider/gemma-1b-entity-resolution-lora...

Evaluating GPU-trained adapter on 338 test pairs...
  50/338...
  100/338...
  150/338...
  200/338...
  250/338...
  300/338...

GPU Fine-Tuned (336 parseable / 338 total):
    accuracy: 0.571
   precision: 0.544
      recall: 0.917
          f1: 0.683

Confusion matrix (rows=actual, cols=predicted):
[[ 37 130]
 [ 14 155]]

    Metric      Base      GPU FT     Delta
--------------------------------------------
  accuracy     0.527       0.571    +0.045
 precision     0.523       0.544    +0.021
    recall     0.675       0.917    +0.243
        f1     0.589       0.683    +0.094


## 9. Next Steps

### Optimize training
- **More data**: Use all available pairs instead of capping at 500+500
- **Higher LoRA rank**: Try `r=16` or `r=32` for more capacity
- **Sequence length**: Current 2048 may truncate some pairs — monitor truncation rate

### Larger models
- **QLoRA on GPU**: Fine-tune MedGemma 4B or 27B with 4-bit quantization on A100/H100
- **MedGemma**: Same architecture as Gemma — same LoRA approach works directly

### Integration
- Use the fine-tuned model as a **gray zone classifier** in `llm-entity-resolution/src/classify.py`
- Replace MedGemma 4B Ollama calls with this faster, local fine-tuned model
- Pipeline: demographic blocking → similarity scoring → fine-tuned LLM for ambiguous pairs