In [None]:
"""
File: XGBoost.ipynb
Code to train and evaluate an XGBoost model on MIMIC-IV FHIR dataset.
"""

def Project():
    """
    __Objectives__
    0. Import data and separate unique visit tokens
    >>> 1. Reduce the number of features (manual selection, hierarchy aggregation)
    2. Create frequency features from event tokens
    >>> 3. Include num_visits, youngest and oldest age, and maybe time
    4. Use label column to create the prediction objective
    >>> 5. Train XGBoost model and evaluate on test dataset
    >>> 6. Use LightGBM and CatBoost

    __Questions__
    0. Why does CEHR-BERT only have 512 possible concept and time tokens? -> Probably most tokens are not present in the sample

    __Extra__
    Hyperparameters: {learning rate (LR), maximum tree depth (max depth), number of estimators (n estimators),
                      column sampling by tree (colsample), row subsampling (subsample) and the regulation parameter α.}
    """
    return ProjectObjectives.__doc__

In [56]:
import os
import scipy
import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
from scipy.sparse import csr_matrix, hstack, vstack, save_npz, load_npz

from tqdm import tqdm

ROOT = '../../data/baseline'
DATA_PATH = f'{ROOT}/patient_sequences.parquet'
SAMPLE_DATA_PATH = f'{ROOT}/CEHR-BERT_sample_patient_sequence.parquet'
FREQ_DF_PATH = f'{ROOT}/patient_feature_freq.csv'
FREQ_MATRIX_PATH = f'{ROOT}/patient_freq_matrix.npz'

In [None]:
data = pd.read_parquet(DATA_PATH)
sample_data = pd.read_parquet(SAMPLE_DATA_PATH)
data

In [None]:
unique_event_tokens = set()

for patient_event_tokens in tqdm(data['event_tokens_updated'].values, desc='Loading Tokens', unit=' Patients'):
    for event_token in patient_event_tokens:
        unique_event_tokens.add(event_token)

unique_event_tokens = list(unique_event_tokens)
unique_event_tokens.sort(reverse=True)

print(f"Complete list of unique event tokens\nLength: {len(unique_event_tokens)}\nHead: {unique_event_tokens[:10]}...")

In [None]:
print(unique_event_tokens)

In [None]:
special_tokens = ['[PAD]', 'VS', 'VE', 'W_0', 'W_1', 'W_2', 'W_3', *[f'M_{i}' for i in range(0,13)], 'LT']
feature_event_tokens = ['id'] + [token for token in unique_event_tokens if token not in special_tokens]

patient_freq_df = pd.DataFrame(columns=feature_event_tokens)
patient_freq_df

In [None]:
"""
Get and save frequencies of each token for each patient sequence.
"""

id2patient = {}
patient_freq_matrix = None
buffer_size = 10000
df_buffer = []
matrix_buffer = []


for idx, patient in tqdm(data.iterrows(), desc='Loading Tokens', unit=' Patients'):
    patient_history = {token:0 for token in feature_event_tokens}
    patient_history['id'] = idx
    id2patient[idx] = patient['patient_id']

    for event_token in patient['event_tokens_updated']:
        if event_token not in special_tokens:
            patient_history[event_token] += 1


    # TODO Code to normalize frequencies


    matrix_buffer.append(list(patient_history.values()))

    if len(matrix_buffer) >= buffer_size:
        current_matrix = csr_matrix(matrix_buffer, shape=(len(matrix_buffer), len(feature_event_tokens)))

        if patient_freq_matrix is None:
            patient_freq_matrix = current_matrix
        else:
            patient_freq_matrix = vstack([patient_freq_matrix, current_matrix], format='csr')

        matrix_buffer = []


if matrix_buffer:
    current_matrix = csr_matrix(matrix_buffer, shape=(len(matrix_buffer), len(feature_event_tokens)))

    if patient_freq_matrix is None:
        patient_freq_matrix = current_matrix
    else:
        patient_freq_matrix = vstack([patient_freq_matrix, current_matrix], format='csr')

    matrix_buffer = []


save_npz(FREQ_MATRIX_PATH, patient_freq_matrix)
print(f"Save & Done! Final Matrix Shape: {patient_freq_matrix.shape}")

In [None]:
patient_freq_matrix = load_npz(FREQ_MATRIX_PATH)
num_patients = patient_freq_matrix.shape[0]
patient_freq_matrix

In [35]:
# Get intuition about the frequency of different features in the dataset
print(np.sum((patient_freq_matrix.getnnz(axis=0) > 1500).astype(int)))
# plt.hist(patient_freq_matrix.getnnz(axis=0), bins=range(num_patients+1), edgecolor='black')
# plt.xlabel('Number of Nonzero Rows')
# plt.ylabel('Number of Columns')
# plt.title('Histogram of Nonzero Rows per Column')
# plt.show()

1023


In [43]:
features_sorted_by_freq = np.argsort(-patient_freq_matrix.getnnz(axis=0))
selected_features = features_sorted_by_freq[1:5001]
selected_features

array([ 2303,  5846,  5570,  5875,  5767,  4446,  5512,  4440,  5467,
        5510,  5877,  5766,  5847,  5648, 15809,  5649,  5466,  5715,
        4214,  4447,  4285,  5511,  5516,  4344,  4213,  5370,  4441,
        4286,  4215,  5468,  5874,  5807,  5647,  5568,  4343,  5806,
        5876,  5369,  5517,  5650,  5515,  5509,  4295,  5465,  5848,
        4448,  4294,  5569, 15786,  5513,  5768,  4287, 16575,  4345,
        5805, 16582,  5765,  4378,  4377,  5371,  4442,  4382,  4383,
        4212,  5518,  5808,  5845,  5878,  4305,  4409,  5571,  5572,
        5716,  5646,  4310,  4306,  5714,  4410,  4449,  4379,  4384,
        4293,  5514,  5849,  5368,  4311,  4288,  5717,  4443, 16117,
        5084, 15815,  4342,  4309, 16471,  4381,  4304,  4216,  5469,
        4380,  4376,  5769, 16364,  4621,  5809,  4411,  5830,  4385,
        4292,  4450,  5859,  4408,  1961,  5831,  4444, 15794,  4289,
        2933,  4516, 16193,  4308, 15904,  4514, 16430, 15816,  4370,
        5906,  5918,

In [49]:
X = patient_freq_matrix[:, selected_features]
Y = data['label'].values

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, stratify=Y, random_state=23)

In [52]:
xgb_model = xgb.XGBClassifier(objective='binary:logistic', random_state=23)
xgb_model.fit(X_train, y_train)

In [57]:
y_train_pred = xgb_model.predict(X_train)
y_test_pred = xgb_model.predict(X_test)
all_data_pred = xgb_model.predict(X)

y_train_accuracy = balanced_accuracy_score(y_train, y_train_pred)
y_test_accuracy = balanced_accuracy_score(y_test, y_test_pred)
all_data_accuracy = balanced_accuracy_score(Y, all_data_pred)

print(f"Balanced Accuracy\nTrain: {y_train_accuracy}  |  Test: {y_test_accuracy}  |  All Data: {all_data_accuracy}")

Accuracy
Train: 0.8364036278920834  |  Test: 0.759330187137921  |  All Data: 0.8209862237693658
