In [1]:
from models.encoder import QCLR_Classifier, COMET_Classifier
from tasks.fine_tuning import finetune_predict
from config_files.ASAN_Configs import Config as Configs
import pickle
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
import random
import copy
import sklearn
from utils import seed_everything
import matplotlib.pyplot as plt
from dtaidistance import dtw
import pingouin as pg
from scipy.stats import iqr as scipy_iqr
from scipy.stats import wilcoxon
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
configs = Configs()
RANDOM_SEED = configs.RANDOM_SEED
seed_everything(RANDOM_SEED)
import warnings
warnings.filterwarnings('ignore') # Suppress simple warnings

working_directory = configs.working_directory
dataset_save_path = working_directory
if not os.path.exists(working_directory):
    os.makedirs(working_directory)
logging_directory = configs.logging_directory
if not os.path.exists(logging_directory):
    os.makedirs(logging_directory)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set the GPU 0 to use
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The program will run on {device}!")

# --- Loop through each fold ---
N_FOLDS = 5
# SOFA Score: Initialize lists to store results across all folds (Model: QCLR, COMET, RANDOM, FeatureAVG)
all_folds_my_dtw = []
all_folds_base_dtw = []
all_folds_random_dtw = []
all_folds_featavg_dtw = []
# Total Bilirubin: Initialize lists to store results across all folds (Model: QCLR, COMET, RANDOM, FeatureAVG)
all_folds_my_tbil_dtw = []
all_folds_base_tbil_dtw = []
all_folds_random_tbil_dtw = []
all_folds_featavg_tbil_dtw = []
# Platelet Count: Initialize lists to store results across all folds (Model: QCLR, COMET, RANDOM, FeatureAVG)
all_folds_my_plt_dtw = []
all_folds_base_plt_dtw = []
all_folds_random_plt_dtw = []
all_folds_featavg_plt_dtw = []
# Lactic Acid: Initialize lists to store results across all folds (Model: QCLR, COMET, RANDOM, FeatureAVG)
all_folds_my_latic_dtw = []
all_folds_base_latic_dtw = []
all_folds_random_latic_dtw = []
all_folds_featavg_latic_dtw = []

