# <u>Overview</u>

The goal of this project is to combine data from a few different tables in the MIMIC-III dataset to train an XGBoost Classifier to predict if a patient is likely to die based only on the data available within 24 hours of their time of admission. Such a classifier can help caregivers triage their patients, allocating vital resources to those with the greatest chance at survival while identifying cases where palliative care is more appropriate.

While a final production model could certainly include a much wider variety of features, the data considered for this proof-of-concept are limited to the ADMISSIONS, PATIENTS, NOTEEVENTS, and PRESCRIPTIONS tables of the MIMIC-III dataset.

From the ADMISSIONS table, we include the admission location, insurance type, ethnicity, and the ground truth label of an in-hospital death. We also merge the gender and age of the patient from the PATIENTS table to these data.

The NOTEEVENTS table includes a wide variety of notes writen by nurses, physicians, radiologists, etc for each hospital admission. We use ClinicalBERT to tokenize and embed these notes before performing PCA on these embeddings and feeding these features into the XGBoost model.

Similarly, we take the names of the medications started within 24 hours of admission from the PRESCRIPTIONS table, tokenize, embed, and use PCA to add these data as features for the XGBoost model as well. 

Finally, we combine all features for each admission and tune an XGBoost model with RandomizedSearchCV to achieve 0.79 PR AUC and 0.85 precision on the true (death) labels. Since life-saving resources may be diverted away from patients whom this model predicts are most likely to die in-hospital, precision on the true class will determine the correctness of specifically these predictions and can be considered the most important metric for this model.

Note: Throughout the notebook, you will see many lines commented-out. This is for the purpose of hiding any cell outputs that would otherwise contain sensitive patient data. Please un-comment these lines before running these cells to view the output when connected to your own MIMIC-III data source.

# <u>Setup</u>

In [3]:
import pandas as pd
import numpy as np
import os

from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, average_precision_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from scipy.stats import uniform, randint

from transformers import AutoTokenizer, AutoModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Replace with your own path to the MIMIC-III dataset
mimic_iii_path = '/mnt/2TB-HDD-Ubuntu/GitHub-Repositories/UT Austin/2025 Spring/AI-Healthcare/MIMIC-III-Dataset/mimic-iii-clinical-database-1.4'

# <u>Import data</u>

In [None]:
# This function will be used to filter out rows from NOTEEVENTS and PRESCRIPTIONS that are timestamped after 24 hours of admission
def filter_df_24h(admissions_df, data_df, log_dt_col:str) -> pd.DataFrame:
    merged_df = pd.merge(
        data_df,
        admissions_df[['HADM_ID', 'ADMITTIME', 'DEATH']],
        on='HADM_ID',
        how='inner'
    )

    # Calculate time difference between CHARTTIME and ADMITTIME
    merged_df['time_diff_hours'] = (merged_df[log_dt_col] - merged_df['ADMITTIME']).dt.total_seconds() / 3600

    # Filter rows where CHARTTIME is within 24 hours after ADMITTIME
    merged_24h_df = merged_df[merged_df['time_diff_hours'] <= 24].copy()

    # Drop the temporary column we created
    merged_24h_df = merged_24h_df.drop(['time_diff_hours'], axis=1)

    return merged_24h_df

## Admissions

In [6]:
admissions_df = pd.read_csv(os.path.join(mimic_iii_path, 'ADMISSIONS.csv'))
# admissions_df # commented out for privacy

Checking the ADMISSION_TYPE values reveals many NEWBORN and ELECTIVE admissions that are not likely relevant for triage. While we filter down admissions_df to include only the helpful columns, we will also drop any rows that are not EMERGENCY or URGENT types. We will also define our gold label DEATH column as a boolean and ensure that the ADMITTIME and DEATHTIME columns are datetime types.

Note that we do not keep the DIAGNOSIS column, although it would be relevant, as there is no way to determine at what time each diagnosis was made, and we therefore could not possibly filter these to the first 24 hours of each admission.

In [7]:
admissions_df['ADMISSION_TYPE'].value_counts()

ADMISSION_TYPE
EMERGENCY    42071
NEWBORN       7863
ELECTIVE      7706
URGENT        1336
Name: count, dtype: int64

In [8]:
admissions_cols = [
    'SUBJECT_ID',
    'HADM_ID',
    'ADMITTIME',
    'DEATHTIME',
    'ADMISSION_LOCATION',
    'INSURANCE',
    'ETHNICITY'
]

emergency_urgent = [
    'EMERGENCY',
    'URGENT'
]

