# Entity Resolution Evaluation

This notebook evaluates the patient matching pipeline against ground truth data.

## Overview

The MedMatch AI system uses a 4-stage pipeline:
1. **Blocking** - Reduces O(n¬≤) comparisons using phonetic and key-based blocking
2. **Rules** - Deterministic matching for clear cases (exact matches, MRN matches)
3. **Scoring** - Weighted feature scoring for moderate confidence cases
4. **AI** - Medical fingerprinting for ambiguous cases using Gemini API

## Targets

- **Easy cases**: ‚â•95% accuracy
- **Medium cases**: ‚â•85% accuracy  
- **Hard/ambiguous cases**: ‚â•70% accuracy
- **Overall**: ‚â•85% accuracy

# Entity Resolution Evaluation - Phase 2 Complete

**Status:** ‚úÖ **Phase 2 Complete** (January 2026)

**Overall Results:**

- **Accuracy:** 94.51% (437 test pairs)
- **All Accuracy Targets Exceeded:** ‚úÖ

**Accuracy by Difficulty:**

| Difficulty | Target | Achieved | Status |
|------------|--------|----------|--------|
| Easy       | 95%    | 100.00%  | ‚úÖ PASS |
| Medium     | 85%    | 100.00%  | ‚úÖ PASS |
| Hard       | 70%    | 88.24%   | ‚úÖ PASS |
| Ambiguous  | 70%    | 80.54%   | ‚úÖ PASS |

**Pipeline Configuration:**

This notebook evaluates the complete 4-stage progressive pipeline:

1. **Blocking** - Reduces 33,930 possible pairs ‚Üí 437 candidates (97% reduction, 97.3% recall)
2. **Deterministic Rules** - Handles 74% of decisions (92.6% accuracy)
3. **Feature Scoring** - Handles 0% of decisions when AI enabled, 26% otherwise (87.6% accuracy)
4. **AI Medical Fingerprinting** - Handles 26% of decisions (100% accuracy)

**Key Capabilities Demonstrated:**

- Medical abbreviation understanding (T2DM = Type 2 Diabetes, HTN = Hypertension)
- Name variation handling (John vs Johnny, typos, format differences)
- Medical history comparison using Gemini API
- Explainable decisions with confidence scores
- Progressive pipeline routing based on difficulty

**Configuration Notes:**

- AI enabled with `api_rate_limit=0` (billing enabled, no rate limiting)
- Uses `gemini-2.5-flash` model
- Medical records loaded for AI comparison
- All 10 integration tests passing

**Documentation:**

- [Matching Module README](../src/medmatch/matching/README.md) - Complete architecture and usage
- [Scripts README](../scripts/README.md) - CLI wrapper for batch processing
- [Quick Start Guide](../docs/quickstart.md) - 5-minute getting started guide

---

In [None]:
# Setup
import sys
import os
from pathlib import Path

# Enable auto-reload of modules (detects code changes)
%load_ext autoreload
%autoreload 2

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import date
from tqdm import tqdm

# MedMatch imports
from medmatch.matching import PatientRecord, PatientMatcher, MatchExplainer
from medmatch.evaluation import MatchEvaluator, EvaluationMetrics
from medmatch.data.models.patient import Demographics, Address

# Configure plotting
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

print("Setup complete! (Auto-reload enabled)")

## 1. Load Data

Load the synthetic demographics dataset and convert to PatientRecord objects.

In [None]:
# Load demographics
data_dir = project_root / 'data' / 'synthetic'
demographics_path = data_dir / 'synthetic_demographics.csv'
ground_truth_path = data_dir / 'ground_truth.csv'

df_demo = pd.read_csv(demographics_path)
df_gt = pd.read_csv(ground_truth_path)

print(f"Loaded {len(df_demo)} demographic records")
print(f"Loaded {len(df_gt)} ground truth entries")
print(f"\nUnique patients: {df_gt['patient_id'].nunique()}")
df_demo.head()