for i in range(1, N_FOLDS + 1):
    print(f"\n--- Processing Fold {i} ---")

    # --- A. Load Data and Embeddings for the fold ---
    with open('data/asan_fold_' + str(i) + '.pkl', 'rb') as f:
        fold_data = pickle.load(f)

    print(f"Fold {i}:")
    print("Train set size:", len(fold_data['X_train']),
          (fold_data['y_train'] == 0).sum(), (fold_data['y_train'] == 1).sum())
    print("Train set size:", len(fold_data['X_val']),
          (fold_data['y_val'] == 0).sum(), (fold_data['y_val'] == 1).sum())
    print("Train set size:", len(fold_data['X_test']),
          (fold_data['y_test'] == 0).sum(), (fold_data['y_test'] == 1).sum())

    X_train = fold_data['X_train']
    X_val = fold_data['X_val']
    X_test = fold_data['X_test']

    y_train = fold_data['y_train']
    y_val = fold_data['y_val']
    y_test = fold_data['y_test']

    seq_train = fold_data['seq_train']
    seq_val = fold_data['seq_valid']
    seq_test = fold_data['seq_test']

    id_train = fold_data['id_train']
    id_val = fold_data['id_valid']
    id_test = fold_data['id_test']

    X_val = X_train.copy()
    seq_val = seq_train.copy()
    y_val = y_train.copy()
    id_val = id_train.copy()

    target_features = [59, 60, 70] # critical intervention (59-intubation, 60-ECMO, 70-surgery)
    mask = np.any(np.any(X_test[:, :, target_features] == 1, axis=1), axis=1)
    indices = np.where(mask)[0]
    print('Number of Selected Patient (critical intervention): ', indices.shape)
    X_test = X_test[indices]
    y_test = y_test[indices]
    seq_test = seq_test[indices]
    id_test = id_test[indices]

    X_test_ori = X_test.copy()
    '''
    clinical decision support is often critical during the initial stages of treatment planning, 
    we utilized only the early phase, specifically the first four time steps, of these high-risk test patient time series for evaluating similarity. 
    '''
    X_test[:, 4:, :] = 0
    seq_test[:] = 3

    model = QCLR_Classifier(input_dims=configs.input_dims,
                          output_dims=configs.output_dims,
                          depth=configs.depth,
                          p_output_dims=configs.num_classes, device=device,
                          flag_use_multi_gpu=configs.flag_use_multi_gpu)

    RANDOM_SEED = i
    model.load_state_dict(torch.load("test_run/models/QCLR_ASAN/seed" + str(RANDOM_SEED) + ".pt"))
    model.eval()
    _, emb_val_tensor = finetune_predict(model, X_val, y_val)
    _, emb_test_tensor = finetune_predict(model, X_test, y_test)
    print('eSOFA Score (unique): ', np.unique(X_val[:, :, 56]))
    del model 
    torch.cuda.empty_cache()

    model = COMET_Classifier(input_dims=configs.input_dims,
                         output_dims=configs.output_dims,
                         depth=configs.depth,
                         p_output_dims=configs.num_classes, device=device,
                         flag_use_multi_gpu=configs.flag_use_multi_gpu)
    model.load_state_dict(torch.load(
        "test_run/models/COMET_ASAN/seed" + str(RANDOM_SEED) + ".pt"))
    model.eval()
    _, emb_val_base_tensor = finetune_predict(model, X_val, y_val)
    _, emb_test_base_tensor = finetune_predict(model, X_test, y_test)
    del model
    torch.cuda.empty_cache()

    # Convert embeddings to numpy and normalize
    my_emd_train = emb_val_tensor.cpu().numpy()
    my_emd_test = emb_test_tensor.cpu().numpy()
    base_emd_train = emb_val_base_tensor.cpu().numpy()
    base_emd_test = emb_test_base_tensor.cpu().numpy()

    K = 3  # Top-K similiar patients
    ''' Selecting sepsis-related blood test indicators (SOFA score, Total Bilirubin, Platelet Count, Lactic Acid) '''
    SOFA_FEATURE_INDEX = 56  # SOFA score (Index 56)
    TBIL_FEATURE_INDEX = 27  # Total Bilirubin
    PLT_FEATURE_INDEX = 36  # Platelet Count
    LATIC_FEATURE_INDEX = 49  # Lactic Acid
    n_train_patients = X_val.shape[0]
    n_test_patients = X_test.shape[0]
    max_seq_len = X_val.shape[1]
    n_features = X_val.shape[2]
    my_embedding_dim = 128
    base_embedding_dim = 320

    def get_sequence(patient_data, seq_length, feature_index):
        """
        Extract a time series of SOFA scores from given patient data and actual sequence lengths.
        """
        if seq_length <= 0:
            return None
        sofa_score = patient_data[:seq_length+1, feature_index]
        return sofa_score

    print("Calculating average feature vectors for FeatureAvg baseline...")
    avg_vec_train = np.zeros((n_train_patients, n_features))
    for j in range(n_train_patients):
        s_len = seq_val[j].astype(int)+1
        if s_len > 0:
            avg_vec_train[j, :] = np.mean(X_val[j, :s_len, :], axis=0)

    avg_vec_test = np.zeros((n_test_patients, n_features))
    for i in range(n_test_patients):
        s_len = seq_test[i].astype(int)+1
        if s_len > 0:
            avg_vec_test[i, :] = np.mean(X_test[i, :s_len, :], axis=0)
    print("Average feature vectors calculated.")

    # --- Steps 1 & 2: Find similar patients and calculate SOFA DTW distance ---
    my_model_avg_dtw_distances = []
    base_model_avg_dtw_distances = []
    random_model_avg_dtw_distances = []
    featavg_model_avg_dtw_distances = []

    my_model_avg_dtw_tbil_distances = []
    base_model_avg_dtw_tbil_distances = []
    random_model_avg_dtw_tbil_distances = []
    featavg_model_avg_dtw_tbil_distances = []

    my_model_avg_dtw_plt_distances = []
    base_model_avg_dtw_plt_distances = []
    random_model_avg_dtw_plt_distances = []
    featavg_model_avg_dtw_plt_distances = []

    my_model_avg_dtw_latic_distances = []
    base_model_avg_dtw_latic_distances = []
    random_model_avg_dtw_latic_distances = []
    featavg_model_avg_dtw_latic_distances = []

    #  Repeat with each patient in the test set as a 'reference patient (anchor)'
    print(f"Calculating Top-{K} similar patients and SOFA DTW distances...")
    for ref_idx in range(n_test_patients): 

        # 0. Extract the actual SOFA/ time series for the baseline patient
        ref_sofa_seq = get_sequence(X_test[ref_idx], seq_test[ref_idx].astype(int), SOFA_FEATURE_INDEX)
        ref_tbil_seq = get_sequence(X_test[ref_idx], seq_test[ref_idx].astype(int), TBIL_FEATURE_INDEX)
        ref_plt_seq = get_sequence(X_test[ref_idx], seq_test[ref_idx].astype(int), PLT_FEATURE_INDEX)
        ref_latic_seq = get_sequence(X_test[ref_idx], seq_test[ref_idx].astype(int), LATIC_FEATURE_INDEX)

        if ref_sofa_seq is None or len(ref_sofa_seq) == 0:
            # print(f"Skipping reference patient {ref_idx} due to zero or invalid sequence length.")
            continue  # Skip if there is no SOFA time series for similar patients

        # --- Our Model (QCLR) ---
        # 1. Calculate cosine similarity (reference patient vs all test patients)
        similarities_my = cosine_similarity(my_emd_test[ref_idx].reshape(1, -1), my_emd_train)[0]

        # Sort indices by highest similarity and select top K
        top_k_indices_my = np.argsort(similarities_my)[::-1][:K]
        # top_k_indices_my = np.argsort(similarities_my)[:K]

        # 2.  Calculate SOFA DTW distance between Top K similar patients and reference patients
        dtw_distances_my = []
        dtw_distances_tbil_my = []
        dtw_distances_plt_my = []
        dtw_distances_latic_my = []
        dtw_treatment_my = 0
        for sim_idx in top_k_indices_my:
            sim_sofa_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), SOFA_FEATURE_INDEX)
            sim_tbil_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), TBIL_FEATURE_INDEX)
            sim_plt_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), PLT_FEATURE_INDEX)
            sim_latic_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), LATIC_FEATURE_INDEX)
            if sim_sofa_seq is None or len(sim_sofa_seq) == 0:
                # print(f"  Skipping similar patient {sim_idx} (my model) for ref {ref_idx} due to zero length.")
                continue  # Skip if there is no SOFA time series for similar patients

            # DTW distance
            distance = dtw.distance(ref_sofa_seq, sim_sofa_seq)
            dtw_distances_my.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_tbil_seq)
            dtw_distances_tbil_my.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_plt_seq)
            dtw_distances_plt_my.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_latic_seq)
            dtw_distances_latic_my.append(distance)

            sim_samples = X_val[sim_idx]
            if np.any(sim_samples[:, target_features] == 1):
                dtw_treatment_my += 1

        # Get Avg. DTW distances
        if len(dtw_distances_my) > 0:
            avg_dtw_my = np.mean(dtw_distances_my)
            avg_dtw_tbil_my = np.mean(dtw_distances_tbil_my)
            avg_dtw_plt_my = np.mean(dtw_distances_plt_my)
            avg_dtw_latic_my = np.mean(dtw_distances_latic_my)
        else:
            avg_dtw_my = np.nan  # if cannot calculate, NaN

        # --- Baselines (COMET, Random, FeatAvg) ---
        # Cosine similarity
        similarities_base = cosine_similarity(base_emd_test[ref_idx].reshape(1, -1), base_emd_train)[0]

        # Find Top-K
        top_k_indices_base = np.argsort(similarities_base)[::-1][:K]

        # 2. COMET (baseline #1)
        dtw_distances_base = []
        dtw_distances_tbil_base = []
        dtw_distances_plt_base = []
        dtw_distances_latic_base = []
        dtw_treatment_base = 0
        for sim_idx in top_k_indices_base:
            sim_sofa_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), SOFA_FEATURE_INDEX)
            sim_tbil_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), TBIL_FEATURE_INDEX)
            sim_plt_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), PLT_FEATURE_INDEX)
            sim_latic_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), LATIC_FEATURE_INDEX)
            if sim_sofa_seq is None or len(sim_sofa_seq) == 0:
                # print(f"  Skipping similar patient {sim_idx} (base model) for ref {ref_idx} due to zero length.")
                continue

            distance = dtw.distance(ref_sofa_seq, sim_sofa_seq)
            dtw_distances_base.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_tbil_seq)
            dtw_distances_tbil_base.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_plt_seq)
            dtw_distances_plt_base.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_latic_seq)
            dtw_distances_latic_base.append(distance)

            sim_samples = X_val[sim_idx]
            if np.any(sim_samples[:, target_features] == 1):
                dtw_treatment_base += 1

        # Avg. DTW 
        if len(dtw_distances_base) > 0:
            avg_dtw_base = np.mean(dtw_distances_base)
            avg_dtw_tbil_base = np.mean(dtw_distances_tbil_base)
            avg_dtw_plt_base = np.mean(dtw_distances_plt_base)
            avg_dtw_latic_base = np.mean(dtw_distances_latic_base)
        else:
            avg_dtw_base = np.nan

        # 3. Random (baseline #2)
        possible_indices = list(range(n_train_patients))
        k_random = min(K, n_train_patients)
        random_k_indices = np.random.choice(possible_indices, k_random,
                                            replace=False)
        dtw_distances_random = []
        dtw_distances_tbil_random = []
        dtw_distances_plt_random = []
        dtw_distances_latic_random = []
        dtw_treatment_random = 0
        for rand_idx in random_k_indices:
            rand_sofa_seq = get_sequence(X_val[rand_idx], seq_val[rand_idx].astype(int), SOFA_FEATURE_INDEX)
            sim_tbil_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), TBIL_FEATURE_INDEX)
            sim_plt_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), PLT_FEATURE_INDEX)
            sim_latic_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), LATIC_FEATURE_INDEX)
            if rand_sofa_seq is None or len(rand_sofa_seq) == 0: continue
            distance = dtw.distance(ref_sofa_seq, rand_sofa_seq)
            dtw_distances_random.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_tbil_seq)
            dtw_distances_tbil_random.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_plt_seq)
            dtw_distances_plt_random.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_latic_seq)
            dtw_distances_latic_random.append(distance)

            sim_samples = X_val[sim_idx]
            if np.any(sim_samples[:, target_features] == 1):
                dtw_treatment_random += 1

        if len(dtw_distances_random) > 0:
            avg_dtw_random = np.mean(dtw_distances_random)
            avg_dtw_tbil_random = np.mean(dtw_distances_tbil_random)
            avg_dtw_plt_random = np.mean(dtw_distances_plt_random)
            avg_dtw_latic_random = np.mean(dtw_distances_latic_random)
        else:
            avg_dtw_random = np.nan

        # 4. FeatureAvg (baseline #3)
        similarities_featavg = cosine_similarity(avg_vec_test[ref_idx].reshape(1, -1), avg_vec_train)[0]
        top_k_indices_featavg = np.argsort(similarities_featavg)[::-1][:K]
        dtw_distances_featavg = []
        dtw_distances_tbil_featavg = []
        dtw_distances_plt_featavg = []
        dtw_distances_latic_featavg = []
        dtw_treatment_featavg = 0
        for sim_idx in top_k_indices_featavg:
            sim_sofa_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), SOFA_FEATURE_INDEX)
            sim_tbil_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), TBIL_FEATURE_INDEX)
            sim_plt_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), PLT_FEATURE_INDEX)
            sim_latic_seq = get_sequence(X_val[sim_idx], seq_val[sim_idx].astype(int), LATIC_FEATURE_INDEX)
            if sim_sofa_seq is None or len(sim_sofa_seq) == 0: continue
            distance = dtw.distance(ref_sofa_seq, sim_sofa_seq)
            dtw_distances_featavg.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_tbil_seq)
            dtw_distances_tbil_featavg.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_plt_seq)
            dtw_distances_plt_featavg.append(distance)

            distance = dtw.distance(ref_sofa_seq, sim_latic_seq)
            dtw_distances_latic_featavg.append(distance)

            sim_samples = X_val[sim_idx]
            if np.any(sim_samples[:, target_features] == 1):
                dtw_treatment_featavg += 1

        if len(dtw_distances_featavg) > 0:
            avg_dtw_featavg = np.mean(dtw_distances_featavg)
            avg_dtw_tbil_featavg = np.mean(dtw_distances_tbil_featavg)
            avg_dtw_plt_featavg = np.mean(dtw_distances_plt_featavg)
            avg_dtw_latic_featavg = np.mean(dtw_distances_latic_featavg)
        else:
            avg_dtw_featavg = np.nan

        # Saving Results (except NaN)
        my_model_avg_dtw_distances.append(avg_dtw_my)
        base_model_avg_dtw_distances.append(avg_dtw_base)
        random_model_avg_dtw_distances.append(avg_dtw_random)
        featavg_model_avg_dtw_distances.append(avg_dtw_featavg)

        my_model_avg_dtw_tbil_distances.append(avg_dtw_tbil_my)
        base_model_avg_dtw_tbil_distances.append(avg_dtw_tbil_base)
        random_model_avg_dtw_tbil_distances.append(avg_dtw_tbil_random)
        featavg_model_avg_dtw_tbil_distances.append(avg_dtw_tbil_featavg)

        my_model_avg_dtw_plt_distances.append(avg_dtw_plt_my)
        base_model_avg_dtw_plt_distances.append(avg_dtw_plt_base)
        random_model_avg_dtw_plt_distances.append(avg_dtw_plt_random)
        featavg_model_avg_dtw_plt_distances.append(avg_dtw_plt_featavg)

        my_model_avg_dtw_latic_distances.append(avg_dtw_latic_my)
        base_model_avg_dtw_latic_distances.append(avg_dtw_latic_base)
        random_model_avg_dtw_latic_distances.append(avg_dtw_latic_random)
        featavg_model_avg_dtw_latic_distances.append(avg_dtw_latic_featavg)

    # --- E. Extend Aggregate Lists (End of Fold Loop) ---
    all_folds_my_dtw.extend(my_model_avg_dtw_distances)
    all_folds_base_dtw.extend(base_model_avg_dtw_distances)
    all_folds_random_dtw.extend(random_model_avg_dtw_distances)
    all_folds_featavg_dtw.extend(featavg_model_avg_dtw_distances)

    all_folds_my_tbil_dtw.extend(my_model_avg_dtw_tbil_distances)
    all_folds_base_tbil_dtw.extend(base_model_avg_dtw_tbil_distances)
    all_folds_random_tbil_dtw.extend(random_model_avg_dtw_tbil_distances)
    all_folds_featavg_tbil_dtw.extend(featavg_model_avg_dtw_tbil_distances)

    all_folds_my_plt_dtw.extend(my_model_avg_dtw_plt_distances)
    all_folds_base_plt_dtw.extend(base_model_avg_dtw_plt_distances)
    all_folds_random_plt_dtw.extend(random_model_avg_dtw_plt_distances)
    all_folds_featavg_plt_dtw.extend(featavg_model_avg_dtw_plt_distances)
    
    all_folds_my_latic_dtw.extend(my_model_avg_dtw_latic_distances)
    all_folds_base_latic_dtw.extend(base_model_avg_dtw_latic_distances)
    all_folds_random_latic_dtw.extend(random_model_avg_dtw_latic_distances)
    all_folds_featavg_latic_dtw.extend(featavg_model_avg_dtw_latic_distances)

    print("Calculation finished.")
    print(
        f"Number of valid paired comparisons: {len(my_model_avg_dtw_distances)}")

