## Graph Information Extraction from clinical notes


In [1]:
# Verify environment after fixing
import numpy as np
import transformers
import datasets
import os
import random
import traceback

print(f"NumPy version: {np.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Also check if packages can be imported without errors
try:
    # Try to import some commonly used modules
    import pandas as pd
    from gliner import GLiNER
    # from relik import Relik

    print("\nAll required packages imported successfully!")
    print("The NumPy compatibility issue appears to be resolved.")

    # Check if the data directory exists
    data_dir = "src/data"
    if os.path.exists(data_dir):
        print(f"\nData directory '{data_dir}' exists.")
        # List any files in the directory
        files = os.listdir(data_dir)
        if files:
            print(f"Files in directory: {', '.join(files)}")
        else:
            print("No files in directory yet.")
    else:
        print(f"\nData directory '{data_dir}' does not exist yet. It will be created when running the extraction.")
    
except Exception as e:
    print(f"\nError importing packages: {e}")
    print("You may need to reinstall additional packages or restart the kernel.")

NumPy version: 2.2.5
Transformers version: 4.51.3
Datasets version: 3.6.0

All required packages imported successfully!
The NumPy compatibility issue appears to be resolved.

Data directory 'src/data' does not exist yet. It will be created when running the extraction.


In [2]:
from datasets import load_dataset
# from relik.inference.data.objects import RelikOutput
# from relik.retriever.indexers.document import Document


dataset = load_dataset(
    'AGBonnet/augmented-clinical-notes',
    split='train',
)

# sample = dataset.__iter__().__next__()

def nested_print(key, element, level=0):
    if isinstance(element, dict):
        print(f'{"│ "*(level)}├─{key}:')
        for k, v in element.items():
            nested_print(k, v, level+1)
    else:
        print(f'{"│ "*(level)}├─{key}: {element}')

# nested_print('sample', sample)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Sample 1:
├─sample:
│ ├─idx: 155216
│ ├─note: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements. She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position. She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles. There was a sideways bending of the back in the lumbar region. To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported. Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking. She would require her parents to help her with daily chores, including all activities of self-care.
She had been experiencing these difficulties for the past four months since when she was introdu

In [3]:
# extract idx and full_note
def extract_idx_full_note(sample):
    idx = sample['idx']
    full_note = sample['full_note']
    return {
        'idx': idx,
        'full_note': full_note
    }

# extract idx and full_note from the sample
dataset = dataset.map(
    extract_idx_full_note,
    remove_columns=dataset.column_names,
    batch_size=1000)

# Shuffle the dataset
random_seed = random.randint(0, 1000)
print(f"\nShuffling dataset with random seed: {random_seed}")
## Favourite seeds # 883,
dataset = dataset.shuffle(seed=random_seed)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Shuffling dataset with random seed: 766

Sample 1:
├─sample:
│ ├─idx: 156806
│ ├─full_note: A 28-year-old male working as a tractor driver was brought to the Emergency surgical ward by his wife with complaints of cutting off at the left side at the base of his scrotum with a kitchen knife.\nHis history revealed that patient has been taking arrack for the last 6 years. He started with 200 ml of arrack and slowly increased the amount to 1l over a period of 6 years as he was not getting the desired effect he used to get with 200 ml. On occasions, when he did not have the money or opportunity to drink, he used to experience dysphoric state including severe anxiety, palpitations sweating, restlessness, and tremors of hands. These symptoms used to subside upon taking alcohol.\nHe also did not have any control over the amount of alcohol he is taking and the money he is spending on it once he started to drink. He also neglected his duties both at work and at home due to this habit, for which 

In [4]:
# Model
gliner_model = GLiNER.from_pretrained("Ihor/gliner-biomed-bi-large-v1.0")

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [25]:
from mlx_lm import load, generate

# Compare performance of medgemma-4b-it with base gemma-4b-it model
model_name = "google/medgemma-4b-it"
# model_name = "google/gemma-3-4b-it"
# model_name = "Qwen/Qwen3-4B"
# model_name = "Qwen/Qwen3-30B-A3B"
medgemma_model, tokenizer = load(model_name)

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [26]:
# --- NER Function ---
def perform_ner(text_note):
    """Performs NER on a text using, gliner model, and label list.
        https://huggingface.co/Ihor/gliner-biomed-bi-large-v1.0
    """

    # Labels for the GLiNER model
    labels = ["Disease or Condition", "Medication", "Medication Dosage and Frequency",
              "Procedure", "Lab Test", "Lab Test Result", "Body Site",
              "Medical Device", "Demographic Information"]

    entities = gliner_model.predict_entities(
        text_note,
        labels=labels,
        threshold=0.5,
    )

    return entities

In [27]:
import wikipedia

def fetch_entity_descriptions(entity):
    """Fetch description of entity['text'] from Wikipedia knowledge base

    Args:
        entity (dict): Entity dictionary containing 'text' and other fields

    Returns:
        dict or None: Returns the entity with added description if found, None otherwise
    """
    try:
        
        # Get entity text
        query = entity['text']
        
        # Try to get a summary from Wikipedia
        try:
            # Search for the page
            search_results = wikipedia.search(query, results=1)
            if search_results:
                # Get the page summary
                page = wikipedia.page(search_results[0], auto_suggest=False)
                description = page.summary[0:200] + "..." if len(page.summary) > 200 else page.summary

                entity['description'] = description
                # Return the entity with description and page URL
                return entity
        except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError):
            # Entity not found or ambiguous
            pass

        return None  # Will be filtered out
    except Exception as e:
        print(f"Error fetching description for {entity['text']}: {str(e)}")
        return None



