In [None]:
import os
import random
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, classification_report, auc_score
from sklearn.utils.class_weight import compute_class_weight

import shap
import optuna

import torch
import torch.nn as nn
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
from typing import Tuple, List

Some useful global constants and setting the seed

In [25]:
SEED = 7
CORR_THRESHOLD = 0.85 # correlation threshold for dimensionality reduction
TEST_SIZE = 0.30 # train-test split
VAL_SIZE = 0.20  # train-val split   
N_TRIALS = 50                     
MAX_PRETRAIN_EPOCHS = 150
MAX_FINETUNE_EPOCHS = 200
EARLY_STOPPING_PATIENCE = 30
BATCH_SIZE = 2048
VIRTUAL_BATCH_SIZE = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

FOCAL_ALPHA = 0.75                 # Higher alpha for rare positive class
FOCAL_GAMMA = 2.0   

In [8]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Soem useful helper functions

In [None]:
def bootstrap_auc_ci(y_true, y_scores, n_bootstraps=2000, ci=0.95):
    """ 
    Simple Bootstrapping method to get an confidence interval on the AUROC score.
    """
    rng = np.random.default_rng(42)
    aucs = []

    y_true = np.array(y_true)
    y_scores = np.array(y_scores)

    for _ in range(n_bootstraps):
        idx = rng.integers(0, len(y_true), len(y_true))
        if len(np.unique(y_true[idx])) < 2:
            continue
        aucs.append(roc_auc_score(y_true[idx], y_scores[idx]))

    lower = np.percentile(aucs, (1 - ci) / 2 * 100)
    upper = np.percentile(aucs, (1 + ci) / 2 * 100)
    return np.mean(aucs), lower, upper

def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, 
                         title: str = "Confusion Matrix"):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['No Readmission', 'Readmission'],yticklabels=['No Readmission', 'Readmission'])
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

