In [35]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from inspect import signature
import omegaconf
import shap


# PyTorch Tabular (modern API)
import torch
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, TrainerConfig, OptimizerConfig
try:
    from pytorch_tabular.models.tab_transformer.config import TabTransformerConfig
except Exception:
    from pytorch_tabular.models.tab_transformer import TabTransformerConfig
from unittest.mock import patch



In [36]:
# Load dataset
da = "smote"
df = pd.read_csv("./processed_datasets/dataset_preprocessed_" + da + ".csv", index_col=False)

df.head()

Unnamed: 0,AgeAtStartOfSpell,Ethnicity,IMD Decile,Body Mass Index at Booking,Obese?,Parity,Gravida,No_Of_previous_Csections,FolicAcidDose,GlucoseToleranceTest,Gestational Diabetes
0,2.000001,0,2.000001,2.742637,1.0,1e-06,0.500001,0.0,0,0,0.0
1,1.875001,1,1.666668,2.51473,1.0,1e-06,0.500001,0.0,0,0,0.0
2,1.000001,0,0.333334,1.392249,0.0,1e-06,0.500001,0.0,0,1,0.0
3,2.125001,2,0.666668,0.482172,0.0,1e-06,0.500001,0.0,0,2,0.0
4,1.250001,0,1.333334,0.679071,0.0,1e-06,0.500001,0.0,0,2,0.0


In [37]:
target = "Gestational Diabetes"

numeric_cols = [
    "AgeAtStartOfSpell",
    "IMD Decile",
    "Body Mass Index at Booking",
    "Parity",
    "Gravida",
    "No_Of_previous_Csections"
]

categorical_cols = [
    "Ethnicity",
    "Obese?",
    "FolicAcidDose",
    "GlucoseToleranceTest"
]

# Train/valid split (stratified)
train_df, test_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df[target]
)

# Class weights for imbalance
classes = np.array([0, 1], dtype=int)
weights = compute_class_weight(
    class_weight="balanced",
    classes=classes,
    y=train_df[target].astype(int).values
)
class_weight_dict = {int(c): float(w) for c, w in zip(classes, weights)}
weights_list = [class_weight_dict.get(0,1.0), class_weight_dict.get(1,1.0)]
print("class_weight_dict:", class_weight_dict)


class_weight_dict: {0: 1.0, 1: 1.0}


In [38]:
loss_params = {"weight": [class_weight_dict.get(0, 1.0), class_weight_dict.get(1, 1.0)]}


In [40]:
# model

# Version-safe config helpers
def safe_config(ConfigClass, kwargs):
    params = set(signature(ConfigClass).parameters)
    return ConfigClass(**{k: v for k, v in kwargs.items() if k in params})

# Keep configs MINIMAL to avoid version-specific args
data_config = safe_config(DataConfig, dict(
    target=[target],
    continuous_cols=numeric_cols,
    categorical_cols=categorical_cols,
))

# trainer
trainer_config = safe_config(TrainerConfig, dict(
    batch_size=512,
    max_epochs=30,
    seed=42,
    log_every_n_steps=10,
    enable_checkpointing=False,
))

opt_name = "AdamW" if hasattr(torch.optim, "AdamW") else "Adam"
optimizer_config = safe_config(OptimizerConfig, dict(
    optimizer=opt_name,
    optimizer_params={"weight_decay": 1e-5},
))

# No metrics passed here (older versions can choke); we’ll compute AUC with sklearn
model_config = safe_config(TabTransformerConfig, dict(
    task="classification",
    loss="CrossEntropyLoss",
    loss_params={"weight": weights_list},
    num_attn_blocks=4,
    num_heads=4,
    input_embed_dim=32,
    attn_dropout=0.1,
    ff_dropout=0.1,
    embedding_dropout=0.05,
    learning_rate=1e-3,
    seed=42,
))

# Train
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    trainer_config=trainer_config,
    optimizer_config=optimizer_config,
)

# --- patch torch.load so internal checkpoint reload uses weights_only=False ---
_orig_torch_load = torch.load

def _torch_load_weights_only_false(*args, **kwargs):
    # force legacy, trusted behavior for this call site
    kwargs.setdefault("weights_only", False)
    return _orig_torch_load(*args, **kwargs)

# Apply patch only during fit()
with patch("torch.load", new=_torch_load_weights_only_false):
    tabular_model.fit(train=train_df, validation=test_df)

Seed set to 42


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  X_encoded[col].fillna(self._imputed, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  X_encoded[col].fillna(self._imputed, inplace=True)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


/Users/sindabesrour/Projects/gestationalDiabetesPrediction/env/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/sindabesrour/Projects/gestationalDiabetesPrediction/GDM-transformers/saved_models exists and is not empty.


Output()

In [41]:
# Evaluate (AUC + confusion matrix + report)
eval_res = tabular_model.evaluate(test_df)
print("Eval: ", eval_res)

pred_df = tabular_model.predict(test_df)
# robust grab of positive-class probability
prob_cols = [c for c in pred_df.columns if c.endswith("_probability")]
if not prob_cols:
    prob_cols = [c for c in pred_df.columns if "prob" in c.lower()]
assert len(prob_cols) >= 1, f"Couldn't find probability column in: {pred_df.columns.tolist()}"

# if both 0_probability and 1_probability exist, pick class 1
pos_col = None
for c in prob_cols:
    if c.startswith("1_") or c.lower().startswith("class1"):
        pos_col = c
        break
if pos_col is None:
    pos_col = prob_cols[-1]  # fallback

y_prob = pred_df[pos_col].to_numpy()
y_true = test_df[target].to_numpy().astype(int)
y_pred = (y_prob >= 0.5).astype(int)

print(f"AUC: {roc_auc_score(y_true, y_prob):.3f}")
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))
print("\nClassification report:\n", classification_report(y_true, y_pred, digits=3))

Output()

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  X_encoded[col].fillna(self._imputed, inplace=True)
/Users/sindabesrour/Projects/gestationalDiabetesPrediction/env/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Eval:  [{'test_loss_0': 0.6151891350746155, 'test_loss': 0.6151891350746155, 'test_accuracy': 0.6689847111701965}]


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  X_encoded[col].fillna(self._imputed, inplace=True)


AUC: 0.719
Confusion matrix:
 [[2802 1512]
 [1344 2970]]

Classification report:
               precision    recall  f1-score   support

           0      0.676     0.650     0.662      4314
           1      0.663     0.688     0.675      4314

    accuracy                          0.669      8628
   macro avg      0.669     0.669     0.669      8628
weighted avg      0.669     0.669     0.669      8628

