## Graph Information Extraction from Clinical Notes

This notebook implements medical entity extraction and linking using Gliner, SciSpacy's UMLS and MeSH knowledge bases.
along with relation extraction using a custom google models.

In [1]:
# Verify environment and import dependencies
import numpy as np
import transformers
import datasets
import os
import random
import traceback
import spacy
import scispacy
from scispacy.linking import EntityLinker
from scispacy.abbreviation import AbbreviationDetector
import pandas as pd
from gliner import GLiNER
from typing import List, Tuple, Dict
import re
from functools import lru_cache
import concurrent.futures
from collections import defaultdict

print(f"NumPy version: {np.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")
print(f"SpaCy version: {spacy.__version__}")
print(f"SciSpacy version: {scispacy.__version__}")

NumPy version: 1.26.4
Transformers version: 4.51.3
Datasets version: 3.6.0
SpaCy version: 3.6.1
SciSpacy version: 0.5.5


In [2]:
# Load SciSpacy model with entity linker
print("Loading SciSpacy model...")
nlp = spacy.load("en_core_sci_lg")

# Add abbreviation detector
nlp.add_pipe("abbreviation_detector")

# Add UMLS entity linker
nlp.add_pipe("scispacy_linker", config={
    "resolve_abbreviations": True,
    "linker_name": "umls",
    "max_entities_per_mention": 3  # Get top 3 candidates
})

print("SciSpacy pipeline loaded with components:", nlp.pipe_names)

Loading SciSpacy model...


  deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


SciSpacy pipeline loaded with components: ['tok2vec', 'tagger', 'attribute_ruler', 'lemmatizer', 'parser', 'ner', 'abbreviation_detector', 'scispacy_linker']


In [3]:
# Load dataset
from datasets import load_dataset

dataset = load_dataset(
    'AGBonnet/augmented-clinical-notes',
    split='train',
)

def nested_print(key, element, level=0):
    if isinstance(element, dict):
        print(f'{"│ "*(level)}├─{key}:')
        for k, v in element.items():
            nested_print(k, v, level+1)
    else:
        print(f'{"│ "*(level)}├─{key}: {element}')

# Extract idx and full_note
def extract_idx_full_note(sample):
    idx = sample['idx']
    full_note = sample['full_note']
    return {
        'idx': idx,
        'full_note': full_note
    }

dataset = dataset.map(
    extract_idx_full_note,
    remove_columns=dataset.column_names,
    batch_size=1000)

# Shuffle dataset
random_seed = random.randint(0, 1000)
print(f"\nShuffling dataset with random seed: {random_seed}")
dataset = dataset.shuffle(seed=random_seed)


Shuffling dataset with random seed: 650


In [4]:
# Load GLiNER model for initial NER
gliner_model = GLiNER.from_pretrained("Ihor/gliner-biomed-bi-large-v1.0")

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
# Load MedGemma model for relationship extraction
from mlx_lm import load, generate

#TO:DO - Implement Confidence and Post-Filtering !!!!

model_name = "google/medgemma-4b-it"
medgemma_model, medgemma_tokenizer = load(model_name)
use_medgemma = True

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [6]:
# Load Gemma model for relationship extraction
from mlx_lm import load, generate

model_name = "google/gemma-3-4b-it"
gemma_model, gemma_tokenizer = load(model_name)
use_gemma = not use_medgemma

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

