## Graph Information Extraction from clinical notes


In [1]:
# Verify environment after fixing
import numpy as np
import transformers
import datasets
import os
import random
import traceback

print(f"NumPy version: {np.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Also check if packages can be imported without errors
try:
    # Try to import some commonly used modules
    import pandas as pd
    from gliner import GLiNER
    # from relik import Relik

    print("\nAll required packages imported successfully!")
    print("The NumPy compatibility issue appears to be resolved.")

    # Check if the data directory exists
    data_dir = "src/data"
    if os.path.exists(data_dir):
        print(f"\nData directory '{data_dir}' exists.")
        # List any files in the directory
        files = os.listdir(data_dir)
        if files:
            print(f"Files in directory: {', '.join(files)}")
        else:
            print("No files in directory yet.")
    else:
        print(f"\nData directory '{data_dir}' does not exist yet. It will be created when running the extraction.")
    
except Exception as e:
    print(f"\nError importing packages: {e}")
    print("You may need to reinstall additional packages or restart the kernel.")

NumPy version: 2.2.5
Transformers version: 4.51.3
Datasets version: 3.6.0

All required packages imported successfully!
The NumPy compatibility issue appears to be resolved.

Data directory 'src/data' does not exist yet. It will be created when running the extraction.


In [2]:
from datasets import load_dataset
# from relik.inference.data.objects import RelikOutput
# from relik.retriever.indexers.document import Document


dataset = load_dataset(
    'AGBonnet/augmented-clinical-notes',
    split='train',
)

# sample = dataset.__iter__().__next__()

def nested_print(key, element, level=0):
    if isinstance(element, dict):
        print(f'{"│ "*(level)}├─{key}:')
        for k, v in element.items():
            nested_print(k, v, level+1)
    else:
        print(f'{"│ "*(level)}├─{key}: {element}')

# nested_print('sample', sample)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Sample 1:
├─sample:
│ ├─idx: 155216
│ ├─note: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements. She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position. She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles. There was a sideways bending of the back in the lumbar region. To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported. Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking. She would require her parents to help her with daily chores, including all activities of self-care.
She had been experiencing these difficulties for the past four months since when she was introdu

In [None]:
# extract idx and full_note
def extract_idx_full_note(sample):
    idx = sample['idx']
    full_note = sample['full_note']
    return {
        'idx': idx,
        'full_note': full_note
    }

# extract idx and full_note from the sample
dataset = dataset.map(
    extract_idx_full_note,
    remove_columns=dataset.column_names,
    batch_size=1000)

# Shuffle the dataset
random_seed = random.randint(0, 1000)
print(f"\nShuffling dataset with random seed: {random_seed}")
## Favourite seeds # 883,
dataset = dataset.shuffle(seed=random_seed)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Shuffling dataset with random seed: 883

Sample 1:
├─sample:
│ ├─idx: 101199
│ ├─full_note: A 23-year-old married female presented to our outpatient department with complaints of swelling of the vulva and ulcers on the perineal folds for 5 years. The swelling started in 2011 during 1st trimester of gestation, which was gradually progressive (1–5 cm in size). Few painless ulcers developed on the inner aspect of both thighs after 1 year, which gradually increased in number and size. Gradually, she developed pain in the perineum. There was no history of preceding fever, chronic cough, weight loss, oral ulcers, loose stool, and blood in stool. She gave a history of cesarean section, which was performed due to genital prolapse. Then, she was operated for prolapse, and later, biopsy was taken from the swelling to rule out postoperative lymphangioma. The biopsy showed nonspecific granulomas and chronic infiltrates. She was started on oral azathioprine, antibiotics, metronidazole, and steroid

In [4]:
# Model
gliner_model = GLiNER.from_pretrained("Ihor/gliner-biomed-bi-large-v1.0")

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [42]:
from mlx_lm import load, generate

# Compare performance of medgemma-4b-it with base gemma-4b-it model
# model_name = "google/medgemma-4b-it"
model_name = "google/gemma-3-4b-it"
# model_name = "Qwen/Qwen3-4B"
# model_name = "Qwen/Qwen3-30B-A3B"
medgemma_model, tokenizer = load(model_name)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

In [43]:
# --- NER Function ---
def perform_ner(text_note):
    """Performs NER on a text using, gliner model, and label list.
        https://huggingface.co/Ihor/gliner-biomed-bi-large-v1.0
    """

    # Labels for the GLiNER model
    labels = ["Disease or Condition", "Medication", "Medication Dosage and Frequency",
              "Procedure", "Lab Test", "Lab Test Result", "Body Site",
              "Medical Device", "Demographic Information"]

    entities = gliner_model.predict_entities(
        text_note,
        labels=labels,
        threshold=0.5,
    )

    return entities

In [44]:
import wikipedia

def fetch_entity_descriptions(entity):
    """Fetch description of entity['text'] from Wikipedia knowledge base

    Args:
        entity (dict): Entity dictionary containing 'text' and other fields

    Returns:
        dict or None: Returns the entity with added description if found, None otherwise
    """
    try:
        
        # Get entity text
        query = entity['text']
        
        # Try to get a summary from Wikipedia
        try:
            # Search for the page
            search_results = wikipedia.search(query, results=1)
            if search_results:
                # Get the page summary
                page = wikipedia.page(search_results[0], auto_suggest=False)
                description = page.summary[0:200] + "..." if len(page.summary) > 200 else page.summary

                entity['description'] = description
                # Return the entity with description and page URL
                return entity
        except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError):
            # Entity not found or ambiguous
            pass

        return None  # Will be filtered out
    except Exception as e:
        print(f"Error fetching description for {entity['text']}: {str(e)}")
        return None