def plot_roc_curve(y_true: np.ndarray, y_scores: np.ndarray):
    """Plot ROC curve"""
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_precision_recall_curve(y_true: np.ndarray, y_scores: np.ndarray):
    """Plot Precision-Recall curve (better for imbalanced data)"""
    precision, recall, _ = precision_recall_curve(y_true, y_scores)
    ap_score = average_precision_score(y_true, y_scores)
    
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=2, label=f'PR curve (AP = {ap_score:.4f})')
    plt.axhline(y=y_true.mean(), color='red', linestyle='--', label=f'Baseline ({y_true.mean():.3f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc="best")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_calibration_curve(y_true: np.ndarray, y_scores: np.ndarray, n_bins: int = 10):
    """Plot calibration curve"""
    fraction_of_positives, mean_predicted_value = calibration_curve(
        y_true, y_scores, n_bins=n_bins, strategy='uniform'
    )
    
    plt.figure(figsize=(8, 6))
    plt.plot(mean_predicted_value, fraction_of_positives, 's-', label='Model', color='blue')
    plt.plot([0, 1], [0, 1], '--', label='Perfect calibration', color='gray')
    plt.xlabel('Mean Predicted Probability')
    plt.ylabel('Fraction of Positives')
    plt.title('Calibration Curve')
    plt.legend(loc='best')
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

## Loading the dataset, pre-processing, and analysing the data

In [10]:
cohort_data = pd.read_csv('../cohort_data_new.csv')
cohort_data

Unnamed: 0,icustay_id,anion_gap_mean,anion_gap_sd,anion_gap_min,anion_gap_max,bicarbonate_mean,bicarbonate_sd,bicarbonate_min,bicarbonate_max,calcium_total_mean,...,urea_nitrogen_min,urea_nitrogen_max,white_blood_cells_mean,white_blood_cells_sd,white_blood_cells_min,white_blood_cells_max,age,gender,icu_los_hours,target
0,200003,13.375000,3.583195,9.0,21.0,25.250000,3.105295,18.0,28.0,7.771429,...,10.0,21.0,26.471429,13.176711,13.2,43.9,48,M,141,0
1,200007,15.500000,2.121320,14.0,17.0,23.000000,1.414214,22.0,24.0,8.900000,...,8.0,10.0,10.300000,1.272792,9.4,11.2,44,M,30,0
2,200009,9.500000,2.121320,8.0,11.0,23.333333,2.081666,21.0,25.0,8.000000,...,15.0,21.0,12.471429,1.471637,10.5,14.3,47,F,51,0
3,200012,,,,,,,,,,...,,,4.900000,,4.9,4.9,33,F,10,0
4,200014,10.000000,1.732051,9.0,12.0,24.000000,1.000000,23.0,25.0,7.733333,...,21.0,24.0,13.233333,2.203028,10.7,14.7,85,M,41,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30484,299992,15.375000,2.856153,11.0,25.0,23.125000,2.609556,15.0,26.0,8.307143,...,8.0,23.0,14.134783,3.781727,8.1,22.1,41,M,499,0
30485,299993,9.400000,1.341641,8.0,11.0,29.600000,2.073644,26.0,31.0,8.000000,...,12.0,15.0,12.600000,0.605530,12.0,13.3,26,M,67,0
30486,299994,16.157895,2.477973,13.0,24.0,21.631579,3.451417,17.0,31.0,8.100000,...,28.0,63.0,10.076190,2.642329,5.3,14.5,74,F,152,1
30487,299998,11.500000,1.732051,10.0,14.0,23.500000,1.290994,22.0,25.0,8.800000,...,20.0,22.0,9.900000,1.210372,7.9,11.0,87,M,46,1


In [11]:
print(f"Dataset shape: {cohort_data.shape}")
print(f"Readmission rate: {cohort_data['target'].mean() * 100:.2f}%")

Dataset shape: (30489, 93)
Readmission rate: 10.74%


In [12]:
lab_cols = [
    'anion_gap_mean', 'anion_gap_min', 'anion_gap_max', 'anion_gap_sd',
    'bicarbonate_mean', 'bicarbonate_min', 'bicarbonate_max', 'bicarbonate_sd',
    'calcium_total_mean', 'calcium_total_min', 'calcium_total_max', 'calcium_total_sd',
    'chloride_mean', 'chloride_min', 'chloride_max', 'chloride_sd',
    'creatinine_mean', 'creatinine_min', 'creatinine_max', 'creatinine_sd',
    'glucose_mean', 'glucose_min', 'glucose_max', 'glucose_sd',
    'hematocrit_mean', 'hematocrit_min', 'hematocrit_max', 'hematocrit_sd',
    'hemoglobin_mean', 'hemoglobin_min', 'hemoglobin_max', 'hemoglobin_sd',
    'mchc_mean', 'mchc_min', 'mchc_max', 'mchc_sd',
    'mcv_mean', 'mcv_min', 'mcv_max', 'mcv_sd',
    'magnesium_mean', 'magnesium_min', 'magnesium_max', 'magnesium_sd',
    'pt_mean', 'pt_min', 'pt_max', 'pt_sd',
    'phosphate_mean', 'phosphate_min', 'phosphate_max', 'phosphate_sd',
    'platelet_count_mean', 'platelet_count_min', 'platelet_count_max', 'platelet_count_sd',
    'potassium_mean', 'potassium_min', 'potassium_max', 'potassium_sd',
    'rdw_mean', 'rdw_min', 'rdw_max', 'rdw_sd',
    'red_blood_cells_mean', 'red_blood_cells_min', 'red_blood_cells_max', 'red_blood_cells_sd',
    'sodium_mean', 'sodium_min', 'sodium_max', 'sodium_sd',
    'urea_nitrogen_mean', 'urea_nitrogen_min', 'urea_nitrogen_max', 'urea_nitrogen_sd',
    'white_blood_cells_mean', 'white_blood_cells_min', 'white_blood_cells_max', 'white_blood_cells_sd',
    'age', 'icu_los_hours'
]

REmove the ICUstay_id and the gender

In [13]:
drop_cols = [c for c in cohort_data.columns if 'icustay_id' in c.lower() or 'gender' in c.lower()]
df = cohort_data.drop(columns=['icustay_id', 'gender'], errors='ignore')

X = df.drop(columns=['target'])
y = df['target']

Trying out some feature engineering

In [14]:
X = X.select_dtypes(include=['number']).replace([np.inf, -np.inf], np.nan)
print(f"initial feature matrix shape: {X.shape}")

def create_engineered_features(df: pd.DataFrame) -> pd.DataFrame:
    df_eng = df.copy()
    
    # BUN/Creatinine ratio (kidney function indicator)
    if 'urea_nitrogen_mean' in df_eng.columns and 'creatinine_mean' in df_eng.columns:
        df_eng['bun_creatinine_ratio'] = (
            df_eng['urea_nitrogen_mean'] / (df_eng['creatinine_mean'] + 1e-6)
        )

    # Variability indices (physiological instability)
    variability_features = []
    for base_name in ['glucose', 'potassium', 'sodium', 'hemoglobin']:
        mean_col = f'{base_name}_mean'
        sd_col = f'{base_name}_sd'
        if mean_col in df_eng.columns and sd_col in df_eng.columns:
            cv_col = f'{base_name}_cv'
            df_eng[cv_col] = df_eng[sd_col] / (df_eng[mean_col] + 1e-6)
            variability_features.append(cv_col)
    
    # Range features (max - min)
    for base_name in ['glucose', 'creatinine', 'potassium']:
        max_col = f'{base_name}_max'
        min_col = f'{base_name}_min'
        if max_col in df_eng.columns and min_col in df_eng.columns:
            range_col = f'{base_name}_range'
            df_eng[range_col] = df_eng[max_col] - df_eng[min_col]
    
    return df_eng

X_engineered = create_engineered_features(X)
print(f"final feature matrix shape: {X_engineered.shape}")

initial feature matrix shape: (30489, 90)
final feature matrix shape: (30489, 98)


Dimensionality reduction by deleting the columns with high correlation.

In [15]:
corr = X_engineered.corr().abs()
upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
to_drop = [col for col in upper.columns if any(upper[col] >= CORR_THRESHOLD)]

print(f"dropping {len(to_drop)} features")
X_reduced = X_engineered.drop(columns=to_drop, errors='ignore')
print(f"final feature count: {X_reduced.shape[1]}")
feature_names = X_reduced.columns.tolist()

dropping 42 features
final feature count: 56


Creating the final train-val-test sets

In [16]:
# separate test set
X_temp, X_test, y_temp, y_test = train_test_split(X_reduced, y.values, test_size=TEST_SIZE, random_state=SEED, stratify=y.values)

# separate validation set 
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=VAL_SIZE, random_state=SEED, stratify=y_temp)

