In [None]:
import sys
!{sys.executable} -m pip install pytorch-tabnet

import pandas as pd
import numpy as np
import torch

from ucimlrepo import fetch_ucirepo
from pytorch_tabnet.tab_model import TabNetClassifier

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

from diabetes_utils import clean_diabetes_data, plot_and_save_metrics

# 1) Load raw data from UCI and build df
diabetes_data = fetch_ucirepo(id=296)
X_raw = diabetes_data.data.features
y_raw = diabetes_data.data.targets

# Make sure target column has a consistent name
if "readmitted" not in y_raw.columns:
    y_raw.columns = ["readmitted"]

df = pd.concat([X_raw, y_raw], axis=1)
print("Raw shape:", df.shape)

# 2) Clean dataset with our helper function
df_clean = clean_diabetes_data(df)
print("Cleaned shape:", df_clean.shape)

# 3) Build feature matrix for TabNet
#    Target is the binary 30-day readmission flag
target_col = "readmit_30d"

df_tab = df_clean.copy()
y = df_tab[target_col].values
df_tab = df_tab.drop(columns=["readmitted", target_col])  # drop raw label + target

# Treat object columns as categorical
cat_cols = df_tab.select_dtypes(include="object").columns.tolist()
num_cols = [c for c in df_tab.columns if c not in cat_cols]

print("Categorical cols for TabNet:", cat_cols)
print("Numeric cols for TabNet:", num_cols)

# Label-encode each categorical column for TabNet
encoders = {}
for col in cat_cols:
    le = LabelEncoder()
    df_tab[col] = le.fit_transform(df_tab[col].astype(str))
    encoders[col] = le

# Feature matrix
X = df_tab.values

# Indices and dimensions of categorical features
cat_idxs = [df_tab.columns.get_loc(c) for c in cat_cols]
cat_dims = [df_tab[c].nunique() for c in cat_cols]

print("cat_idxs:", cat_idxs)
print("cat_dims:", cat_dims)

# 4) Trainâ€“test split
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

# 5) Define TabNet model
tabnet_clf = TabNetClassifier(
    n_d=32,
    n_a=32,
    n_steps=3,
    gamma=1.3,
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=1,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params={"step_size": 10, "gamma": 0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type="sparsemax",
    verbose=1
)

# 6) Train TabNet (AUC is your c-stat)
tabnet_clf.fit(
    X_train, y_train,
    eval_set=[(X_train, y_train), (X_test, y_test)],
    eval_name=["train", "valid"],
    eval_metric=["auc"],
    max_epochs=50,
    patience=5,
    batch_size=1024,
    virtual_batch_size=128
)

# 7) Evaluate and save plots
# predict_proba returns (N, 2) for binary classification
y_prob = tabnet_clf.predict_proba(X_test)[:, 1]
y_pred = tabnet_clf.predict(X_test)

tabnet_results = {
    "accuracy": round(accuracy_score(y_test, y_pred), 3),
    "roc_auc": round(roc_auc_score(y_test, y_prob), 3),
    "f1_pos":  round(f1_score(y_test, y_pred, zero_division=0), 3),
}

print("\nTabNet model results:")
for k, v in tabnet_results.items():
    print(f"  {k}: {v}")

# Save plots
plot_and_save_metrics("tabnet", y_test, y_prob)

# Save probabilites
np.save("y_test_tabnet.npy", y_test)
np.save("probs_tabnet.npy", y_prob)

In [None]:
from diabetes_utils import clean_diabetes_data, plot_and_save_metrics

y_test = np.load("y_test_tabnet.npy")

y_prob = np.load("probs_tabnet.npy").ravel()

# Save plots
plot_and_save_metrics("tabnet", y_test, y_prob)

In [None]:
from matplotlib import pyplot as plt
import shap
import pandas as pd
import numpy as np

# Get TabNet per-sample explanation matrix
explain_matrix, masks = tabnet_clf.explain(X_test)

# Pick sample 0
i = 0
contrib = explain_matrix[i]
feature_names = df_tab.columns

# Build SHAP-like Explanation object
exp = shap.Explanation(
    values=contrib,
    base_values=np.mean(y_prob),     # TabNet has no base value; we approximate
    data=X_test[i],
    feature_names=feature_names
)

In [None]:
shap.waterfall_plot(exp, show=False, max_display=7)
plt.title("SHAP TabNet Waterfall Plot", fontsize=20)

fig, ax = plt.gcf(), plt.gca()

# Modifying main plot parameters
for text in fig.findobj(match=plt.Text):
    if not text.get_text().startswith("\n"):
        text.set_fontsize(18)

plt.savefig("figures/shap_tabnet_waterfall_plot.png", bbox_inches='tight')
plt.show()