In [2]:
import sys
!{sys.executable} -m pip install pytorch-tabnet

import pandas as pd
import numpy as np
import torch

from ucimlrepo import fetch_ucirepo
from pytorch_tabnet.tab_model import TabNetClassifier

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

from diabetes_utils import clean_diabetes_data, plot_and_save_metrics

# Load raw data from UCI and build df
diabetes_data = fetch_ucirepo(id=296)
X_raw = diabetes_data.data.features
y_raw = diabetes_data.data.targets

# Make sure target column has a consistent name
if "readmitted" not in y_raw.columns:
    y_raw.columns = ["readmitted"]

df = pd.concat([X_raw, y_raw], axis=1)
print("Raw shape:", df.shape)

# Clean dataset with our helper function
df_clean = clean_diabetes_data(df)
print("Cleaned shape:", df_clean.shape)

# Build feature matrix for TabNet
# Target is the binary 30-day readmission flag
target_col = "readmit_30d"

df_tab = df_clean.copy()
y = df_tab[target_col].values
df_tab = df_tab.drop(columns=["readmitted", target_col])  # drop raw label + target

# Treat object columns as categorical
cat_cols = df_tab.select_dtypes(include="object").columns.tolist()
num_cols = [c for c in df_tab.columns if c not in cat_cols]

print("Categorical cols for TabNet:", cat_cols)
print("Numeric cols for TabNet:", num_cols)

# Label-encode each categorical column for TabNet
encoders = {}
for col in cat_cols:
    le = LabelEncoder()
    df_tab[col] = le.fit_transform(df_tab[col].astype(str))
    encoders[col] = le

# Feature matrix
X = df_tab.values

# Indices and dimensions of categorical features
cat_idxs = [df_tab.columns.get_loc(c) for c in cat_cols]
cat_dims = [df_tab[c].nunique() for c in cat_cols]

print("cat_idxs:", cat_idxs)
print("cat_dims:", cat_dims)


# Stratified 5-fold CV for TabNet
def build_tabnet():
    return TabNetClassifier(
        n_d=32,
        n_a=32,
        n_steps=3,
        gamma=1.3,
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=1,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        scheduler_params={"step_size": 10, "gamma": 0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type="sparsemax",
        verbose=0  # quiet for CV
    )

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_metrics = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), start=1):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]

    tabnet_cv = build_tabnet()

    tabnet_cv.fit(
        X_tr, y_tr,
        eval_set=[(X_tr, y_tr), (X_val, y_val)],
        eval_name=["train", "valid"],
        eval_metric=["auc"],
        max_epochs=50,
        patience=5,
        batch_size=1024,
        virtual_batch_size=128
    )

    # Evaluate on validation part of this fold
    y_val_prob = tabnet_cv.predict_proba(X_val)[:, 1]
    y_val_pred = tabnet_cv.predict(X_val)

    fold_result = {
        "fold": fold,
        "accuracy": accuracy_score(y_val, y_val_pred),
        "roc_auc": roc_auc_score(y_val, y_val_prob),
        "f1_pos":  f1_score(y_val, y_val_pred, zero_division=0),
    }
    cv_metrics.append(fold_result)

    print(f"\nFold {fold}:")
    print(f"  accuracy: {fold_result['accuracy']:.3f}")
    print(f"  roc_auc:  {fold_result['roc_auc']:.3f}")
    print(f"  f1_pos:   {fold_result['f1_pos']:.3f}")

cv_df = pd.DataFrame(cv_metrics)
print("\n5-fold CV summary (TabNet)")
print(cv_df[["accuracy", "roc_auc", "f1_pos"]].mean().round(3))


# Original single trainâ€“test split + final TabNet
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