# --- Step 3: Test for statistical significance ---
# convert list to numpy
all_folds_my_dtw = np.array(all_folds_my_dtw)
all_folds_base_dtw = np.array(all_folds_base_dtw)
all_folds_random_dtw = np.array(all_folds_random_dtw)
all_folds_featavg_dtw = np.array(all_folds_featavg_dtw)

all_folds_my_tbil_dtw = np.array(all_folds_my_tbil_dtw)
all_folds_base_tbil_dtw = np.array(all_folds_base_tbil_dtw)
all_folds_random_tbil_dtw = np.array(all_folds_random_tbil_dtw)
all_folds_featavg_tbil_dtw = np.array(all_folds_featavg_tbil_dtw)

all_folds_my_plt_dtw = np.array(all_folds_my_plt_dtw)
all_folds_base_plt_dtw = np.array(all_folds_base_plt_dtw)
all_folds_random_plt_dtw = np.array(all_folds_random_plt_dtw)
all_folds_featavg_plt_dtw = np.array(all_folds_featavg_plt_dtw)

all_folds_my_latic_dtw = np.array(all_folds_my_latic_dtw)
all_folds_base_latic_dtw = np.array(all_folds_base_latic_dtw)
all_folds_random_latic_dtw = np.array(all_folds_random_latic_dtw)
all_folds_featavg_latic_dtw = np.array(all_folds_featavg_latic_dtw)