admissions_df_filtered = admissions_df[admissions_df['ADMISSION_TYPE'].isin(emergency_urgent)][admissions_cols].reset_index(drop=True)
admissions_df_filtered['DEATH'] = ~admissions_df_filtered['DEATHTIME'].isna()

admissions_df_filtered['ADMITTIME'] = pd.to_datetime(admissions_df_filtered['ADMITTIME'])
admissions_df_filtered['DEATHTIME'] = pd.to_datetime(admissions_df_filtered['DEATHTIME'])

# admissions_df_filtered # commented out for privacy

## Note Events

In [9]:
noteevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'NOTEEVENTS.csv'))
# noteevents_df # commented out for privacy

  noteevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'NOTEEVENTS.csv'))


First we display a few samples of the notes, and we see a wide variety of formats and information included. This can also be seen by inspecting the CATEGORY values.

In [10]:
# commented out for privacy
# for note in noteevents_df['TEXT'].sample(10, random_state=42):
#     print(note)

In [11]:
noteevents_df['CATEGORY'].value_counts()

CATEGORY
Nursing/other        822497
Radiology            522279
Nursing              223556
ECG                  209051
Physician            141624
Discharge summary     59652
Echo                  45794
Respiratory           31739
Nutrition              9418
General                8301
Rehab Services         5431
Social Work            2670
Case Management         967
Pharmacy                103
Consult                  98
Name: count, dtype: int64

Next we set the type of the datetime columns and remove all rows where ISERROR is not null while filtering down the columns. We also remove any rows with missing CHARTTIMEs, as these make them impossible to filter by time.

In [12]:
noteevents_df['CHARTDATE'] = pd.to_datetime(noteevents_df['CHARTDATE'])
noteevents_df['CHARTTIME'] = pd.to_datetime(noteevents_df['CHARTTIME'])
noteevents_df['STORETIME'] = pd.to_datetime(noteevents_df['STORETIME'])
noteevents_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2083180 entries, 0 to 2083179
Data columns (total 11 columns):
 #   Column       Dtype         
---  ------       -----         
 0   ROW_ID       int64         
 1   SUBJECT_ID   int64         
 2   HADM_ID      float64       
 3   CHARTDATE    datetime64[ns]
 4   CHARTTIME    datetime64[ns]
 5   STORETIME    datetime64[ns]
 6   CATEGORY     object        
 7   DESCRIPTION  object        
 8   CGID         float64       
 9   ISERROR      float64       
 10  TEXT         object        
dtypes: datetime64[ns](3), float64(3), int64(2), object(3)
memory usage: 174.8+ MB


In [13]:
noteevents_df['ISERROR'].value_counts()

ISERROR
1.0    886
Name: count, dtype: int64

In [14]:
noteevents_cols = [
    'SUBJECT_ID',
    'HADM_ID',
    'CHARTTIME',
    'TEXT'
]

noteevents_df_filtered = noteevents_df[
    (noteevents_df['ISERROR'].isna())\
    &\
    (~noteevents_df['CHARTTIME'].isna())
][noteevents_cols].copy()
# noteevents_df_filtered # commented out for privacy

Finally here we filter out any notes with a CHARTTIME more than 24 hours after the time of admission, and we see an average of 6 notes per admission remaining.

In [15]:
noteevents_df_filtered_24h = filter_df_24h(admissions_df_filtered, noteevents_df_filtered, 'CHARTTIME').reset_index(drop=True)
# noteevents_df_filtered_24h # commented out for privacy

In [16]:
noteevents_df_filtered_24h['HADM_ID'].value_counts().mean()

5.981649941352796

## Patients

In [17]:
patients_df = pd.read_csv(os.path.join(mimic_iii_path, 'PATIENTS.csv'))
# patients_df # commented out for privacy

Just as before, we set the datetime DOB (date of birth) column to the proper type and filter down patients_df to just the columns we need. We also convert the string male/female GENDER column into a binary column.

In [18]:
patients_cols = [
    'SUBJECT_ID',
    'GENDER',
    'DOB'
]

patients_df['DOB'] = pd.to_datetime(patients_df['DOB'])

patients_df_filtered = patients_df[patients_cols].copy()

# Convert gender to binary feature (0 for female, 1 for male)
patients_df_filtered['GENDER_BINARY'] = patients_df_filtered['GENDER'].apply(lambda x: 1 if x == 'M' else 0)

# patients_df_filtered # commented out for privacy

## Prescriptions