In [28]:
import re
from typing import List, Tuple

# pre-compiled regex:  ("something", "something", "something")
TUPLE_RX = re.compile(
    r'''\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*\)''')

def extract_triples(raw: str) -> List[Tuple[str, str, str]]:
    """
    Return all well-formed 3-item tuples found in *raw*.
    Truncated / malformed lines are silently skipped.
    """
    return [match for match in TUPLE_RX.findall(raw)]

In [29]:
# Function to extract triples from a text note
def triplet_CIE(
    full_note_content: str,
    extracted_entities: List[dict],
    max_length: int = 512
) -> List[Tuple[str, str, str]]:
    """
    Extract triples from a text note using MedGemma model.

    Args:
        full_note_content (str): The full clinical note content.
        extracted_entities (List[dict]): List of entities extracted from the note.
        max_length (int): Maximum length for the model generation.
    Returns:
        List[Tuple[str, str, str]]: List of extracted triplets in the form (entity1, relation, entity2).
    """

    relationship_extraction_prompt = f""" Your goal is to perform a Closed Information Extraction task on the following clinical note {full_note_content}
    You are provided with a list of entities extracted from the note in the form of JSON objects. {extracted_entities}, each object contains the following fields:
    - "text": The text of the entity
    - "label": The label that can be used to identify the entity type
    - "description": A short description of the entity (if available) that can be used to understand the entity better
    Your task is to reason about the full clinical note and the list of entities provided, and generate high quality triplets of the form (entity1, relation, entity2)
    where the relationship is explicitly stated or strongly implied in the clinical note and the entities are from the provided list.
    The triplets should be concise and relevant to the clinical note.
    You should not generate any triplets that are not relevant to the clinical note or the entities provided.
    The triplets should be unique and not contain any duplicates.
    Please return the triplets in the following format:
    [
      (entity1, relation, entity2),
      (entity3, relation, entity4),
      ...
    ]
    """
    messages = [
        {
          "role": "system",
          "content": "You are an expert clinical information extraction system. Your task is to analyze a clinical note and a pre-compiled list of named entities to extract meaningful, factual relationships."
        },
        {
          "role": "user",
          "content": relationship_extraction_prompt
        }]
    inputs = tokenizer.apply_chat_template(
      messages, add_generation_prompt=True
    )

    # Generate text with MLX model
    triplet_str = generate(
      medgemma_model,
      tokenizer,
      prompt=inputs,
      verbose=False,
      max_tokens=max_length,
    )

    triples_list = extract_triples(triplet_str)
    
    return triples_list