results_summary = {
    'Method': ['My Model', 'Base Model', 'Random', 'FeatureAvg'],
    'Mean DTW': [np.nanmean(d) for d in [all_folds_my_dtw, all_folds_base_dtw, all_folds_random_dtw, all_folds_featavg_dtw]],
    'Std Dev DTW': [np.nanstd(d) for d in [all_folds_my_dtw, all_folds_base_dtw, all_folds_random_dtw, all_folds_featavg_dtw]],
    'Median DTW': [np.nanmedian(d) for d in [all_folds_my_dtw, all_folds_base_dtw, all_folds_random_dtw, all_folds_featavg_dtw]],
    'IQR DTW': [scipy_iqr(d, nan_policy='omit') for d in [all_folds_my_dtw, all_folds_base_dtw, all_folds_random_dtw, all_folds_featavg_dtw]],
    'N Valid': [np.sum(~np.isnan(d)) for d in [all_folds_my_dtw, all_folds_base_dtw, all_folds_random_dtw, all_folds_featavg_dtw]]
}
results_df = pd.DataFrame(results_summary)
print("\n--- Summary Statistics (eSOFA Score) ---")
print(results_df.to_string(index=False, float_format="%.4f"))

results_summary_tbil = {
    'Method': ['My Model', 'Base Model', 'Random', 'FeatureAvg'],
    'Mean DTW': [np.nanmean(d) for d in [all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, all_folds_random_tbil_dtw, all_folds_featavg_tbil_dtw]],
    'Std Dev DTW': [np.nanstd(d) for d in [all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, all_folds_random_tbil_dtw, all_folds_featavg_tbil_dtw]],
    'Median DTW': [np.nanmedian(d) for d in [all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, all_folds_random_tbil_dtw, all_folds_featavg_tbil_dtw]],
    'IQR DTW': [scipy_iqr(d, nan_policy='omit') for d in [all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, all_folds_random_tbil_dtw, all_folds_featavg_tbil_dtw]],
    'N Valid': [np.sum(~np.isnan(d)) for d in [all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, all_folds_random_tbil_dtw, all_folds_featavg_tbil_dtw]]
}
results_df_tbil = pd.DataFrame(results_summary_tbil)
print("\n--- Summary Statistics (Total Bilirubin) ---")
print(results_df_tbil.to_string(index=False, float_format="%.4f"))

