In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from MIMIC_IV_HAIM_API import split_note_document, get_biobert_embeddings

  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)


In [None]:
mm_dir = "/data/wang/junh/datasets/multimodal"
output_dir = os.path.join(mm_dir, "preprocessing")

notes_df = pd.read_pickle(os.path.join(output_dir, "clinic_notes_text.pkl"))
rad_notes_df = pd.read_pickle(os.path.join(output_dir, "notes_text.pkl"))

In [None]:
notes_df.columns

Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text', 'stay_id', 'icu_time_delta',
       'hosp_time_delta'],
      dtype='object')

In [None]:
notes_df.rename(columns={"text": "dis_text"}, inplace=True)

In [None]:
rad_notes_df = rad_notes_df.merge(notes_df[['subject_id','hadm_id', 'note_id', 'dis_text']], on=['subject_id','hadm_id'], how="left")

In [None]:
icu_rad_notes_df = rad_notes_df[rad_notes_df['stay_id'].notna()]

In [None]:
print("Number of ICU radiology notes: ", len(icu_rad_notes_df))
print("Number of unique stays: ", len(icu_rad_notes_df['stay_id'].unique()))
print("Missing discharge summaries: ", icu_rad_notes_df['dis_text'].isna().sum())

Number of ICU radiology notes:  282833
Number of unique stays:  56824
Missing discharge summaries:  3996


In [None]:
icu_rad_notes_df.dropna(subset=['dis_text'], inplace=True)

In [None]:
icu_rad_notes_df.columns

Index(['note_id_x', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text', 'stay_id', 'icu_time_delta',
       'hosp_time_delta', 'note_id_y', 'dis_text'],
      dtype='object')

In [None]:
from tqdm import tqdm
# Set batch size (you can tune this based on your GPU memory)
BATCH_SIZE = 16

# Set device to use GPU with ID 1
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

icu_rad_notes_df['biobert_embeddings'] = None
icu_rad_notes_df['longformer_embeddings'] = None

def process_batch(chunk_batch, device, model_name='biobert'):
    # Process all chunks in the batch at once
    curr_embeddings, _ = get_biobert_embeddings(chunk_batch, device, model_name)
    embeddings_batch = curr_embeddings.detach().cpu().numpy()
    return embeddings_batch

# Process in batches
for index_start in tqdm(range(0, icu_rad_notes_df.shape[0], BATCH_SIZE)):
    index_end = min(index_start + BATCH_SIZE, icu_rad_notes_df.shape[0])
    batch_df = icu_rad_notes_df.iloc[index_start:index_end]

    for index, row in batch_df.iterrows():
        curr_subject_id = int(row['subject_id'])
        curr_note_id = row['note_id_x']
        curr_text = row['text']
        curr_dis_text = row['dis_text']

        # Process 'text' column with BioBERT embeddings (with chunking)
        text_chunk_parse, text_chunk_length = split_note_document(curr_text, 15)
        text_embeddings = []
        for chunk_batch_start in range(0, len(text_chunk_parse), BATCH_SIZE):
            chunk_batch_end = min(chunk_batch_start + BATCH_SIZE, len(text_chunk_parse))
            chunk_batch = text_chunk_parse[chunk_batch_start:chunk_batch_end]
            embeddings_batch = process_batch(chunk_batch, device, model_name='biobert')
            text_embeddings.extend(embeddings_batch)
        
        # Process 'dis_text' column with Longformer embeddings (no chunking)
        dis_text_embeddings = process_batch([curr_dis_text], device, model_name='longformer')

        # Store the results in the DataFrame
        icu_rad_notes_df.at[index, 'biobert_embeddings'] = text_embeddings
        icu_rad_notes_df.at[index, 'longformer_embeddings'] = dis_text_embeddings





cuda:5


100%|██████████| 17428/17428 [36:59:38<00:00,  7.64s/it]   


In [None]:
icu_rad_notes_df.to_pickle(os.path.join(output_dir, "DuoAllchunk_icu_notes_text_embeddings.pkl"))