In [7]:
# Enhanced NER function with SciSpacy entity linking
def perform_ner_with_linking(text_note: str) -> List[Dict]:
    """Performs NER on text using GLiNER, then links entities using SciSpacy."""

    # Labels for the GLiNER model
    labels = ["Disease or Condition", "Medication", "Medication Dosage and Frequency",
              "Procedure", "Lab Test", "Lab Test Result", "Body Site",
              "Medical Device", "Demographic Information"]

    # Extract entities with GLiNER
    entities = gliner_model.predict_entities(
        text_note,
        labels=labels,
        threshold=0.5,
    )
    
    # Process text with SciSpacy for entity linking
    doc = nlp(text_note)
    
    # Create a mapping of entity text to SciSpacy linked entities
    entity_links = {}
    
    for ent in doc.ents:
        if ent._.kb_ents:
            # Get top candidates with scores
            candidates = []
            for umls_ent in ent._.kb_ents[:3]:  # Top 3 candidates
                cui, score = umls_ent
                linker = nlp.get_pipe("scispacy_linker")
                kb_entity = linker.kb.cui_to_entity[cui]
                candidates.append({
                    'cui': cui,
                    'score': score,
                    'name': kb_entity.canonical_name,
                    'definition': kb_entity.definition if kb_entity.definition else '',
                    'types': list(kb_entity.types)
                })
            entity_links[ent.text.lower()] = candidates

    # Enhance GLiNER entities with SciSpacy linking
    enhanced_entities = []
    for entity in entities:
        entity_text_lower = entity['text'].lower()
        
        # Check if we have linking information
        if entity_text_lower in entity_links:
            candidates = entity_links[entity_text_lower]
            if candidates:
                # Use the top candidate
                top_candidate = candidates[0]
                entity['cui'] = top_candidate['cui']
                entity['canonical_name'] = top_candidate['name']
                entity['description'] = top_candidate['definition']
                entity['semantic_types'] = top_candidate['types']
                entity['linking_score'] = top_candidate['score']
                entity['alternative_candidates'] = candidates[1:] if len(candidates) > 1 else []

        enhanced_entities.append(entity)

    # Check for abbreviations and their expansions
    abbreviations = {}
    for abbr in doc._.abbreviations:
        abbreviations[abbr.text] = abbr._.long_form.text
    
    # Add abbreviation information to entities
    for entity in enhanced_entities:
        if entity['text'] in abbreviations:
            entity['expanded_form'] = abbreviations[entity['text']]

        # Remove label and score from GLiNER entities
        entity.pop('label', None)
        entity.pop('score', None)

    print(f"First 5 enhanced entities: {enhanced_entities[:5]}")
    return enhanced_entities

In [8]:
# Cache for entity linking results
@lru_cache(maxsize=10000)
def cached_entity_lookup(entity_text: str) -> Dict:
    """Cache entity lookups to improve performance."""
    doc = nlp(entity_text)
    if doc.ents and doc.ents[0]._.kb_ents:
        ent = doc.ents[0]
        cui, score = ent._.kb_ents[0]
        kb_entity = nlp.kb.cui_to_entity[cui]
        return {
            'cui': cui,
            'score': score,
            'name': kb_entity.canonical_name,
            'definition': kb_entity.definition if kb_entity.definition else ''
        }
    return None

In [9]:
# Regex for triple extraction
TUPLE_RX = re.compile(
    r'''\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*\)''')

def extract_triples(raw: str) -> List[Tuple[str, str, str]]:
    """Return all well-formed 3-item tuples found in raw."""
    return [match for match in TUPLE_RX.findall(raw)]

In [10]:
# Enhanced triple extraction with entity metadata
def triplet_CIE(
    full_note_content: str,
    extracted_entities: List[dict],
    max_length: int = 512,
) -> List[Tuple[str, str, str]]:
    """Extract triples from a text note using MedGemma model."""

    # Format entities with their medical knowledge
    entity_descriptions = []
    for ent in extracted_entities:
        desc = f"- \"{ent['text']}\""
        if 'canonical_name' in ent:
            desc += f" [UMLS: {ent['canonical_name']}]"
        if 'description' in ent and ent['description']:
            desc += f" - {ent['description'][:100]}..."
        entity_descriptions.append(desc)

    entities_text = "\n".join(entity_descriptions)

    relationship_extraction_prompt = f"""Your goal is to perform a Closed Information Extraction task on the following clinical note:

{full_note_content}

You are provided with a list of medical entities extracted from the note:
{entities_text}

Your task is to generate high quality triplets of the form (entity1, relation, entity2) where:
- The relationship is explicitly stated or strongly implied in the clinical note
- The entities are from the provided list (use the exact text as it appears)
- The triplets should be clinically meaningful and relevant

Please return the triplets in the following format:
[
  ("entity1", "relation", "entity2"),
  ("entity3", "relation", "entity4"),
  ...
]
"""
    
    messages = [
        {
            "role": "system",
            "content": "You are an expert clinical information extraction system specialized in identifying medical relationships."
        },
        {
            "role": "user",
            "content": relationship_extraction_prompt
        }
    ]

    if use_medgemma:
        model, tokenizer = medgemma_model, medgemma_tokenizer
    else:
        model, tokenizer = gemma_model, gemma_tokenizer
    
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True
    )

    # Generate text with MLX model
    triplet_str = generate(
        model,
        tokenizer,
        prompt=inputs,
        verbose=False,
        max_tokens=max_length,
    )

    triples_list = extract_triples(triplet_str)
    
    return triples_list