In [19]:
prescriptions_df = pd.read_csv(os.path.join(mimic_iii_path, 'PRESCRIPTIONS.csv'))
# prescriptions_df # commented out for privacy

  prescriptions_df = pd.read_csv(os.path.join(mimic_iii_path, 'PRESCRIPTIONS.csv'))


Upon first inspection, it's easy to see many missing drug names and codes in prescriptions_df. We inspect the null counts for each identifying column and find that only DRUG has no missing values. This will be the drug identifier column we use.

In [20]:
prescriptions_df[['DRUG', 'DRUG_NAME_POE', 'DRUG_NAME_GENERIC', 'GSN', 'NDC']].isna().sum()

DRUG                       0
DRUG_NAME_POE        1664234
DRUG_NAME_GENERIC    1662989
GSN                   507164
NDC                     4463
dtype: int64

In [21]:
prescriptions_df['DRUG'].value_counts()

DRUG
Potassium Chloride      192993
Insulin                 143465
D5W                     142241
Furosemide              133122
0.9% Sodium Chloride    130147
                         ...  
Renaphro                     1
Morphine Sulfat              1
humulin R                    1
Meperidine PCA               1
rasagiline (Azilect)         1
Name: count, Length: 4525, dtype: int64

As we've done many times before, we convert STARTDATE to a datetime, filter the columns down, and then filter out all rows more than 24 hours after the admission datetime, leaving us with an average of 32 prescriptions for each HADM_ID.

In [22]:
prescriptions_cols = [
    'HADM_ID',
    'STARTDATE',
    'DRUG'
]

prescriptions_df['STARTDATE'] = pd.to_datetime(prescriptions_df['STARTDATE'])


prescriptions_df_filtered = prescriptions_df[prescriptions_cols].copy()
# prescriptions_df_filtered # commented out for privacy

In [23]:
prescriptions_df_filtered_24h = filter_df_24h(admissions_df_filtered, prescriptions_df_filtered, 'STARTDATE').reset_index(drop=True)
# prescriptions_df_filtered_24h # commented out for privacy

In [24]:
prescriptions_df_filtered_24h['HADM_ID'].value_counts().mean()

31.953456674827734

# <u>Feature Engineering</u>

## Methods

In [25]:
# This function will be used after aggregating the NOTEEVENTS and PRESCRIPTIONS features by HADM_ID
# This is simply to verify the aggregation process was successful
def assert_unique_HADM_ID(df):
    assert df['HADM_ID'].nunique() == len(df), 'Some HADM_IDs have multiple rows'

## Admissions-Patients

### Join patients

First we join the PATIENTS data to admissions_df_filtered by SUBJECT_ID. This will attach the patients' DOB and gender to each of their ADMISSIONS rows. 

We then calcuate the age at the time of admission by subtracting their DOB from the admission datetime. This ended up being more complicated than expected, as roughly 2500 admissions were paired with patients' DOBs in the 1800s -- in each case almost exactly 300 years before their admission. Obviously something went wrong with either the data entry or the anonymization transformations applied to the DOB values. These rows were dropped from the resulting table and the remaining ages were inspected to make sure they seemed reasonable with a mean age of 62 and a maximum of 89.

In [83]:
admissions_patients_df = pd.merge(
    admissions_df_filtered,
    patients_df_filtered,
    on='SUBJECT_ID',
    how='inner'
)
# admissions_patients_df # commented out for privacy

In [None]:
from pandas._libs.tslibs.np_datetime import OutOfBoundsDatetime

# I'll save a single error message as an example, so you don't have to read thousands of lines of output
example_error = None

# Calculate age safely with error handling for extreme date ranges
def calculate_age(admit_date, birth_date):
    global example_error
    try:
        return (admit_date - birth_date).days / 365
    except (OverflowError, OutOfBoundsDatetime):
        # If dates are too extreme, check if year difference is reasonable
        year_diff = admit_date.year - birth_date.year
        if 0 <= year_diff <= 120:  # Reasonable age range
            return year_diff
        else:
            # print(f'HADM_ID: {admit_date} - {birth_date} has unreasonable age difference: {year_diff}')
            example_error = f'HADM_ID: {admit_date} - {birth_date} has unreasonable age difference: {year_diff}'
            return float('nan')  # Return NaN for unreasonable values

# Apply the safer calculation
admissions_patients_df['AGE_YRS'] = admissions_patients_df.apply(
    lambda row: calculate_age(row['ADMITTIME'], row['DOB']), axis=1
)

print(example_error)

