In [3]:
import pandas as pd
import numpy as np
import os

from utils.diabetes_utils import clean_diabetes_data
from utils.diabetes_utils import plot_and_save_metrics

from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    roc_auc_score,
    f1_score,
    average_precision_score,
    precision_recall_curve,
)
from sklearn.utils.class_weight import compute_class_weight

import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping


# Load dataset
df = pd.read_csv("data/diabetic_data.csv")

df_clean = clean_diabetes_data(df)
print("Cleaned shape:", df_clean.shape)


# Longitudinal numeric features
long_feats = [
    "time_in_hospital",
    "num_lab_procedures",
    "num_procedures",
    "num_medications",
    "number_outpatient",
    "number_emergency",
    "number_inpatient",
    "number_diagnoses",
]

# Demographics / categorical
demo_cols = [
    "race",
    "gender",
    "age",
    "diag_1_group",
    "diag_2_group",
    "diag_3_group",
    "insulin",
    "change",
    "diabetesMed",
]

model_df = df_clean[long_feats + demo_cols + ["readmit_30d", "patient_nbr"]].dropna()

groups = model_df["patient_nbr"].values
y_seq = model_df["readmit_30d"].values.astype(int)

print("Positive rate:", round(y_seq.mean(), 3))


# Longitudinal tensor
X_long = model_df[long_feats].values.astype(np.float32)
n_samples = X_long.shape[0]
X_long = X_long.reshape(n_samples, 1, -1)

# Time channel
X_time = np.ones((n_samples, 1, 1), dtype=np.float32)

# One-hot encode demographics
demo_df = pd.get_dummies(model_df[demo_cols], columns=demo_cols, drop_first=True)
X_demo = demo_df.values.astype(np.float32)


def build_ta_rnn_model(n_long_feats, n_demo_features, T=1):

    long_input = Input(shape=(T, n_long_feats))
    time_input = Input(shape=(T, 1))
    demo_input = Input(shape=(n_demo_features,))

    seq_input = Concatenate(axis=-1)([long_input, time_input])

    x = LSTM(64)(seq_input)
    x = Dropout(0.3)(x)
    x = Dense(32, activation="relu")(x)

    x = Concatenate(axis=-1)([x, demo_input])
    output = Dense(1, activation="sigmoid")(x)

    model = Model(
        inputs=[long_input, time_input, demo_input],
        outputs=output,
    )

    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
    )

    return model


gkf = GroupKFold(n_splits=5)

n_long_feats = X_long.shape[2]
n_demo_features = X_demo.shape[1]
T = 1

cv_metrics = []

oof_probs = np.zeros(len(y_seq))
oof_targets = np.zeros(len(y_seq))

