# Complete NER Pipeline for Building Patient - Disease Graph

-- *Google Colab notebook* --

Extract medical entities from case reports using OpenMed models, then export to CSV for Neo4j graph database ingestion.

# Setup and Imports

In [None]:
!pip install transformers torch pandas numpy rapidfuzz tqdm

In [None]:
import pandas as pd
import numpy as np
from transformers import pipeline
import torch
from tqdm.auto import tqdm
import gc
import json
from collections import defaultdict
import hashlib
from rapidfuzz.fuzz import ratio
from rapidfuzz.process import cdist
import os

import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Mount the Google Drive (if working in Colab)
from google.colab import drive
drive.mount('/content/drive')

# Configuration

In [None]:
class Config:
    """Configuration for the NER pipeline."""

    # Model(s) configuration
    MODELS = {
        'disease': {
            'model_name': "OpenMed/OpenMed-NER-PathologyDetect-PubMed-109M",
            'entity_label': 'DISEASE'
        },
    }

     # Output files
    OUTPUT_DIR = '/content/drive/MyDrive/'
    CASES_FILE = "cases.csv"

    # Dataset
    DATASET = f"{OUTPUT_DIR}synthetic/synthetic_hpi_cases_1k_combined.csv"

    # Processing parameters
    MIN_CONFIDENCE = 0.75  # Minimum confidence score for entities
    MIN_LENGTH = 3  # Minimum character length for entities

    # Device configuration
    DEVICE = 0 if torch.cuda.is_available() else -1

config = Config()

# Utilities


In [None]:
def clean_entity_text(text):
    """Normalize entity text."""
    if isinstance(text, dict):
        # If text is a dict (e.g., from pipeline output), extract its string value
        text = text.get("word") or text.get("text") or str(text)
    elif not isinstance(text, str):
        text = str(text)

    normalized = text.strip().lower().replace("\n", " ").replace("  ", " ")
    normalized = " ".join(normalized.split())
    return normalized


In [None]:
def fuzzy_match_entities(entities_list, similarity_threshold=85):
    """
    Deduplicate entities using RapidFuzz fuzzy matching.
    """
    if not entities_list:
        return []

    # Group entities by type and normalize text
    grouped = defaultdict(list)
    for e in entities_list:
        score = float(e.get('score', 0.0))
        raw_text = e.get('word') or e.get('text') or ''
        grouped[e.get('entity_group', 'UNKNOWN')].append({
            'text': clean_entity_text(raw_text),
            'original': raw_text.strip(),
            'score': score,
            'patient_id': e.get('patient_id'),
        })

    results = []

    for entity_type, entities in grouped.items():
        texts = [e['text'] for e in entities]
        sim_matrix = cdist(texts, texts, scorer=ratio)

        visited = set()
        for i, row in enumerate(sim_matrix):
            if i in visited:
                continue

            cluster_idx = [j for j, sim in enumerate(row) if sim >= similarity_threshold]
            visited.update(cluster_idx)
            cluster = [entities[j] for j in cluster_idx]

            scores = [c['score'] for c in cluster]
            originals = [c['original'] for c in cluster]
            patient_ids = {c['patient_id'] for c in cluster if c.get('patient_id')}

            results.append({
                'text': min((c['text'] for c in cluster), key=len),
                'canonical_text': max(set(originals), key=originals.count),
                'entity_group': entity_type,
                'score': sum(scores) / len(scores),
                'merged_mentions': len(cluster),
                'patient_count': len(patient_ids),
            })

    return sorted(results, key=lambda x: x['score'], reverse=True)

In [None]:
def create_entity_id(entity_text: str, entity_type: str) -> str:
    """
    Generate a short unique ID for an entity based on normalized text and
    entity type.
    """
    normalized = clean_entity_text(entity_text)
    key = f"{normalized}_{entity_type}".encode("utf-8")
    text_hash = hashlib.md5(key).hexdigest()[:8]
    return f"{entity_type[:3].upper()}_{text_hash}"


# Model Execution

In [None]:
def run_ner_model(cases, model_config, model_type, config):
    """
    Run a NER model on clinical notes and return entities and scores.
    """

    print(f"\n{'='*90}")
    print(f"Processing with {model_type.upper()} model: {model_config['model_name']}")
    print(f"{'='*90}")

    # Initialize the model
    ner_pipeline = pipeline(
        task="ner",
        model=model_config["model_name"],
        aggregation_strategy="average",
        device=config.DEVICE
        )

    all_entities = []

    for case in tqdm(cases, desc=f"{model_type} NER"):
        text = case.get("clinical_note", "")
        cid = case.get("patient_id")

        try:
            ents = ner_pipeline(text)
        except Exception as e:
            print(f"Error on {cid}: {e}")
            continue

        for ent in ents:
            # Get the cleaned text first
            raw_text = ent.get("word") or ent.get("text") or ""
            entity_text = clean_entity_text(raw_text)

            # Confidence filtering
            if ent.get("score", 0.0) < config.MIN_CONFIDENCE:
                continue

            # Length filtering
            if len(entity_text) <= config.MIN_LENGTH:
                continue

            # Normalize entity group
            ent_group = str(ent.get("entity_group", "UNKNOWN")).upper()

            all_entities.append({
                "patient_id": cid,
                "entity_group": ent_group,
                "text": entity_text,
                "score": float(ent.get("score", 0.0)),
                "model_type": model_type,
        })

    # Cleanup
    try:
        del ner_pipeline
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass

    return all_entities


