In [3]:
import os
import random
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.utils.class_weight import compute_class_weight

import shap
import optuna

import torch
import torch.nn as nn
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
from typing import Tuple, List

Some useful global constants and setting the seed

In [4]:
SEED = 7
CORR_THRESHOLD = 0.85 # correlation threshold for dimensionality reduction
TEST_SIZE = 0.30 # train-test split
VAL_SIZE = 0.20  # train-val split   
N_TRIALS = 50                     
MAX_PRETRAIN_EPOCHS = 150
MAX_FINETUNE_EPOCHS = 200
EARLY_STOPPING_PATIENCE = 30
BATCH_SIZE = 2048
VIRTUAL_BATCH_SIZE = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

FOCAL_ALPHA = 0.75                 # Higher alpha for rare positive class
FOCAL_GAMMA = 2.0   

In [5]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Soem useful helper functions

In [None]:
def bootstrap_auc_ci(y_true, y_scores, n_bootstraps=2000, ci=0.95):
    """ 
    Simple Bootstrapping method to get an confidence interval on the AUROC score.
    """
    rng = np.random.default_rng(42)
    aucs = []

    y_true = np.array(y_true)
    y_scores = np.array(y_scores)

    for _ in range(n_bootstraps):
        idx = rng.integers(0, len(y_true), len(y_true))
        if len(np.unique(y_true[idx])) < 2:
            continue
        aucs.append(roc_auc_score(y_true[idx], y_scores[idx]))

    lower = np.percentile(aucs, (1 - ci) / 2 * 100)
    upper = np.percentile(aucs, (1 + ci) / 2 * 100)
    return np.mean(aucs), lower, upper

## Loading the dataset, pre-processing, and analysing the data

In [7]:
cohort_data = pd.read_csv('../cohort_data_new.csv')
cohort_data

Unnamed: 0,icustay_id,anion_gap_mean,anion_gap_sd,anion_gap_min,anion_gap_max,bicarbonate_mean,bicarbonate_sd,bicarbonate_min,bicarbonate_max,calcium_total_mean,...,urea_nitrogen_min,urea_nitrogen_max,white_blood_cells_mean,white_blood_cells_sd,white_blood_cells_min,white_blood_cells_max,age,gender,icu_los_hours,target
0,200003,13.375000,3.583195,9.0,21.0,25.250000,3.105295,18.0,28.0,7.771429,...,10.0,21.0,26.471429,13.176711,13.2,43.9,48,M,141,0
1,200007,15.500000,2.121320,14.0,17.0,23.000000,1.414214,22.0,24.0,8.900000,...,8.0,10.0,10.300000,1.272792,9.4,11.2,44,M,30,0
2,200009,9.500000,2.121320,8.0,11.0,23.333333,2.081666,21.0,25.0,8.000000,...,15.0,21.0,12.471429,1.471637,10.5,14.3,47,F,51,0
3,200012,,,,,,,,,,...,,,4.900000,,4.9,4.9,33,F,10,0
4,200014,10.000000,1.732051,9.0,12.0,24.000000,1.000000,23.0,25.0,7.733333,...,21.0,24.0,13.233333,2.203028,10.7,14.7,85,M,41,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30484,299992,15.375000,2.856153,11.0,25.0,23.125000,2.609556,15.0,26.0,8.307143,...,8.0,23.0,14.134783,3.781727,8.1,22.1,41,M,499,0
30485,299993,9.400000,1.341641,8.0,11.0,29.600000,2.073644,26.0,31.0,8.000000,...,12.0,15.0,12.600000,0.605530,12.0,13.3,26,M,67,0
30486,299994,16.157895,2.477973,13.0,24.0,21.631579,3.451417,17.0,31.0,8.100000,...,28.0,63.0,10.076190,2.642329,5.3,14.5,74,F,152,1
30487,299998,11.500000,1.732051,10.0,14.0,23.500000,1.290994,22.0,25.0,8.800000,...,20.0,22.0,9.900000,1.210372,7.9,11.0,87,M,46,1


In [8]:
print(f"Dataset shape: {cohort_data.shape}")
print(f"Readmission rate: {cohort_data['target'].mean() * 100:.2f}%")

Dataset shape: (30489, 93)
Readmission rate: 10.74%


