In [None]:
import torch
import xgboost as xgb
import numpy as np
import random
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, average_precision_score, 
    precision_recall_curve, auc
)
from opacus.accountants import RDPAccountant

from loaders_data import process_clinical_data
from utils import *
from metrics import *

In [None]:
# Set seeds
SEED = 40
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# ----------------------
# Load and Prepare Data
# ----------------------
TRAIN, VAL, TEST = process_clinical_data('./final_data/UHB.csv')
X_train, y_train = TRAIN
X_test, y_test = TEST
X_val, y_val = VAL


# ----------------------
# Train XGBoost Model
# ----------------------
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
dtest = xgb.DMatrix(X_test, label=y_test)

xgb_params = {
    "objective": "binary:logistic",
    "learning_rate": 0.01,
    "max_depth": 3,
    "subsample": 0.7,
    "reg_lambda": 2.0,
    "reg_alpha": 1.0,
    "min_child_weight": 10,
    "eval_metric": "auc"
}

print("\n=== Training XGBoost Binary Classifier ===")
xgb_model = xgb.train(
    params=xgb_params,
    dtrain=dtrain,
    evals=[(dval, "validation")],
    num_boost_round=15000,
    early_stopping_rounds=250,
    verbose_eval=False
)


# ----------------------
# Initial Evaluation
# ----------------------
y_pred = xgb_model.predict(dtest)
auroc = roc_auc_score(y_test, y_pred)
loss = binary_cross_entropy(y_test, y_pred)

precision, recall, _ = precision_recall_curve(y_test, y_pred)
svc_pr_auc = auc(recall, precision)

# Evaluate on separate validation and test sets
print("\n=== Evaluation with Threshold Optimization ===")
evaluator = BinaryClassifierEvaluator()

results = evaluator.evaluate_separately(
    y_val_true=y_val,
    y_val_prob=xgb_model.predict(dval),
    y_test_true=y_test,
    y_test_prob=y_pred,
    target_metric='Recall',
    target_value=0.85,
    error_margin=0.1,
    n_bootstraps=1000,return_ci=True
)
print(" | ".join([f"FULL DATASET---"] + [f"{k}: {v}" for k, v in results.items()]))

In [None]:
# ----------------------
# Initialize DP Settings, to be tuned for required privacy budget and condensed samples
# ----------------------
num_syn = 200 # condensed samples
T = 30000
delta = 1e-5
noise_multiplier = 1
max_grad_norm = 1
min_snr = 1.25
min_noise_multiplier = 0.25
sample_rate = num_syn / len(X_train)
accountant = SelectiveAccountant(RDPAccountant(), max_steps=0)

