In [None]:
"""
File: XGBoost.ipynb
Code to train and evaluate an XGBoost model on MIMIC-IV FHIR dataset.
"""


def Project():
    """
    __Objectives__
    0. Import data and separate unique visit tokens
    1. Reduce the number of features (manual selection, hierarchy aggregation)
    2. Create frequency features from event tokens
    3. Include num_visits, youngest and oldest age, and maybe time
    4. Use label column to create the prediction objective
    5. Train XGBoost model and evaluate on test dataset

    __Questions__
    0. Why does CEHR-BERT only have 512 possible concept and time tokens? -> Probably most tokens are not present in the sample

    __Extra__
    Hyperparameters: {learning rate (LR), maximum tree depth (max depth), number of estimators (n estimators),
                      column sampling by tree (colsample), row subsampling (subsample) and the regulation parameter α.}
    """
    return ProjectObjectives.__doc__

In [None]:
import os

ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)
import sys
import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib.pyplot as plt

from sklearn.preprocessing import MaxAbsScaler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score
from sklearn.metrics import (
    f1_score,
    roc_curve,
    auc,
    precision_recall_curve,
    roc_auc_score,
    average_precision_score,
)
from scipy.sparse import csr_matrix, hstack, vstack, save_npz, load_npz

from tqdm import tqdm

%matplotlib inline

DATA_ROOT = f"{ROOT}/data"
DATA_PATH = f"{DATA_ROOT}/slurm_data/one_month/pretrain.parquet"
FINE_TUNE_PATH = f"{DATA_ROOT}/slurm_data/one_month/fine_tune.parquet"
FREQ_MATRIX_PATH = f"{DATA_ROOT}/slurm_data/one_month/patient_freq_matrix.npz"

In [None]:
# Load data
pretrain_data = pd.read_parquet(DATA_PATH)
pretrain_data = pretrain_data[pretrain_data['event_tokens_2048'].notnull()]
pretrain_data

In [None]:
# Load data
finetune_data = pd.read_parquet(FINE_TUNE_PATH)
finetune_data = finetune_data[finetune_data['event_tokens_2048'].notnull()]
finetune_data

In [52]:
data = pd.concat((pretrain_data, finetune_data)).reset_index().drop_duplicates(subset='index', keep='first').set_index('index')
data