In [9]:
lab_cols = [
    'anion_gap_mean', 'anion_gap_min', 'anion_gap_max', 'anion_gap_sd',
    'bicarbonate_mean', 'bicarbonate_min', 'bicarbonate_max', 'bicarbonate_sd',
    'calcium_total_mean', 'calcium_total_min', 'calcium_total_max', 'calcium_total_sd',
    'chloride_mean', 'chloride_min', 'chloride_max', 'chloride_sd',
    'creatinine_mean', 'creatinine_min', 'creatinine_max', 'creatinine_sd',
    'glucose_mean', 'glucose_min', 'glucose_max', 'glucose_sd',
    'hematocrit_mean', 'hematocrit_min', 'hematocrit_max', 'hematocrit_sd',
    'hemoglobin_mean', 'hemoglobin_min', 'hemoglobin_max', 'hemoglobin_sd',
    'mchc_mean', 'mchc_min', 'mchc_max', 'mchc_sd',
    'mcv_mean', 'mcv_min', 'mcv_max', 'mcv_sd',
    'magnesium_mean', 'magnesium_min', 'magnesium_max', 'magnesium_sd',
    'pt_mean', 'pt_min', 'pt_max', 'pt_sd',
    'phosphate_mean', 'phosphate_min', 'phosphate_max', 'phosphate_sd',
    'platelet_count_mean', 'platelet_count_min', 'platelet_count_max', 'platelet_count_sd',
    'potassium_mean', 'potassium_min', 'potassium_max', 'potassium_sd',
    'rdw_mean', 'rdw_min', 'rdw_max', 'rdw_sd',
    'red_blood_cells_mean', 'red_blood_cells_min', 'red_blood_cells_max', 'red_blood_cells_sd',
    'sodium_mean', 'sodium_min', 'sodium_max', 'sodium_sd',
    'urea_nitrogen_mean', 'urea_nitrogen_min', 'urea_nitrogen_max', 'urea_nitrogen_sd',
    'white_blood_cells_mean', 'white_blood_cells_min', 'white_blood_cells_max', 'white_blood_cells_sd',
    'age', 'icu_los_hours'
]

REmove the ICUstay_id and the gender

In [10]:
drop_cols = [c for c in cohort_data.columns if 'icustay_id' in c.lower() or 'gender' in c.lower()]
df = cohort_data.drop(columns=['icustay_id', 'gender'], errors='ignore')

X = df.drop(columns=['target'])
y = df['target']

Trying out some feature engineering

In [None]:
# X = X.select_dtypes(include=['number']).replace([np.inf, -np.inf], np.nan)
# print(f"initial feature matrix shape: {X.shape}")

# def create_engineered_features(df: pd.DataFrame) -> pd.DataFrame:
#     df_eng = df.copy()
    
#     # BUN/Creatinine ratio (kidney function indicator)
#     if 'urea_nitrogen_mean' in df_eng.columns and 'creatinine_mean' in df_eng.columns:
#         df_eng['bun_creatinine_ratio'] = (
#             df_eng['urea_nitrogen_mean'] / (df_eng['creatinine_mean'] + 1e-6)
#         )

#     # Variability indices (physiological instability)
#     variability_features = []
#     for base_name in ['glucose', 'potassium', 'sodium', 'hemoglobin']:
#         mean_col = f'{base_name}_mean'
#         sd_col = f'{base_name}_sd'
#         if mean_col in df_eng.columns and sd_col in df_eng.columns:
#             cv_col = f'{base_name}_cv'
#             df_eng[cv_col] = df_eng[sd_col] / (df_eng[mean_col] + 1e-6)
#             variability_features.append(cv_col)
    
#     # Range features (max - min)
#     for base_name in ['glucose', 'creatinine', 'potassium']:
#         max_col = f'{base_name}_max'
#         min_col = f'{base_name}_min'
#         if max_col in df_eng.columns and min_col in df_eng.columns:
#             range_col = f'{base_name}_range'
#             df_eng[range_col] = df_eng[max_col] - df_eng[min_col]
    
#     return df_eng

# X_engineered = create_engineered_features(X)
# print(f"final feature matrix shape: {X_engineered.shape}")

initial feature matrix shape: (30489, 90)
final feature matrix shape: (30489, 98)


Dimensionality reduction by deleting the columns with high correlation.

In [12]:
# corr = X_engineered.corr().abs()
# upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
# to_drop = [col for col in upper.columns if any(upper[col] >= CORR_THRESHOLD)]

# print(f"dropping {len(to_drop)} features")
# X_reduced = X_engineered.drop(columns=to_drop, errors='ignore')
# print(f"final feature count: {X_reduced.shape[1]}")
# feature_names = X_reduced.columns.tolist()

Creating the final train-val-test sets

In [13]:
# separate test set
X_temp, X_test, y_temp, y_test = train_test_split(X, y.values, test_size=TEST_SIZE, random_state=SEED, stratify=y.values)

# separate validation set 
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=VAL_SIZE, random_state=SEED, stratify=y_temp)

print(f"Train set: {X_train.shape[0]} samples ({y_train.mean()*100:.2f}% readmission)")
print(f"Validation set: {X_val.shape[0]} samples ({y_val.mean()*100:.2f}% readmission)")
print(f"Test set: {X_test.shape[0]} samples ({y_test.mean()*100:.2f}% readmission)")

# SimpleImputation using median strat and scaling;
# Imputation - FIT on train only and avoidning data leakage:
imputer = SimpleImputer(strategy="median")
X_train_imputed = imputer.fit_transform(X_train)
X_val_imputed = imputer.transform(X_val)
X_test_imputed = imputer.transform(X_test)

# Scaling - FIT on train only
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_imputed)
X_val_scaled = scaler.transform(X_val_imputed)
X_test_scaled = scaler.transform(X_test_imputed)