In [None]:
class DPFiniteDiffGradientEstimator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X_syn, use_dp=False, noise_multiplier=noise_multiplier, clip_norm=max_grad_norm, 
                epsilon_range=(0.25, 0.5), seed=None):
        ctx.save_for_backward(X_syn)
        ctx.use_dp = use_dp
        ctx.noise_multiplier = noise_multiplier
        ctx.clip_norm = clip_norm
        ctx.epsilon_range = epsilon_range
        ctx.seed = seed
        ctx.model = xgb_model

        X_np = X_syn.detach().cpu().numpy()
        preds = ctx.model.predict(xgb.DMatrix(X_np))
        return torch.tensor(preds, dtype=torch.float32, device=X_syn.device)

    @staticmethod
    def backward(ctx, grad_output):
        X_syn, = ctx.saved_tensors
        device = X_syn.device
        model = ctx.model
        seed = ctx.seed
        eps_range = ctx.epsilon_range
        use_dp = ctx.use_dp
        noise_multiplier = ctx.noise_multiplier
        clip_norm = ctx.clip_norm

        X_np = X_syn.detach().cpu().numpy()
        n, d = X_np.shape
        rng = np.random.default_rng(seed)
        epsilons = rng.uniform(low=eps_range[0], high=eps_range[1], size=(d,))
        grad_estimate = np.zeros_like(X_np)

        # Perturbation and prediction (batched)
        X_batch = []
        for j in range(d):
            X_plus = X_np.copy(); X_plus[:, j] += epsilons[j]
            X_minus = X_np.copy(); X_minus[:, j] -= epsilons[j]
            X_batch.append(X_plus)
            X_batch.append(X_minus)

        X_batch = np.concatenate(X_batch, axis=0)
        preds_batch = model.predict(xgb.DMatrix(X_batch))  # shape: (2 * d * n,)
        preds_batch = preds_batch.reshape(2 * d, n)
  
        # Compute finite-diff gradients
        for j in range(d):
            f_plus = preds_batch[2 * j]
            f_minus = preds_batch[2 * j + 1]
            df_dx = (f_plus - f_minus) / (2 * epsilons[j])
            grad_estimate[:, j] = df_dx * grad_output.detach().cpu().numpy()

        # DP noise addition
        if use_dp:
            preclip_norm = np.linalg.norm(grad_estimate)
            # DP clipping
            if preclip_norm > clip_norm:
               grad_estimate *= clip_norm / (preclip_norm + 1e-6)
               grad_norm = clip_norm
            else:
               grad_norm = preclip_norm  # carry through unclipped value
            
            min_noise_scale = clip_norm * (noise_multiplier * min_noise_multiplier)  # Minimum protection
            target_noise_std = grad_norm / min_snr
            # Discrete noise levels (multiples of base noise_multiplier)
            noise_options = np.array([0.25, 0.5, 0.75, 1.0, 1.5, 2.0]) * noise_multiplier * clip_norm
            chosen_noise_std = max(min(noise_options[noise_options >= target_noise_std]), min_noise_scale) if len(noise_options[noise_options >= target_noise_std]) > 0 else min_noise_scale

            valid_noise = noise_options[noise_options >= target_noise_std]
            
            if len(valid_noise) > 0:
                # Case 1: Can satisfy SNR exactly
                chosen_noise_std = max(valid_noise.min(), min_noise_scale)
            else:
                # Case 2: Use largest available noise while respecting minimum
                chosen_noise_std = max(noise_options.max(), min_noise_scale)

            noise = rng.normal(0, chosen_noise_std, grad_estimate.shape)
            grad_estimate = grad_estimate + noise

            effective_noise_multiplier = chosen_noise_std / clip_norm
            accountant.step(noise_multiplier=effective_noise_multiplier,sample_rate=sample_rate)
            # print(f"[DP-Step] pre-clip={preclip_norm:.3f}, final={grad_norm:.3f}, noise_std={chosen_noise_std:.3f}, SNR={grad_norm / chosen_noise_std:.2f}"
        return torch.tensor(grad_estimate, dtype=torch.float32, device=device), None, None, None, None, None


In [None]:
# ================================================
# Initialize Synthetic Dataset, can be done based on mean+noise of real classes 
# ================================================
# num_features = X_train.shape[1]
# X_syn = torch.randn((num_syn, num_features), dtype=torch.float32, requires_grad=True)
# y_syn = torch.cat([torch.zeros(num_syn//2), torch.ones(num_syn//2)])
# evaluate(X_syn,y_syn,xgb_params,dval,X_val,y_val) 

In [None]:
# ----------------------
# Synthetic Data Init
# ----------------------
num_features = X_train.shape[1]
pos_ratio = 0.5
dp_noise_scale = 0.01
num_pos = int(num_syn * pos_ratio)
num_neg = num_syn - num_pos

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.to_numpy(), dtype=torch.long)
X_train_0 = X_train_tensor[y_train_tensor == 0]
X_train_1 = X_train_tensor[y_train_tensor == 1]

mu_0 = X_train_0.mean(dim=0) + torch.normal(0, dp_noise_scale, size=(num_features,))
std_0 = X_train_0.std(dim=0) + torch.normal(0, dp_noise_scale, size=(num_features,)) + 1e-5
mu_1 = X_train_1.mean(dim=0) + torch.normal(0, dp_noise_scale, size=(num_features,))
std_1 = X_train_1.std(dim=0) + torch.normal(0, dp_noise_scale, size=(num_features,)) + 1e-5