HADM_ID: 2178-09-17 18:31:00 - 1878-09-17 00:00:00 has unreasonable age difference: 300


In [28]:
admissions_patients_df['AGE_YRS'].isna().sum()

2507

In [29]:
admissions_patients_df.dropna(subset=['AGE_YRS'], inplace=True, ignore_index=True)
admissions_patients_df['AGE_YRS'].describe()

count    40900.000000
mean        61.983834
std         17.691616
min          0.000000
25%         51.035616
50%         64.230137
75%         76.394521
max         89.060274
Name: AGE_YRS, dtype: float64

### Feature Extraction

For each column in the merged DataFrame that we plan to use as features, we print out the number of unique values. 5 INSURANCE values or 9 ADMISSION_LOCATION values are manageable for categorical features, but 41 distinct ETHNICITY values definitely seems excessive. 

Upon further inspection, we see that the ETHNICITY field was never standardized, as there are many overlapping ethnicities and multiple ethinicities combined into many values. There are also various 'OTHER' or 'DECLINED TO ANSWER' values that provide no information at all.

There are probably better ways to do this, but for now we just define five ethnicity 'buckets' to categorize these values into. We assign these ethnicity buckets simply by if they appear in the ETHNICITY value, in order of commonality. 

In [30]:
admissions_patients_feature_cols = [
    # Admissions features
    'SUBJECT_ID',
    'HADM_ID',
    'ADMISSION_LOCATION',
    'INSURANCE',
    'ETHNICITY',
    # Patients features
    'GENDER_BINARY',
    'AGE_YRS'
]
for column in admissions_patients_feature_cols:
    print(f'{column} unique values: {admissions_patients_df[column].nunique()}')

SUBJECT_ID unique values: 31663
HADM_ID unique values: 40900
ADMISSION_LOCATION unique values: 9
INSURANCE unique values: 5
ETHNICITY unique values: 41
GENDER_BINARY unique values: 2
AGE_YRS unique values: 18887


In [31]:
admissions_df_filtered['ETHNICITY'].value_counts(normalize=True)

ETHNICITY
WHITE                                                       0.698044
BLACK/AFRICAN AMERICAN                                      0.097450
UNKNOWN/NOT SPECIFIED                                       0.081876
HISPANIC OR LATINO                                          0.026977
OTHER                                                       0.022231
UNABLE TO OBTAIN                                            0.016173
ASIAN                                                       0.016034
PATIENT DECLINED TO ANSWER                                  0.006819
HISPANIC/LATINO - PUERTO RICAN                              0.004700
ASIAN - CHINESE                                             0.004354
BLACK/CAPE VERDEAN                                          0.003410
WHITE - RUSSIAN                                             0.003340
BLACK/HAITIAN                                               0.002143
MULTI RACE ETHNICITY                                        0.002050
HISPANIC/LATINO - DOMINI

In [32]:
ethnicity_buckets = [
    'WHITE',
    'BLACK',
    'HISPANIC',
    'ASIAN'
]

def map_ethnicity(ethnicity_str):
    ethnicity_upper = ethnicity_str.upper()
    for bucket in ethnicity_buckets:
        if bucket in ethnicity_upper:
            return bucket
    return 'OTHER'


admissions_patients_df['ETHNICITY_BUCKET'] = admissions_patients_df['ETHNICITY'].apply(map_ethnicity)
admissions_patients_df['ETHNICITY_BUCKET'].value_counts(normalize=True)

admissions_patients_feature_cols.append('ETHNICITY_BUCKET')
admissions_patients_feature_cols.remove('ETHNICITY')

Finally for these data, we convert all features except AGE_YRS into categorical type columns that can be parsed by XGBoost without much more work, add the DEATH label back onto the DataFrame, and assert that each HADM_ID has its own row.

In [33]:
admissions_patients_features_df = admissions_patients_df[admissions_patients_feature_cols].copy()

for col in admissions_patients_feature_cols:
    if col != 'AGE_YRS':  # Only convert non-numeric columns to category
        admissions_patients_features_df[col] = admissions_patients_features_df[col].astype('category')

admissions_patients_features_df['DEATH'] = admissions_patients_df['DEATH']
# admissions_patients_features_df # commented out for privacy

In [34]:
admissions_patients_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40900 entries, 0 to 40899
Data columns (total 8 columns):
 #   Column              Non-Null Count  Dtype   