print(f"Train set: {X_train.shape[0]} samples ({y_train.mean()*100:.2f}% readmission)")
print(f"Validation set: {X_val.shape[0]} samples ({y_val.mean()*100:.2f}% readmission)")
print(f"Test set: {X_test.shape[0]} samples ({y_test.mean()*100:.2f}% readmission)")

# SimpleImputation using median strat and scaling;
# Imputation - FIT on train only and avoidning data leakage:
imputer = SimpleImputer(strategy="median")
X_train_imputed = imputer.fit_transform(X_train)
X_val_imputed = imputer.transform(X_val)
X_test_imputed = imputer.transform(X_test)

# Scaling - FIT on train only
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_imputed)
X_val_scaled = scaler.transform(X_val_imputed)
X_test_scaled = scaler.transform(X_test_imputed)

Train set: 17073 samples (10.74% readmission)
Validation set: 4269 samples (10.75% readmission)
Test set: 9147 samples (10.75% readmission)


### Pretraining the TABNet

In [18]:
def run_pretraining(X_train: np.ndarray, X_val: np.ndarray, pretrain_params: dict):
    """Run unsupervised pretraining"""
    pretrainer = TabNetPretrainer(**pretrain_params)
    pretrainer.fit(
        X_train=X_train,
        eval_set=[X_val],
        max_epochs=MAX_PRETRAIN_EPOCHS,
        patience=EARLY_STOPPING_PATIENCE,
        batch_size=BATCH_SIZE,
        virtual_batch_size=VIRTUAL_BATCH_SIZE,
        num_workers=0,
        drop_last=False
    )
    
    print("Pretraining complete!")
    return pretrainer

pretrain_params = dict(
    n_d=32, 
    n_a=32,
    n_steps=5,
    gamma=1.5,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=1e-3),
    mask_type="entmax",
    device_name=DEVICE
)

pretrainer = run_pretraining(X_train_scaled, X_val_scaled, pretrain_params)



epoch 0  | loss: 2873458.54586| val_0_unsup_loss_numpy: 249431.34375|  0:00:02s
epoch 1  | loss: 2178197.42618| val_0_unsup_loss_numpy: 453751.53125|  0:00:05s
epoch 2  | loss: 1646358.82802| val_0_unsup_loss_numpy: 418497.21875|  0:00:08s
epoch 3  | loss: 1313544.28527| val_0_unsup_loss_numpy: 350956.40625|  0:00:11s
epoch 4  | loss: 1134472.02326| val_0_unsup_loss_numpy: 384032.40625|  0:00:14s
epoch 5  | loss: 910770.47804| val_0_unsup_loss_numpy: 499828.5|  0:00:16s
epoch 6  | loss: 797942.89425| val_0_unsup_loss_numpy: 175733.984375|  0:00:19s
epoch 7  | loss: 701835.1237| val_0_unsup_loss_numpy: 278253.875|  0:00:22s
epoch 8  | loss: 547988.5787| val_0_unsup_loss_numpy: 356743.375|  0:00:24s
epoch 9  | loss: 528224.34876| val_0_unsup_loss_numpy: 171060.046875|  0:00:27s
epoch 10 | loss: 476515.14827| val_0_unsup_loss_numpy: 186770.4375|  0:00:30s
epoch 11 | loss: 458406.31128| val_0_unsup_loss_numpy: 143591.578125|  0:00:33s
epoch 12 | loss: 389109.70262| val_0_unsup_loss_numpy: 



Hyperparam Search

In [32]:
def make_objective(X_train, y_train, X_val, y_val, class_weights_tensor, pretrainer):
    """Create Optuna objective function"""
    
    def objective(trial):
        # Hyperparameters to tune
        n_d = trial.suggest_int("n_d", 16, 64)
        n_a = trial.suggest_int("n_a", 16, 64)
        n_steps = trial.suggest_int("n_steps", 3, 7)
        gamma = trial.suggest_float("gamma", 1.0, 2.5)
        lambda_sparse = trial.suggest_float("lambda_sparse", 1e-6, 1e-3, log=True)
        lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
        mask_type = trial.suggest_categorical("mask_type", ["sparsemax", "entmax"])
        
        clf = TabNetClassifier(n_d=n_d, n_a=n_a, n_steps=n_steps, gamma=gamma, lambda_sparse=lambda_sparse, optimizer_fn=torch.optim.Adam,optimizer_params=dict(lr=lr), mask_type=mask_type, device_name=DEVICE, verbose=0)
        
        try:
            clf.fit(
                X_train, y_train,
                eval_set=[(X_val, y_val)],
                eval_name=["val"],
                eval_metric=["auc"],
                max_epochs=MAX_FINETUNE_EPOCHS,
                patience=20, 
                batch_size=BATCH_SIZE,
                virtual_batch_size=VIRTUAL_BATCH_SIZE,
                num_workers=0,
                drop_last=False,
                from_unsupervised=pretrainer,
            )
        except Exception as e:
            print(f"Trial failed: {e}")
            raise optuna.exceptions.TrialPruned()
        
        pred_proba = clf.predict_proba(X_val)[:, 1]
        auc = roc_auc_score(y_val, pred_proba)
        
        return auc
    
    return objective