Train set: 17073 samples (10.74% readmission)
Validation set: 4269 samples (10.75% readmission)
Test set: 9147 samples (10.75% readmission)


### Pretraining the TABNet

In [14]:
def run_pretraining(X_train: np.ndarray, X_val: np.ndarray, pretrain_params: dict):
    """Run unsupervised pretraining"""
    pretrainer = TabNetPretrainer(**pretrain_params)
    pretrainer.fit(
        X_train=X_train,
        eval_set=[X_val],
        max_epochs=MAX_PRETRAIN_EPOCHS,
        patience=EARLY_STOPPING_PATIENCE,
        batch_size=BATCH_SIZE,
        virtual_batch_size=VIRTUAL_BATCH_SIZE,
        num_workers=0,
        drop_last=False
    )
    
    print("Pretraining complete!")
    return pretrainer

pretrain_params = dict(
    n_d=32, 
    n_a=32,
    n_steps=5,
    gamma=1.5,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=1e-3),
    mask_type="entmax",
    device_name=DEVICE
)

pretrainer = run_pretraining(X_train_scaled, X_val_scaled, pretrain_params)



epoch 0  | loss: 12.34471| val_0_unsup_loss_numpy: 3.271250009536743|  0:00:03s
epoch 1  | loss: 11.24094| val_0_unsup_loss_numpy: 3.1070899963378906|  0:00:06s
epoch 2  | loss: 9.99972 | val_0_unsup_loss_numpy: 3.025629997253418|  0:00:09s
epoch 3  | loss: 9.19899 | val_0_unsup_loss_numpy: 3.0878798961639404|  0:00:12s
epoch 4  | loss: 8.46691 | val_0_unsup_loss_numpy: 3.0285000801086426|  0:00:14s
epoch 5  | loss: 7.69079 | val_0_unsup_loss_numpy: 3.1906800270080566|  0:00:17s
epoch 6  | loss: 7.14098 | val_0_unsup_loss_numpy: 3.362839937210083|  0:00:20s
epoch 7  | loss: 6.58374 | val_0_unsup_loss_numpy: 2.852479934692383|  0:00:23s
epoch 8  | loss: 6.17838 | val_0_unsup_loss_numpy: 2.8965299129486084|  0:00:25s
epoch 9  | loss: 5.66219 | val_0_unsup_loss_numpy: 2.9274098873138428|  0:00:28s
epoch 10 | loss: 5.28782 | val_0_unsup_loss_numpy: 2.619230031967163|  0:00:32s
epoch 11 | loss: 4.74309 | val_0_unsup_loss_numpy: 2.4046199321746826|  0:00:36s
epoch 12 | loss: 4.38431 | val_0_



Hyperparam Search

In [20]:
def make_objective(X_train, y_train, X_val, y_val, pretrainer):
    """Create Optuna objective function"""
    
    def objective(trial):
        # Hyperparameters to tune
        n_d = trial.suggest_int("n_d", 16, 64)
        n_a = trial.suggest_int("n_a", 16, 64)
        n_steps = trial.suggest_int("n_steps", 3, 7)
        gamma = trial.suggest_float("gamma", 1.0, 2.5)
        lambda_sparse = trial.suggest_float("lambda_sparse", 1e-6, 1e-3, log=True)
        lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
        mask_type = trial.suggest_categorical("mask_type", ["sparsemax", "entmax"])
        
        clf = TabNetClassifier(n_d=n_d, n_a=n_a, n_steps=n_steps, gamma=gamma, lambda_sparse=lambda_sparse, optimizer_fn=torch.optim.Adam,optimizer_params=dict(lr=lr), mask_type=mask_type, device_name=DEVICE, verbose=0)
        
        try:
            clf.fit(
                X_train, y_train,
                eval_set=[(X_val, y_val)],
                eval_name=["val"],
                eval_metric=["auc"],
                max_epochs=MAX_FINETUNE_EPOCHS,
                patience=20, 
                batch_size=BATCH_SIZE,
                virtual_batch_size=VIRTUAL_BATCH_SIZE,
                num_workers=0,
                drop_last=False,
                from_unsupervised=pretrainer,
            )
        except Exception as e:
            print(f"Trial failed: {e}")
            raise optuna.exceptions.TrialPruned()
        
        pred_proba = clf.predict_proba(X_val)[:, 1]
        auc = roc_auc_score(y_val, pred_proba)
        
        return auc
    
    return objective

In [21]:
print("hyperparameter optimization under way")
study = optuna.create_study(direction="maximize", study_name="tabnet_readmission_10pct")

objective = make_objective(X_train_scaled, y_train, X_val_scaled, y_val, pretrainer)
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)
print(f"\nBest trial: {study.best_trial.number}")
print(f"Best validation AUROC: {study.best_value:.4f}")
print("\nBest hyperparameters:")
for key, value in study.best_trial.params.items():
    print(f"  {key}: {value}")

best_params = study.best_trial.params

[I 2025-11-30 08:14:38,931] A new study created in memory with name: tabnet_readmission_10pct


hyperparameter optimization under way


  0%|          | 0/50 [00:00<?, ?it/s]




