## Graph Information Extraction from clinical notes


In [1]:
import mlx.core as mx
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig


dataset = load_dataset(
    'AGBonnet/augmented-clinical-notes',
    # 'starmpcc/Asclepius-Synthetic-Clinical-Notes',
    split='train',
    streaming=True
)

sample = dataset.__iter__().__next__()

def nested_print(key, element, level=0):
    if isinstance(element, dict):
        print(f'{"│ "*(level)}├─{key}:')
        for k, v in element.items():
            nested_print(k, v, level+1)
    else:
        print(f'{"│ "*(level)}├─{key}: {element}')

nested_print('sample', sample)

├─sample:
│ ├─idx: 155216
│ ├─note: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements. She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position. She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles. There was a sideways bending of the back in the lumbar region. To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported. Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking. She would require her parents to help her with daily chores, including all activities of self-care.
She had been experiencing these difficulties for the past four months since when she was introduced to olan

In [2]:
# extract idx and full_note
def extract_idx_full_note(sample):
    idx = sample['idx']
    full_note = sample['full_note']
    # idx = sample['patient_id']
    # full_note = sample['note']
    return {
        'idx': idx,
        'full_note': full_note
    }

# extract idx and full_note from the sample
dataset = dataset.map(
    extract_idx_full_note,
    remove_columns=dataset.column_names,
    batch_size=1000)

dataset = dataset.take(5)

# iterate over the dataset and print idx and full_note
for sample in dataset:
    nested_print(sample['idx'], sample['full_note'])

├─155216: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements. She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position. She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles. There was a sideways bending of the back in the lumbar region. To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported. Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking. She would require her parents to help her with daily chores, including all activities of self-care.\nShe had been experiencing these difficulties for the past four months since when she was introduced to olanzapine tablets for the co

In [3]:
# --- NER Function ---
def perform_ner_with_mlx(text_note, ner_tokenizer, ner_mlx_model, ner_id2label_map):
  """Performs NER on a text using a tokenizer, MLX model, and id2label map."""
  inputs = ner_tokenizer(text_note, return_tensors="np", padding="max_length", truncation=True, max_length=512)

  input_ids_mx = mx.array(inputs["input_ids"])
  attention_mask_mx = mx.array(inputs["attention_mask"])

  # Get logits from the MLX model
  logits = ner_mlx_model(input_ids_mx, attention_mask=attention_mask_mx)

  # Get predicted label IDs
  predictions_mx = mx.argmax(logits, axis=-1)[0] # Predictions for the first (and only) item in batch
  predicted_ids = predictions_mx.tolist()

  tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

  entities = []
  current_entity_tokens = []
  current_entity_label_type = None

  for token, pred_id in zip(tokens, predicted_ids):
    # Skip special tokens
    if token in [ner_tokenizer.cls_token, ner_tokenizer.sep_token, ner_tokenizer.pad_token]:
      continue

    label_str = ner_id2label_map.get(pred_id, "O") # Default to "O" if ID not in map

    if label_str.startswith("B-"): # Beginning of a new entity
      if current_entity_tokens: # If there's an existing entity, save it
        entities.append({
          "text": ner_tokenizer.convert_tokens_to_string(current_entity_tokens),
          "label": current_entity_label_type
        })
      current_entity_tokens = [token]
      current_entity_label_type = label_str[2:] # Get label type (e.g., "PROBLEM")
    elif label_str.startswith("I-") and current_entity_label_type == label_str[2:]:
      # Inside an entity of the same type
      current_entity_tokens.append(token)
    else: # Outside an entity (O) or a different entity type
      if current_entity_tokens: # If there's an existing entity, save it
        entities.append({
          "text": ner_tokenizer.convert_tokens_to_string(current_entity_tokens),
          "label": current_entity_label_type
        })
      current_entity_tokens = []
      current_entity_label_type = None

  if current_entity_tokens: # Append any last entity
    entities.append({
      "text": ner_tokenizer.convert_tokens_to_string(current_entity_tokens),
      "label": current_entity_label_type
    })

  return entities