In [None]:
# Ground truth difficulty distribution
print("Difficulty distribution:")
print(df_gt['difficulty'].value_counts())

In [None]:
import json

def load_patient_records(df: pd.DataFrame, medical_records_path: Path = None) -> list:
    """Convert demographics DataFrame to PatientRecord list, with optional medical records."""
    records = []
    
    # Load medical records if path provided
    medical_by_patient = {}
    if medical_records_path and medical_records_path.exists():
        with open(medical_records_path, 'r') as f:
            medical_data = json.load(f)
        # Index by patient_id (each patient may have multiple medical records, use first)
        for mr in medical_data:
            patient_id = mr['patient_id']
            if patient_id not in medical_by_patient:
                medical_by_patient[patient_id] = mr
        print(f"Loaded {len(medical_by_patient)} medical records")
    
    for _, row in df.iterrows():
        # Parse date of birth
        dob_str = row['date_of_birth']
        if isinstance(dob_str, str):
            dob = date.fromisoformat(dob_str)
        else:
            dob = dob_str
        
        # Parse record date
        rec_date_str = row.get('record_date')
        if pd.isna(rec_date_str):
            rec_date = date.today()
        elif isinstance(rec_date_str, str):
            rec_date = date.fromisoformat(rec_date_str.split('T')[0])
        else:
            rec_date = rec_date_str
        
        # Create Address if available
        address = None
        if pd.notna(row.get('address_street')):
            address = Address(
                street=row['address_street'],
                city=row.get('address_city', ''),
                state=row.get('address_state', ''),
                zip_code=str(row.get('address_zip', '')),
            )
        
        # Create Demographics object
        demo = Demographics(
            record_id=row['record_id'],
            patient_id=row['patient_id'],  # For validation only - not used in matching
            name_first=row['name_first'],
            name_middle=row.get('name_middle') if pd.notna(row.get('name_middle')) else None,
            name_last=row['name_last'],
            name_suffix=row.get('name_suffix') if pd.notna(row.get('name_suffix')) else None,
            date_of_birth=dob,
            gender=row['gender'],
            mrn=str(row['mrn']),  # Convert to string for Pydantic validation
            ssn_last4=str(row['ssn_last4']) if pd.notna(row.get('ssn_last4')) else None,  # Convert to string
            phone=row.get('phone') if pd.notna(row.get('phone')) else None,
            email=row.get('email') if pd.notna(row.get('email')) else None,
            address=address,
            record_source=row.get('record_source', 'unknown'),
            record_date=rec_date,
            data_quality_flag=row.get('data_quality_flag') if pd.notna(row.get('data_quality_flag')) else None,
        )
        
        # Get medical record for this patient (if available)
        medical = None
        patient_id = row['patient_id']
        if patient_id in medical_by_patient:
            mr_data = medical_by_patient[patient_id]
            # Import MedicalRecord model
            from medmatch.data.models.patient import MedicalRecord, MedicalHistory, MedicalCondition, Surgery
            
            # Parse medical history
            mh_data = mr_data.get('medical_history', {})
            conditions = [
                MedicalCondition(
                    name=c['name'],
                    abbreviation=c.get('abbreviation'),
                    onset_year=c.get('onset_year'),
                    status=c.get('status', 'active')
                ) for c in mh_data.get('conditions', [])
            ]
            surgeries = [
                Surgery(
                    procedure=s['procedure'],
                    date=date.fromisoformat(s['date']) if s.get('date') else None
                ) for s in mh_data.get('surgeries', [])
            ]
            medical_history = MedicalHistory(
                conditions=conditions,
                medications=mh_data.get('medications', []),
                allergies=mh_data.get('allergies', []),
                surgeries=surgeries,
                family_history=mh_data.get('family_history', []),
                social_history=mh_data.get('social_history', '')
            )
            
            # Create MedicalRecord
            medical = MedicalRecord(
                record_id=mr_data['record_id'],
                patient_id=mr_data['patient_id'],
                record_source=mr_data.get('record_source', 'unknown'),
                record_date=date.fromisoformat(mr_data['record_date'].split('T')[0]),
                chief_complaint=mr_data.get('chief_complaint'),
                medical_history=medical_history,
                assessment=mr_data.get('assessment'),
                plan=mr_data.get('plan'),
                clinical_notes=mr_data.get('clinical_notes'),
            )
        
        records.append(PatientRecord.from_demographics(demo, medical))
    
    return records