In [45]:
import re
from typing import List, Tuple

# pre-compiled regex:  ("something", "something", "something")
TUPLE_RX = re.compile(
    r'''\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*\)''')

def extract_triples(raw: str) -> List[Tuple[str, str, str]]:
    """
    Return all well-formed 3-item tuples found in *raw*.
    Truncated / malformed lines are silently skipped.
    """
    return [match for match in TUPLE_RX.findall(raw)]

In [46]:
# Function to extract triples from a text note
def triplet_CIE(
    full_note_content: str,
    extracted_entities: List[dict],
    max_length: int = 512
) -> List[Tuple[str, str, str]]:
    """
    Extract triples from a text note using MedGemma model.

    Args:
        full_note_content (str): The full clinical note content.
        extracted_entities (List[dict]): List of entities extracted from the note.
        max_length (int): Maximum length for the model generation.
    Returns:
        List[Tuple[str, str, str]]: List of extracted triplets in the form (entity1, relation, entity2).
    """

    relationship_extraction_prompt = f""" I would like you to perform a Closed Information Extraction task on the following clinical note {full_note_content}
    You are provided with a list of entities extracted from the note in the form of JSON objects. {extracted_entities}
    Your task is to reason about the full clinical note and the entities provided, and generate at most 20 unique and high quality triplets
    of the form (entity1, relation, entity2) that represent the relationships between the entities in the full clinical note.
    Please return the triplets in the following format:
    [
      (entity1, relation, entity2),
      (entity3, relation, entity4),
      ...
    ]

    Note the response provides should be valid JSON Syntax and should contain no more than 20 triplets.
    The triplets should be relevant to the clinical note and the entities provided.
    The triplets should be unique and not contain any duplicates.
    The triplets should be concise and not contain any unnecessary information.
    The triplets list must end with a closing square bracket.
    """
    messages = [
        {
          "role": "system",
          "content": "You are a helpful medical assistant."
        },
        {
          "role": "user",
          "content": relationship_extraction_prompt
        }]
    inputs = tokenizer.apply_chat_template(
      messages, add_generation_prompt=True
    )

    # Generate text with MLX model
    triplet_str = generate(
      medgemma_model,
      tokenizer,
      prompt=inputs,
      verbose=False,
      max_tokens=max_length,
    )

    triples_list = extract_triples(triplet_str)
    
    return triples_list