---  ------              --------------  -----   
 0   SUBJECT_ID          40900 non-null  category
 1   HADM_ID             40900 non-null  category
 2   ADMISSION_LOCATION  40900 non-null  category
 3   INSURANCE           40900 non-null  category
 4   GENDER_BINARY       40900 non-null  category
 5   AGE_YRS             40900 non-null  float64 
 6   ETHNICITY_BUCKET    40900 non-null  category
 7   DEATH               40900 non-null  bool    
dtypes: bool(1), category(6), float64(1)
memory usage: 3.3 MB


In [35]:
assert_unique_HADM_ID(admissions_patients_features_df)

## Note Events

### Extract Embeddings

Here we load the clinicalBERT model from HuggingFace and write a function that will return the embeddings for a batch of input strings. 

Note here that the `tokenizer_max_length` is set to 512 as that is the maximum input size for clinicalBERT. We also use `truncation=True` in the tokenizer to ensure that no token sequences will exceed this length, but unfortunately this does mean that many (almost half) of our notes are getting truncated to some extent. This can be seen by tokenizing a random sample of 1000 notes and inspecting their statistics, showing a median token sequence length 459.

Also note that each embedding output from clinicalBERT will be of shape (768,).

In [36]:
clinicalBERT = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
clinicalBERT_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

def get_bert_embeddings(texts:list[str], model, tokenizer, batch_size=32, tokenizer_max_length=512) -> np.ndarray:
    if torch.cuda.is_available():
        print('Using CUDA')
        model = model.cuda()

    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch_idx = i // batch_size
        if batch_idx % 20 == 0: # Print every 50 batches
            print(f'Processing batch {batch_idx} of {len(texts) // batch_size}')
            
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, padding="max_length", truncation=True, 
                          max_length=tokenizer_max_length, return_tensors="pt")  # Reduced max_length
        
        # Move to GPU if available
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
            
        with torch.no_grad():
            outputs = model(**inputs)
        
        batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        embeddings.append(batch_embeddings)

    return np.vstack(embeddings)

In [37]:
example_text = "Patient presented with hypotension and low oxygen saturation."
embedding = get_bert_embeddings([example_text], clinicalBERT, clinicalBERT_tokenizer, batch_size=1)
embedding[0].shape

Using CUDA
Processing batch 0 of 1


(768,)

In [38]:
noteevents_test_tokens = noteevents_df_filtered_24h['TEXT'].sample(1000, random_state=42)\
    .apply(lambda x: clinicalBERT_tokenizer.encode(x, truncation=False))
# noteevents_test_tokens # commented out for privacy

In [39]:
print(f'Minimum token length: {noteevents_test_tokens.apply(len).min()}')
print(f'Maximum token length: {noteevents_test_tokens.apply(len).max()}')
print(f'Mean token length: {noteevents_test_tokens.apply(len).mean()}')
print(f'Median token length: {noteevents_test_tokens.apply(len).median()}')

Minimum token length: 13
Maximum token length: 5069
Mean token length: 673.372
Median token length: 459.0


We use the following cell to generate the embeddings for all notes in our dataset, but note here that this took 24 minutes to process 229486 notes on a RTX 4070 Super GPU. To avoid re-calculating these embeddings, I've also included code here to save and load them as a numpy file.

In [40]:
# noteevents_embeddings = get_bert_embeddings_batched( # uncomment to run
#     list(noteevents_df_filtered_24h['TEXT']),
#     clinicalBERT,
#     clinicalBERT_tokenizer,
#     batch_size=256 # Takes about 10GB of GPU memory
# )
# noteevents_embeddings

In [41]:
noteevents_embeddings_path = os.path.join('feature-embeddings', 'noteevents_embeddings.npy')

# np.save(noteevents_embeddings_path, noteevents_embeddings)
# print(f"Saved embeddings to {noteevents_embeddings_path}")

In [42]:
noteevents_embeddings = np.load(noteevents_embeddings_path)
noteevents_embeddings.shape

(229486, 768)

### Aggregate Embeddings

Now that we have embeddings for each note, we still have about 6 notes per hospital admission. We need to aggregate these embeddings by taking the average and count for each HADM_ID.

In [43]:
noteevents_embeddings_df = pd.DataFrame({
    'HADM_ID': noteevents_df_filtered_24h['HADM_ID'],
    'NOTE_EMBEDDING': list(noteevents_embeddings),
    'DEATH': noteevents_df_filtered_24h['DEATH']
})
# noteevents_embeddings_df # commented out for privacy