# Load all records WITH medical history
medical_records_path = data_dir / 'synthetic_medical_records.json'
records = load_patient_records(df_demo, medical_records_path)
print(f"Converted {len(records)} PatientRecord objects")

# Check how many have medical history
with_medical = sum(1 for r in records if r.conditions or r.medications)
print(f"Records with medical history: {with_medical}/{len(records)}")

## 2. Initialize Matcher

Set up the PatientMatcher with all stages enabled (except AI by default).

In [None]:
# Initialize matcher WITHOUT AI (fast, no API calls)
# For production, use ai_backend="ollama" when use_ai=True
matcher = PatientMatcher(
    use_blocking=True,
    use_rules=True,
    use_scoring=True,
    use_ai=False,  # Set to True to enable AI for ambiguous cases
)

# Initialize evaluator
evaluator = MatchEvaluator(str(ground_truth_path))

# Initialize explainer for generating human-readable explanations
explainer = MatchExplainer()

## 3. Run Matching

Execute the full matching pipeline on all records.

In [None]:
# Run matching
print("Running entity resolution pipeline...")
results = matcher.match_datasets(records, show_progress=True)

print(f"\nGenerated {len(results)} match results")

In [None]:
# Quick summary
stats = matcher.get_stats(results)
print("\nMatching Statistics:")
print(f"  Total pairs evaluated: {stats['total_pairs']}")
print(f"  Matches found: {stats['matches']}")
print(f"  No-matches: {stats['no_matches']}")
print(f"  Average confidence: {stats['avg_confidence']:.3f}")
print(f"\n  By stage: {stats['by_stage']}")
print(f"  By type: {stats['by_match_type']}")

## 4. Evaluate Results

Calculate precision, recall, F1, and accuracy metrics.

In [None]:
# Overall evaluation
overall_metrics = evaluator.evaluate(results)
print(overall_metrics)

In [None]:
# Generate full report
report = evaluator.generate_report(results, verbose=True)
print(report)

## 5. Metrics by Difficulty

Compare performance across difficulty levels.

In [None]:
# Evaluate by difficulty
by_difficulty = evaluator.evaluate_by_difficulty(results)

# Create comparison DataFrame
difficulty_data = []
targets = {'easy': 0.95, 'medium': 0.85, 'hard': 0.70, 'ambiguous': 0.70}

for diff in ['easy', 'medium', 'hard', 'ambiguous']:
    if diff in by_difficulty:
        m = by_difficulty[diff]
        difficulty_data.append({
            'Difficulty': diff.capitalize(),
            'Pairs': m.total_pairs,
            'Accuracy': m.accuracy,
            'Target': targets[diff],
            'Precision': m.precision,
            'Recall': m.recall,
            'F1': m.f1_score,
            'Status': '‚úì PASS' if m.accuracy >= targets[diff] else '‚úó FAIL'
        })

df_metrics = pd.DataFrame(difficulty_data)
df_metrics

In [None]:
# Visualize accuracy by difficulty
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(df_metrics))
width = 0.35

bars1 = ax.bar(x - width/2, df_metrics['Accuracy'], width, label='Actual', color='steelblue')
bars2 = ax.bar(x + width/2, df_metrics['Target'], width, label='Target', color='coral', alpha=0.7)