In [11]:
def get_unique_entities(triples_list):
    """Given a list of triples, return unique entities."""
    entities = set()
    for entity1, _, entity2 in triples_list:
        entities.add(entity1)
        entities.add(entity2)
    return list(entities)

In [12]:
def generate_merge_query_with_metadata(entities, entity_metadata):
    """Generates Cypher MERGE query with UMLS metadata."""
    queries = []
    
    for i, entity in enumerate(entities):
        query = f'MERGE (e{i}:MedicalEntity {{name: "{entity}"'
        
        # Add metadata if available
        if entity in entity_metadata:
            metadata = entity_metadata[entity]
            if 'cui' in metadata:
                query += f', cui: "{metadata["cui"]}"'
            if 'canonical_name' in metadata:
                query += f', canonical_name: "{metadata["canonical_name"]}"'
            if 'semantic_types' in metadata:
                types_str = '|'.join(metadata['semantic_types'])
                query += f', semantic_types: "{types_str}"'
            if 'linking_score' in metadata:
                query += f', linking_score: {metadata["linking_score"]}'
        
        query += '})'
        queries.append(query)
    
    return '\n'.join(queries)

In [13]:
def generate_merge_relationships(triples_list, merge_entity_queries):
    """Generate Cypher MERGE statements for relationships."""
    # Parse entity name to variable mapping
    entity_var_map = {}
    for match in re.finditer(r'MERGE\s*\((e\d+):MedicalEntity\s*\{\s*name:\s*"((?:[^"\\]|\\.)*)"', merge_entity_queries):
        var, name = match.group(1), match.group(2)
        entity_var_map[name] = var

    def escape_quotes(s):
        return s.replace('"', '\\"')

    def format_relation(relation):
        return escape_quotes(relation.lower().replace(" ", "_"))

    seen = set()
    cypher_lines = []
    for entity1, relation, entity2 in triples_list:
        var1 = entity_var_map.get(entity1)
        var2 = entity_var_map.get(entity2)
        if var1 and var2:
            key = (var1, format_relation(relation), var2)
            if key not in seen:
                cypher_lines.append(
                    f'MERGE ({var1})-[:RELATIONSHIP {{type: "{format_relation(relation)}"}}]->({var2})'
                )
                seen.add(key)
    return '\n'.join(cypher_lines)

In [14]:
# Enhanced parallel processing function
def run_parallel(sample_item):
    try:
        index, sample_item = sample_item
        note_id = sample_item["idx"]
        full_note_content = sample_item["full_note"]
        
        # Perform NER with SciSpacy entity linking
        extracted_entities = perform_ner_with_linking(full_note_content)

        # Remove duplicate entities based on text
        unique_entities = {}
        for entity in extracted_entities:
            key = entity['text']
            if key not in unique_entities or \
               (key in unique_entities and \
                entity.get('linking_score', 0) > unique_entities[key].get('linking_score', 0)):
                unique_entities[key] = entity

        extracted_entities = list(unique_entities.values())
        print(f"Extracted {len(extracted_entities)} unique entities from note {note_id}.")
        
        # Filter entities with high confidence scores
        high_confidence_entities = [
            ent for ent in extracted_entities 
            if ent.get('linking_score', 0) > 0.7 or 'linking_score' not in ent
        ]
        
        print(f"Filtered to {len(high_confidence_entities)} high-confidence entities.")

        # Extract relationships
        triples_list = triplet_CIE(
            full_note_content=full_note_content,
            extracted_entities=high_confidence_entities,
            max_length=512,
        )
        print(f"Extracted {len(triples_list)} triplets from note {note_id}.")

        # Get unique entities from triplets
        triple_entities = get_unique_entities(triples_list)
        print(f"Found {len(triple_entities)} unique entities in triplets.")
        
        # Create entity metadata map
        entity_metadata = {ent['text']: ent for ent in extracted_entities}
        
        # Generate Cypher queries with metadata
        cypher_query = generate_merge_query_with_metadata(triple_entities, entity_metadata)
        cypher_relationship_query = generate_merge_relationships(triples_list, cypher_query)

        return {
            "note_id": note_id,
            "entities": extracted_entities,
            "content": full_note_content,
            "triplets": triples_list,
            "cypher_query": cypher_query,
            "cypher_relationship_query": cypher_relationship_query
        }
    except Exception as e:
        return {"note_id": note_id, "error": str(e), "traceback": traceback.format_exc()}