In [30]:
def get_unique_entities(triples_list):
  """
  Given a list of triples (entity1, relation, entity2), return a list of unique entities.
  """
  entities = set()
  for entity1, _, entity2 in triples_list:
    entities.add(entity1)
    entities.add(entity2)
  return list(entities)

In [31]:
def generate_merge_query(entities):
    """Generates a Cypher MERGE query for the given list of entities.

    Args:
        entities: A list of entities in the form [entity1_name, entity2_name, ...].

    Returns:
        A string containing the Cypher MERGE query.
    """
    query = "MERGE (e0:Entity {name: \"" + entities[0] + "\"})\n"
    for i in range(1, len(entities)):
        query += "MERGE (e" + str(i+1) + ":Entity {name: \"" + entities[i] + "\"})\n"
    return query

In [32]:
def generate_merge_relationships(triples_list, merge_entity_queries):
  """
  Generate Cypher MERGE statements for relationships between entities, using entity variables
  as assigned in merge_entity_queries. Only relationships between entities present in merge_entity_queries
  are included. Duplicate relationships are avoided.

  Args:
    triples_list (List[Tuple[str, str, str]]): List of (entity1, relation, entity2) triples.
    merge_entity_queries (str): Cypher MERGE statements for entities, e.g. 'MERGE (e0:Entity {name: "foo"})\nMERGE (e1:Entity {name: "bar"})'

  Returns:
    str: Cypher MERGE statements for relationships.
  """

  # Parse entity name to variable mapping from merge_entity_queries
  entity_var_map = {}
  for match in re.finditer(r'MERGE\s*\((e\d+):Entity\s*\{\s*name:\s*"((?:[^"\\]|\\.)*)"\s*\}\)', merge_entity_queries):
    var, name = match.group(1), match.group(2)
    entity_var_map[name] = var

  # Helper to escape double quotes in relation type
  def escape_quotes(s):
    return s.replace('"', '\\"')

  # Helper to convert relation string to lowercase, replace spaces with underscores, and escape quotes
  def format_relation(relation):
    return escape_quotes(relation.lower().replace(" ", "_"))

  seen = set()
  cypher_lines = []
  for entity1, relation, entity2 in triples_list:
    var1 = entity_var_map.get(entity1)
    var2 = entity_var_map.get(entity2)
    if var1 and var2:
      key = (var1, format_relation(relation), var2)
      if key not in seen:
        cypher_lines.append(
          f'MERGE ({var1})-[:RELATIONSHIP {{type: "{format_relation(relation)}"}}]->({var2})'
        )
        seen.add(key)
  return "\n".join(cypher_lines)

In [33]:
# Now process NER in parallel
def run_parallel(sample_item):
  try:
    index, sample_item = sample_item  # Unpack the tuple
    note_id = sample_item["idx"]
    full_note_content = sample_item["full_note"]
    extracted_entities = perform_ner(full_note_content)

    # Remove duplicate entities
    extracted_entities = list({entity['text']: entity for entity in extracted_entities}.values())
    print(f"Extracted {len(extracted_entities)} entities from note {note_id}.")

    # Entity Linking in parallel as before
    extracted_entities = list(filter(fetch_entity_descriptions, extracted_entities))

    # Filter only text and description
    extracted_entities = list(map(lambda x: { "text": x['text'], "label": x['label'], "description": x.get('description', '')}, extracted_entities))

    print(f"Filtered {len(extracted_entities)} entities with label and description from note {note_id}.")

    triples_list = triplet_CIE(
      full_note_content=full_note_content,
      extracted_entities=extracted_entities,
    )
    print(f"Extracted {len(triples_list)} triplets from note {note_id}.")

    triple_entities = get_unique_entities(triples_list)
    print(f"Extracted {len(triple_entities)} unique entities from triplets in note {note_id}.")
    # Generate Cypher query for MERGE syntax
    cypher_query = generate_merge_query(triple_entities)
    print(f"Generated Cypher query for note {note_id}: {cypher_query}")

    if not cypher_query:
      raise ValueError(f"Generated Cypher query is empty for note {note_id}.")

    # Generate Cypher query for relationships
    cypher_relationship_query = generate_merge_relationships(triples_list, cypher_query)
    print(f"Generated Cypher relationship query for note {note_id}: {cypher_relationship_query}")

    # Return the results
    return {
      "note_id": note_id,
      "entities": extracted_entities,
      "content": full_note_content,
      "triplets": triples_list,
      "cypher_query": cypher_query,
      "cypher_relationship_query": cypher_relationship_query
    }
  except Exception as e:
    return {"note_id": note_id, "error": str(e), "traceback": traceback.format_exc()}