results_summary_plt = {
    'Method': ['My Model', 'Base Model', 'Random', 'FeatureAvg'],
    'Mean DTW': [np.nanmean(d) for d in [all_folds_my_plt_dtw, all_folds_base_plt_dtw, all_folds_random_plt_dtw, all_folds_featavg_plt_dtw]],
    'Std Dev DTW': [np.nanstd(d) for d in [all_folds_my_plt_dtw, all_folds_base_plt_dtw, all_folds_random_plt_dtw, all_folds_featavg_plt_dtw]],
    'Median DTW': [np.nanmedian(d) for d in [all_folds_my_plt_dtw, all_folds_base_plt_dtw, all_folds_random_plt_dtw, all_folds_featavg_plt_dtw]],
    'IQR DTW': [scipy_iqr(d, nan_policy='omit') for d in [all_folds_my_plt_dtw, all_folds_base_plt_dtw, all_folds_random_plt_dtw, all_folds_featavg_plt_dtw]],
    'N Valid': [np.sum(~np.isnan(d)) for d in [all_folds_my_plt_dtw, all_folds_base_plt_dtw, all_folds_random_plt_dtw, all_folds_featavg_plt_dtw]]
}
results_df_plt = pd.DataFrame(results_summary_plt)
print("\n--- Summary Statistics (Platelet Count) ---")
print(results_df_plt.to_string(index=False, float_format="%.4f"))

