## Graph Information Extraction from clinical notes


In [1]:
# Verify environment after fixing
import numpy as np
import transformers
import datasets
import os
import traceback

print(f"NumPy version: {np.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Also check if packages can be imported without errors
try:
    # Try to import some commonly used modules
    import pandas as pd
    from gliner import GLiNER
    # from relik import Relik

    print("\nAll required packages imported successfully!")
    print("The NumPy compatibility issue appears to be resolved.")

    # Check if the data directory exists
    data_dir = "src/data"
    if os.path.exists(data_dir):
        print(f"\nData directory '{data_dir}' exists.")
        # List any files in the directory
        files = os.listdir(data_dir)
        if files:
            print(f"Files in directory: {', '.join(files)}")
        else:
            print("No files in directory yet.")
    else:
        print(f"\nData directory '{data_dir}' does not exist yet. It will be created when running the extraction.")
    
except Exception as e:
    print(f"\nError importing packages: {e}")
    print("You may need to reinstall additional packages or restart the kernel.")

NumPy version: 2.2.5
Transformers version: 4.51.3
Datasets version: 3.6.0

All required packages imported successfully!
The NumPy compatibility issue appears to be resolved.

Data directory 'src/data' does not exist yet. It will be created when running the extraction.


In [2]:
from datasets import load_dataset
# from relik.inference.data.objects import RelikOutput
# from relik.retriever.indexers.document import Document


dataset = load_dataset(
    'AGBonnet/augmented-clinical-notes',
    split='train',
)

# sample = dataset.__iter__().__next__()

def nested_print(key, element, level=0):
    if isinstance(element, dict):
        print(f'{"│ "*(level)}├─{key}:')
        for k, v in element.items():
            nested_print(k, v, level+1)
    else:
        print(f'{"│ "*(level)}├─{key}: {element}')

# nested_print('sample', sample)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Sample 1:
├─sample:
│ ├─idx: 155216
│ ├─note: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements. She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position. She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles. There was a sideways bending of the back in the lumbar region. To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported. Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking. She would require her parents to help her with daily chores, including all activities of self-care.
She had been experiencing these difficulties for the past four months since when she was introdu

In [3]:
# extract idx and full_note
def extract_idx_full_note(sample):
    idx = sample['idx']
    full_note = sample['full_note']
    return {
        'idx': idx,
        'full_note': full_note
    }

# extract idx and full_note from the sample
dataset = dataset.map(
    extract_idx_full_note,
    remove_columns=dataset.column_names,
    batch_size=1000)

# Shuffle the dataset
dataset = dataset.shuffle(seed=42)

# Print the first 5 elements of the dataset using nested_print
for i, sample in enumerate(dataset):
    if i < 5:
        print(f"\nSample {i+1}:")
        nested_print('sample', sample)
    else:
        break


Sample 1:
├─sample:
│ ├─idx: 5107
│ ├─full_note: A 35-year-old gravida 6, para 5 mother who is 38-week pregnant from last normal menstrual period has presented to Tercha General Hospital (a rural hospital in Southern Ethiopia). The patient is referred from a health center 60 kms far from this hospital for suspected “big baby” in labor. The patient was an illiterate housewife. In terms of past obstetrics history, all previous deliveries occurred at home vaginally with live birth with no major complication. During the index pregnancy, she had antenatal care visits at a nearby health center without ultrasound examination. She reports that the current pregnancy is heavier than previous ones and associated with significant discomfort than her previous pregnancy experiences. Otherwise, she has no self or family history of twinning in the past.\nExamination shows a stable gravida with normal vital signs. Abdominal examination shows big for date uterus with two cephalic poles in the lower abd

In [4]:
# Model
gliner_model = GLiNER.from_pretrained("Ihor/gliner-biomed-bi-large-v1.0")

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
from mlx_lm import load, generate

model_name = "google/medgemma-4b-it"
medgemma_model, tokenizer = load(model_name)

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

In [6]:
# --- NER Function ---
def perform_ner(text_note):
    """Performs NER on a text using, gliner model, and label list.
        https://huggingface.co/Ihor/gliner-biomed-bi-large-v1.0
    """

    # Labels for the GLiNER model
    labels = ["Disease or Condition", "Medication", "Medication Dosage and Frequency",
              "Procedure", "Lab Test", "Lab Test Result", "Body Site",
              "Medical Device", "Demographic Information"]

    entities = gliner_model.predict_entities(
        text_note,
        labels=labels,
        threshold=0.5,
    )

    return entities

In [7]:
import wikipedia

def fetch_entity_descriptions(entity):
    """Fetch description of entity['text'] from Wikipedia knowledge base
    
    Args:
        entity (dict): Entity dictionary containing 'text' and other fields

    Returns:
        dict or None: Returns the entity with added description if found, None otherwise
    """
    try:
        
        # Get entity text
        query = entity['text']
        
        # Try to get a summary from Wikipedia
        try:
            # Search for the page
            search_results = wikipedia.search(query, results=1)
            if search_results:
                # Get the page summary
                page = wikipedia.page(search_results[0], auto_suggest=False)
                description = page.summary[0:200] + "..." if len(page.summary) > 200 else page.summary

                entity['description'] = description
                # Return the entity with description and page URL
                return entity
        except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError):
            # Entity not found or ambiguous
            pass

        return None  # Will be filtered out
    except Exception as e:
        print(f"Error fetching description for {entity['text']}: {str(e)}")
        return None