In [15]:
# Main execution cell
import concurrent.futures

# Create output directory if it doesn't exist
os.makedirs("data", exist_ok=True)
file_path = "data/" + ("medgemma_" if use_medgemma else "gemma_") + str(random_seed) + "_enhanced_extraction_results.txt"

# Clear previous contents
with open(file_path, "w") as f:
    f.write("="*30)
    f.write(" Enhanced Extraction Process using " + use_medgemma * "MedGemma " if use_medgemma else "Gemma " + "Model")
    f.write("="*30 + "\n\n")

# Convert dataset to list for processing
items_to_process = list(enumerate(dataset))

# Process first few items for testing
items_to_process = items_to_process[:1]  # Adjust based on compute resources

print(f"Total items to process: {len(items_to_process)}")

# Execute in parallel with optimized batching
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
    ner_results = list(executor.map(run_parallel, items_to_process))

# Write enhanced results to file
with open(file_path, "a") as f:
    for result in ner_results:
        if "error" in result:
            f.write(f"Error processing note {result['note_id']}: {result['error']}\n")
            f.write(result['traceback'] + "\n")
        else:
            f.write(f"Note ID: {result['note_id']} (Seed: {random_seed})\n")
            f.write("\nContent:\n")
            f.write(result['content'] + "\n\n")
            f.write("Entities with UMLS Linking:\n")
            
            for entity in result['entities']:
                f.write(f"  - Text: {entity['text']}\n")

                if 'cui' in entity:
                    f.write(f"    CUI: {entity['cui']}\n")
                    f.write(f"    Canonical Name: {entity.get('canonical_name', 'N/A')}\n")
                    f.write(f"    Linking Score: {entity.get('linking_score', 'N/A'):.3f}\n")

                if 'semantic_types' in entity:
                    f.write(f"    Semantic Types: {', '.join(entity['semantic_types'])}\n")
                
                if 'description' in entity and entity['description']:
                    f.write(f"    Description: {entity['description'][:150]}...\n")
                
                if 'expanded_form' in entity:
                    f.write(f"    Expanded Form: {entity['expanded_form']}\n")

                if 'alternative_candidates' in entity and entity['alternative_candidates']:
                    f.write("    Alternative Candidates:\n")
                    for alt in entity['alternative_candidates'][:2]:
                        f.write(f"      - {alt['name']} (CUI: {alt['cui']}, Score: {alt['score']:.3f})\n")
            
            f.write("\nTriplets:\n")
            for triplet in result['triplets']:
                f.write(f"  - ({triplet[0]}, {triplet[1]}, {triplet[2]})\n")
            
            f.write("\nCypher Query with Metadata:\n")
            f.write(result['cypher_query'] + "\n\n")
            f.write("Cypher Relationship Query:\n")
            f.write(result['cypher_relationship_query'] + "\n")
            f.write("="*50 + "\n\n")

print(f"Processing complete. Check {file_path} for results.")
print(f"SciSpacy entity linking cache size: {cached_entity_lookup.cache_info()}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Total items to process: 1
First 5 enhanced entities: [{'start': 87, 'end': 112, 'text': 'native aortic coarctation'}, {'start': 316, 'end': 334, 'text': 'right and left arm'}, {'start': 377, 'end': 397, 'text': 'right and left ankle'}, {'start': 433, 'end': 450, 'text': 'upper extremities', 'cui': 'C1140618', 'canonical_name': 'Upper Extremity', 'description': 'The region of the upper limb in animals, extending from the deltoid region to the HAND, and including the ARM; AXILLA; and SHOULDER.', 'semantic_types': ['T023'], 'linking_score': 0.9906526207923889, 'alternative_candidates': [{'cui': 'C0003793', 'score': 0.8803924322128296, 'name': 'Bone structure of upper limb', 'definition': 'The bones of the upper and lower ARM. They include the CLAVICLE and SCAPULA.', 'types': ['T023']}, {'cui': 'C0222201', 'score': 0.8537867069244385, 'name': 'Skin structure of upper limb', 'definition': 'The integumentary covering of the upper extremities.', 'types': ['T023']}]}, {'start': 458, 'end': 481