results_summary_latic = {
    'Method': ['My Model', 'Base Model', 'Random', 'FeatureAvg'],
    'Mean DTW': [np.nanmean(d) for d in [all_folds_my_latic_dtw, all_folds_base_latic_dtw, all_folds_random_latic_dtw, all_folds_featavg_latic_dtw]],
    'Std Dev DTW': [np.nanstd(d) for d in [all_folds_my_latic_dtw, all_folds_base_latic_dtw, all_folds_random_latic_dtw, all_folds_featavg_latic_dtw]],
    'Median DTW': [np.nanmedian(d) for d in [all_folds_my_latic_dtw, all_folds_base_latic_dtw, all_folds_random_latic_dtw, all_folds_featavg_latic_dtw]],
    'IQR DTW': [scipy_iqr(d, nan_policy='omit') for d in [all_folds_my_latic_dtw, all_folds_base_latic_dtw, all_folds_random_latic_dtw, all_folds_featavg_latic_dtw]],
    'N Valid': [np.sum(~np.isnan(d)) for d in [all_folds_my_latic_dtw, all_folds_base_latic_dtw, all_folds_random_latic_dtw, all_folds_featavg_latic_dtw]]
}
results_df_latic = pd.DataFrame(results_summary_latic)
print("\n--- Summary Statistics (Lactic Acid) ---")
print(results_df_latic.to_string(index=False, float_format="%.4f"))