Early stopping occurred at epoch 98 with best_epoch = 78 and best_val_auc = 0.60442




[I 2025-11-30 08:19:46,023] Trial 0 finished with value: 0.6044196272851514 and parameters: {'n_d': 18, 'n_a': 16, 'n_steps': 3, 'gamma': 2.033045226025318, 'lambda_sparse': 0.00014661664502985748, 'lr': 0.0002983613163810673, 'mask_type': 'sparsemax'}. Best is trial 0 with value: 0.6044196272851514.





Early stopping occurred at epoch 195 with best_epoch = 175 and best_val_auc = 0.6489




[I 2025-11-30 08:31:29,848] Trial 1 finished with value: 0.6489029557579813 and parameters: {'n_d': 22, 'n_a': 45, 'n_steps': 7, 'gamma': 1.5685858968175923, 'lambda_sparse': 3.5651468876369945e-06, 'lr': 0.00043195168127290416, 'mask_type': 'sparsemax'}. Best is trial 1 with value: 0.6489029557579813.





Early stopping occurred at epoch 91 with best_epoch = 71 and best_val_auc = 0.62293




[I 2025-11-30 08:38:21,339] Trial 2 finished with value: 0.6229344289480154 and parameters: {'n_d': 44, 'n_a': 39, 'n_steps': 6, 'gamma': 1.311666929999685, 'lambda_sparse': 2.636230921113233e-06, 'lr': 0.000663048492764778, 'mask_type': 'sparsemax'}. Best is trial 1 with value: 0.6489029557579813.





Early stopping occurred at epoch 138 with best_epoch = 118 and best_val_auc = 0.66382




