**1. Dependencies and Setup**

In [1]:
!pip install catboost tslearn

Collecting catboost
  Downloading catboost-1.2.7-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting tslearn
  Downloading tslearn-0.6.3-py3-none-any.whl.metadata (14 kB)
Downloading catboost-1.2.7-cp310-cp310-manylinux2014_x86_64.whl (98.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.7/98.7 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tslearn-0.6.3-py3-none-any.whl (374 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.4/374.4 kB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tslearn, catboost
Successfully installed catboost-1.2.7 tslearn-0.6.3


**2. Importing Libraries**

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from scipy import stats
from sklearn.preprocessing import StandardScaler
import matplotlib.dates as mdates
from tslearn.preprocessing import TimeSeriesScalerMeanVariance

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

**3. Data Loading**

In [3]:
df = pd.read_csv("final_data_clipped.csv")

print(df.head())

  df = pd.read_csv("final_data_clipped.csv")


   Unnamed: 0  patient_id    pr_display spo2_display resp_display  \
0           0        7001  71 beats/min          96%          17%   
1           1        7001  71 beats/min          96%          17%   
2           2        7001  71 beats/min          96%          17%   
3           3        7001  71 beats/min          96%          17%   
4           4        7001  71 beats/min          96%          17%   

   pulse_rate_obscount  pulse_rate_avg  pulse_rate_min  pulse_rate_max  \
0                 10.0           70.01           66.63           71.52   
1                 10.0           70.01           66.63           71.52   
2                 10.0           70.01           66.63           71.52   
3                 10.0           70.01           66.63           71.52   
4                 10.0           70.01           66.63           71.52   

   pulse_rate_iqr  ...  QC Deviation from median.1  \
0            2.88  ...                    0.047985   
1            2.88  ...          

**4. Exploratory Data Analysis and Cleaning**

In this section, relevant features are identified and subsets are created based on the type of agents (JNJ or BMS).

In [4]:
complete_set = [
    'PT_ID','CRS on date (0 No, 1 Yes)','Agent (JNJ/BMS/Caribou)','datetime', 'spo2_avg', 'pulse_rate_avg','respiratory_rate_avg',
    'covered_skin_temperature_avg','covered_axil_temperature_avg','Highest Ferritin','Highest CRP','IL8','TNFRSF9','TIE2','MCP-3',
    'CD40-L','IL-1 alpha','CD244','EGF','ANGPT1','IL7','PGF','IL6','ADGRG1','MCP-1','CRTAM','CXCL11','MCP-4','TRAIL','FGF2','CXCL9',
    'CD8A','CAIX','MUC-16','ADA','CD4','NOS3','IL2','Gal-9','VEGFR-2','CD40','IL18','GZMH','KIR3DL1','LAP TGF-beta-1','CXCL1','TNFSF14',
    'IL33','TWEAK','PDGF subunit B','PDCD1','FASLG','CD28','CCL19','MCP-2','CCL4','IL15','Gal-1','PD-L1','CD27','CXCL5','IL5','HGF','GZMA',
    'HO-1','CX3CL1','CXCL10','CD70','IL10','TNFRSF12A','CCL23','CD5','CCL3','MMP7','ARG1','NCR1','DCN','TNFRSF21','TNFRSF4','MIC-A/B',
    'CCL17','ANGPT2','PTN','CXCL12','IFN-gamma','LAMP3','CASP-8','ICOSLG','MMP12','CXCL13','PD-L2','VEGFA','IL4','LAG3','IL12RB1','IL13',
    'CCL20','TNF','KLRD1','GZMB','CD83','IL12','CSF-1'
]

Taking different feature sets for JNJ and BMS

In [5]:
columns_test_JNJ = ['Agent (JNJ/BMS/Caribou)', 'CAIX', 'CASP-8', 'CCL23', 'CD40-L', 'CD70',
'CRS on date (0 No, 1 Yes)', 'CXCL10', 'CXCL11', 'CXCL13', 'FASLG',
'FGF2', 'GZMB', 'GZMH', 'Highest CRP', 'Highest Ferritin', 'IFN-gamma',
'IL10', 'IL13', 'IL15', 'IL6', 'IL8', 'MCP-2', 'MMP12', 'PT_ID',
'TIE2', 'TNFRSF9', 'TNFSF14', 'covered_skin_temperature_avg', 'datetime',
'pulse_rate_avg', 'respiratory_rate_avg', 'spo2_avg']
columns_test_BMS = ['FASLG', 'MCP-1', 'CD8A', 'CD70', 'CCL19', 'Highest CRP', 'KLRD1', 'TNFRSF9', 'CXCL12', 'ADGRG1', 'IL2', 'CXCL11', 'GZMH', 'TRAIL', 'IL5', 'TNFSF14', 'HO-1', 'CXCL1', 'CXCL5', 'CD244',
 'PT_ID', 'CRS on date (0 No, 1 Yes)', 'Agent (JNJ/BMS/Caribou)', 'datetime', 'spo2_avg', 'pulse_rate_avg', 'respiratory_rate_avg', 'covered_skin_temperature_avg', 'IL8', 'IL6', 'CXCL10',
 'IFN-gamma', 'CCL23', 'CASP-8', 'CXCL13']
df_subset_JNJ = df[columns_test_JNJ]
df_subset_BMS = df[columns_test_BMS]
df_subset_JNJ.head(5)
df_subset_BMS.head(5)

Unnamed: 0,FASLG,MCP-1,CD8A,CD70,CCL19,Highest CRP,KLRD1,TNFRSF9,CXCL12,ADGRG1,...,pulse_rate_avg,respiratory_rate_avg,covered_skin_temperature_avg,IL8,IL6,CXCL10,IFN-gamma,CCL23,CASP-8,CXCL13
0,5.80632,11.55053,8.24133,3.73004,10.73964,22.6,4.53595,5.67664,2.73464,2.15725,...,70.01,24.35,27.65,5.57085,4.75454,9.13509,5.93862,10.8933,4.38744,6.72158
1,5.80632,11.55053,8.24133,3.73004,10.73964,22.6,4.53595,5.67664,2.73464,2.15725,...,70.01,24.35,27.65,5.57085,4.75454,9.13509,5.93862,10.8933,4.38744,6.72158
2,5.810654,11.572441,8.239465,3.734411,10.745863,22.574747,4.536448,5.677832,2.735678,2.15513,...,70.01,19.86,27.735,5.575702,4.773422,9.158341,6.003697,10.897528,4.40339,6.724164
3,5.814989,11.594352,8.237599,3.738781,10.752086,22.549495,4.536946,5.679023,2.736717,2.15301,...,70.01,15.37,27.82,5.580553,4.792303,9.181591,6.068774,10.901756,4.41934,6.726747
4,5.819323,11.616262,8.235734,3.743152,10.75831,22.524242,4.537444,5.680215,2.737755,2.15089,...,70.01,11.77,27.78,5.585405,4.811185,9.204842,6.133852,10.905983,4.435289,6.729331


In [6]:
df_JNJ = df_subset_JNJ[(df_subset_JNJ['Agent (JNJ/BMS/Caribou)']=='JNJ') | (df_subset_JNJ['Agent (JNJ/BMS/Caribou)']=='JNJ OOS')]
df_BMS = df_subset_BMS[(df_subset_BMS['Agent (JNJ/BMS/Caribou)']=='BMS')]

**5. Data Individualization and Baseline Adjustment**

Columns for JNJ are chosen here.

In [7]:
data = df_JNJ.copy()

baseline = data.groupby('PT_ID').first().reset_index()

columns_to_individualize = [col for col in columns_test_JNJ if col not in ['PT_ID', 'CRS on date (0 No, 1 Yes)', 'Agent (JNJ/BMS/Caribou)','datetime']]

for col in columns_to_individualize:
    data[col] = pd.to_numeric(data[col], errors='coerce')
    data[col] = data[col] - data.groupby('PT_ID')[col].transform('first')

data.head()

Unnamed: 0,Agent (JNJ/BMS/Caribou),CAIX,CASP-8,CCL23,CD40-L,CD70,"CRS on date (0 No, 1 Yes)",CXCL10,CXCL11,CXCL13,...,MMP12,PT_ID,TIE2,TNFRSF9,TNFSF14,covered_skin_temperature_avg,datetime,pulse_rate_avg,respiratory_rate_avg,spo2_avg
42639,JNJ,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,7004,0.0,0.0,0.0,0.0,2022-12-05 15:44:00,0.0,0.0,0.0
42640,JNJ,1.8e-05,-0.005952,-0.000287,0.004429,-0.001577,0,-0.001205,0.004758,-5.1e-05,...,-0.000666,7004,-0.000192,-0.001087,-0.000739,0.0,2022-12-05 15:45:00,0.0,0.0,0.0
42641,JNJ,3.6e-05,-0.011904,-0.000574,0.008858,-0.003154,0,-0.00241,0.009515,-0.000102,...,-0.001331,7004,-0.000384,-0.002174,-0.001478,0.01,2022-12-05 15:46:00,0.0,7.21,0.0
42642,JNJ,5.5e-05,-0.017857,-0.000861,0.013288,-0.004731,0,-0.003616,0.014273,-0.000153,...,-0.001997,7004,-0.000576,-0.003261,-0.002216,0.08,2022-12-05 15:47:00,0.0,4.605,0.0
42643,JNJ,7.3e-05,-0.023809,-0.001148,0.017717,-0.006308,0,-0.004821,0.01903,-0.000203,...,-0.002663,7004,-0.000767,-0.004347,-0.002955,0.02,2022-12-05 15:48:00,0.0,2.0,0.0


**6. Feature Engineering: Rolling and Lagged Features**

Additional features are generated to capture short-term trends and variability in the measurements. Past values and rolling statistics over a 6-hour window are computed to provide temporal context for the prediction model.

In [8]:
# Define the window size for rolling statistics (e.g., past 6 hours)
window_size = 6
lag_size = 6
time_interval = 30

# Sort data by patient ID and datetime
data = data.sort_values(by=['PT_ID', 'datetime'])

# Function to create lagged features and rolling statistics
def add_past_features(data, columns,lag_size):
    for col in columns:
        # Add lagged values
        for lag_base in range(1, lag_size + 1):
            lag = lag_base * time_interval
            data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)

        rolling_size = lag_size * time_interval
        # Add rolling statistics
        data[f'{col}_rolling_mean_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).mean().reset_index(level=0, drop=True)
        data[f'{col}_rolling_std_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).std().reset_index(level=0, drop=True)
        data[f'{col}_rolling_min_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).min().reset_index(level=0, drop=True)
        data[f'{col}_rolling_max_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).max().reset_index(level=0, drop=True)


    return data

# Add past features for selected columns
columns_to_process = [col for col in columns_test_JNJ if col not in ['PT_ID','CRS on date (0 No, 1 Yes)','Agent (JNJ/BMS/Caribou)','datetime']]
data = add_past_features(data, columns_to_process, lag_size)

  data[f'{col}_rolling_min_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).min().reset_index(level=0, drop=True)
  data[f'{col}_rolling_max_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).max().reset_index(level=0, drop=True)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_lag_{lag}'] = data.groupby('PT_ID')[col].shift(lag)
  data[f'{col}_rolling_mean_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).mean().reset_index(level=0, drop=True)
  data[f'{col}_rolling_std_{rolling_size}'] = data.groupby('PT_ID')[col].rolling(rolling_size, min_periods=1).std().reset_index(level=0, drop=True)
  da

**7. Creating column: CRS in 6 Hours**

In this section, a binary target variable CRS_in_6_hours is created. It indicates whether a patient will experience CRS within the next 6 hours from any given measurement time.

In [9]:
from datetime import timedelta

def assign_crs_in_6_hours(data):
    """
    Assign CRS_in_6_hours for each row based on whether `datetime + 6 hours` falls within a CRS occurrence time frame.

    Parameters:
        data (DataFrame): Input DataFrame with 'PT_ID', 'datetime', and 'CRS on date (0 No, 1 Yes)' columns.

    Returns:
        DataFrame: Updated DataFrame with a new column 'CRS_in_6_hours'.
    """
    # Ensure 'datetime' is a datetime object
    data['datetime'] = pd.to_datetime(data['datetime'])
    data = data.sort_values(by=['PT_ID', 'datetime'])

    # Initialize a new column
    data['CRS_in_6_hours'] = 0

    # Process each patient group separately
    for pt_id, group in data.groupby('PT_ID'):
        # Sort by datetime for the current patient
        group = group.sort_values('datetime')

        # Identify CRS occurrence start and end timeframes
        crs_start = group.index[(group['CRS on date (0 No, 1 Yes)'].shift(1) == 0) &
                                (group['CRS on date (0 No, 1 Yes)'] == 1)].tolist()
        crs_end = group.index[(group['CRS on date (0 No, 1 Yes)'].shift(1) == 1) &
                              (group['CRS on date (0 No, 1 Yes)'] == 0)].tolist()

        # If a CRS event starts but does not end, assume it continues until the last datetime
        if len(crs_start) > len(crs_end):
            crs_end.append(group.index[-1])

        # Assign CRS_in_6_hours for each row
        for start_idx, end_idx in zip(crs_start, crs_end):
            crs_start_time = group.loc[start_idx, 'datetime']
            crs_end_time = group.loc[end_idx, 'datetime']

            # Any datetime + 6 hours within the CRS occurrence timeframe is set to 1
            within_crs_timeframe = (group['datetime'] + timedelta(hours=6) >= crs_start_time) & \
                                   (group['datetime'] + timedelta(hours=6) <= crs_end_time)
            data.loc[group[within_crs_timeframe].index, 'CRS_in_6_hours'] = 1

    return data

data = assign_crs_in_6_hours(data)


In [10]:
data[data['CRS_in_6_hours']==1]

Unnamed: 0,Agent (JNJ/BMS/Caribou),CAIX,CASP-8,CCL23,CD40-L,CD70,"CRS on date (0 No, 1 Yes)",CXCL10,CXCL11,CXCL13,...,spo2_avg_lag_60,spo2_avg_lag_90,spo2_avg_lag_120,spo2_avg_lag_150,spo2_avg_lag_180,spo2_avg_rolling_mean_180,spo2_avg_rolling_std_180,spo2_avg_rolling_min_180,spo2_avg_rolling_max_180,CRS_in_6_hours
63665,JNJ OOS,-1.396280,0.303270,0.435560,1.484800,0.071950,0,1.624240,1.996430,0.801890,...,-1.43,-2.82,-3.510,-3.630000,-3.270000,-2.650250,0.913485,-4.87,-0.68,1
63666,JNJ OOS,-1.396280,0.303270,0.435560,1.484800,0.071950,0,1.624240,1.996430,0.801890,...,-1.53,-2.69,-3.580,-3.550000,-3.320000,-2.641917,0.914183,-4.87,-0.68,1
63667,JNJ OOS,-1.396280,0.303270,0.435560,1.484800,0.071950,0,1.624240,1.996430,0.801890,...,-2.02,-3.42,-3.820,-3.620000,-3.170000,-2.634417,0.915364,-4.87,-0.68,1
63668,JNJ OOS,-1.396280,0.303270,0.435560,1.484800,0.071950,0,1.624240,1.996430,0.801890,...,-1.89,-3.59,-3.550,-3.630000,-3.070000,-2.633028,0.914888,-4.87,-0.68,1
63669,JNJ OOS,-1.396280,0.303270,0.435560,1.484800,0.071950,0,1.624240,1.996430,0.801890,...,-1.58,-3.49,-3.600,-3.630000,-2.940000,-2.626806,0.916596,-4.87,-0.68,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
313952,JNJ,-0.026034,0.619929,0.278777,-0.880114,0.880528,1,2.763175,1.728585,0.362329,...,5.03,1.17,7.225,4.342941,4.156250,5.047976,3.472780,-12.00,11.00,1
313953,JNJ,-0.025722,0.620340,0.278712,-0.880123,0.880709,1,2.763756,1.729170,0.362344,...,6.14,1.68,7.330,4.505000,3.984375,5.074729,3.483075,-12.00,11.00,1
313954,JNJ,-0.025411,0.620752,0.278647,-0.880132,0.880889,1,2.764337,1.729754,0.362359,...,7.25,4.87,7.435,4.667059,3.812500,5.100104,3.490457,-12.00,11.00,1
313955,JNJ,-0.025099,0.621163,0.278581,-0.880141,0.881070,1,2.764919,1.730338,0.362374,...,7.19,6.62,7.540,4.829118,3.640625,5.111712,3.489051,-12.00,11.00,1


**8. Model Training and Evaluation**

In this section, the dataset is split by patients into training and test sets using K-fold cross-validation. Several models (LightGBM, CatBoost, XGBoost) are trained to predict CRS_in_6_hours. Random oversampling is used to handle class imbalance. The performance is evaluated using accuracy, AUC-ROC, and classification reports.

**a. LightGBM Model with Oversampling**


A LightGBM classifier is trained and evaluated with cross-validation.

In [11]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score, confusion_matrix
from imblearn.over_sampling import RandomOverSampler
from lightgbm import LGBMClassifier

if 'datetime' in data.columns:
    data = data.drop(columns=['datetime'])
if 'cutoff_date' in data.columns:
    data = data.drop(columns=['cutoff_date'])

# Define columns to drop (identifiers, target)
drop_cols = ['PT_ID', 'CRS_in_6_hours', 'CRS on date (0 No, 1 Yes)', 'Agent (JNJ/BMS/Caribou)']
feature_cols = [col for col in data.columns if col not in drop_cols]

unique_patients = data['PT_ID'].unique()
kf = KFold(n_splits=5, shuffle=False)

cv_accuracies = []
cv_auc_scores = []
aggregate_conf_matrix = np.array([[0, 0],
                                  [0, 0]])

for fold, (train_idx, test_idx) in enumerate(kf.split(unique_patients)):
    print(f"\nFold {fold + 1}/5")

    train_patients = unique_patients[train_idx]
    test_patients = unique_patients[test_idx]

    train_data = data[data['PT_ID'].isin(train_patients)]
    test_data = data[data['PT_ID'].isin(test_patients)]

    X_train = train_data[feature_cols]
    y_train = train_data['CRS_in_6_hours']

    X_test = test_data[feature_cols]
    y_test = test_data['CRS_in_6_hours']

    # Handle class imbalance with Random Oversampling
    oversampler = RandomOverSampler(random_state=42)
    X_train_resampled, y_train_resampled = oversampler.fit_resample(X_train, y_train)

    # Print class distribution after oversampling
    print("Class distribution after oversampling:")
    print(pd.Series(y_train_resampled).value_counts())


    lgbm_model = LGBMClassifier(
        objective='binary',
        max_depth=6,
        learning_rate=0.1,
        n_estimators=100,
        random_state=42
    )

    # Train the model
    lgbm_model.fit(X_train_resampled, y_train_resampled)

    # Predict on test set
    y_pred = lgbm_model.predict(X_test)
    y_prob = lgbm_model.predict_proba(X_test)[:, 1]

    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    auc_score = roc_auc_score(y_test, y_prob)
    conf_matrix = confusion_matrix(y_test, y_pred)

    print("Accuracy:", accuracy)
    print("AUC-ROC Score:", auc_score)
    print("Classification Report:\n", classification_report(y_test, y_pred))

    cv_accuracies.append(accuracy)
    cv_auc_scores.append(auc_score)
    aggregate_conf_matrix += conf_matrix

# Aggregate results
print("\nCross-Validation Results:")
print(f"Average Accuracy: {np.mean(cv_accuracies)}")
print(f"Average AUC-ROC: {np.mean(cv_auc_scores)}")
print("Aggregated Confusion Matrix:\n", aggregate_conf_matrix)

tn, fp, fn, tp = aggregate_conf_matrix.ravel()
class0_accuracy = tn / (tn + fp) if (tn + fp) > 0 else 0
class1_accuracy = tp / (fn + tp) if (fn + tp) > 0 else 0

print(f"Accuracy for class 0: {class0_accuracy}")
print(f"Accuracy for class 1: {class1_accuracy}")


Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.




Fold 1/5
Class distribution after oversampling:
CRS_in_6_hours
0    108387
1    108387
Name: count, dtype: int64
[LightGBM] [Info] Number of positive: 108387, number of negative: 108387
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.474370 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 79654
[LightGBM] [Info] Number of data points in the train set: 216774, number of used features: 319
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Accuracy: 0.8955329709288584
AUC-ROC Score: 0.8911451905504334
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.92      0.94     40638
           1       0.16      0.37      0.22      1672

    accuracy                           0.90     42310
   macro avg       0.56      0.64      0.58     42310
weighted avg       0.94      0.90      0.92     42310


Fold 2/5
Class distribution

**b. CatBoost Model**


CatBoost classifier is also tested. Class weights are used to handle class imbalance, and a similar cross-validation approach is followed.

In [12]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score, recall_score
from catboost import CatBoostClassifier
import pandas as pd
import numpy as np


catboost_model = CatBoostClassifier(
    loss_function='Logloss',
    max_depth=6,
    learning_rate=0.1,
    iterations=100,
    random_seed=42,
    verbose=False,
    class_weights=[1, 2.0]
)


kf = StratifiedKFold(n_splits=5, shuffle=False)

fold_metrics = []
fold = 1

for train_idx, test_idx in kf.split(X_train, y_train):
    print(f"Fold {fold}/5")

    # Split data into train and test sets for this fold
    X_fold_train, X_fold_test = X_train.iloc[train_idx], X_train.iloc[test_idx]
    y_fold_train, y_fold_test = y_train.iloc[train_idx], y_train.iloc[test_idx]

    # Class distributions
    print(f"Training Class Distribution:\n{y_fold_train.value_counts()}")
    print(f"Validation Class Distribution:\n{y_fold_test.value_counts()}")

    # Train the CatBoost model directly with class weights
    catboost_model.fit(X_fold_train, y_fold_train)

    # Predict probabilities
    y_prob = catboost_model.predict_proba(X_fold_test)[:, 1]

    thresholds = np.linspace(0, 1, 101)
    best_bal_acc = 0.0
    best_threshold = 0.5

    for th in thresholds:
        y_pred_th = (y_prob >= th).astype(int)
        rec_class0 = recall_score(y_fold_test, y_pred_th, pos_label=0)
        rec_class1 = recall_score(y_fold_test, y_pred_th, pos_label=1)
        bal_acc = 0.5 * (rec_class0 + rec_class1)
        if bal_acc > best_bal_acc:
            best_bal_acc = bal_acc
            best_threshold = th

    # Use best threshold
    y_pred = (y_prob >= best_threshold).astype(int)

    acc = accuracy_score(y_fold_test, y_pred)
    auc = roc_auc_score(y_fold_test, y_prob)
    report = classification_report(y_fold_test, y_pred, digits=2, output_dict=True)

    # Extract metrics
    f1_class0 = report["0"]["f1-score"]
    f1_class1 = report["1"]["f1-score"]
    recall_class0 = report["0"]["recall"]
    recall_class1 = report["1"]["recall"]

    fold_metrics.append({
        "Fold": fold,
        "Best Threshold": best_threshold,
        "Accuracy": acc,
        "AUC-ROC": auc,
        "F1-Score (Class 0)": f1_class0,
        "F1-Score (Class 1)": f1_class1,
        "Recall (Class 0)": recall_class0,
        "Recall (Class 1)": recall_class1,
        "Balanced Accuracy": best_bal_acc
    })

    # Print fold results
    print(f"Best Threshold: {best_threshold:.3f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"AUC-ROC: {auc:.4f}")
    print(f"Balanced Accuracy: {best_bal_acc:.4f}")
    print("Classification Report:")
    print(classification_report(y_fold_test, y_pred, digits=2))
    print("-" * 50)

    fold += 1

fold_metrics_df = pd.DataFrame(fold_metrics)
print("\nCross-Validation Results:")
print(fold_metrics_df)

# Compute averages of metrics
avg_threshold = fold_metrics_df["Best Threshold"].mean()
avg_f1_class1 = fold_metrics_df["F1-Score (Class 1)"].mean()
avg_accuracy = fold_metrics_df["Accuracy"].mean()
avg_auc = fold_metrics_df["AUC-ROC"].mean()
avg_recall_class0 = fold_metrics_df["Recall (Class 0)"].mean()
avg_recall_class1 = fold_metrics_df["Recall (Class 1)"].mean()
avg_bal_acc = fold_metrics_df["Balanced Accuracy"].mean()

print("\nCross-Validation Results:")
print(f"Average Accuracy: {avg_accuracy}")
print(f"Average AUC-ROC: {avg_auc}")
print(f"Average Accuracy (Class 0): {avg_recall_class0}")  # Using recall as proxy
print(f"Average Accuracy (Class 1): {avg_recall_class1}")  # Using recall as proxy




Fold 1/5
Training Class Distribution:
CRS_in_6_hours
0    98070
1    15982
Name: count, dtype: int64
Validation Class Distribution:
CRS_in_6_hours
0    24518
1     3996
Name: count, dtype: int64
Best Threshold: 0.010
Accuracy: 0.7262
AUC-ROC: 0.6591
Balanced Accuracy: 0.5332
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.80      0.83     24518
           1       0.18      0.27      0.21      3996

    accuracy                           0.73     28514
   macro avg       0.52      0.53      0.52     28514
weighted avg       0.77      0.73      0.75     28514

--------------------------------------------------
Fold 2/5
Training Class Distribution:
CRS_in_6_hours
0    98070
1    15983
Name: count, dtype: int64
Validation Class Distribution:
CRS_in_6_hours
0    24518
1     3995
Name: count, dtype: int64
Best Threshold: 0.010
Accuracy: 0.9298
AUC-ROC: 0.9799
Balanced Accuracy: 0.9160
Classification Report:
              precision 

**c. XGBoost Model Training**


In this section, an XGBoost model is trained and evaluated using the same methodology. Random oversampling is applied and performance metrics are calculated to compare against other models.

In [13]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score, confusion_matrix
from imblearn.over_sampling import RandomOverSampler
import xgboost as xgb

if 'cutoff_date' in data.columns:
    data = data.drop(columns=['cutoff_date'])

if 'datetime' in data.columns:
    data = data.drop(columns=['datetime'])

drop_cols = ['PT_ID', 'CRS on date (0 No, 1 Yes)', 'Agent (JNJ/BMS/Caribou)', 'CRS_in_6_hours']
feature_cols = [col for col in data.columns if col not in drop_cols]

unique_patients = data['PT_ID'].unique()
kf = KFold(n_splits=5, shuffle=False)

cv_accuracies = []
cv_auc_scores = []
cv_classification_reports = []
aggregate_conf_matrix = np.array([[0, 0],
                                  [0, 0]])

for fold, (train_idx, test_idx) in enumerate(kf.split(unique_patients)):
    print(f"\nFold {fold + 1}/5")

    train_patients = unique_patients[train_idx]
    test_patients = unique_patients[test_idx]

    train_data = data[data['PT_ID'].isin(train_patients)]
    test_data = data[data['PT_ID'].isin(test_patients)]

    X_train = train_data[feature_cols]
    y_train = train_data['CRS_in_6_hours']

    X_test = test_data[feature_cols]
    y_test = test_data['CRS_in_6_hours']


    # Random Oversampling
    oversampler = RandomOverSampler(random_state=42)
    X_train_resampled, y_train_resampled = oversampler.fit_resample(X_train, y_train)

    dtrain = xgb.DMatrix(X_train_resampled, label=y_train_resampled)
    dtest = xgb.DMatrix(X_test, label=y_test)

    params = {
        'objective': 'binary:logistic',
        'eval_metric': 'logloss',
        'max_depth': 6,
        'learning_rate': 0.1
    }

    model = xgb.train(params, dtrain, num_boost_round=100)

    y_prob = model.predict(dtest)
    threshold = 0.01
    y_pred = (y_prob > threshold).astype(int)

    accuracy = accuracy_score(y_test, y_pred)
    auc_score = roc_auc_score(y_test, y_prob)
    conf_matrix = confusion_matrix(y_test, y_pred)
    class_report = classification_report(y_test, y_pred, output_dict=True)

    print("Accuracy:", accuracy)
    print("AUC-ROC Score:", auc_score)
    print("Classification Report:\n", classification_report(y_test, y_pred))

    cv_accuracies.append(accuracy)
    cv_auc_scores.append(auc_score)
    cv_classification_reports.append(class_report)
    aggregate_conf_matrix += conf_matrix

print("\nCross-Validation Results:")
print(f"Average Accuracy: {np.mean(cv_accuracies)}")
print(f"Average AUC-ROC: {np.mean(cv_auc_scores)}")
print("Aggregated Confusion Matrix:\n", aggregate_conf_matrix)

tn, fp, fn, tp = aggregate_conf_matrix.ravel()
class0_accuracy = tn / (tn + fp) if (tn + fp) > 0 else 0
class1_accuracy = tp / (fn + tp) if (fn + tp) > 0 else 0

print(f"Accuracy for class 0: {class0_accuracy}")
print(f"Accuracy for class 1: {class1_accuracy}")




Fold 1/5
Accuracy: 0.8677144883006381
AUC-ROC Score: 0.9142930500738108
Classification Report:
               precision    recall  f1-score   support

           0       0.98      0.88      0.93     40638
           1       0.18      0.66      0.28      1672

    accuracy                           0.87     42310
   macro avg       0.58      0.77      0.61     42310
weighted avg       0.95      0.87      0.90     42310


Fold 2/5
Accuracy: 0.7714589288876521
AUC-ROC Score: 0.8804348313109012
Classification Report:
               precision    recall  f1-score   support

           0       1.00      0.76      0.87     38124
           1       0.13      1.00      0.24      1405

    accuracy                           0.77     39529
   macro avg       0.57      0.88      0.55     39529
weighted avg       0.97      0.77      0.84     39529


Fold 3/5
Accuracy: 0.7884872197361947
AUC-ROC Score: 0.7805919170598046
Classification Report:
               precision    recall  f1-score   support



Results Evaluation

**CatBoost** emerged as the most effective model for predicting CRS onset within 6 hours, outperforming LightGBM and XGBoost.

It achieved a higher AUC-ROC (around 0.90), indicating superior discrimination between patients who will and will not develop CRS. Unlike other models that struggled with the minority class (imminent CRS cases), CatBoost maintained more balanced performance across both classes.

This is crucial in a clinical setting, where correctly identifying even a small number of high-risk patients can significantly impact treatment decisions. By reliably detecting the minority class without sacrificing accuracy on the majority class, CatBoost stands out as the best choice for guiding early interventions and improving patient outcomes.