Unnamed: 0_level_0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,label
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0,f8f3289a-057f-5fcc-a714-5f6109ca16c4,2,0,,,1,4,"[[CLS], [VS], 8938, [VE], [PAD], [PAD], [PAD],...","[1, 2, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 18, 18, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 8262, 8262, 8262, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 2, 2, 2, 2049, 2049, 2049, 2049, 2049, 204...",0
1,9b62c9f4-3fdc-5020-82b5-ae5b8292445a,4,0,,,43,52,"[[CLS], [VS], 7569, 66689036430, 00904224461, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 3, 4, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 5963, 5963, 5963, 5963, 5963, 5963, 5963, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
2,2ca522eb-dd89-5f79-8155-9599ea46b0b2,2,1,244.0,242.0,51,54,"[[CLS], [VS], 00904629261, 00904642281, 009046...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 8016, 8016, 8016, 8016, 8016, 8016, 8016, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
4,02adf8a6-8bc0-55d3-81ae-4d8582094896,9,1,20.0,11.0,640,664,"[[CLS], [VS], 51079045420, 00006494300, 177140...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, ...","[0, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65...","[0, 8002, 8002, 8002, 8002, 8002, 8002, 8002, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1
5,744fe3c4-9b03-55ae-ac9f-6bc4e967cde7,3,0,,,80,86,"[[CLS], [VS], 7813, 7813, 7902, 7902, 9604, 00...","[1, 2, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29...","[0, 7582, 7582, 7582, 7582, 7582, 7582, 7582, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59898,3dafdc7c-c80b-56f0-832a-3ab7bf5667cc,13,1,11.0,10.0,1956,1992,"[[CLS], [VS], 8611, 00574200202, 00008084199, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75...","[0, 8754, 8754, 8754, 8754, 8754, 8754, 8754, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1
164563,17cd52f2-9a32-5b2d-aec8-bbaefce55e7b,4,1,16.0,5.0,1877,1886,"[[CLS], [VS], 5491, 63739035410, 61958150101, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58...","[0, 6777, 6777, 6777, 6777, 6777, 6777, 6777, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1
22291,052ca40b-d12e-5390-a9f1-70a6edbfa162,2,1,63.0,24.0,1879,1882,"[[CLS], [VS], 5A1955Z, 0B9B8ZX, 3E0G76Z, 02100...","[1, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[0, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89...","[0, 6083, 6083, 6083, 6083, 6083, 6083, 6083, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1
28930,472730fd-189f-524c-a649-1f9d184e80c7,6,1,2.0,0.0,1090,1105,"[[CLS], [VS], 00904516561, 57896042101, 764390...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79...","[0, 6498, 6498, 6498, 6498, 6498, 6498, 6498, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1


In [53]:
# Find the unique set of all possible tokens, including special tokens
unique_event_tokens = set()

for patient_event_tokens in tqdm(
    data["event_tokens_2048"].values, desc="Loading Tokens", unit=" Patients"
):
    for event_token in patient_event_tokens:
        unique_event_tokens.add(event_token)

unique_event_tokens = list(unique_event_tokens)
unique_event_tokens.sort(reverse=True)

print(
    f"Complete list of unique event tokens\nLength: {len(unique_event_tokens)}\nHead: {unique_event_tokens[:30]}..."
)

Loading Tokens: 100%|██████████| 170669/170669 [00:41<00:00, 4144.33 Patients/s]

Complete list of unique event tokens
Length: 19915
Head: ['[W_3]', '[W_2]', '[W_1]', '[W_0]', '[VS]', '[VE]', '[PAD]', '[M_9]', '[M_8]', '[M_7]', '[M_6]', '[M_5]', '[M_4]', '[M_3]', '[M_2]', '[M_1]', '[M_12]', '[M_11]', '[M_10]', '[M_0]', '[LT]', '[CLS]', 'XY0VX83', 'XW0DXR5', 'XW0DX82', 'XW043C3', 'XW043B3', 'XW04351', 'XW033H4', 'XW033B3']...





In [54]:
special_tokens = [
    "[CLS]",
    "[PAD]",
    "[VS]",
    "[VE]",
    "[W_0]",
    "[W_1]",
    "[W_2]",
    "[W_3]",
    *[f"[M_{i}]" for i in range(0, 13)],
    "[LT]",
]
feature_event_tokens = ["id"] + [
    token for token in unique_event_tokens if token not in special_tokens
]

print(feature_event_tokens[:20])

['id', 'XY0VX83', 'XW0DXR5', 'XW0DX82', 'XW043C3', 'XW043B3', 'XW04351', 'XW033H4', 'XW033B3', 'XW03372', 'XW03331', 'X2RF332', 'X2RF032', 'X2C1361', 'X2C0361', 'X2A5312', 'HZ99ZZZ', 'HZ87ZZZ', 'HZ85ZZZ', 'HZ81ZZZ']


In [None]:
###  Get and save frequencies of each token for each patient sequence.  ###

patient_freq_matrix = None
buffer_size = 50000
df_buffer = []
matrix_buffer = []


for idx, patient in tqdm(data.iterrows(), desc="Loading Tokens", unit=" Patients"):
    patient_history = {token: 0 for token in feature_event_tokens}
    patient_history["id"] = idx

    for event_token in patient["event_tokens_2048"]:
        if event_token not in special_tokens:
            patient_history[event_token] += 1

    matrix_buffer.append(list(patient_history.values()))

    if len(matrix_buffer) >= buffer_size:
        current_matrix = csr_matrix(
            matrix_buffer, shape=(len(matrix_buffer), len(feature_event_tokens))
        )

        if patient_freq_matrix is None:
            patient_freq_matrix = current_matrix
        else:
            patient_freq_matrix = vstack(
                [patient_freq_matrix, current_matrix], format="csr"
            )

        matrix_buffer = []


if matrix_buffer:
    current_matrix = csr_matrix(
        matrix_buffer, shape=(len(matrix_buffer), len(feature_event_tokens))
    )

    if patient_freq_matrix is None:
        patient_freq_matrix = current_matrix
    else:
        patient_freq_matrix = vstack(
            [patient_freq_matrix, current_matrix], format="csr"
        )

    matrix_buffer = []


save_npz(FREQ_MATRIX_PATH, patient_freq_matrix)
print(f"Save & Done! Final Matrix Shape: {patient_freq_matrix.shape}")

Loading Tokens: 3945 Patients [00:10, 386.67 Patients/s]

In [None]:
# Load frequency matrix
# patient_freq_matrix = load_npz(FREQ_MATRIX_PATH)
num_patients = patient_freq_matrix.shape[0]
patient_freq_matrix

In [None]:
def find_min_greater_than_zero(lst):
    positive_numbers = np.array(lst)[np.array(lst) > 0]

    if len(positive_numbers) == 0:
        return 0

    min_positive = np.min(positive_numbers)
    return min_positive


# Get extra features
num_visits = data["num_visits"].values
min_age = [
    find_min_greater_than_zero(patient_age_tokens)
    for patient_age_tokens in data["age_tokens"]
]
max_age = [np.max(patient_age_tokens) for patient_age_tokens in data["age_tokens"]]

# Add extra features to the frequency dataset
patient_freq_matrix = hstack(
    [patient_freq_matrix, csr_matrix([num_visits, min_age, max_age]).T], format="csr"
)
patient_freq_matrix = patient_freq_matrix[:, 1:]  # Drop id feature
patient_freq_matrix

In [None]:
# Get intuition about the frequency of different features in the dataset
report_threshold = 50
features_above_threshold = np.sum(
    (patient_freq_matrix.getnnz(axis=0) > report_threshold).astype(int)
)
print(
    f"How many features have been reported for at least {report_threshold} patients?\n{features_above_threshold} Features"
)

# Plot the histogram of feature frequency
# plt.hist(patient_freq_matrix.getnnz(axis=0), bins=range(num_patients+1), edgecolor='black')
# plt.xlabel('Number of Nonzero Rows')
# plt.ylabel('Number of Columns')
# plt.title('Histogram of Nonzero Rows per Column')
# plt.show()

In [None]:
# Pick features to train the model on
NUM_FEATURES = 10000
features_sorted_by_freq = np.argsort(-patient_freq_matrix.getnnz(axis=0))
selected_features = features_sorted_by_freq[: NUM_FEATURES + 1]
selected_features

In [None]:
# Define custom labels, here death in 12 M
data["label"] = (
    (data["death_after_end"] == 0) & (data["death_after_end"] <= 180)
).astype(int)

print(f"Total positive labels: {sum(data['label'])} out of {len(data)}")

In [None]:
# Prepare data for model training
X = patient_freq_matrix[:, selected_features]
Y = data["label"].values

# Optional, Scale features. Didn't improve performance
# scaler = MaxAbsScaler()
# X = scaler.fit_transform(X)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, Y, test_size=0.2, stratify=Y, random_state=1
)

In [None]:
# Single XGBoost Classifier
xgb_model = xgb.XGBClassifier(objective="binary:logistic", random_state=23)
xgb_model.fit(X_train, y_train)

In [None]:
### ASSESS MODEL PERFORMANCE ###

# Predict labels for train, test, and all data
y_train_pred = xgb_model.predict(X_train)
y_test_pred = xgb_model.predict(X_test)
all_data_pred = xgb_model.predict(X)

# Balanced Accuracy
y_train_accuracy = balanced_accuracy_score(y_train, y_train_pred)
y_test_accuracy = balanced_accuracy_score(y_test, y_test_pred)
all_data_accuracy = balanced_accuracy_score(Y, all_data_pred)

# F1 Score
y_train_f1 = f1_score(y_train, y_train_pred)
y_test_f1 = f1_score(y_test, y_test_pred)
all_data_f1 = f1_score(Y, all_data_pred)

# Precision
y_train_precision = precision_score(y_train, y_train_pred)
y_test_precision = precision_score(y_test, y_test_pred)
all_data_precision = precision_score(Y, all_data_pred)

# Recall
y_train_recall = recall_score(y_train, y_train_pred)
y_test_recall = recall_score(y_test, y_test_pred)
all_data_recall = recall_score(Y, all_data_pred)

# AUROC
y_train_auroc = roc_auc_score(y_train, y_train_pred)
y_test_auroc = roc_auc_score(y_test, y_test_pred)
all_data_auroc = roc_auc_score(Y, all_data_pred)

# AUC-PR (Area Under the Precision-Recall Curve)
y_train_p, y_train_r, _ = precision_recall_curve(y_train, y_train_pred)
y_test_p, y_test_r, _ = precision_recall_curve(y_test, y_test_pred)
all_data_p, all_data_r, _ = precision_recall_curve(Y, all_data_pred)

y_train_auc_pr = auc(y_train_r, y_train_p)
y_test_auc_pr = auc(y_test_r, y_test_p)
all_data_auc_pr = auc(all_data_r, all_data_p)

# Average Precision Score (APS)
y_train_aps = average_precision_score(y_train, y_train_pred)
y_test_aps = average_precision_score(y_test, y_test_pred)
all_data_aps = average_precision_score(Y, all_data_pred)

# Print Metrics
print(
    f"Balanced Accuracy\nTrain: {y_train_accuracy:.5f}  |  Test: {y_test_accuracy:.5f}  |  All Data: {all_data_accuracy:.5f}\n"
)
print(
    f"F1 Score\nTrain: {y_train_f1:.5f}  |  Test: {y_test_f1:.5f}  |  All Data: {all_data_f1:.5f}\n"
)
print(
    f"Precision\nTrain: {y_train_precision:.5f}  |  Test: {y_test_precision:.5f}  |  All Data: {all_data_precision:.5f}\n"
)
print(
    f"Recall\nTrain: {y_train_recall:.5f}  |  Test: {y_test_recall:.5f}  |  All Data: {all_data_recall:.5f}\n"
)
print(
    f"AUROC\nTrain: {y_train_auroc:.5f}  |  Test: {y_test_auroc:.5f}  |  All Data: {all_data_auroc:.5f}\n"
)
print(
    f"AUC-PR\nTrain: {y_train_auc_pr:.5f}  |  Test: {y_test_auc_pr:.5f}  |  All Data: {all_data_auc_pr:.5f}\n"
)
print(
    f"Average Precision Score\nTrain: {y_train_aps:.5f}  |  Test: {y_test_aps:.5f}  |  All Data: {all_data_aps:.5f}\n"
)

# Plot ROC Curve
fpr_train, tpr_train, _ = roc_curve(y_train, y_train_pred)
fpr_test, tpr_test, _ = roc_curve(y_test, y_test_pred)
fpr_all_data, tpr_all_data, _ = roc_curve(Y, all_data_pred)

# Plot Information
plt.figure(figsize=(10, 7))
plt.plot(fpr_train, tpr_train, label=f"Train AUROC={y_train_auroc:.2f}")
plt.plot(fpr_test, tpr_test, label=f"Test AUROC={y_test_auroc:.2f}")
plt.plot(fpr_all_data, tpr_all_data, label=f"All Data AUROC={all_data_auroc:.2f}")
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend()
plt.show()

In [None]:
### Assess which features are the most important

# Get feature importances
feature_importances = xgb_model.feature_importances_

# Create a list of tuples (feature, importance) and sort it by importance in descending order
sorted_importances = sorted(zip(selected_features, feature_importances), key=lambda x: x[1], reverse=True)

# Display the top 10 most important features
top_features = sorted_importances[:10]
for feature, importance in top_features:
    print(f"{feature}: {importance}")

In [None]:
### SCRIPT FOR K-FOLD VALIDATION ###
N_FOLDS = 10

# Initialize StratifiedKFold
stratified_kfold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

# Initialize lists to store performance metrics for each fold
accuracy_scores = []
f1_scores = []
precisions = []
recalls = []
aurocs = []
auc_prs = []
average_precision_scores = []

# Perform k-fold cross-validation
for train_index, test_index in tqdm(
    stratified_kfold.split(X, Y), desc=f"{N_FOLDS}-Fold Validation", unit=" Model(s)"
):
    # Get the relevant train and test data
    X_train_fold, X_test_fold = X[train_index], X[test_index]
    y_train_fold, y_test_fold = Y[train_index], Y[test_index]

    # Create a new XGBoost model for each fold
    xgb_model = xgb.XGBClassifier(objective="binary:logistic", random_state=23)

    # Train the model on the training fold
    xgb_model.fit(X_train_fold, y_train_fold)

    # Predict on the test fold
    y_pred_fold = xgb_model.predict(X_test_fold)

    # Calculate performance metrics
    accuracy_fold = balanced_accuracy_score(y_test_fold, y_pred_fold)
    f1_fold = f1_score(y_test_fold, y_pred_fold)
    precision_fold = precision_score(y_test_fold, y_pred_fold)
    recall_fold = recall_score(y_test_fold, y_pred_fold)
    auroc_fold = roc_auc_score(y_test_fold, y_pred_fold)
    p_fold, r_fold, _ = precision_recall_curve(y_test_fold, y_pred_fold)
    auc_pr_fold = auc(r_fold, p_fold)
    average_precision_score_fold = average_precision_score(y_test_fold, y_pred_fold)

    # Append metrics to lists
    accuracy_scores.append(accuracy_fold)
    f1_scores.append(f1_fold)
    precisions.append(precision_fold)
    recalls.append(recall_fold)
    aurocs.append(auroc_fold)
    auc_prs.append(auc_pr_fold)
    average_precision_scores.append(average_precision_score_fold)

# Print average metrics across all folds
print(f"Average Balanced Accuracy: {sum(accuracy_scores) / N_FOLDS:.5f}")
print(f"Average F1 Score: {sum(f1_scores) / N_FOLDS:.5f}")
print(f"Average Precision: {sum(precisions) / N_FOLDS:.5f}")
print(f"Average Recall: {sum(recalls) / N_FOLDS:.5f}")
print(f"Average AUROC: {sum(aurocs) / N_FOLDS:.5f}")
print(f"Average AUC-PR: {sum(auc_prs) / N_FOLDS:.5f}")
print(f"Average Precision Score: {sum(average_precision_scores) / N_FOLDS:.5f}")