In [33]:
print("hyperparameter optimization under way")
study = optuna.create_study(direction="maximize", study_name="tabnet_readmission_10pct")

objective = make_objective(X_train_scaled, y_train, X_val_scaled, y_val, class_weights_tensor, pretrainer)
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)
print(f"\nBest trial: {study.best_trial.number}")
print(f"Best validation AUROC: {study.best_value:.4f}")
print("\nBest hyperparameters:")
for key, value in study.best_trial.params.items():
    print(f"  {key}: {value}")

best_params = study.best_trial.params

[I 2025-11-29 14:35:05,300] A new study created in memory with name: tabnet_readmission_10pct


hyperparameter optimization under way


  0%|          | 0/50 [00:00<?, ?it/s]




Early stopping occurred at epoch 25 with best_epoch = 5 and best_val_auc = 0.55876




[I 2025-11-29 14:36:05,597] Trial 0 finished with value: 0.5587609146895853 and parameters: {'n_d': 16, 'n_a': 42, 'n_steps': 3, 'gamma': 2.0279761670348284, 'lambda_sparse': 0.00023585544043090657, 'lr': 0.0005995208416834294, 'mask_type': 'sparsemax'}. Best is trial 0 with value: 0.5587609146895853.





Early stopping occurred at epoch 21 with best_epoch = 1 and best_val_auc = 0.55598




[I 2025-11-29 14:37:06,992] Trial 1 finished with value: 0.5559792771001664 and parameters: {'n_d': 27, 'n_a': 21, 'n_steps': 4, 'gamma': 2.4362947429033426, 'lambda_sparse': 2.1576224252756286e-06, 'lr': 0.00014885526310768444, 'mask_type': 'entmax'}. Best is trial 0 with value: 0.5587609146895853.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.55157




[I 2025-11-29 14:38:03,513] Trial 2 finished with value: 0.5515722299418455 and parameters: {'n_d': 50, 'n_a': 17, 'n_steps': 7, 'gamma': 1.6818613284896964, 'lambda_sparse': 2.071243394344492e-06, 'lr': 0.0019745299263699027, 'mask_type': 'sparsemax'}. Best is trial 0 with value: 0.5587609146895853.





Early stopping occurred at epoch 30 with best_epoch = 10 and best_val_auc = 0.56718




[I 2025-11-29 14:39:39,219] Trial 3 finished with value: 0.567176161803304 and parameters: {'n_d': 19, 'n_a': 42, 'n_steps': 7, 'gamma': 1.5709721776426693, 'lambda_sparse': 2.140208127619011e-05, 'lr': 0.00024370476626022695, 'mask_type': 'sparsemax'}. Best is trial 3 with value: 0.567176161803304.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.55714




[I 2025-11-29 14:40:48,106] Trial 4 finished with value: 0.5571406515362051 and parameters: {'n_d': 24, 'n_a': 49, 'n_steps': 7, 'gamma': 1.9786304592963204, 'lambda_sparse': 8.266405914122419e-05, 'lr': 0.0015557032472153364, 'mask_type': 'entmax'}. Best is trial 3 with value: 0.567176161803304.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.57306




[I 2025-11-29 14:41:52,718] Trial 5 finished with value: 0.5730610879522413 and parameters: {'n_d': 45, 'n_a': 41, 'n_steps': 4, 'gamma': 1.2501140854293475, 'lambda_sparse': 0.0002938146458747543, 'lr': 0.0001831861495350541, 'mask_type': 'entmax'}. Best is trial 5 with value: 0.5730610879522413.





Early stopping occurred at epoch 22 with best_epoch = 2 and best_val_auc = 0.57813




[I 2025-11-29 14:43:04,921] Trial 6 finished with value: 0.5781260185614052 and parameters: {'n_d': 58, 'n_a': 24, 'n_steps': 5, 'gamma': 1.4096510069916823, 'lambda_sparse': 3.969473009794932e-06, 'lr': 0.00013612241447842713, 'mask_type': 'sparsemax'}. Best is trial 6 with value: 0.5781260185614052.





Early stopping occurred at epoch 26 with best_epoch = 6 and best_val_auc = 0.56918




[I 2025-11-29 14:44:25,827] Trial 7 finished with value: 0.5691849793285644 and parameters: {'n_d': 48, 'n_a': 16, 'n_steps': 7, 'gamma': 1.7060753686316334, 'lambda_sparse': 3.204512131800496e-05, 'lr': 0.00023362139296410526, 'mask_type': 'sparsemax'}. Best is trial 6 with value: 0.5781260185614052.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.57042




[I 2025-11-29 14:45:32,574] Trial 8 finished with value: 0.5704181176699318 and parameters: {'n_d': 64, 'n_a': 47, 'n_steps': 4, 'gamma': 1.488164595382199, 'lambda_sparse': 1.3009746949084117e-06, 'lr': 0.000600288387434428, 'mask_type': 'entmax'}. Best is trial 6 with value: 0.5781260185614052.





Early stopping occurred at epoch 26 with best_epoch = 6 and best_val_auc = 0.57014




