In [49]:
import random
from os import path

import numpy as np
import polars as pl
import torch
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

In [50]:
SEED = 491
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [51]:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
print(f"Using device: {device}")

Using device: cuda:0


In [52]:
data_path = "../../data/previous"
syn_cohort = path.join(data_path, "synthetic_patients_2.0_cohort.csv")
processed_seer = path.join(data_path, "processed_seer_with_age.csv")
syn_cohort_dataset = pl.read_csv(syn_cohort)
seer_dataset = pl.read_csv(processed_seer)

In [53]:
label_col = "target"

X_seer = syn_cohort_dataset.select(pl.exclude(label_col)).to_numpy().astype(np.float32)
y_seer = syn_cohort_dataset.get_column(label_col).to_numpy().astype(np.int64).ravel()
X_train, X_val, y_train, y_val = train_test_split(X_seer, y_seer, test_size=0.2, stratify=y_seer, random_state=491)

X_test = seer_dataset.select(pl.exclude(label_col)).to_numpy().astype(np.float32)
y_test = seer_dataset.get_column(label_col).to_numpy().astype(np.int64).ravel()

In [54]:
scaler = StandardScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train).astype(np.float32)
X_val = scaler.transform(X_val).astype(np.float32)
X_test = scaler.transform(X_test).astype(np.float32)

In [55]:
X_train_t = torch.from_numpy(X_train).to(device)
y_train_t = torch.from_numpy(y_train).to(device)
X_val_t = torch.from_numpy(X_val).to(device)
y_val_t = torch.from_numpy(y_val).to(device)
X_test_t = torch.from_numpy(X_test).to(device)
y_test_t = torch.from_numpy(y_test).to(device)

g = torch.Generator()
g.manual_seed(SEED)
train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=64, shuffle=True, generator=g)
val_loader = DataLoader(TensorDataset(X_val_t, y_val_t), batch_size=64, shuffle=False)
test_loader = DataLoader(TensorDataset(X_test_t, y_test_t), batch_size=64, shuffle=False)

In [56]:
n_features = X_train.shape[1]
n_classes = int(np.unique(y_seer).size)
print(f"n_features: {n_features}\nn_classes: {n_classes}")

n_features: 6
n_classes: 2


In [57]:
torch.backends.cuda.matmul.fp32_precision = "ieee"


class BinaryMLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.mlp(x).squeeze(1)


model = torch.compile(BinaryMLP(n_features).to(device))
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [58]:
def optimal_threshold(y_true, probs):
    fpr, tpr, thresholds = roc_curve(y_true, probs)
    j_scores = tpr - fpr
    return thresholds[np.argmax(j_scores)]

In [59]:
def eval_metrics(loader, name="Validation"):
    model.eval()
    all_logits, all_targets = [], []
    with torch.no_grad():
        for x_batch, y_batch in loader:
            logits_v = model(x_batch)
            all_logits.append(torch.sigmoid(logits_v).cpu())
            all_targets.append(y_batch.cpu())

    y_prob = torch.cat(all_logits).numpy()
    y_true = torch.cat(all_targets).numpy()

    auc = roc_auc_score(y_true, y_prob)
    threshold = optimal_threshold(y_true, y_prob)
    print(f"Optimal threshold for {name}: {threshold:.3f}")

    y_pred = (y_prob >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    acc = accuracy_score(y_true, y_pred)
    sens = recall_score(y_true, y_pred, zero_division=0)
    spec = tn / (tn + fp) if (tn + fp) else 0
    prec = precision_score(y_true, y_pred, zero_division=0)
    npv = tn / (tn + fn) if (tn + fn) else 0
    f1 = f1_score(y_true, y_pred, zero_division=0)

    print(f"\n{name} metrics:")
    print(f"Accuracy              : {acc:.4f}")
    print(f"Sensitivity (Recall)  : {sens:.4f}")
    print(f"Specificity           : {spec:.4f}")
    print(f"Precision (PPV)       : {prec:.4f}")
    print(f"NPV                   : {npv:.4f}")
    print(f"F1 Score              : {f1:.4f}")
    print(f"AUC                   : {auc:.4f}")

    return {
        "Accuracy": acc,
        "Sensitivity": sens,
        "Specificity": spec,
        "Precision": prec,
        "NPV": npv,
        "F1": f1,
        "AUC": auc,
    }

In [60]:
EPOCHS = 20

pos_weight = torch.tensor([len(y_train) / np.sum(y_train) - 1], dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch}")
    model.train()
    for batch_x_t, batch_y_t in train_loader:
        optimizer.zero_grad(set_to_none=True)
        logits = model(batch_x_t)
        loss = criterion(logits, batch_y_t.float())
        loss.backward()
        optimizer.step()

    eval_metrics(val_loader, name="Validation")

print("\n" + "-" * 50)
print("Final metrics:")
val_metrics = eval_metrics(val_loader, name="Validation")
test_metrics = eval_metrics(test_loader, name=f"Test (cohort)")


Epoch 0
Optimal threshold for Validation: 0.510

Validation metrics:
Accuracy              : 0.6185
Sensitivity (Recall)  : 0.5669
Specificity           : 0.6710
Precision (PPV)       : 0.6370
NPV                   : 0.6034
F1 Score              : 0.5999
AUC                   : 0.6506

Epoch 1
Optimal threshold for Validation: 0.443

Validation metrics:
Accuracy              : 0.6075
Sensitivity (Recall)  : 0.6026
Specificity           : 0.6125
Precision (PPV)       : 0.6129
NPV                   : 0.6022
F1 Score              : 0.6077
AUC                   : 0.6453

Epoch 2
Optimal threshold for Validation: 0.459

Validation metrics:
Accuracy              : 0.6030
Sensitivity (Recall)  : 0.6472
Specificity           : 0.5580
Precision (PPV)       : 0.5985
NPV                   : 0.6084
F1 Score              : 0.6219
AUC                   : 0.6428

Epoch 3
Optimal threshold for Validation: 0.488

Validation metrics:
Accuracy              : 0.6080
Sensitivity (Recall)  : 0.6581
Specifi