In [None]:
# Function to generate MERGE syntax from Named Entities
def generate_merge_syntax(triples_list: List[Tuple[str, str, str]], max_length: int = 512) -> str:

    cypher_merge_prompt = f"""You are a Cypher MERGE-query generator.
    Input:
    You are provided with the following list of triplets {triples_list} in the form of (entity1, relation, entity2) that represent the relationships between entities in a clinical note.
    Task:
    Your task is to generate a Cypher query that merges only the entities in the triplets, without any relationships.
    The entities should be merged as nodes in the graph database.
    Please return the Cypher query in the following format:
    MERGE (e0:Entity {{name: "entity1_name"}})
    MERGE (e1:Entity {{name: "entity2_name"}})
    """

    cypher_messages = [
        {
          "role": "system",
          "content": "You are a Cypher query generator."
        },
        {
          "role": "user",
          "content": cypher_merge_prompt
        }]
    cypher_inputs = tokenizer.apply_chat_template(
      cypher_messages, add_generation_prompt=True
    )

    # Generate Cypher query with MLX model
    cypher_query = generate(
      medgemma_model,
      tokenizer,
      prompt=cypher_inputs,
      verbose=False,
      max_tokens=max_length,
    )

    return cypher_query.strip()

In [74]:
# Function to extract Merge relationships from triples
def extract_merge_relationships(triples_list: List[Tuple[str, str, str]], merge_entity_queries: str, max_length: int = 512) -> str:
    # cypher_merge_prompt = f"""You are a Cypher query generator.
    # You are provided with the following list of triplets {triples_list} in the form of (entity1, relation, entity2) that represent the relationships between entities in a clinical note.
    # Your task is to generate a Cypher query that merges the entities and relationships in the triplets.
    # Please return the Cypher query in the following format:
    # MERGE (entity1)-[:RELATIONSHIP {{type: "relation_type"}}]->(entity2)
    # MERGE (entity3)-[:RELATIONSHIP {{type: "relation_type"}}]->(entity4)
    # """

    cypher_merge_prompt = f"""You are a Cypher query generator.

      You are given:
        •	A list of triples called {triples_list} in the form of (entity1, relation, entity2). Each triple represents a relationship between two entities found in a clinical note.
        •	A list of Cypher MERGE queries for entities called {merge_entity_queries}. Only entities that appear in these MERGE queries should be used when creating relationship queries.

      Your task:
      Generate a Cypher query that MERGEs the relationships represented by each triple only if both entities in the triple have a corresponding MERGE query in {merge_entity_queries}.

      Instructions and rules:
        •	For each valid triple, use the format: MERGE (eX)-[:RELATIONSHIP {{type: "relation_type"}}]->(eY)
        •	The entity variables (e.g., e0, e1) must match those assigned in {merge_entity_queries}.
        •	Use the exact relation string from the triple as the type property in the relationship.
        •	Do not redeclare entity variables (e.g., e0, e1) with labels or properties. they have already been declared in {merge_entity_queries} so only use the variable for all relationships.
        •	Sanitize all entity and relation names:
        •	Wrap all string values (names, relation types) in double quotes.
        •	Escape internal quotes or special characters if present.
        •	Do not create MERGE statements for entities or relationships not present in {merge_entity_queries}.
        •	Ensure that each MERGE statement is syntactically correct for Cypher.
        •	Avoid duplicate relationships: If a relationship between two entities with the same type appears more than once, only include it once.
        •	Do not comment or explain; just return the Cypher queries.

      Example output format:
      MERGE (e0)-[:RELATIONSHIP {{type: "relation_type"}}]->(e1)
      MERGE (e2)-[:RELATIONSHIP {{type: "relation_type"}}]->(e3)
    """

    cypher_messages = [
        {
          "role": "system",
          "content": "You are a Cypher query generator."
        },
        {
          "role": "user",
          "content": cypher_merge_prompt
        }]
    cypher_inputs = tokenizer.apply_chat_template(
      cypher_messages, add_generation_prompt=True
    )

    # Generate Cypher query with MLX model
    cypher_relationship_query = generate(
      medgemma_model,
      tokenizer,
      prompt=cypher_inputs,
      verbose=False,
      max_tokens=max_length,
    )

    return cypher_relationship_query.strip()

