In [2]:
import pandas as pd
import numpy as np

from ucimlrepo import fetch_ucirepo
from diabetes_utils import clean_diabetes_data, plot_and_save_metrics

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping


# Load and clean data
diabetes_data = fetch_ucirepo(id=296)
X = diabetes_data.data.features
y = diabetes_data.data.targets

if "readmitted" not in y.columns:
    y.columns = ["readmitted"]

df = pd.concat([X, y], axis=1)
df_clean = clean_diabetes_data(df)

print("Cleaned shape:", df_clean.shape)
print(df_clean.columns)

# longitudinal numeric features used at each encounter
long_feats = [
    "time_in_hospital",
    "num_lab_procedures",
    "num_procedures",
    "num_medications",
    "number_outpatient",
    "number_emergency",
    "number_inpatient",
    "number_diagnoses",
]

# demographic / categorical features
demo_cols = [
    "race",
    "gender",
    "age",
    "diag_1_group",
    "diag_2_group",
    "diag_3_group",
    "insulin",
    "change",
    "diabetesMed",
]

# target label
y_seq = df_clean["readmit_30d"].values
n_samples = df_clean.shape[0]

# longitudinal numeric features as a 3D tensor: (samples, timesteps, features)
X_long = df_clean[long_feats].values
X_long = X_long.reshape(n_samples, 1, -1)  # (N, T=1, n_long_feats)

# simple time-gap channel: 1 for each real step (no padding here)
X_time = np.ones((n_samples, 1, 1))  # shape (N, 1, 1)

# demographics as one-hot encoded 2D matrix
demo_df = df_clean[demo_cols].copy()
demo_df = pd.get_dummies(demo_df, columns=demo_cols, drop_first=True)
X_demo = demo_df.values  # shape (N, n_demo_features)

print("X_long shape:", X_long.shape)
print("X_time shape:", X_time.shape)
print("X_demo shape:", X_demo.shape)
print("Positive rate:", y_seq.mean().round(3))


# helper to build a fresh TA-RNN model (used in both CV + final train)
def build_ta_rnn_model(n_long_feats, n_demo_features, T=1):
    # inputs: sequence features, time gap channel, demographics
    long_input = Input(shape=(T, n_long_feats), name="longitudinal")
    time_input = Input(shape=(T, 1), name="time_gaps")
    demo_input = Input(shape=(n_demo_features,), name="demographics")

    # attach time gaps to each time step
    seq_input = Concatenate(axis=-1)([long_input, time_input])

    # LSTM over the (length-1) sequence
    x = LSTM(64)(seq_input)
    x = Dropout(0.3)(x)
    x = Dense(32, activation="relu")(x)

    # combine sequence representation with demographics
    x = Concatenate(axis=-1)([x, demo_input])

    output = Dense(1, activation="sigmoid")(x)

    model = Model(
        inputs=[long_input, time_input, demo_input],
        outputs=output,
        name="ta_rnn_simplified",
    )

    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
    )
    return model


# Stratified 5-fold CV for TA-RNN
n_long_feats = X_long.shape[2]
n_demo_features = X_demo.shape[1]
T = X_long.shape[1]  # should be 1

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_metrics = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X_demo, y_seq), start=1):
    # split all inputs using the same indices
    X_tr_long = X_long[train_idx].copy()
    X_val_long = X_long[val_idx].copy()
    X_tr_time = X_time[train_idx].copy()
    X_val_time = X_time[val_idx].copy()
    X_tr_demo = X_demo[train_idx].copy()
    X_val_demo = X_demo[val_idx].copy()
    y_tr = y_seq[train_idx]
    y_val = y_seq[val_idx]

    # scale longitudinal numeric features within this fold
    scaler_long = StandardScaler()

    # flatten (N*T, features) for scaling
    X_tr_long_2d = X_tr_long.reshape(-1, n_long_feats)
    X_val_long_2d = X_val_long.reshape(-1, n_long_feats)

    X_tr_long_2d = scaler_long.fit_transform(X_tr_long_2d)
    X_val_long_2d = scaler_long.transform(X_val_long_2d)

    # reshape back to (N, T, features)
    X_tr_long = X_tr_long_2d.reshape(X_tr_long.shape)
    X_val_long = X_val_long_2d.reshape(X_val_long.shape)

    # build a fresh TA-RNN model for this fold
    model_cv = build_ta_rnn_model(n_long_feats, n_demo_features, T=T)

    early_stop_cv = EarlyStopping(
        monitor="val_auc",
        patience=2,
        mode="max",
        restore_best_weights=True,
    )

    history_cv = model_cv.fit(
        [X_tr_long, X_tr_time, X_tr_demo],
        y_tr,
        validation_data=([X_val_long, X_val_time, X_val_demo], y_val),
        epochs=10,
        batch_size=256,
        callbacks=[early_stop_cv],
        verbose=0, 
    )

    # evaluate this fold
    y_val_prob = model_cv.predict([X_val_long, X_val_time, X_val_demo]).ravel()
    y_val_pred = (y_val_prob >= 0.5).astype(int)

    fold_result = {
        "fold": fold,
        "accuracy": accuracy_score(y_val, y_val_pred),
        "roc_auc": roc_auc_score(y_val, y_val_prob),
        "f1_pos":  f1_score(y_val, y_val_pred, zero_division=0),
    }
    cv_metrics.append(fold_result)

    print(f"\nFold {fold}:")
    print(f"  accuracy: {fold_result['accuracy']:.3f}")
    print(f"  roc_auc:  {fold_result['roc_auc']:.3f}")
    print(f"  f1_pos:   {fold_result['f1_pos']:.3f}")