print("\n--- Statistical Test (Wilcoxon Signed-Rank Test) ---")
alpha = 0.05

# --- Function to perform and print Wilcoxon test ---
def perform_wilcoxon(data1, data2, label1, label2, alternative='less'):
    print(f"\nComparison: {label1} vs {label2}")
    # Pingouin handles NaNs by default using pairwise deletion
    test_results = pg.wilcoxon(data1, data2, alternative=alternative)
    p_val = test_results['p-val'].iloc[0]
    rbc = test_results['RBC'].iloc[0] # Rank-Biserial Correlation
    cles = test_results['CLES'].iloc[0] # Common Language Effect Size

    print(f"   Wilcoxon Test (H1: {label1} < {label2}):")
    print(f"     P-value: {p_val:.4g}") # Use general format for small p-values
    print(f"     Effect Size (Rank-Biserial Correlation, RBC): {rbc:.4f}")
    print(f"     Effect Size (Common Language Effect Size, CLES): {cles:.4f}")

    if p_val < alpha:
        print(f"     Result: Statistically significant difference (p < {alpha}).")
    else:
        print(f"     Result: No statistically significant difference (p >= {alpha}).")
    return p_val, rbc, cles

# Perform tests
print("\n--- Statistical Tests (eSOFA Score) ---")
p_my_base, rbc_my_base, cles_my_base = perform_wilcoxon(all_folds_my_dtw, all_folds_base_dtw, "My Model", "Base Model")
p_my_random, rbc_my_random, cles_my_random = perform_wilcoxon(all_folds_my_dtw, all_folds_random_dtw, "My Model", "Random")
p_my_featavg, rbc_my_featavg, cles_my_featavg = perform_wilcoxon(all_folds_my_dtw, all_folds_featavg_dtw, "My Model", "FeatureAvg")