In [44]:
noteevents_embeddings_df_avg = noteevents_embeddings_df.groupby('HADM_ID').agg(
    NOTE_EMBEDDING_AVG=('NOTE_EMBEDDING', lambda x: np.mean(np.vstack(x), axis=0)),
    NOTE_COUNT=('NOTE_EMBEDDING', 'count'),  # Count number of notes per HADM_ID
    DEATH=('DEATH', 'first')  # Keep the first DEATH value (all are the same per HADM_ID)
).reset_index()
# noteevents_embeddings_df_avg # commented out for privacy

### PCA Embeddings

Finally, remember that each embedding is of shape (768,), which is far too large to reasonably put into an XGBoost as raw features. To make the dimension more manageable, we'll perform PCA on each HADM_ID's average note embedding and take the top 10 components to retain 86% of the embeddings' variance. These 10 components and the NOTE_COUNT will make up the final features from our NOTEEVENTS data.

In [45]:
# Stack the embeddings into a numpy array
noteevents_embeddings_array = np.vstack(noteevents_embeddings_df_avg['NOTE_EMBEDDING_AVG'].values)

# Apply PCA to reduce dimensionality
pca = PCA(n_components=10) # 10 components retain 86% variance
noteevents_reduced_embeddings = pca.fit_transform(noteevents_embeddings_array)

# Create DataFrame with reduced embeddings
noteevents_features_df = pd.DataFrame(
    noteevents_reduced_embeddings, 
    columns=[f'note_pca_{i}' for i in range(noteevents_reduced_embeddings.shape[1])]
)

# Add back HADM_ID, NOTE_COUNT, 7 DEATH columns
noteevents_features_df['HADM_ID'] = noteevents_embeddings_df_avg['HADM_ID'].astype('category')
noteevents_features_df['NOTE_COUNT'] = noteevents_embeddings_df_avg['NOTE_COUNT']
noteevents_features_df['DEATH'] = noteevents_embeddings_df_avg['DEATH']

print(f"Variance explained by PCA components: {pca.explained_variance_ratio_.sum():.4f}")
# noteevents_features_df # commented out for privacy

Variance explained by PCA components: 0.8556


In [46]:
assert_unique_HADM_ID(noteevents_features_df)

## Prescriptions

### Feature extraction

Just like we did to get the embedding features for NOTEEVENTS, we do almost the same to get the features for the DRUG names in PRESCRIPTIONS. The only noteable difference is that the embeddings for the DRUG names are much shorter than the notes, allowing us to drop the `tokenizer_max_length` down to 32 and make the embedding processing much more efficient.

In [47]:
prescriptions_feature_cols = [
    'HADM_ID',
    'DRUG'
]
prescriptions_features_df_many = prescriptions_df_filtered_24h[prescriptions_feature_cols].copy()

prescriptions_features_df_many['DEATH'] = prescriptions_df_filtered_24h['DEATH']
# prescriptions_features_df_many # commented out for privacy

### Get ClinicalBERT Embeddings

In [48]:
prescriptions_test_tokens = prescriptions_features_df_many['DRUG'].sample(1000, random_state=42)\
    .apply(lambda x: clinicalBERT_tokenizer.encode(x, truncation=False))
# prescriptions_test_tokens # commented out for privacy

In [49]:
print(f'Minimum token length: {prescriptions_test_tokens.apply(len).min()}')
print(f'Maximum token length: {prescriptions_test_tokens.apply(len).max()}')
print(f'Mean token length: {prescriptions_test_tokens.apply(len).mean()}')
print(f'Median token length: {prescriptions_test_tokens.apply(len).median()}')

Minimum token length: 3
Maximum token length: 22
Mean token length: 7.182
Median token length: 7.0


In [None]:
# prescriptions_embeddings = get_bert_embeddings_batched( # uncomment to run
#     list(prescriptions_features_df_many['DRUG']),
#     clinicalBERT,
#     clinicalBERT_tokenizer,
#     tokenizer_max_length=32, # Reduce max_length to 32 to match max token length of 22
#     batch_size=256 # Takes about 10GB of GPU memory
# )
# prescriptions_embeddings # commented out for privacy

In [51]:
prescriptions_embeddings_path = os.path.join(os.getcwd(), 'feature-embeddings', 'prescriptions_embeddings.npy')

# np.save(prescriptions_embeddings_path, prescriptions_embeddings)
# print(f"Saved embeddings to {prescriptions_embeddings_path}")

In [52]:
prescriptions_embeddings = np.load(prescriptions_embeddings_path)
prescriptions_embeddings.shape

(1265964, 768)