X_syn_0 = mu_0 + torch.randn((num_neg, num_features)) * std_0
X_syn_1 = mu_1 + torch.randn((num_pos, num_features)) * std_1
X_syn = torch.cat([X_syn_0, X_syn_1], dim=0).requires_grad_(True)
y_syn = torch.cat([torch.zeros(num_neg), torch.ones(num_pos)])

evaluate(X_syn, y_syn, xgb_params, dval, X_val, y_val)

In [None]:
# Getting preds from model for real examples, to be used in matching loss
pred_real = DPFiniteDiffGradientEstimator.apply(torch.tensor(X_train, dtype=torch.float32), SEED)

In [None]:
LOSS1 = LOSS2 = LOSS = 0
eval_itr = 100
best_cindex = 0
best_X = None
optimizer = torch.optim.Adam([X_syn], lr=0.01)

for step in range(T):
    
    step_rng = torch.Generator().manual_seed(SEED + step)    
    pred_syn =  DPFiniteDiffGradientEstimator.apply(X_syn, True, noise_multiplier, max_grad_norm, (0.25, 3),SEED+step)
    
    # Supervision Loss
    loss_pred = torch.nn.BCELoss()(pred_syn, y_syn)

    # Distrubtion matching in preds space
    pred_syn_cat, idx_groups = stratify_pred_binary(y_syn, pred_syn)
    pred_real_cat = stratify_real_pred_binary(pred_real.clone().detach(), idx_groups, X_syn, step_rng)
    loss_matching = torch.norm(pred_syn_cat - pred_real_cat, p=1) / (X_syn.shape[0] + 1e-8)

    # Weighted combination of survival and gradient matching loss
    desired_ratio = 0.01 # hyperparameter
    alpha = (loss_pred.item() / (loss_matching.item() + 1e-9)) * (desired_ratio / (1 - desired_ratio))
    loss = loss_pred + alpha * loss_matching

    LOSS1 += loss_pred.item()
    LOSS2 += alpha*loss_matching.item()
    LOSS += loss.item()

    # Backprop with DP update
    optimizer.zero_grad() 
    loss.backward()
    optimizer.step()
    
    # Evaluation
    if step>0 and step % eval_itr == 0:
        c_index=evaluate(X_syn,y_syn,xgb_params,dval,X_val,y_val) 
        print(f"Step {step:04d} | Total Loss: {LOSS/eval_itr:.3f} | Pred Loss: {LOSS1/eval_itr:.3f} | Grad Loss: {LOSS2/eval_itr:.3f} | AUROC: {c_index:.4f}")
        LOSS = LOSS1 = LOSS2 = 0
        if c_index > best_cindex:
            best_cindex = c_index
            best_X = X_syn.detach().clone()
            best_step = step
            accountant.max_steps = best_step 
        # Early stopping
        if (step-best_step)>4000:
            break    


In [None]:
eps_best, _ = accountant.get_privacy_spent(delta)
print(f"📊 Final Privacy Budget (up to best step {best_step}): ε = {eps_best:.3f}, δ = {delta}")

In [None]:
d_syn = xgb.DMatrix(data=best_X, label=y_syn.detach().numpy())
model_eval = xgb.train(
    params=xgb_params,
    dtrain=d_syn,
    evals=[(dval, "validation")],
    num_boost_round=5000,
    early_stopping_rounds=50,
    verbose_eval=False
)


print("\n================ FINAL METRICS ================")
y_pred = model_eval.predict(dtest)
auroc = roc_auc_score(y_test, y_pred)
loss = binary_cross_entropy(y_test, y_pred)
precision, recall, _ = precision_recall_curve(y_test, y_pred)
svc_pr_auc = auc(recall, precision)
print(f"AUROC: {auroc:.4f} | AUPRC: {svc_pr_auc:.4f} | CE Loss: {loss:.4f}")