print("\n--- Statistical Tests (Total Bilirubin) ---")
p_my_base, rbc_my_base, cles_my_base = perform_wilcoxon(all_folds_my_tbil_dtw, all_folds_base_tbil_dtw, "My Model", "Base Model")
p_my_random, rbc_my_random, cles_my_random = perform_wilcoxon(all_folds_my_tbil_dtw, all_folds_random_tbil_dtw, "My Model", "Random")
p_my_featavg, rbc_my_featavg, cles_my_featavg = perform_wilcoxon(all_folds_my_tbil_dtw, all_folds_featavg_tbil_dtw, "My Model", "FeatureAvg")

print("\n--- Statistical Tests (Platelet Count) ---")
p_my_base, rbc_my_base, cles_my_base = perform_wilcoxon(all_folds_my_plt_dtw, all_folds_base_plt_dtw, "My Model", "Base Model")
p_my_random, rbc_my_random, cles_my_random = perform_wilcoxon(all_folds_my_plt_dtw, all_folds_random_plt_dtw, "My Model", "Random")
p_my_featavg, rbc_my_featavg, cles_my_featavg = perform_wilcoxon(all_folds_my_plt_dtw, all_folds_featavg_plt_dtw, "My Model", "FeatureAvg")

print("\n--- Statistical Tests (Lactic Acid) ---")
p_my_base, rbc_my_base, cles_my_base = perform_wilcoxon(all_folds_my_latic_dtw, all_folds_base_latic_dtw, "My Model", "Base Model")
p_my_random, rbc_my_random, cles_my_random = perform_wilcoxon(all_folds_my_latic_dtw, all_folds_random_latic_dtw, "My Model", "Random")
p_my_featavg, rbc_my_featavg, cles_my_featavg = perform_wilcoxon(all_folds_my_latic_dtw, all_folds_featavg_latic_dtw, "My Model", "FeatureAvg")


The program will run on cuda!

--- Processing Fold 1 ---
Fold 1:
Train set size: 9500 6405 3095
Train set size: 2374 1595 779
Train set size: 2969 2000 969
Number of Selected Patient (critical intervention):  (1149,)
eSOFA Score (unique):  [0.         0.00199203 0.20119522 0.40039841 0.59960159 0.79880478
 0.99800797]
Calculating average feature vectors for FeatureAvg baseline...
Average feature vectors calculated.
Calculating Top-3 similar patients and SOFA DTW distances...
Calculation finished.
Number of valid paired comparisons: 1149

--- Processing Fold 2 ---
Fold 2:
Train set size: 9500 6391 3109
Train set size: 2374 1609 765
Train set size: 2969 2000 969
Number of Selected Patient (critical intervention):  (1135,)
eSOFA Score (unique):  [0.         0.00199203 0.20119522 0.40039841 0.59960159 0.79880478
 0.99800797]
Calculating average feature vectors for FeatureAvg baseline...
Average feature vectors calculated.
Calculating Top-3 similar patients and SOFA DTW distances...
Calcula