# Define TabNet model for final training
tabnet_clf = TabNetClassifier(
    n_d=32,
    n_a=32,
    n_steps=3,
    gamma=1.3,
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=1,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params={"step_size": 10, "gamma": 0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type="sparsemax",
    verbose=1
)

# Train TabNet (AUC is your c-stat) on train set, validate on test set
tabnet_clf.fit(
    X_train, y_train,
    eval_set=[(X_train, y_train), (X_test, y_test)],
    eval_name=["train", "valid"],
    eval_metric=["auc"],
    max_epochs=50,
    patience=5,
    batch_size=1024,
    virtual_batch_size=128
)

# Evaluate and save plots on held-out test set
# predict_proba returns (N, 2) for binary classification
y_prob = tabnet_clf.predict_proba(X_test)[:, 1]
y_pred = tabnet_clf.predict(X_test)

tabnet_results = {
    "accuracy": round(accuracy_score(y_test, y_pred), 3),
    "roc_auc": round(roc_auc_score(y_test, y_prob), 3),
    "f1_pos":  round(f1_score(y_test, y_pred, zero_division=0), 3),
}

print("\nTabNet model results (no k fold):")
for k, v in tabnet_results.items():
    print(f"  {k}: {v}")

# Save plots
plot_and_save_metrics("tabnet", y_test, y_prob)

# Save probabilites
np.save("y_test_tabnet.npy", y_test)
np.save("probs_tabnet.npy", y_prob)



  df = pd.read_csv(data_url)


Raw shape: (101766, 48)
Cleaned shape: (101766, 49)
Categorical cols for TabNet: ['race', 'gender', 'age', 'diag_1', 'diag_2', 'diag_3', 'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide', 'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone', 'tolazamide', 'examide', 'citoglipton', 'insulin', 'glyburide-metformin', 'glipizide-metformin', 'glimepiride-pioglitazone', 'metformin-rosiglitazone', 'metformin-pioglitazone', 'change', 'diabetesMed', 'diag_1_group', 'diag_2_group', 'diag_3_group']
Numeric cols for TabNet: ['admission_type_id', 'discharge_disposition_id', 'admission_source_id', 'time_in_hospital', 'num_lab_procedures', 'num_procedures', 'num_medications', 'number_outpatient', 'number_emergency', 'number_inpatient', 'number_diagnoses']
cat_idxs: [0, 1, 2, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 3




Fold 1:
  accuracy: 0.888
  roc_auc:  0.637
  f1_pos:   0.003

Early stopping occurred at epoch 7 with best_epoch = 2 and best_valid_auc = 0.62245





Fold 2:
  accuracy: 0.889
  roc_auc:  0.622
  f1_pos:   0.015

Early stopping occurred at epoch 14 with best_epoch = 9 and best_valid_auc = 0.62454





Fold 3:
  accuracy: 0.888
  roc_auc:  0.625
  f1_pos:   0.000

Early stopping occurred at epoch 11 with best_epoch = 6 and best_valid_auc = 0.63209





Fold 4:
  accuracy: 0.888
  roc_auc:  0.632
  f1_pos:   0.001

Early stopping occurred at epoch 10 with best_epoch = 5 and best_valid_auc = 0.61854





Fold 5:
  accuracy: 0.888
  roc_auc:  0.619
  f1_pos:   0.016

5-fold CV summary (TabNet)
accuracy    0.888
roc_auc     0.627
f1_pos      0.007
dtype: float64
Train shape: (81412, 47)
Test shape: (20354, 47)




epoch 0  | loss: 0.38087 | train_auc: 0.55596 | valid_auc: 0.55749 |  0:00:09s
epoch 1  | loss: 0.34396 | train_auc: 0.59705 | valid_auc: 0.58683 |  0:00:18s
epoch 2  | loss: 0.34294 | train_auc: 0.61355 | valid_auc: 0.60617 |  0:00:27s
epoch 3  | loss: 0.34158 | train_auc: 0.61044 | valid_auc: 0.60502 |  0:00:36s
epoch 4  | loss: 0.34134 | train_auc: 0.61588 | valid_auc: 0.61566 |  0:00:45s
epoch 5  | loss: 0.34048 | train_auc: 0.61332 | valid_auc: 0.60976 |  0:00:53s
epoch 6  | loss: 0.3407  | train_auc: 0.6069  | valid_auc: 0.59862 |  0:01:03s
epoch 7  | loss: 0.34102 | train_auc: 0.61766 | valid_auc: 0.61145 |  0:01:11s
epoch 8  | loss: 0.33941 | train_auc: 0.62823 | valid_auc: 0.6178  |  0:01:20s
epoch 9  | loss: 0.33957 | train_auc: 0.59925 | valid_auc: 0.60117 |  0:01:29s
epoch 10 | loss: 0.33809 | train_auc: 0.63007 | valid_auc: 0.6267  |  0:01:39s
epoch 11 | loss: 0.33811 | train_auc: 0.61593 | valid_auc: 0.61683 |  0:01:48s
epoch 12 | loss: 0.33854 | train_auc: 0.57552 | vali




TabNet model results (no k fold):
  accuracy: 0.888
  roc_auc: 0.627
  f1_pos: 0.002