[I 2025-11-30 08:50:13,908] Trial 3 finished with value: 0.6638195552353342 and parameters: {'n_d': 21, 'n_a': 55, 'n_steps': 6, 'gamma': 1.0984980106465843, 'lambda_sparse': 2.5076929589126155e-06, 'lr': 0.0004933357422732249, 'mask_type': 'entmax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 132 with best_epoch = 112 and best_val_auc = 0.61762




[I 2025-11-30 09:00:18,877] Trial 4 finished with value: 0.6176173239782936 and parameters: {'n_d': 32, 'n_a': 28, 'n_steps': 5, 'gamma': 1.9110611282235221, 'lambda_sparse': 0.0005008706036628327, 'lr': 0.00027422576685063536, 'mask_type': 'sparsemax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 62 with best_epoch = 42 and best_val_auc = 0.60574




[I 2025-11-30 09:04:25,054] Trial 5 finished with value: 0.6057393969544657 and parameters: {'n_d': 32, 'n_a': 17, 'n_steps': 5, 'gamma': 2.383063291676091, 'lambda_sparse': 0.0008669481247850811, 'lr': 0.000935723285923363, 'mask_type': 'entmax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 134 with best_epoch = 114 and best_val_auc = 0.61073




[I 2025-11-30 09:13:28,989] Trial 6 finished with value: 0.6107342791301413 and parameters: {'n_d': 46, 'n_a': 37, 'n_steps': 7, 'gamma': 1.3257625796457084, 'lambda_sparse': 1.2198898031490479e-06, 'lr': 0.00021946096543762233, 'mask_type': 'sparsemax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 21 with best_epoch = 1 and best_val_auc = 0.52314




[I 2025-11-30 09:15:02,627] Trial 7 finished with value: 0.523144288336507 and parameters: {'n_d': 28, 'n_a': 63, 'n_steps': 3, 'gamma': 2.370038476932514, 'lambda_sparse': 1.0865607426812601e-06, 'lr': 0.00030585618574312706, 'mask_type': 'entmax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 22 with best_epoch = 2 and best_val_auc = 0.52868




[I 2025-11-30 09:16:39,379] Trial 8 finished with value: 0.5286806877898432 and parameters: {'n_d': 54, 'n_a': 44, 'n_steps': 6, 'gamma': 2.311996215490513, 'lambda_sparse': 1.1554645357642597e-05, 'lr': 0.00017267253312817468, 'mask_type': 'entmax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 159 with best_epoch = 139 and best_val_auc = 0.63583




[I 2025-11-30 09:25:11,278] Trial 9 finished with value: 0.6358296307732776 and parameters: {'n_d': 17, 'n_a': 63, 'n_steps': 5, 'gamma': 2.0076551734606176, 'lambda_sparse': 0.0001626610272939941, 'lr': 0.00046765622380281016, 'mask_type': 'sparsemax'}. Best is trial 3 with value: 0.6638195552353342.





Early stopping occurred at epoch 70 with best_epoch = 50 and best_val_auc = 0.67686




[I 2025-11-30 09:28:12,811] Trial 10 finished with value: 0.6768622876388819 and parameters: {'n_d': 63, 'n_a': 54, 'n_steps': 4, 'gamma': 1.1067122979255082, 'lambda_sparse': 1.825499161823432e-05, 'lr': 0.002640349555139673, 'mask_type': 'entmax'}. Best is trial 10 with value: 0.6768622876388819.





Early stopping occurred at epoch 62 with best_epoch = 42 and best_val_auc = 0.67627




[I 2025-11-30 09:30:47,195] Trial 11 finished with value: 0.6762727371496863 and parameters: {'n_d': 63, 'n_a': 54, 'n_steps': 4, 'gamma': 1.0180464216150498, 'lambda_sparse': 2.1932688704884334e-05, 'lr': 0.0029947412787386498, 'mask_type': 'entmax'}. Best is trial 10 with value: 0.6768622876388819.





Early stopping occurred at epoch 60 with best_epoch = 40 and best_val_auc = 0.68583




[I 2025-11-30 09:33:33,910] Trial 12 finished with value: 0.6858347771887991 and parameters: {'n_d': 64, 'n_a': 53, 'n_steps': 4, 'gamma': 1.0221896225632354, 'lambda_sparse': 2.7885294628576305e-05, 'lr': 0.0033257933569175143, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 92 with best_epoch = 72 and best_val_auc = 0.68044




[I 2025-11-30 09:39:01,400] Trial 13 finished with value: 0.6804396182503332 and parameters: {'n_d': 63, 'n_a': 53, 'n_steps': 4, 'gamma': 1.3409903630528681, 'lambda_sparse': 5.938892311565925e-05, 'lr': 0.004234061731246198, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 70 with best_epoch = 50 and best_val_auc = 0.68439




[I 2025-11-30 09:43:39,675] Trial 14 finished with value: 0.6843880626032856 and parameters: {'n_d': 56, 'n_a': 49, 'n_steps': 4, 'gamma': 1.5208466965455254, 'lambda_sparse': 6.842033003026489e-05, 'lr': 0.004703988029622159, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 99 with best_epoch = 79 and best_val_auc = 0.66439




[I 2025-11-30 09:49:43,493] Trial 15 finished with value: 0.6643948101258584 and parameters: {'n_d': 53, 'n_a': 46, 'n_steps': 4, 'gamma': 1.5981776191660957, 'lambda_sparse': 5.3957169312943066e-05, 'lr': 0.0015102890281563108, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 104 with best_epoch = 84 and best_val_auc = 0.66546




[I 2025-11-30 09:55:56,147] Trial 16 finished with value: 0.6654572590190931 and parameters: {'n_d': 55, 'n_a': 33, 'n_steps': 3, 'gamma': 1.685819329497655, 'lambda_sparse': 9.072173261477716e-06, 'lr': 0.0018070083746098832, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 61 with best_epoch = 41 and best_val_auc = 0.67996




[I 2025-11-30 09:59:35,953] Trial 17 finished with value: 0.6799621452547189 and parameters: {'n_d': 58, 'n_a': 49, 'n_steps': 4, 'gamma': 1.4659709207561753, 'lambda_sparse': 0.00012575794514134124, 'lr': 0.004885742676160388, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 86 with best_epoch = 66 and best_val_auc = 0.64664




[I 2025-11-30 10:04:56,044] Trial 18 finished with value: 0.6466396765763757 and parameters: {'n_d': 48, 'n_a': 61, 'n_steps': 3, 'gamma': 1.8107185778360457, 'lambda_sparse': 3.967992808019166e-05, 'lr': 0.0015195303559948493, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 20 with best_epoch = 0 and best_val_auc = 0.53359




[I 2025-11-30 10:06:24,273] Trial 19 finished with value: 0.5335926554932268 and parameters: {'n_d': 39, 'n_a': 28, 'n_steps': 4, 'gamma': 1.2087142256694299, 'lambda_sparse': 0.0002956774320181486, 'lr': 0.0001186915656307489, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 47 with best_epoch = 27 and best_val_auc = 0.63692




[I 2025-11-30 10:09:50,440] Trial 20 finished with value: 0.636923815895562 and parameters: {'n_d': 58, 'n_a': 59, 'n_steps': 5, 'gamma': 2.177902198798227, 'lambda_sparse': 8.868822776627862e-05, 'lr': 0.0028938675016151934, 'mask_type': 'entmax'}. Best is trial 12 with value: 0.6858347771887991.





Early stopping occurred at epoch 94 with best_epoch = 74 and best_val_auc = 0.68743




[I 2025-11-30 10:16:50,917] Trial 21 finished with value: 0.6874301660004918 and parameters: {'n_d': 64, 'n_a': 51, 'n_steps': 4, 'gamma': 1.349774058053704, 'lambda_sparse': 5.78971363203213e-05, 'lr': 0.004680550316052563, 'mask_type': 'entmax'}. Best is trial 21 with value: 0.6874301660004918.





Early stopping occurred at epoch 55 with best_epoch = 35 and best_val_auc = 0.66887




[I 2025-11-30 10:20:31,338] Trial 22 finished with value: 0.668865329742279 and parameters: {'n_d': 59, 'n_a': 50, 'n_steps': 4, 'gamma': 1.455573761961228, 'lambda_sparse': 3.16797229998993e-05, 'lr': 0.003819791229955228, 'mask_type': 'entmax'}. Best is trial 21 with value: 0.6874301660004918.





Early stopping occurred at epoch 73 with best_epoch = 53 and best_val_auc = 0.67614




[I 2025-11-30 10:25:01,002] Trial 23 finished with value: 0.6761412176419124 and parameters: {'n_d': 64, 'n_a': 42, 'n_steps': 3, 'gamma': 1.1877100346536993, 'lambda_sparse': 9.138596248650149e-06, 'lr': 0.002080837512443387, 'mask_type': 'entmax'}. Best is trial 21 with value: 0.6874301660004918.





Early stopping occurred at epoch 115 with best_epoch = 95 and best_val_auc = 0.64048




[I 2025-11-30 10:32:27,606] Trial 24 finished with value: 0.6404794171970334 and parameters: {'n_d': 50, 'n_a': 48, 'n_steps': 4, 'gamma': 1.4326690076143853, 'lambda_sparse': 6.972635010781019e-05, 'lr': 0.001038862222180127, 'mask_type': 'entmax'}. Best is trial 21 with value: 0.6874301660004918.





Early stopping occurred at epoch 92 with best_epoch = 72 and best_val_auc = 0.69138




[I 2025-11-30 10:38:15,635] Trial 25 finished with value: 0.6913814694731787 and parameters: {'n_d': 59, 'n_a': 57, 'n_steps': 5, 'gamma': 1.7111039914098805, 'lambda_sparse': 0.00024137154945623953, 'lr': 0.0035177974693979767, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 142 with best_epoch = 122 and best_val_auc = 0.67141




[I 2025-11-30 10:46:17,173] Trial 26 finished with value: 0.6714099463057314 and parameters: {'n_d': 60, 'n_a': 57, 'n_steps': 5, 'gamma': 1.7105285581605927, 'lambda_sparse': 0.0002907909568832351, 'lr': 0.003413363747905262, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 92 with best_epoch = 72 and best_val_auc = 0.6724




[I 2025-11-30 10:51:22,387] Trial 27 finished with value: 0.6724023467654779 and parameters: {'n_d': 50, 'n_a': 59, 'n_steps': 5, 'gamma': 1.2792981500891853, 'lambda_sparse': 0.0002818227634273266, 'lr': 0.002154034242579788, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 93 with best_epoch = 73 and best_val_auc = 0.67283




[I 2025-11-30 10:56:25,323] Trial 28 finished with value: 0.6728272119579825 and parameters: {'n_d': 42, 'n_a': 52, 'n_steps': 6, 'gamma': 1.0508831413752007, 'lambda_sparse': 1.9178900342334478e-05, 'lr': 0.0012440825934780472, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 98 with best_epoch = 78 and best_val_auc = 0.66353




[I 2025-11-30 11:01:18,428] Trial 29 finished with value: 0.6635313559661251 and parameters: {'n_d': 38, 'n_a': 59, 'n_steps': 5, 'gamma': 1.8288352009467073, 'lambda_sparse': 0.00015735381798312893, 'lr': 0.002902113645866384, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 98 with best_epoch = 78 and best_val_auc = 0.67253




[I 2025-11-30 11:06:10,943] Trial 30 finished with value: 0.6725324367133847 and parameters: {'n_d': 60, 'n_a': 64, 'n_steps': 3, 'gamma': 1.181078575384395, 'lambda_sparse': 0.00011079086445359092, 'lr': 0.002406048691409037, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 101 with best_epoch = 81 and best_val_auc = 0.68213




[I 2025-11-30 11:11:13,354] Trial 31 finished with value: 0.6821299298372017 and parameters: {'n_d': 56, 'n_a': 51, 'n_steps': 4, 'gamma': 1.6107769180210172, 'lambda_sparse': 5.0626960353130955e-05, 'lr': 0.004964336946369967, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 59 with best_epoch = 39 and best_val_auc = 0.66198




[I 2025-11-30 11:14:09,495] Trial 32 finished with value: 0.6619754230067646 and parameters: {'n_d': 52, 'n_a': 46, 'n_steps': 4, 'gamma': 1.5265358618301703, 'lambda_sparse': 0.00018815850128125835, 'lr': 0.0038019949761093096, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 137 with best_epoch = 117 and best_val_auc = 0.66743




[I 2025-11-30 11:20:59,810] Trial 33 finished with value: 0.6674323389314898 and parameters: {'n_d': 61, 'n_a': 56, 'n_steps': 5, 'gamma': 1.6663461854448247, 'lambda_sparse': 2.7657801951503034e-05, 'lr': 0.003775496614955459, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 62 with best_epoch = 42 and best_val_auc = 0.68239




[I 2025-11-30 11:24:27,494] Trial 34 finished with value: 0.6823872506132812 and parameters: {'n_d': 57, 'n_a': 40, 'n_steps': 3, 'gamma': 1.3731131881029426, 'lambda_sparse': 6.049416291597958e-06, 'lr': 0.004876919215303812, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 109 with best_epoch = 89 and best_val_auc = 0.64588




[I 2025-11-30 11:30:40,016] Trial 35 finished with value: 0.6458814380228615 and parameters: {'n_d': 62, 'n_a': 47, 'n_steps': 6, 'gamma': 1.543870894595551, 'lambda_sparse': 8.485627089426298e-05, 'lr': 0.0007129912388642234, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 112 with best_epoch = 92 and best_val_auc = 0.67373




[I 2025-11-30 11:37:30,010] Trial 36 finished with value: 0.6737341247376758 and parameters: {'n_d': 64, 'n_a': 43, 'n_steps': 4, 'gamma': 1.9359256547880368, 'lambda_sparse': 0.0005430482143903764, 'lr': 0.0018536683466806778, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 77 with best_epoch = 57 and best_val_auc = 0.6614




[I 2025-11-30 11:42:38,362] Trial 37 finished with value: 0.6613961653486126 and parameters: {'n_d': 56, 'n_a': 57, 'n_steps': 7, 'gamma': 1.7939161873912624, 'lambda_sparse': 4.025863786186014e-05, 'lr': 0.003189135134422202, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 75 with best_epoch = 55 and best_val_auc = 0.66838




[I 2025-11-30 11:46:56,169] Trial 38 finished with value: 0.6683827103311433 and parameters: {'n_d': 45, 'n_a': 36, 'n_steps': 5, 'gamma': 1.2730859440538689, 'lambda_sparse': 0.0005043973588640182, 'lr': 0.002462872221858567, 'mask_type': 'entmax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 54 with best_epoch = 34 and best_val_auc = 0.68842




[I 2025-11-30 11:50:53,661] Trial 39 finished with value: 0.6884159904848495 and parameters: {'n_d': 51, 'n_a': 51, 'n_steps': 6, 'gamma': 1.1130046690463282, 'lambda_sparse': 0.0009154426057575289, 'lr': 0.0039021481711319027, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 178 with best_epoch = 158 and best_val_auc = 0.67844




[I 2025-11-30 12:09:26,219] Trial 40 finished with value: 0.6784430949399298 and parameters: {'n_d': 51, 'n_a': 19, 'n_steps': 6, 'gamma': 1.1342847630975037, 'lambda_sparse': 0.0009596193753775677, 'lr': 0.0006089079510678374, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 52 with best_epoch = 32 and best_val_auc = 0.6856




[I 2025-11-30 12:12:03,794] Trial 41 finished with value: 0.6855968984269124 and parameters: {'n_d': 60, 'n_a': 51, 'n_steps': 6, 'gamma': 1.0115419332008881, 'lambda_sparse': 0.0002112181205928748, 'lr': 0.004289145238749252, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 84 with best_epoch = 64 and best_val_auc = 0.69




[I 2025-11-30 12:16:54,973] Trial 42 finished with value: 0.6899953682260306 and parameters: {'n_d': 60, 'n_a': 52, 'n_steps': 6, 'gamma': 1.0171040587544928, 'lambda_sparse': 0.0007603883784493343, 'lr': 0.0035891865514553313, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 62 with best_epoch = 42 and best_val_auc = 0.68434




[I 2025-11-30 12:20:19,421] Trial 43 finished with value: 0.6843434603354318 and parameters: {'n_d': 61, 'n_a': 55, 'n_steps': 7, 'gamma': 1.0880602719889694, 'lambda_sparse': 0.0007756215222284108, 'lr': 0.0035910822322213322, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 83 with best_epoch = 63 and best_val_auc = 0.61006




[I 2025-11-30 12:25:03,371] Trial 44 finished with value: 0.610064673288388 and parameters: {'n_d': 36, 'n_a': 53, 'n_steps': 6, 'gamma': 1.2548969437809554, 'lambda_sparse': 0.0006447989659234164, 'lr': 0.0003714376022450471, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 47 with best_epoch = 27 and best_val_auc = 0.68085




[I 2025-11-30 12:27:56,031] Trial 45 finished with value: 0.6808467569004854 and parameters: {'n_d': 26, 'n_a': 60, 'n_steps': 6, 'gamma': 1.082839715153365, 'lambda_sparse': 0.00039688336066339195, 'lr': 0.0025762401014970218, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 57 with best_epoch = 37 and best_val_auc = 0.65671




[I 2025-11-30 12:31:28,897] Trial 46 finished with value: 0.6567066371605511 and parameters: {'n_d': 54, 'n_a': 56, 'n_steps': 7, 'gamma': 2.488077498301922, 'lambda_sparse': 0.0003544298137617609, 'lr': 0.0033335817339692772, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 88 with best_epoch = 68 and best_val_auc = 0.6718




[I 2025-11-30 12:35:52,990] Trial 47 finished with value: 0.6717964992937975 and parameters: {'n_d': 48, 'n_a': 62, 'n_steps': 6, 'gamma': 1.138506256526948, 'lambda_sparse': 0.0007002407260733054, 'lr': 0.001605447614760261, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 101 with best_epoch = 81 and best_val_auc = 0.68924




[I 2025-11-30 12:41:06,688] Trial 48 finished with value: 0.6892365578485695 and parameters: {'n_d': 64, 'n_a': 45, 'n_steps': 5, 'gamma': 1.0025942828427767, 'lambda_sparse': 0.0004295411263328813, 'lr': 0.002272200268295512, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.





Early stopping occurred at epoch 61 with best_epoch = 41 and best_val_auc = 0.66886




[I 2025-11-30 12:44:13,476] Trial 49 finished with value: 0.6688601833267573 and parameters: {'n_d': 58, 'n_a': 45, 'n_steps': 5, 'gamma': 1.367113271007407, 'lambda_sparse': 0.0004104048356351314, 'lr': 0.004180368268552419, 'mask_type': 'sparsemax'}. Best is trial 25 with value: 0.6913814694731787.

Best trial: 25
Best validation AUROC: 0.6914

Best hyperparameters:
  n_d: 59
  n_a: 57
  n_steps: 5
  gamma: 1.7111039914098805
  lambda_sparse: 0.00024137154945623953
  lr: 0.0035177974693979767
  mask_type: entmax


Retraining using the best params

In [22]:
final_model = TabNetClassifier(n_d=best_params["n_d"], n_a=best_params["n_a"], n_steps=best_params["n_steps"], gamma=best_params["gamma"], lambda_sparse=best_params["lambda_sparse"], optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=best_params["lr"]), mask_type=best_params.get("mask_type", "entmax"), device_name=DEVICE, verbose=1)

final_model.fit(X_train_scaled, y_train, eval_set=[(X_val_scaled, y_val)], eval_name=["val"], eval_metric=["auc"], max_epochs=MAX_FINETUNE_EPOCHS, patience=EARLY_STOPPING_PATIENCE, batch_size=BATCH_SIZE, virtual_batch_size=VIRTUAL_BATCH_SIZE, num_workers=0, drop_last=False, from_unsupervised=pretrainer, weights=1)
final_model.save_model("tabnet_readmission_final")
print("Model saved.")



epoch 0  | loss: 1.13584 | val_auc: 0.52226 |  0:00:02s
epoch 1  | loss: 0.89311 | val_auc: 0.55845 |  0:00:04s
epoch 2  | loss: 0.82872 | val_auc: 0.5585  |  0:00:06s
epoch 3  | loss: 0.78105 | val_auc: 0.57096 |  0:00:08s
epoch 4  | loss: 0.76253 | val_auc: 0.59085 |  0:00:10s
epoch 5  | loss: 0.73844 | val_auc: 0.60154 |  0:00:13s
epoch 6  | loss: 0.72085 | val_auc: 0.60592 |  0:00:15s
epoch 7  | loss: 0.71155 | val_auc: 0.62129 |  0:00:17s
epoch 8  | loss: 0.70354 | val_auc: 0.61798 |  0:00:19s
epoch 9  | loss: 0.69408 | val_auc: 0.61406 |  0:00:21s
epoch 10 | loss: 0.69128 | val_auc: 0.61359 |  0:00:23s
epoch 11 | loss: 0.69554 | val_auc: 0.62069 |  0:00:25s
epoch 12 | loss: 0.69249 | val_auc: 0.62463 |  0:00:28s
epoch 13 | loss: 0.6868  | val_auc: 0.62261 |  0:00:30s
epoch 14 | loss: 0.67813 | val_auc: 0.63902 |  0:00:32s
epoch 15 | loss: 0.68438 | val_auc: 0.63961 |  0:00:34s
epoch 16 | loss: 0.67382 | val_auc: 0.63584 |  0:00:36s
epoch 17 | loss: 0.67438 | val_auc: 0.64296 |  0



Successfully saved model at tabnet_readmission_final.zip
Model saved.


Evaluation

In [23]:
y_test_proba = final_model.predict_proba(X_test_scaled)[:, 1]

y_test_pred_default = (y_test_proba >= 0.5).astype(int)

auc_score = roc_auc_score(y_test, y_test_proba)
auc_mean, auc_lower, auc_upper = bootstrap_auc_ci(y_test, y_test_proba, n_bootstraps=2000)

print(f"\nAUROC: {auc_score:.4f}")
print(f"95% CI: [{auc_lower:.4f}, {auc_upper:.4f}]")

# Classification report (default threshold)
print("Classification Report (threshold=0.5)")
print(classification_report(y_test, y_test_pred_default, target_names=['No Readmission', 'Readmission']))


AUROC: 0.6507
95% CI: [0.6315, 0.6684]
Classification Report (threshold=0.5)
                precision    recall  f1-score   support

No Readmission       0.93      0.62      0.75      8164
   Readmission       0.17      0.62      0.26       983

      accuracy                           0.62      9147
     macro avg       0.55      0.62      0.50      9147
  weighted avg       0.85      0.62      0.69      9147



Plotting

In [24]:
# Get feature importance from TabNet's attention mechanism
feature_importance = final_model.feature_importances_

# Create dataframe
importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': feature_importance
}).sort_values('importance', ascending=False)

print("\nTop 20 most important features:")
print(importance_df.head(20).to_string(index=False))

NameError: name 'feature_names' is not defined