In [34]:
import concurrent.futures

file_path = "data/ner_output.txt"

# Clear previous contents if needed
with open(file_path, "w") as f:
  f.write("Starting extraction process\n")
  f.write("="*50 + "\n\n")

# Prepare the items to process

# Convert dataset to list for processing
items_to_process = list(enumerate(dataset))

# Pick the first item for testing
items_to_process = items_to_process[:1] # Increasing length of this list requires more compute resources

print(f"Total items to process: {len(items_to_process)}")
# Execute NER in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
  ner_results = list(executor.map(run_parallel, items_to_process))

# Write results to file
with open(file_path, "a") as f:
    for result in ner_results:
        if "error" in result:
            f.write(f"Error processing note {result['note_id']}: {result['error']}\n")
            f.write(result['traceback'] + "\n")
        else:
            f.write(f"Note ID: {result['note_id']} and SEED: {random_seed}\n")
            f.write("\nContent:\n")
            f.write(result['content'] + "\n\n")
            f.write("Entities:\n")
            for entity in result['entities']:
              f.write(f"  - Text: {entity['text']}\n")
              f.write(f"    Label: {entity['label']}\n")
              if entity.get('description'):
                f.write(f"    Description: {entity['description'][:100]}\n")
            f.write("Triplets:\n")
            for triplet in result['triplets']:
              f.write(f"  - ({triplet[0]}, {triplet[1]}, {triplet[2]})\n")
            f.write("\nCypher Query:\n")
            f.write(result['cypher_query'] + "\n\n")
            f.write("Cypher Relationship Query:\n")
            f.write(result['cypher_relationship_query'] + "\n")
            f.write("="*50 + "\n")

print("Processing complete. Check src/data/ner_output.txt for results.")

Total items to process: 1
Extracted 17 entities from note 156806.
Filtered 17 entities with label and description from note 156806.
Extracted 27 triplets from note 156806.
Extracted 7 unique entities from triplets in note 156806.
Generated Cypher query for note 156806: MERGE (e0:Entity {name: "left side"})
MERGE (e2:Entity {name: "delirium"})
MERGE (e3:Entity {name: "Mental and behavioral disorders"})
MERGE (e4:Entity {name: "alcohol withdrawal state"})
MERGE (e5:Entity {name: "arrack"})
MERGE (e6:Entity {name: "scrotum"})
MERGE (e7:Entity {name: "kitchen knife"})

Generated Cypher relationship query for note 156806: MERGE (e0)-[:RELATIONSHIP {type: "location_of_injury"}]->(e6)
MERGE (e7)-[:RELATIONSHIP {type: "instrument_used_in_injury"}]->(e6)
MERGE (e5)-[:RELATIONSHIP {type: "substance_causing_alcohol_withdrawal"}]->(e4)
MERGE (e5)-[:RELATIONSHIP {type: "substance_causing_alcohol_withdrawal"}]->(e2)
MERGE (e5)-[:RELATIONSHIP {type: "substance_causing_alcohol_withdrawal"}]->(e3)
Proc