[I 2025-11-29 14:46:53,165] Trial 9 finished with value: 0.5701376380240051 and parameters: {'n_d': 18, 'n_a': 59, 'n_steps': 6, 'gamma': 1.7849146144126498, 'lambda_sparse': 9.352867688957997e-05, 'lr': 0.00030018372508486946, 'mask_type': 'entmax'}. Best is trial 6 with value: 0.5781260185614052.





Early stopping occurred at epoch 51 with best_epoch = 31 and best_val_auc = 0.65634




[I 2025-11-29 14:49:26,132] Trial 10 finished with value: 0.6563406698345714 and parameters: {'n_d': 63, 'n_a': 29, 'n_steps': 5, 'gamma': 1.1177037425123468, 'lambda_sparse': 7.439923686686814e-06, 'lr': 0.004478360820807508, 'mask_type': 'sparsemax'}. Best is trial 10 with value: 0.6563406698345714.





Early stopping occurred at epoch 85 with best_epoch = 65 and best_val_auc = 0.69462




[I 2025-11-29 14:54:00,132] Trial 11 finished with value: 0.6946191366602051 and parameters: {'n_d': 63, 'n_a': 29, 'n_steps': 5, 'gamma': 1.0368448017651986, 'lambda_sparse': 7.962313494947115e-06, 'lr': 0.004503380969275998, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 84 with best_epoch = 64 and best_val_auc = 0.69116




[I 2025-11-29 15:00:30,564] Trial 12 finished with value: 0.6911630327254845 and parameters: {'n_d': 64, 'n_a': 30, 'n_steps': 5, 'gamma': 1.042767635337442, 'lambda_sparse': 9.308953617762418e-06, 'lr': 0.004801163413084997, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 86 with best_epoch = 66 and best_val_auc = 0.68232




[I 2025-11-29 15:06:17,739] Trial 13 finished with value: 0.6823177740037397 and parameters: {'n_d': 55, 'n_a': 31, 'n_steps': 5, 'gamma': 1.0238502352569834, 'lambda_sparse': 9.86114278551921e-06, 'lr': 0.0035119091345912606, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 135 with best_epoch = 115 and best_val_auc = 0.67634




[I 2025-11-29 15:14:38,310] Trial 14 finished with value: 0.6763436433190949 and parameters: {'n_d': 37, 'n_a': 33, 'n_steps': 6, 'gamma': 1.2581888201660896, 'lambda_sparse': 1.809310610189356e-05, 'lr': 0.0023094446407705507, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.56531




[I 2025-11-29 15:15:39,367] Trial 15 finished with value: 0.5653074411450203 and parameters: {'n_d': 56, 'n_a': 35, 'n_steps': 6, 'gamma': 1.000267584828258, 'lambda_sparse': 5.231007449961608e-06, 'lr': 0.0010637854893068462, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 89 with best_epoch = 69 and best_val_auc = 0.67161




[I 2025-11-29 15:21:09,442] Trial 16 finished with value: 0.6716100846871265 and parameters: {'n_d': 37, 'n_a': 26, 'n_steps': 3, 'gamma': 1.2545857965130243, 'lambda_sparse': 5.8927958086102926e-05, 'lr': 0.003165664763134903, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 92 with best_epoch = 72 and best_val_auc = 0.62943




[I 2025-11-29 15:28:08,538] Trial 17 finished with value: 0.6294334940158625 and parameters: {'n_d': 60, 'n_a': 36, 'n_steps': 4, 'gamma': 2.493670594753345, 'lambda_sparse': 1.3805595230608313e-05, 'lr': 0.0049052763757230665, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 21 with best_epoch = 1 and best_val_auc = 0.55817




[I 2025-11-29 15:29:27,397] Trial 18 finished with value: 0.558174795144071 and parameters: {'n_d': 52, 'n_a': 56, 'n_steps': 6, 'gamma': 1.324276462590059, 'lambda_sparse': 0.0007902938726958225, 'lr': 0.001169721945218466, 'mask_type': 'sparsemax'}. Best is trial 11 with value: 0.6946191366602051.





Early stopping occurred at epoch 117 with best_epoch = 97 and best_val_auc = 0.69897




[I 2025-11-29 15:36:28,741] Trial 19 finished with value: 0.6989738619273898 and parameters: {'n_d': 44, 'n_a': 22, 'n_steps': 5, 'gamma': 1.1481212253039474, 'lambda_sparse': 3.9715216216283716e-05, 'lr': 0.002616746364609544, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.55093




[I 2025-11-29 15:37:36,148] Trial 20 finished with value: 0.550933788505195 and parameters: {'n_d': 42, 'n_a': 21, 'n_steps': 4, 'gamma': 2.22080246155102, 'lambda_sparse': 3.987630801446338e-05, 'lr': 0.0025689181416155907, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 74 with best_epoch = 54 and best_val_auc = 0.69043




[I 2025-11-29 15:42:13,358] Trial 21 finished with value: 0.6904345290172061 and parameters: {'n_d': 33, 'n_a': 27, 'n_steps': 5, 'gamma': 1.1472581259463779, 'lambda_sparse': 4.541535755278825e-06, 'lr': 0.0038733305274092095, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 69 with best_epoch = 49 and best_val_auc = 0.68627




[I 2025-11-29 15:45:51,524] Trial 22 finished with value: 0.6862739379799748 and parameters: {'n_d': 61, 'n_a': 21, 'n_steps': 5, 'gamma': 1.1302094176316995, 'lambda_sparse': 1.1070200654745234e-05, 'lr': 0.004936519617910341, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.




Stop training because you reached max_epochs = 200 with best_epoch = 184 and best_val_auc = 0.68766




[I 2025-11-29 16:01:22,879] Trial 23 finished with value: 0.6876594674031759 and parameters: {'n_d': 53, 'n_a': 36, 'n_steps': 6, 'gamma': 1.3772927296772242, 'lambda_sparse': 2.497874440161843e-05, 'lr': 0.0028079706948511795, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 136 with best_epoch = 116 and best_val_auc = 0.69736




[I 2025-11-29 16:14:40,445] Trial 24 finished with value: 0.6973558860697969 and parameters: {'n_d': 64, 'n_a': 30, 'n_steps': 5, 'gamma': 1.093844423187339, 'lambda_sparse': 0.00012061187296350777, 'lr': 0.0015794205169544634, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.5634




[I 2025-11-29 16:16:32,008] Trial 25 finished with value: 0.5633992646344044 and parameters: {'n_d': 45, 'n_a': 24, 'n_steps': 5, 'gamma': 1.1726027008265874, 'lambda_sparse': 0.00015330830144466773, 'lr': 0.0014895582906838001, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 21 with best_epoch = 1 and best_val_auc = 0.56771




[I 2025-11-29 16:18:22,914] Trial 26 finished with value: 0.5677128185774164 and parameters: {'n_d': 58, 'n_a': 64, 'n_steps': 5, 'gamma': 1.4705229631328038, 'lambda_sparse': 0.0005535329808370042, 'lr': 0.0010042790056172131, 'mask_type': 'entmax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.56413




[I 2025-11-29 16:20:28,277] Trial 27 finished with value: 0.5641346302300447 and parameters: {'n_d': 32, 'n_a': 38, 'n_steps': 4, 'gamma': 1.2096665019862973, 'lambda_sparse': 4.311694788554745e-05, 'lr': 0.001918604302579347, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.5665




[I 2025-11-29 16:22:13,179] Trial 28 finished with value: 0.5665048404897101 and parameters: {'n_d': 46, 'n_a': 27, 'n_steps': 6, 'gamma': 1.5704415623434569, 'lambda_sparse': 8.080724414556316e-05, 'lr': 0.0008069327988811989, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 26 with best_epoch = 6 and best_val_auc = 0.57556




[I 2025-11-29 16:24:04,371] Trial 29 finished with value: 0.575559100864026 and parameters: {'n_d': 41, 'n_a': 45, 'n_steps': 3, 'gamma': 1.8619922408940606, 'lambda_sparse': 0.0002022971092479664, 'lr': 0.00043918667447453047, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.55698




[I 2025-11-29 16:25:34,836] Trial 30 finished with value: 0.5569762521514876 and parameters: {'n_d': 52, 'n_a': 19, 'n_steps': 5, 'gamma': 1.315291928010348, 'lambda_sparse': 0.0003548972450958969, 'lr': 0.0014794101100865516, 'mask_type': 'sparsemax'}. Best is trial 19 with value: 0.6989738619273898.





Early stopping occurred at epoch 90 with best_epoch = 70 and best_val_auc = 0.71572




[I 2025-11-29 16:32:37,448] Trial 31 finished with value: 0.7157154375310929 and parameters: {'n_d': 64, 'n_a': 32, 'n_steps': 5, 'gamma': 1.0469848880092827, 'lambda_sparse': 0.00014926118239511857, 'lr': 0.0034650554044617, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 108 with best_epoch = 88 and best_val_auc = 0.69672




[I 2025-11-29 16:40:07,139] Trial 32 finished with value: 0.6967188741930135 and parameters: {'n_d': 61, 'n_a': 31, 'n_steps': 5, 'gamma': 1.0785227481620048, 'lambda_sparse': 0.00013198745319180054, 'lr': 0.0032972605453969526, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 75 with best_epoch = 55 and best_val_auc = 0.6678




[I 2025-11-29 16:45:14,298] Trial 33 finished with value: 0.6678011653772036 and parameters: {'n_d': 60, 'n_a': 33, 'n_steps': 4, 'gamma': 1.1061757105490653, 'lambda_sparse': 0.00014015863696785092, 'lr': 0.002266029839881153, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 116 with best_epoch = 96 and best_val_auc = 0.68453




[I 2025-11-29 16:53:26,549] Trial 34 finished with value: 0.6845270158223686 and parameters: {'n_d': 56, 'n_a': 33, 'n_steps': 5, 'gamma': 1.0863715708512913, 'lambda_sparse': 0.00012049905040663968, 'lr': 0.0019010538837787342, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 88 with best_epoch = 68 and best_val_auc = 0.66694




[I 2025-11-29 16:59:15,597] Trial 35 finished with value: 0.6669371393935235 and parameters: {'n_d': 25, 'n_a': 24, 'n_steps': 6, 'gamma': 1.1807099207784513, 'lambda_sparse': 0.0004239838328553843, 'lr': 0.003158654410155728, 'mask_type': 'entmax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.56081




[I 2025-11-29 17:00:42,686] Trial 36 finished with value: 0.560810617627045 and parameters: {'n_d': 60, 'n_a': 39, 'n_steps': 4, 'gamma': 2.2757471876739346, 'lambda_sparse': 0.0002246164325857424, 'lr': 0.001762169917455482, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 146 with best_epoch = 126 and best_val_auc = 0.6915




[I 2025-11-29 17:10:24,060] Trial 37 finished with value: 0.6915044116217499 and parameters: {'n_d': 49, 'n_a': 43, 'n_steps': 6, 'gamma': 1.3392617447480988, 'lambda_sparse': 5.67675780992679e-05, 'lr': 0.0026791977630796454, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 122 with best_epoch = 102 and best_val_auc = 0.647




[I 2025-11-29 17:16:27,187] Trial 38 finished with value: 0.6470010693107806 and parameters: {'n_d': 58, 'n_a': 23, 'n_steps': 5, 'gamma': 1.613187633106881, 'lambda_sparse': 6.823510631412909e-05, 'lr': 0.003657746953536084, 'mask_type': 'entmax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.564




[I 2025-11-29 17:17:29,895] Trial 39 finished with value: 0.5640008234264835 and parameters: {'n_d': 31, 'n_a': 18, 'n_steps': 5, 'gamma': 1.4572422281575204, 'lambda_sparse': 0.00011326240398828456, 'lr': 0.0012735716576099986, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 105 with best_epoch = 85 and best_val_auc = 0.67746




[I 2025-11-29 17:23:03,890] Trial 40 finished with value: 0.6774589859274127 and parameters: {'n_d': 54, 'n_a': 31, 'n_steps': 4, 'gamma': 1.2476717439567995, 'lambda_sparse': 0.00017898388729402037, 'lr': 0.00245824902525999, 'mask_type': 'entmax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 45 with best_epoch = 25 and best_val_auc = 0.66181




[I 2025-11-29 17:25:48,179] Trial 41 finished with value: 0.6618101658861271 and parameters: {'n_d': 62, 'n_a': 28, 'n_steps': 5, 'gamma': 1.0753577719450917, 'lambda_sparse': 3.292667998842974e-05, 'lr': 0.0039520326138579595, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 110 with best_epoch = 90 and best_val_auc = 0.69838




[I 2025-11-29 17:32:22,134] Trial 42 finished with value: 0.6983808804945133 and parameters: {'n_d': 64, 'n_a': 26, 'n_steps': 5, 'gamma': 1.026460123728977, 'lambda_sparse': 0.0002908696486795256, 'lr': 0.003064334547572313, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 128 with best_epoch = 108 and best_val_auc = 0.68588




[I 2025-11-29 17:40:09,089] Trial 43 finished with value: 0.6858822385763872 and parameters: {'n_d': 57, 'n_a': 25, 'n_steps': 5, 'gamma': 1.1747113463240682, 'lambda_sparse': 0.0002794777969203652, 'lr': 0.0029548514115480293, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 77 with best_epoch = 57 and best_val_auc = 0.6657




[I 2025-11-29 17:43:42,895] Trial 44 finished with value: 0.6657014278443953 and parameters: {'n_d': 64, 'n_a': 22, 'n_steps': 5, 'gamma': 1.0133878965103365, 'lambda_sparse': 0.0008685544935126301, 'lr': 0.002130148029457512, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 27 with best_epoch = 7 and best_val_auc = 0.57778




[I 2025-11-29 17:44:55,620] Trial 45 finished with value: 0.577780636897512 and parameters: {'n_d': 61, 'n_a': 20, 'n_steps': 5, 'gamma': 1.0902178449135784, 'lambda_sparse': 0.0005093753699000485, 'lr': 0.00010666012506959906, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 102 with best_epoch = 82 and best_val_auc = 0.69306




[I 2025-11-29 17:50:04,123] Trial 46 finished with value: 0.6930632037008445 and parameters: {'n_d': 64, 'n_a': 16, 'n_steps': 6, 'gamma': 1.284292232689549, 'lambda_sparse': 0.00028264600708784236, 'lr': 0.003416495339806995, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.56229




[I 2025-11-29 17:51:21,923] Trial 47 finished with value: 0.5622922134733158 and parameters: {'n_d': 59, 'n_a': 31, 'n_steps': 4, 'gamma': 2.0302280675190953, 'lambda_sparse': 9.281271835711246e-05, 'lr': 0.0016223840145939829, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 120 with best_epoch = 100 and best_val_auc = 0.7042




[I 2025-11-29 17:56:55,950] Trial 48 finished with value: 0.7042040496571916 and parameters: {'n_d': 62, 'n_a': 29, 'n_steps': 5, 'gamma': 1.2123156808321565, 'lambda_sparse': 0.0001735648393404781, 'lr': 0.0038589495345165556, 'mask_type': 'sparsemax'}. Best is trial 31 with value: 0.7157154375310929.





Early stopping occurred at epoch 93 with best_epoch = 73 and best_val_auc = 0.69981




[I 2025-11-29 18:01:07,699] Trial 49 finished with value: 0.6998050080341265 and parameters: {'n_d': 21, 'n_a': 26, 'n_steps': 7, 'gamma': 1.2127188790782348, 'lambda_sparse': 0.0006079857488742656, 'lr': 0.004107585810310206, 'mask_type': 'entmax'}. Best is trial 31 with value: 0.7157154375310929.

Best trial: 31
Best validation AUROC: 0.7157

Best hyperparameters:
  n_d: 64
  n_a: 32
  n_steps: 5
  gamma: 1.0469848880092827
  lambda_sparse: 0.00014926118239511857
  lr: 0.0034650554044617
  mask_type: sparsemax


Retraining using the best params

In [34]:
final_model = TabNetClassifier(n_d=best_params["n_d"], n_a=best_params["n_a"], n_steps=best_params["n_steps"], gamma=best_params["gamma"], lambda_sparse=best_params["lambda_sparse"], optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=best_params["lr"]), mask_type=best_params.get("mask_type", "entmax"), device_name=DEVICE, verbose=1)

final_model.fit(X_train_scaled, y_train, eval_set=[(X_val_scaled, y_val)], eval_name=["val"], eval_metric=["auc"], max_epochs=MAX_FINETUNE_EPOCHS, patience=EARLY_STOPPING_PATIENCE, batch_size=BATCH_SIZE, virtual_batch_size=VIRTUAL_BATCH_SIZE, num_workers=0, drop_last=False, from_unsupervised=pretrainer, weights=1)
final_model.save_model("tabnet_readmission_final")
print("Model saved.")



epoch 0  | loss: 2.69476 | val_auc: 0.56769 |  0:00:02s
epoch 1  | loss: 1.25609 | val_auc: 0.52337 |  0:00:04s
epoch 2  | loss: 1.09138 | val_auc: 0.53692 |  0:00:06s
epoch 3  | loss: 0.92045 | val_auc: 0.54232 |  0:00:08s
epoch 4  | loss: 0.82316 | val_auc: 0.5397  |  0:00:10s
epoch 5  | loss: 0.80664 | val_auc: 0.54415 |  0:00:12s
epoch 6  | loss: 0.78022 | val_auc: 0.56022 |  0:00:14s
epoch 7  | loss: 0.74605 | val_auc: 0.56819 |  0:00:16s
epoch 8  | loss: 0.74568 | val_auc: 0.57904 |  0:00:19s
epoch 9  | loss: 0.72927 | val_auc: 0.60036 |  0:00:21s
epoch 10 | loss: 0.71823 | val_auc: 0.60178 |  0:00:23s
epoch 11 | loss: 0.6999  | val_auc: 0.61139 |  0:00:26s
epoch 12 | loss: 0.69202 | val_auc: 0.61279 |  0:00:28s
epoch 13 | loss: 0.6824  | val_auc: 0.61792 |  0:00:31s
epoch 14 | loss: 0.68281 | val_auc: 0.63204 |  0:00:34s
epoch 15 | loss: 0.6707  | val_auc: 0.64659 |  0:00:37s
epoch 16 | loss: 0.66472 | val_auc: 0.64739 |  0:00:39s
epoch 17 | loss: 0.66352 | val_auc: 0.64934 |  0



Successfully saved model at tabnet_readmission_final.zip
Model saved.


Evaluation

In [36]:
y_test_proba = final_model.predict_proba(X_test_scaled)[:, 1]

y_test_pred_default = (y_test_proba >= 0.5).astype(int)

auc_score = roc_auc_score(y_test, y_test_proba)
auc_mean, auc_lower, auc_upper = bootstrap_auc_ci(y_test, y_test_proba, n_bootstraps=2000)

print(f"\nAUROC: {auc_score:.4f}")
print(f"95% CI: [{auc_lower:.4f}, {auc_upper:.4f}]")

# Classification report (default threshold)
print("Classification Report (threshold=0.5)")
print(classification_report(y_test, y_test_pred_default, target_names=['No Readmission', 'Readmission']))


AUROC: 0.6910
95% CI: [0.6735, 0.7073]
Classification Report (threshold=0.5)
                precision    recall  f1-score   support

No Readmission       0.94      0.67      0.78      8164
   Readmission       0.19      0.62      0.29       983

      accuracy                           0.67      9147
     macro avg       0.56      0.65      0.53      9147
  weighted avg       0.86      0.67      0.73      9147



Plotting

In [40]:
# Get feature importance from TabNet's attention mechanism
feature_importance = final_model.feature_importances_

# Create dataframe
importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': feature_importance
}).sort_values('importance', ascending=False)

print("\nTop 20 most important features:")
print(importance_df.head(20).to_string(index=False))


Top 20 most important features:
            feature  importance
                age    0.047833
 urea_nitrogen_mean    0.045689
          mchc_mean    0.037300
      phosphate_max    0.037054
      icu_los_hours    0.035873
 calcium_total_mean    0.035227
    hematocrit_mean    0.032977
platelet_count_mean    0.032176
      chloride_mean    0.030202
      potassium_min    0.030158
     potassium_mean    0.029994
       glucose_mean    0.026153
           mch_mean    0.025484
           rdw_mean    0.024030
   urea_nitrogen_sd    0.020903
     bicarbonate_sd    0.020013
            ptt_min    0.019935
       magnesium_sd    0.019637
  calcium_total_min    0.019096
             mch_sd    0.018270