In [53]:
prescriptions_embeddings_df = pd.DataFrame({
    'HADM_ID': prescriptions_features_df_many['HADM_ID'],
    'DRUG_EMBEDDING': list(prescriptions_embeddings),
    'DEATH': prescriptions_features_df_many['DEATH']
})
# prescriptions_embeddings_df # commented out for privacy

### Aggregate embeddings

After embedding each DRUG name, again we aggregate by averaging the embeddings for each HADM_ID and apply PCA to these embedding averages, keeping the top 10 components for 72% of the variance.

In [54]:
prescriptions_embeddings_df_avg = prescriptions_embeddings_df.groupby('HADM_ID').agg(
    DRUG_EMBEDDING_AVG=('DRUG_EMBEDDING', lambda x: np.mean(np.vstack(x), axis=0)),
    DRUG_COUNT=('DRUG_EMBEDDING', 'count'),  # Count number of drugs per HADM_ID
    DEATH=('DEATH', 'first')  # Keep the first DEATH value (all are the same per HADM_ID)
).reset_index()

# prescriptions_embeddings_df_avg # commented out for privacy

### PCA Embeddings

In [55]:
# Stack the embeddings into a numpy array
prescriptions_embeddings_array = np.vstack(prescriptions_embeddings_df_avg['DRUG_EMBEDDING_AVG'].values)

# Apply PCA to reduce dimensionality
pca = PCA(n_components=10) # 10 components retain 72% variance
prescriptions_reduced_embeddings = pca.fit_transform(prescriptions_embeddings_array)

# Create DataFrame with reduced embeddings
prescriptions_features_df = pd.DataFrame(
    prescriptions_reduced_embeddings, 
    columns=[f'rx_pca_{i}' for i in range(prescriptions_reduced_embeddings.shape[1])]
)

# Add back HADM_ID, DRUG_COUNT, & DEATH columns
prescriptions_features_df['HADM_ID'] = prescriptions_embeddings_df_avg['HADM_ID'].astype('category')
prescriptions_features_df['DRUG_COUNT'] = prescriptions_embeddings_df_avg['DRUG_COUNT']
prescriptions_features_df['DEATH'] = prescriptions_embeddings_df_avg['DEATH']

print(f"Variance explained by PCA components: {pca.explained_variance_ratio_.sum():.4f}")
# prescriptions_features_df # commented out for privacy

Variance explained by PCA components: 0.7163


In [56]:
assert_unique_HADM_ID(prescriptions_features_df)

# <u>Tune Final XGBoost Classifier</u>

We're almost done. Now we just need to merge our feature sets together by HADM_ID before plugging into an XGBoost model and tuning to get the highest precision possible.

## Merge feature sets

We merge the features in two steps, first joining the noteevents features to the admissions, then joining the prescription features to this merged feature set. We then do a final assertion to ensure that HADM_ID is unique to each row and a last check on our column datatypes.

In [57]:
merged_features_df = pd.merge(
    admissions_patients_features_df,
    noteevents_features_df,
    on=['HADM_ID', 'DEATH'],
    how='inner'
)
# merged_features_df # commented out for privacy

In [58]:
merged_features_df = pd.merge(
    merged_features_df,
    prescriptions_features_df,
    on=['HADM_ID', 'DEATH'],
    how='inner'
)
# merged_features_df # commented out for privacy

In [59]:
assert_unique_HADM_ID(merged_features_df)

In [None]:
merged_features_df['HADM_ID'] = merged_features_df['HADM_ID'].astype('category')
merged_features_df.info()

## Tune and eval XGBoost

For the very last step we write a script that first tunes the hyperparameters of an `XGBClassifier` using `RandomizedSearchCV` to iterate over 5 folds 30 times each, then takes the best model parameters and evaluates its performance on a 80/20 train-test split of our entire dataset. 

There are some details to pay attention to here. Firstly, notice that our dataset has a massive class imbalance, with only 11.9% True labels. We address this by setting the `scale_pos_weight` of the XGBoost Classifier accordingly. We also make sure to pay attention to weighted performance metrics when possible. 

Second is that many of our features are of categorical type. If we were training another model, like logistic regression, we would have to encode these values somehow, but luckily `XGBClassifier` supports categorical variables natively as long as we set `enable_categorical = True`. 

Third, we set `RandomizedSearchCV` to optimize for precision, since specifically precision on the true labels is the most important metric for this model. 

Fourth and finally, we display the feature importances for the best model for the purpose of interpretability. While we can't dive into the black box of the embedding PCA components, we can see that the top components for the notes and the prescriptions were two of the most impactful features in predicting their in-hospital death, along with the patient's age and ethnicity.