ax.set_ylabel('Accuracy')
ax.set_xlabel('Difficulty Level')
ax.set_title('Entity Resolution Accuracy by Difficulty')
ax.set_xticks(x)
ax.set_xticklabels(df_metrics['Difficulty'])
ax.legend()
ax.set_ylim(0, 1.1)

# Add value labels
for bar in bars1:
    height = bar.get_height()
    ax.annotate(f'{height:.1%}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 6. Confusion Matrix

Visualize true/false positive/negative distribution.

In [None]:
# Build confusion matrix
cm = np.array([
    [overall_metrics.true_negatives, overall_metrics.false_positives],
    [overall_metrics.false_negatives, overall_metrics.true_positives]
])

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Predicted: No Match', 'Predicted: Match'],
            yticklabels=['Actual: No Match', 'Actual: Match'],
            ax=ax)
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()

print(f"True Positives:  {overall_metrics.true_positives}")
print(f"True Negatives:  {overall_metrics.true_negatives}")
print(f"False Positives: {overall_metrics.false_positives}")
print(f"False Negatives: {overall_metrics.false_negatives}")

## 7. Error Analysis

Examine false positives and false negatives to understand failure modes.

In [None]:
# Find all errors
errors = evaluator.find_errors(results)
print(f"Total errors: {len(errors)}")

fp_errors = [e for e in errors if e.error_type == 'false_positive']
fn_errors = [e for e in errors if e.error_type == 'false_negative']

print(f"False Positives: {len(fp_errors)}")
print(f"False Negatives: {len(fn_errors)}")

In [None]:
# Analyze false positives (if any)
if fp_errors:
    print("=" * 60)
    print("FALSE POSITIVES (predicted match but actually different patients)")
    print("=" * 60)
    
    for i, err in enumerate(fp_errors[:5], 1):
        print(f"\n[{i}] {err.record_1_id} ‚Üî {err.record_2_id}")
        print(f"    Stage: {err.stage}")
        print(f"    Confidence: {err.confidence:.2f}")
        print(f"    Difficulty: {err.difficulty}")
        if err.explanation:
            print(f"    Explanation: {err.explanation}")
else:
    print("No false positives!")

In [None]:
# Analyze false negatives (if any)
if fn_errors:
    print("=" * 60)
    print("FALSE NEGATIVES (predicted no-match but actually same patient)")
    print("=" * 60)
    
    for i, err in enumerate(fn_errors[:5], 1):
        print(f"\n[{i}] {err.record_1_id} ‚Üî {err.record_2_id}")
        print(f"    Stage: {err.stage}")
        print(f"    Confidence: {err.confidence:.2f}")
        print(f"    Difficulty: {err.difficulty}")
        if err.explanation:
            print(f"    Explanation: {err.explanation[:100]}...")
else:
    print("No false negatives!")

## 8. Example Matches

Show sample match results with full explanations.

In [None]:
# Show sample matches
matches = [r for r in results if r.is_match]
print(f"Total matches found: {len(matches)}")

if matches:
    print("\n" + "=" * 60)
    print("SAMPLE MATCH EXPLANATIONS")
    print("=" * 60)
    
    for result in matches[:3]:
        print("\n" + explainer.explain(result, verbose=True))
        print("-" * 40)

In [None]:
# Show sample non-matches
non_matches = [r for r in results if not r.is_match][:3]

if non_matches:
    print("\n" + "=" * 60)
    print("SAMPLE NON-MATCH EXPLANATIONS")
    print("=" * 60)
    
    for result in non_matches:
        print("\n" + explainer.explain(result))
        print("-" * 40)

## 9. Stage Distribution

Analyze which pipeline stages are making decisions.

In [None]:
# Evaluate by stage
by_stage = evaluator.evaluate_by_stage(results)