In [8]:
import re
from typing import List, Tuple

# pre-compiled regex:  ("something", "something", "something")
TUPLE_RX = re.compile(
    r'''\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*\)''')

def extract_triples(raw: str) -> List[Tuple[str, str, str]]:
    """
    Return all well-formed 3-item tuples found in *raw*.
    Truncated / malformed lines are silently skipped.
    """
    return [match for match in TUPLE_RX.findall(raw)]

In [9]:
# Now process NER in parallel
def run_parallel(sample_item):
  try:
    index, sample_item = sample_item  # Unpack the tuple
    note_id = sample_item["idx"]
    full_note_content = sample_item["full_note"]
    extracted_entities = perform_ner(full_note_content)

    # Remove duplicate entities
    extracted_entities = list({entity['text']: entity for entity in extracted_entities}.values())
    print(f"Extracted {len(extracted_entities)} entities from note {note_id}.")

    # Entity Linking in parallel as before
    extracted_entities = list(filter(fetch_entity_descriptions, extracted_entities))

    # Filter only text and description
    extracted_entities = list(map(lambda x: { "text": x['text'], "label": x['label'], "description": x.get('description', '')}, extracted_entities))

    # Continue with the rest of the processing (LLM generation, etc.)
    relationship_extraction_prompt = f""" I would like you to perform a Closed Information Extraction task on the following clinical note {full_note_content}
    You are provided with a list of entities extracted from the note in the form of JSON objects. {extracted_entities}
    Your task is to reason about the full clinical note and the entities provided, and generate at most 20 unique and high quality triplets
    of the form (entity1, relation, entity2) that represent the relationships between the entities in the full clinical note.
    Please return the triplets in the following format:
    [
      (entity1, relation, entity2),
      (entity3, relation, entity4),
      ...
    ]

    Note the response provides should be valid JSON Syntax and should contain no more than 20 triplets.
    The triplets should be relevant to the clinical note and the entities provided.
    The triplets should be unique and not contain any duplicates.
    The triplets should be concise and not contain any unnecessary information.
    The triplets list must end with a closing square bracket.
    """
    messages = [
        {
          "role": "system",
          "content": "You are a helpful medical assistant."
        },
        {
          "role": "user",
          "content": relationship_extraction_prompt
        }]
    inputs = tokenizer.apply_chat_template(
      messages, add_generation_prompt=True
    )

    # Generate text with MLX model
    triplet_str = generate(
      medgemma_model,
      tokenizer,
      prompt=inputs,
      verbose=False,
      max_tokens=512,
    )

    triplet_list = extract_triples(triplet_str)


    cypher_merge_prompt = f"""You are a Cypher query generator.
    You are provided with the following list of triplets {triplet_list} in the form of (entity1, relation, entity2) that represent the relationships between entities in a clinical note.
    Your task is to generate a Cypher query that merges the entities and relationships in the triplets.
    Please return the Cypher query in the following format:
    MERGE (entity1:Entity {{name: "entity1_name"}})
    MERGE (entity2:Entity {{name: "entity2_name"}})
    MERGE (entity1)-[:RELATIONSHIP {{type: "relation_type"}}]->(entity2)
    """
    cypher_messages = [
        {
          "role": "system",
          "content": "You are a Cypher query generator."
        },
        {
          "role": "user",
          "content": cypher_merge_prompt
        }]
    cypher_inputs = tokenizer.apply_chat_template(
      cypher_messages, add_generation_prompt=True
    )

    # Generate Cypher query with MLX model
    cypher_query = generate(
      medgemma_model,
      tokenizer,
      prompt=cypher_inputs,
      verbose=False,
      max_tokens=512,
    )
    # Return the note_id, extracted entities, full note content, and triplet list
    # print(f"Generated {len(triplet_list)} triplets for note {note_id}.")

    return {
      "note_id": note_id,
      "entities": extracted_entities,
      "content": full_note_content,
      "triplets": triplet_list,
      "cypher_query": cypher_query
    }
  except Exception as e:
    return {"note_id": note_id, "error": str(e), "traceback": traceback.format_exc()}


In [10]:
import concurrent.futures

file_path = "data/ner_output.txt"

# Clear previous contents if needed
with open(file_path, "w") as f:
  f.write("Starting extraction process\n")
  f.write("="*50 + "\n\n")

# Prepare the items to process

# Convert dataset to list for processing
items_to_process = list(enumerate(dataset))

# Pick the first 3 items for demonstration
items_to_process = items_to_process[:1] # Increasing length of this list requires more compute resources

print(f"Total items to process: {len(items_to_process)}")
# Execute NER in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
  ner_results = list(executor.map(run_parallel, items_to_process))

# Write results to file
with open(file_path, "a") as f:
    for result in ner_results:
        if "error" in result:
            f.write(f"Error processing note {result['note_id']}: {result['error']}\n")
            f.write(result['traceback'] + "\n")
        else:
            f.write(f"Note ID: {result['note_id']}\n")
            f.write(f"Entities: {result['entities']}\n")
            f.write(f"Content: {result['content']}\n")
            f.write(f"Triplets: {result['triplets']}\n")
            f.write(f"Cypher Query: {result['cypher_query']}\n")
            f.write("="*50 + "\n")

print("Processing complete. Check src/data/ner_output.txt for results.")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Total items to process: 1
Extracted 21 entities from note 5107.
Processing complete. Check src/data/ner_output.txt for results.