In [64]:
merged_deathrate = merged_features_df['DEATH'].mean()
merged_deathrate

0.11880769926989682

In [65]:
# Prepare feature and target columns
feature_cols = [col for col in merged_features_df.columns if col not in ['HADM_ID', 'SUBJECT_ID', 'DEATH']]
X = merged_features_df[feature_cols]
y = merged_features_df['DEATH']

# Define the hyperparameter search space for randomized search
param_dist = {
    'max_depth': randint(3, 10),
    'learning_rate': uniform(0.01, 0.2),
    'n_estimators': randint(100, 500),
    'min_child_weight': randint(1, 10),
    'gamma': uniform(0, 0.5),
    'subsample': uniform(0.6, 0.4),
    'colsample_bytree': uniform(0.6, 0.4),
    'scale_pos_weight': uniform(1, 10),
    'reg_alpha': uniform(0, 1),
    'reg_lambda': uniform(1, 10)
}

# Create XGBoost classifier
xgb_model = XGBClassifier(
    objective='binary:logistic',
    eval_metric='aucpr',
    random_state=42,
    enable_categorical = True,
    scale_pos_weight = 1 / merged_deathrate  # Adjust class imbalance
)

# Set up cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Set up scoring metrics
scoring = {
    'precision': 'precision',
    'recall': 'recall',
    'f1': 'f1',
    'auc': 'roc_auc'
}

# Set up RandomizedSearchCV with precision as the primary optimization metric
random_search = RandomizedSearchCV(
    estimator=xgb_model,
    param_distributions=param_dist,
    n_iter=30,  # Number of parameter settings sampled
    scoring=scoring,
    refit='precision',  # Optimize for precision
    cv=cv,
    verbose=1,
    random_state=42,
    return_train_score=True
)

# Fit the model
random_search.fit(X, y)

# Get the best parameters and results
best_params = random_search.best_params_
best_score = random_search.best_score_

print("\n====================")
print(f"Best Precision Score: {best_score:.4f}")
print("Best Parameters:")
for param, value in best_params.items():
    print(f"{param}: {value}")

# Get all the cross-validation results for the best model
cv_results = random_search.cv_results_
best_index = random_search.best_index_

# Train final model with best parameters
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

final_model = XGBClassifier(
    **best_params,
    objective='binary:logistic',
    random_state=42,
    enable_categorical = True,
    tree_method = 'hist'  # Required for categorical feature support
)
final_model.fit(X, y)

# Evaluate the model
y_pred = final_model.predict(X_test)

print("\nBest Model Metrics:")
print('Accuracy:', accuracy_score(y_test, y_pred))
print(f"ROC AUC: {roc_auc_score(y_test, final_model.predict_proba(X_test)[:,1]):.4f}")
print(f"PR AUC: {average_precision_score(y_test, final_model.predict_proba(X_test)[:,1], average='weighted'):.4f}")
print(classification_report(y_test, y_pred))

# Print feature importances
importances = final_model.feature_importances_
feat_importances = pd.DataFrame({
    'feature': X.columns,
    'importance': importances
}).sort_values(by='importance', ascending=False)
print('Feature importances:')
display(feat_importances)


Fitting 5 folds for each of 30 candidates, totalling 150 fits

Best Precision Score: 0.5272
Best Parameters:
colsample_bytree: 0.7727780074568463
gamma: 0.14561457009902096
learning_rate: 0.1323705789444759
max_depth: 4
min_child_weight: 3
n_estimators: 463
reg_alpha: 0.5142344384136116
reg_lambda: 6.924145688620425
scale_pos_weight: 1.4645041271999772
subsample: 0.8430179407605753

Best Model Metrics:
Accuracy: 0.9291101055806938
ROC AUC: 0.9575
PR AUC: 0.7909
              precision    recall  f1-score   support

       False       0.94      0.99      0.96      5834
        True       0.85      0.50      0.63       796

    accuracy                           0.93      6630
   macro avg       0.89      0.74      0.79      6630
weighted avg       0.92      0.93      0.92      6630

Feature importances:


Unnamed: 0,feature,importance
8,note_pca_3,0.102528
3,AGE_YRS,0.07241
18,rx_pca_2,0.056205
4,ETHNICITY_BUCKET,0.047397
17,rx_pca_1,0.044369
20,rx_pca_4,0.042947
12,note_pca_7,0.041543
1,INSURANCE,0.039738
5,note_pca_0,0.035433
11,note_pca_6,0.03501