In [4]:
TOKENIZER_AND_CONFIG_SOURCE = "../models/mlx_biomedical-ner"

try:
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_AND_CONFIG_SOURCE)
  config = AutoConfig.from_pretrained(TOKENIZER_AND_CONFIG_SOURCE)
  id2label = config.id2label # {int_id: "LABEL_NAME"}

  def mlx_model_bert(input_ids_mx_array, attention_mask):
      batch_s, seq_l = input_ids_mx_array.shape
      num_l = len(id2label)
      return mx.random.uniform(low=0.0, high=1.0,
                             shape=(batch_s, seq_l, num_l))

  mlx_model = mlx_model_bert
  print(f"Using tokenizer and model from '{TOKENIZER_AND_CONFIG_SOURCE}'")
  print(f"Loaded id2label map with {len(id2label)} labels. First few: {list(id2label.items())[:20]}")
except Exception as e:
  print(f"Error loading tokenizer or config: {e}")
  print("Ensure you have the correct path and files for the tokenizer and config.")
  raise

Using tokenizer and model from '../models/mlx_biomedical-ner'
Loaded id2label map with 84 labels. First few: [(0, 'O'), (1, 'B-Activity'), (2, 'B-Administration'), (3, 'B-Age'), (4, 'B-Area'), (5, 'B-Biological_attribute'), (6, 'B-Biological_structure'), (7, 'B-Clinical_event'), (8, 'B-Color'), (9, 'B-Coreference'), (10, 'B-Date'), (11, 'B-Detailed_description'), (12, 'B-Diagnostic_procedure'), (13, 'B-Disease_disorder'), (14, 'B-Distance'), (15, 'B-Dosage'), (16, 'B-Duration'), (17, 'B-Family_history'), (18, 'B-Frequency'), (19, 'B-Height')]


In [5]:
for sample_item in dataset:
  note_id = sample_item['idx']
  full_note_content = sample_item['full_note']

  if not full_note_content or not full_note_content.strip():
    print(f"Note {note_id}: Skipped (empty).")
    continue

  print(f"\n--- Note ID: {note_id} ---")
  print(f"  Full Note Content: {full_note_content[:100]}...") # Display first 100 chars for brevity
  # Perform NER
  extracted_entities = perform_ner_with_mlx(full_note_content, tokenizer, mlx_model, id2label)

  # For validation, dump note_id, full_note_content, and extracted entities to a file
  with open("data/ner_output.txt", "a") as f:
    f.write(f"Note ID: {note_id}\n")
    f.write(f"Full Note Content: {full_note_content}\n")
    f.write(f"Extracted Entities: {extracted_entities}\n")
    f.write("\n" + "="*50 + "\n")

  # if extracted_entities:
  #   print(f"  Found {len(extracted_entities)} entities:")
  #   for entity_info in extracted_entities[:5]: # Display first 5 entities for brevity
  #     print(f"    Text: \"{entity_info['text']}\", Label: {entity_info['label']}")
  #   if len(extracted_entities) > 5:
  #     print("    ... (more entities found)")
  # else:
  #   print("  No entities found.")

  # print("\nNER processing with MLX finished.")


--- Note ID: 155216 ---
  Full Note Content: A a sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort ...

--- Note ID: 77465 ---
  Full Note Content: This is the case of a 56-year-old man that was complaining of a dump pain on the right back and a sw...

--- Note ID: 133948 ---
  Full Note Content: A 36-year old female patient visited our hospital with a chief complaint of pain and restricted rang...

--- Note ID: 80176 ---
  Full Note Content: A 49-year-old male presented with a complaint of pain in the left proximal forearm after a fall. The...

--- Note ID: 72232 ---
  Full Note Content: A 47-year-old male patient was referred to the rheumatology clinic because of recurrent attacks of p...