cv_df = pd.DataFrame(cv_metrics)
print("\n5-fold CV summary (TA-RNN)")
print(cv_df[["accuracy", "roc_auc", "f1_pos"]].mean().round(3))

# trainâ€“test split + final TA-RNN
X_train_long, X_test_long, \
X_train_time, X_test_time, \
X_train_demo, X_test_demo, \
y_train, y_test = train_test_split(
    X_long,
    X_time,
    X_demo,
    y_seq,
    test_size=0.2,
    random_state=42,
    stratify=y_seq,
)

# scale longitudinal numeric features across all time steps (train vs test)
n_long_feats = X_train_long.shape[2]
scaler_long = StandardScaler()

X_train_long_2d = X_train_long.reshape(-1, n_long_feats)
X_test_long_2d = X_test_long.reshape(-1, n_long_feats)

X_train_long_2d = scaler_long.fit_transform(X_train_long_2d)
X_test_long_2d = scaler_long.transform(X_test_long_2d)

X_train_long = X_train_long_2d.reshape(X_train_long.shape)
X_test_long = X_test_long_2d.reshape(X_test_long.shape)

# define final TA-RNN model (same architecture)
ta_rnn_model = build_ta_rnn_model(n_long_feats, n_demo_features, T=T)
ta_rnn_model.summary()

early_stop = EarlyStopping(
    monitor="val_auc",
    patience=2,
    mode="max",
    restore_best_weights=True,
)

history = ta_rnn_model.fit(
    [X_train_long, X_train_time, X_train_demo],
    y_train,
    validation_split=0.2,
    epochs=10,
    batch_size=256,
    callbacks=[early_stop],
    verbose=1,
)

# Evaluate and save plots on held-out test set
y_prob = ta_rnn_model.predict([X_test_long, X_test_time, X_test_demo]).ravel()
y_pred = (y_prob >= 0.5).astype(int)

ta_rnn_results = {
    "accuracy": round(accuracy_score(y_test, y_pred), 3),
    "roc_auc": round(roc_auc_score(y_test, y_prob), 3),
    "f1_pos":  round(f1_score(y_test, y_pred, zero_division=0), 3),
}

print("\nTA-RNN-style model results (no k fold):")
for k, v in ta_rnn_results.items():
    print(f"  {k}: {v}")

# Save plots
plot_and_save_metrics("ta_rnn", y_test, y_prob)

# Save probabilities
np.save("y_test_tarnn.npy", y_test)
np.save("prob_tarnn.npy", y_prob)

  df = pd.read_csv(data_url)


Cleaned shape: (101766, 49)
Index(['race', 'gender', 'age', 'admission_type_id',
       'discharge_disposition_id', 'admission_source_id', 'time_in_hospital',
       'num_lab_procedures', 'num_procedures', 'num_medications',
       'number_outpatient', 'number_emergency', 'number_inpatient', 'diag_1',
       'diag_2', 'diag_3', 'number_diagnoses', 'max_glu_serum', 'A1Cresult',
       'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide',
       'glimepiride', 'acetohexamide', 'glipizide', 'glyburide', 'tolbutamide',
       'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone',
       'tolazamide', 'examide', 'citoglipton', 'insulin',
       'glyburide-metformin', 'glipizide-metformin',
       'glimepiride-pioglitazone', 'metformin-rosiglitazone',
       'metformin-pioglitazone', 'change', 'diabetesMed', 'readmitted',
       'diag_1_group', 'diag_2_group', 'diag_3_group', 'readmit_30d'],
      dtype='object')
X_long shape: (101766, 1, 8)
X_time shape: (101766, 1,