stage_data = []
for stage, m in sorted(by_stage.items()):
    stage_data.append({
        'Stage': stage.capitalize(),
        'Pairs': m.total_pairs,
        'Percentage': m.total_pairs / len(results) * 100,
        'Accuracy': m.accuracy,
        'Precision': m.precision,
        'Recall': m.recall,
    })

df_stages = pd.DataFrame(stage_data)
df_stages

In [None]:
# Visualize stage distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Pie chart of decision stage distribution
ax1.pie(df_stages['Pairs'], labels=df_stages['Stage'], autopct='%1.1f%%',
        colors=sns.color_palette('husl', len(df_stages)))
ax1.set_title('Decisions by Pipeline Stage')

# Bar chart of accuracy by stage
ax2.bar(df_stages['Stage'], df_stages['Accuracy'], color='steelblue')
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Pipeline Stage')
ax2.set_title('Accuracy by Decision Stage')
ax2.set_ylim(0, 1.1)

for i, v in enumerate(df_stages['Accuracy']):
    ax2.text(i, v + 0.02, f'{v:.1%}', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

## 10. Summary

Final evaluation summary and target achievement status.

In [None]:
# Final summary
print("=" * 70)
print("ENTITY RESOLUTION EVALUATION SUMMARY")
print("=" * 70)

print(f"\n{'Metric':<20} {'Value':>10}")
print("-" * 32)
print(f"{'Total Pairs':<20} {overall_metrics.total_pairs:>10}")
print(f"{'Accuracy':<20} {overall_metrics.accuracy:>10.2%}")
print(f"{'Precision':<20} {overall_metrics.precision:>10.2%}")
print(f"{'Recall':<20} {overall_metrics.recall:>10.2%}")
print(f"{'F1 Score':<20} {overall_metrics.f1_score:>10.2%}")

print("\n" + "=" * 70)
print("TARGET ACHIEVEMENT")
print("=" * 70)

all_passed = True
for diff in ['easy', 'medium', 'hard', 'ambiguous']:
    if diff in by_difficulty:
        m = by_difficulty[diff]
        target = targets[diff]
        passed = m.accuracy >= target
        status = '‚úì PASS' if passed else '‚úó FAIL'
        all_passed = all_passed and passed
        print(f"{diff.capitalize():<12} {m.accuracy:>6.1%} (target: {target:.0%}) [{status}]")

print("\n" + "=" * 70)
if all_passed:
    print("üéâ ALL TARGETS MET - Phase 2.5 Evaluation Complete!")
else:
    print("‚ö†Ô∏è  Some targets not met - review error analysis above")
print("=" * 70)

In [None]:
# Export summary statistics
summary_stats = evaluator.get_summary_stats(results)

# Save to JSON for programmatic access
import json
output_path = project_root / 'data' / 'synthetic' / 'evaluation_results.json'
with open(output_path, 'w') as f:
    json.dump(summary_stats, f, indent=2)
print(f"Summary stats saved to: {output_path}")

---

## Optional: Run with AI Enabled

To enable AI medical fingerprinting for ambiguous cases, uncomment and run the cell below.
**Note:** This requires a valid `GOOGLE_AI_API_KEY` in your `.env` file.

In [None]:
# Test AI is working - run this cell to verify
#
# Backend Options:
# - "ollama": Local MedGemma (HIPAA-compliant, recommended for production)
# - "gemini": Cloud API (development/testing only, requires GOOGLE_AI_API_KEY)

matcher_with_ai = PatientMatcher(
    use_blocking=True,
    use_rules=True,
    use_scoring=True,
    use_ai=True,
    ai_backend="ollama",  # Use local MedGemma via Ollama
    # ai_backend="gemini",  # Uncomment to use Gemini API instead
    model="medgemma:1.5-4b-q4",  # Quantized model (2-4x faster!)
    api_rate_limit=0,  # No rate limiting
)

print("Running with AI enabled (backend: ollama, model: medgemma:1.5-4b-q4)...")
results_with_ai = matcher_with_ai.match_datasets(records, show_progress=True)