for fold, (train_idx, val_idx) in enumerate(
        gkf.split(X_demo, y_seq, groups=groups), start=1):

    X_tr_long = X_long[train_idx]
    X_val_long = X_long[val_idx]

    X_tr_time = X_time[train_idx]
    X_val_time = X_time[val_idx]

    X_tr_demo = X_demo[train_idx]
    X_val_demo = X_demo[val_idx]

    y_tr = y_seq[train_idx]
    y_val = y_seq[val_idx]

    scaler = StandardScaler()

    X_tr_2d = X_tr_long.reshape(-1, n_long_feats)
    X_val_2d = X_val_long.reshape(-1, n_long_feats)

    X_tr_2d = scaler.fit_transform(X_tr_2d)
    X_val_2d = scaler.transform(X_val_2d)

    X_tr_long = X_tr_2d.reshape(X_tr_long.shape)
    X_val_long = X_val_2d.reshape(X_val_long.shape)

    class_weights_arr = compute_class_weight(
        class_weight="balanced",
        classes=np.array([0, 1]),
        y=y_tr
    )
    class_weight_dict = {0: class_weights_arr[0], 1: class_weights_arr[1]}

    model = build_ta_rnn_model(n_long_feats, n_demo_features, T=T)

    early_stop = EarlyStopping(
        monitor="val_auc",
        patience=2,
        mode="max",
        restore_best_weights=True,
    )

    model.fit(
        [X_tr_long, X_tr_time, X_tr_demo],
        y_tr,
        validation_data=([X_val_long, X_val_time, X_val_demo], y_val),
        epochs=10,
        batch_size=256,
        callbacks=[early_stop],
        class_weight=class_weight_dict,
        verbose=0,
    )

    y_val_prob = model.predict(
        [X_val_long, X_val_time, X_val_demo],
        verbose=0
    ).ravel()

    oof_probs[val_idx] = y_val_prob
    oof_targets[val_idx] = y_val

    y_val_pred_default = (y_val_prob >= 0.5).astype(int)

    auprc = average_precision_score(y_val, y_val_prob)

    precision, recall, thresholds = precision_recall_curve(y_val, y_val_prob)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_idx = np.argmax(f1_scores[:-1])
    best_threshold = thresholds[best_idx]

    y_val_pred_tuned = (y_val_prob >= best_threshold).astype(int)

    fold_result = {
        "fold": fold,
        "roc_auc": roc_auc_score(y_val, y_val_prob),
        "auprc": auprc,
        "f1_default": f1_score(y_val, y_val_pred_default, zero_division=0),
        "f1_tuned": f1_score(y_val, y_val_pred_tuned, zero_division=0),
        "best_threshold": best_threshold,
    }

    cv_metrics.append(fold_result)

    print(f"\nFold {fold}:")
    print(f"  AUC:        {fold_result['roc_auc']:.3f}")
    print(f"  AUPRC:      {fold_result['auprc']:.3f}")
    print(f"  F1 (0.5):   {fold_result['f1_default']:.3f}")
    print(f"  F1 (tuned): {fold_result['f1_tuned']:.3f}")
    print(f"  Best thr:   {best_threshold:.3f}")

cv_df = pd.DataFrame(cv_metrics)

print("\n5-fold CV summary (TA-RNN – patient split + tuned threshold)")
print("Mean AUC:", round(cv_df["roc_auc"].mean(), 3))
print("Mean AUPRC:", round(cv_df["auprc"].mean(), 3))
print("Mean F1 (0.5):", round(cv_df["f1_default"].mean(), 3))
print("Mean F1 (tuned):", round(cv_df["f1_tuned"].mean(), 3))


os.makedirs("oof_predictions", exist_ok=True)

np.save("oof_predictions/oof_ta_rnn.npy", oof_probs)
np.save("oof_predictions/y_oof_ta_rnn.npy", oof_targets)

print("Saved OOF predictions to oof_predictions/")


plot_and_save_metrics(
    model_name="ta_rnn_oof",
    y_test=oof_targets,
    y_prob=oof_probs,
    threshold=0.5
)

Cleaned shape: (101766, 51)
Positive rate: 0.112

Fold 1:
  AUC:        0.644
  AUPRC:      0.204
  F1 (0.5):   0.254
  F1 (tuned): 0.269
  Best thr:   0.566

Fold 2:
  AUC:        0.641
  AUPRC:      0.186
  F1 (0.5):   0.247
  F1 (tuned): 0.253
  Best thr:   0.550

Fold 3:
  AUC:        0.640
  AUPRC:      0.202
  F1 (0.5):   0.252
  F1 (tuned): 0.258
  Best thr:   0.528

Fold 4:
  AUC:        0.647
  AUPRC:      0.202
  F1 (0.5):   0.255
  F1 (tuned): 0.266
  Best thr:   0.528

Fold 5:
  AUC:        0.641
  AUPRC:      0.193
  F1 (0.5):   0.248
  F1 (tuned): 0.256
  Best thr:   0.527

5-fold CV summary (TA-RNN – patient split + tuned threshold)
Mean AUC: 0.642
Mean AUPRC: 0.198
Mean F1 (0.5): 0.251
Mean F1 (tuned): 0.26
Saved OOF predictions to oof_predictions/