# Pipeline Driver

In [None]:
def run_ner_pipeline():
    """
    End-to-end NER pipeline: load dataset, run NER models, deduplicate entities,
    assign entity IDs, build patient–entity relationships, and save CSV outputs.
    """

    # Ensure output directory exists
    os.makedirs(config.OUTPUT_DIR, exist_ok=True)

    # Load dataset and clean dataset
    df = pd.read_csv(config.DATASET)
    df_cases = df.dropna()

    all_entities = []
    for model_type, model_conf in config.MODELS.items():
        ents = run_ner_model(df_cases.to_dict("records"),
                             model_conf,
                             model_type,
                             config
                            )
        all_entities.extend(ents)

    # Exit early if nothing extracted
    if not all_entities:
        print("No entities found.")
        return {
            "cases_file": os.path.join(config.OUTPUT_DIR, config.CASES_FILE),
            "entities_file": None,
            "relationships": {}
        }

    # Deduplicate entities globally
    deduped_entities = fuzzy_match_entities(all_entities)

    # Build canonical map from deduped clusters
    canonical_map = {
        (e["text"].lower().strip(), e["entity_group"]): e["canonical_text"]
        for e in deduped_entities
    }

    # Attach canonical_text back to every original entity record
    for r in all_entities:
        key = (r["text"].lower().strip(), r["entity_group"])
        r["canonical_text"] = canonical_map.get(key, r["text"])

    # Assign entity_id to each deduped cluster
    unique_entities = []
    entity_lookup = {}

    for e in deduped_entities:
        # Create entity_id
        canonical = e["canonical_text"].lower().strip()
        group = e["entity_group"]
        e["entity_id"] = create_entity_id(canonical, group)
        # Create lookup key
        key = (canonical, group)
        entity_lookup[key] = e["entity_id"]
        unique_entities.append(e)

    # Convert to DataFrame for saving
    df_final = pd.DataFrame(unique_entities)

    # Build relationships between entity_id and patient_id
    relationship_records = []
    for record in all_entities:
        canonical = record["canonical_text"].lower().strip()
        group = record["entity_group"]

        key = (canonical, group)
        matched_id = entity_lookup.get(key)

        if matched_id:
            relationship_records.append({
                "patient_id": record["patient_id"],
                "entity_id": matched_id,
                "entity_group": group
            })

    # Use relationship_records for relationships
    df_relationships = pd.DataFrame(relationship_records)

    # Relationship mapping
    rels_map = {"DISEASE": "HAS_DISEASE"}

    # Save entities by type
    for t in ["DISEASE"]:
        sub = df_final[df_final["entity_group"] == t]
        sub_path = os.path.join(config.OUTPUT_DIR, f"{t.lower()}_entities.csv")
        sub.to_csv(sub_path, index=False)
        print(f"Saved {t} entities ({len(sub)}) to: {sub_path}")

    print(f"\n{'='*120}")

    # Save relationships to a file
    rel_files = {}
    for entity_type, rel_name in rels_map.items():
        sub = df_relationships[df_relationships["entity_group"] \
                               == entity_type][["patient_id", "entity_id"]].copy()
        if not sub.empty:
            sub["relation_type"] = rel_name
            out_path = os.path.join(config.OUTPUT_DIR,
                                    f"case_{rel_name.lower()}.csv")
            sub.to_csv(out_path, index=False)
            print(f"Saved relationships: {rel_name} ({len(sub)}) → {out_path}")
            rel_files[entity_type] = out_path

    print(f"\n{'='*120}")

    # Filter patient cases that appear in graph
    patients_in_graph = df_relationships["patient_id"].unique()
    print(f"Patients remaining in graph: {len(patients_in_graph)}")

    df_cases_filtered = df_cases[df_cases["patient_id"].isin(patients_in_graph)]
    df_cases_filtered.to_csv(os.path.join(config.OUTPUT_DIR, config.CASES_FILE), index=False)

    print(f"Saved patient cases metadata to {config.CASES_FILE}")

    return {
        "cases_file": os.path.join(config.OUTPUT_DIR, config.CASES_FILE),
        "disease": os.path.join(config.OUTPUT_DIR, "disease_entities.csv"),
        "relationships": rel_files
    }

# Run the Pipeline

In [None]:
file_paths = run_ner_pipeline()
file_paths