In [75]:
# Now process NER in parallel
def run_parallel(sample_item):
  try:
    index, sample_item = sample_item  # Unpack the tuple
    note_id = sample_item["idx"]
    full_note_content = sample_item["full_note"]
    extracted_entities = perform_ner(full_note_content)

    # Remove duplicate entities
    extracted_entities = list({entity['text']: entity for entity in extracted_entities}.values())
    print(f"Extracted {len(extracted_entities)} entities from note {note_id}.")

    # Entity Linking in parallel as before
    extracted_entities = list(filter(fetch_entity_descriptions, extracted_entities))

    # Filter only text and description
    extracted_entities = list(map(lambda x: { "text": x['text'], "label": x['label'], "description": x.get('description', '')}, extracted_entities))

    print(f"Filtered {len(extracted_entities)} entities with descriptions from note {note_id}.")

    triples_list = triplet_CIE(
      full_note_content=full_note_content,
      extracted_entities=extracted_entities,
    )
    print(f"Extracted {len(triples_list)} triplets from note {note_id}.")

    # Generate Cypher query for MERGE syntax
    cypher_query = generate_merge_syntax(triples_list)
    print(f"Generated Cypher query for note {note_id}: {cypher_query}")

    if not cypher_query:
      raise ValueError(f"Generated Cypher query is empty for note {note_id}.")

    # Generate Cypher query for relationships
    cypher_relationship_query = extract_merge_relationships(triples_list, cypher_query)
    print(f"Generated Cypher relationship query for note {note_id}: {cypher_relationship_query}")

    # Return the results
    return {
      "note_id": note_id,
      "entities": extracted_entities,
      "content": full_note_content,
      "triplets": triples_list,
      "cypher_query": cypher_query,
      "cypher_relationship_query": cypher_relationship_query
    }
  except Exception as e:
    return {"note_id": note_id, "error": str(e), "traceback": traceback.format_exc()}


In [76]:
import concurrent.futures

file_path = "data/ner_output.txt"

# Clear previous contents if needed
with open(file_path, "w") as f:
  f.write("Starting extraction process\n")
  f.write("="*50 + "\n\n")

# Prepare the items to process

# Convert dataset to list for processing
items_to_process = list(enumerate(dataset))

# Pick the first item for testing
items_to_process = items_to_process[:1] # Increasing length of this list requires more compute resources

print(f"Total items to process: {len(items_to_process)}")
# Execute NER in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
  ner_results = list(executor.map(run_parallel, items_to_process))

# Write results to file
with open(file_path, "a") as f:
    for result in ner_results:
        if "error" in result:
            f.write(f"Error processing note {result['note_id']}: {result['error']}\n")
            f.write(result['traceback'] + "\n")
        else:
            f.write(f"Note ID: {result['note_id']}\n")
            f.write("Entities:\n")
            for entity in result['entities']:
              f.write(f"  - Text: {entity['text']}\n")
              f.write(f"    Label: {entity['label']}\n")
              if entity.get('description'):
                f.write(f"    Description: {entity['description']}\n")
            f.write("\nContent:\n")
            f.write(result['content'] + "\n\n")
            f.write("Triplets:\n")
            for triplet in result['triplets']:
              f.write(f"  - ({triplet[0]}, {triplet[1]}, {triplet[2]})\n")
            f.write("\nCypher Query:\n")
            f.write(result['cypher_query'] + "\n\n")
            f.write("Cypher Relationship Query:\n")
            f.write(result['cypher_relationship_query'] + "\n")
            f.write("="*50 + "\n")

print("Processing complete. Check src/data/ner_output.txt for results.")

Total items to process: 1
Extracted 32 entities from note 101199.
Filtered 32 entities with descriptions from note 101199.
Extracted 25 triplets from note 101199.
Generated Cypher query for note 101199: ```cypher
MERGE (e0:Entity {name: "vulva"})
MERGE (e1:Entity {name: "perineal folds"})
MERGE (e2:Entity {name: "perineum"})
MERGE (e3:Entity {name: "cesarean section"})
MERGE (e4:Entity {name: "genital prolapse"})
MERGE (e5:Entity {name: "biopsy"})
MERGE (e6:Entity {name: "epithelioid cell granuloma"})
MERGE (e7:Entity {name: "oral azathioprine"})
MERGE (e8:Entity {name: "antibiotics"})
MERGE (e9:Entity {name: "metronidazole"})
MERGE (e10:Entity {name: "steroids"})
MERGE (e11:Entity {name: "vulvar CD"})
MERGE (e12:Entity {name: "cutaneous tuberculosis"})
MERGE (e13:Entity {name: "sarcoidosis"})
MERGE (e14:Entity {name: "deep fungal infection"})
MERGE (e15:Entity {name: "lymphogranuloma venereum"})
MERGE (e16:Entity {name: "adalimumab"})
MERGE (e17:Entity {name: "azathioprine"})
MERGE (e