In [None]:
import torch
import numpy as np
import random
import pandas as pd
import pickle
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
from tqdm import tqdm
from opacus.accountants import RDPAccountant

from load_survival_data import load_seer
from utils_surv import (
    bootstrap_cindex_ci,
    stratify_syn_pred,
    stratify_real_pred,
    SelectiveAccountant,
    evaluate_coxph,
    coxph_loss
)

In [None]:
# Set seeds
SEED = 40
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# ----------------------------- #
# Load and Preprocess Data
# ----------------------------- #
TRAIN, VAL, TEST, _ = load_seer()
X_train, (dur_train, evt_train) = TRAIN
X_val, (dur_val, evt_val) = VAL
X_test, (dur_test, evt_test) = TEST

# Convert months to years
dur_train /= 12
dur_val /= 12
dur_test /= 12

# Scale durations robustly (optional, useful for numerical stability)
time_median = np.median(dur_train[evt_train == 1])
time_iqr = np.percentile(dur_train[evt_train == 1], 75) - np.percentile(dur_train[evt_train == 1], 25)
time_scale = time_iqr if time_iqr > 0 else time_median

valid_idx = (~np.isnan(dur_val)) & (~np.isinf(dur_val)) & (evt_val >= 0)
test_idx = (~np.isnan(dur_test)) & (~np.isinf(dur_test)) & (evt_test >= 0)

In [None]:
# ----------------------------- #
# Fit CoxPH Model on Real Dataset
# ----------------------------- #
train_df = pd.DataFrame(X_train)
train_df["duration"] = dur_train
train_df["event"] = evt_train

test_df = pd.DataFrame(X_test[test_idx])
test_df["duration"] = dur_test[test_idx]
test_df["event"] = evt_test[test_idx]

cph_real = CoxPHFitter(penalizer=0.1)
cph_real.fit(train_df, duration_col="duration", event_col="event")

val_risk = cph_real.predict_partial_hazard(test_df)
mean_cindex, lower, upper = bootstrap_cindex_ci(test_df["duration"], test_df["event"], val_risk)

print(f"Full Dataset C-index: {mean_cindex:.4f} (95% CI: {lower:.4f} – {upper:.4f})")

In [None]:
#### Parameter setting 

num_syn=200  # Number of synthetic samples
T = 20000    # Training Iterations   

# DP parameters
noise_multiplier =3 
max_grad_norm = 1
delta = 1e-5
min_snr = 1                     # you can tune this
min_noise_multiplier = 0.25      # ensures some privacy even if grad is small
sample_rate=num_syn / len(X_train)
accountant  = SelectiveAccountant(RDPAccountant(), max_steps=0)

In [None]:
# Using to estimate zero-order gradients in Torch AutoDiff Setup
class DPCoxGradientEstimator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X_syn, use_dp=False, noise_multiplier=noise_multiplier, 
                clip_norm=max_grad_norm, epsilon_range=(0.25, 0.5), seed=None):
        ctx.save_for_backward(X_syn)
        ctx.use_dp = use_dp
        ctx.noise_multiplier = noise_multiplier
        ctx.clip_norm = clip_norm
        ctx.epsilon_range = epsilon_range
        ctx.seed = seed
        X_np = X_syn.detach().cpu().numpy()
        preds = cph_real.predict_log_partial_hazard(pd.DataFrame(X_np)).values.flatten() #scope of model is global, so avilable everywhere
        return torch.tensor(preds, dtype=torch.float32, device=X_syn.device)

    @staticmethod
    def backward(ctx, grad_output):
        X_syn, = ctx.saved_tensors
        X_np = X_syn.detach().cpu().numpy()
        n, d = X_np.shape
        rng = np.random.default_rng(ctx.seed)
        
        # Vectorized perturbation
        epsilons = rng.uniform(low=ctx.epsilon_range[0], high=ctx.epsilon_range[1], size=d)
        eye = np.eye(d)
        X_plus = X_np[:, None, :] + epsilons * eye  # shape (n, d, d)
        X_minus = X_np[:, None, :] - epsilons * eye  # shape (n, d, d)
        
        # Reshape for batch prediction
        X_batch = np.vstack([X_plus.reshape(-1, d), X_minus.reshape(-1, d)])
        hr_batch = cph_real.predict_log_partial_hazard(pd.DataFrame(X_batch)).values.flatten()
        
        # Reshape and compute gradients
        hr_batch_plus = hr_batch[:n*d].reshape(n, d)
        hr_batch_minus = hr_batch[n*d:].reshape(n, d)
        hr_diff = (hr_batch_plus - hr_batch_minus)
        grad_estimate = (hr_diff / (2 * epsilons)) * grad_output.cpu().numpy()[:, None]

        # DP processing
        if ctx.use_dp:
            grad_norms = np.linalg.norm(grad_estimate, axis=1, keepdims=True)  # (n, 1)
            clip_mask = grad_norms > ctx.clip_norm  # (n, 1)
            
            scaling = (ctx.clip_norm / (grad_norms + 1e-8))  # (n, 1)
            grad_estimate = np.where(clip_mask, grad_estimate * scaling, grad_estimate)  # (n, d)

            target_noise_std = np.mean(grad_norms) / min_snr

            noise_options = np.array([0.25, 0.5, 0.75, 1.0, 1.5, 2.0]) * ctx.noise_multiplier * ctx.clip_norm
            valid_noise = noise_options[noise_options >= target_noise_std]
            if len(valid_noise) > 0:
                chosen_noise_std = max(valid_noise.min(), ctx.clip_norm * min_noise_multiplier)
            else:
                chosen_noise_std = max(noise_options.max(), ctx.clip_norm * min_noise_multiplier)

            noise = rng.normal(0, chosen_noise_std, grad_estimate.shape)
            grad_estimate += noise

            effective_noise_multiplier = chosen_noise_std / ctx.clip_norm
            accountant.step(noise_multiplier=effective_noise_multiplier, sample_rate=sample_rate)


        return torch.tensor(grad_estimate, dtype=torch.float32, device=X_syn.device), None, None, None, None, None


In [None]:
# ================================================
# Initialising synthetic dataset
# ================================================
num_features = X_train.shape[1]
num_bins = 20  # <-- change this to 4 for quartiles, 10 for deciles, etc.

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
t_min = dur_train.min()
cutt = dur_train[evt_train == 0].max()

bin_edges = np.linspace(t_min, cutt, num_bins + 1)

# Allocate event samples evenly across bins
samples_per_bin = (num_syn // 2) // num_bins

T_events = [np.random.uniform(bin_edges[i], bin_edges[i + 1], samples_per_bin) for i in range(num_bins)]
T_events = np.concatenate(T_events)
T_events = np.pad(T_events, (0, max(0, (num_syn // 2) - len(T_events))), constant_values=cutt)
T_censored = np.full((num_syn // 2,), cutt, dtype=np.float32)

T_syn = np.concatenate([T_events, T_censored])
E_syn = np.concatenate([np.ones_like(T_events), np.zeros_like(T_censored)])

T_syn_tensor = torch.tensor(T_syn, dtype=torch.float32)
E_syn_tensor = torch.tensor(E_syn, dtype=torch.float32)

# Initialize X_syn with DP-noised stats, can be intialised randomly
mu = X_train_tensor.mean(dim=0) + torch.normal(0, 0.001, size=(num_features,))
std = X_train_tensor.std(dim=0) + torch.normal(0, 0.001, size=(num_features,)) + 1e-5
X_syn = mu + torch.randn((num_syn, num_features), dtype=torch.float32) * std

# setting grad true as we are going to update it
X_syn = X_syn.requires_grad_(True)


In [None]:
# ----------------------------- #
# Coverting to torch tensors and setting up optimiser
# ----------------------------- #
optimizer = torch.optim.Adam([X_syn], lr=0.01)
dur_tensor = torch.tensor(dur_train, dtype=torch.float32)
evt_tensor = torch.tensor(evt_train, dtype=torch.float32)
indices_evt = (evt_tensor == 1).nonzero(as_tuple=True)[0]
indices_cens = (evt_tensor == 0).nonzero(as_tuple=True)[0]

### Computing Trained model preds for real dataset, used in matching loss
X_real_sample = torch.tensor(X_train, dtype=torch.float32)
pred_real = DPCoxGradientEstimator.apply(X_real_sample, SEED)

In [None]:
LOSS1 = LOSS2 = LOSS = 0

eval_itr = 100 # Validation after eval_itr iterations
best_cindex = 0
best_X = None

for step in range(T):
    
    step_rng = torch.Generator().manual_seed(SEED + step)    
    pred_syn =  DPCoxGradientEstimator.apply(X_syn, True, noise_multiplier, max_grad_norm, (0.25, 3),SEED+step) # Epsilon range tuned manually
    
    # Supervision Loss
    loss_surv = coxph_loss(pred_syn, T_syn_tensor, E_syn_tensor)

    # Distrubtion matching in preds space
    pred_syn_cat,q1,q2=stratify_syn_pred(T_syn_tensor,E_syn_tensor,pred_syn)
    pred_real_cat=stratify_real_pred(dur_tensor,indices_evt,indices_cens,q1,q2,X_syn,pred_real,step_rng)
    loss_matching = torch.norm( pred_syn_cat - pred_real_cat, p=1) / (X_syn.shape[0] + 1e-8)

    LOSS1 += loss_surv.item()
    LOSS2 += loss_matching.item()

    # Weighted combination of survival and gradient matching loss
    desired_ratio = 0.1              # to be tuned manually
    alpha = (loss_surv.item() / (loss_matching.item() + 1e-9)) * (desired_ratio / (1 - desired_ratio))
    loss = loss_surv + alpha * loss_matching
    
    LOSS += loss.item()

    # Backprop with DP update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Evaluation
    if step % eval_itr == 0:
        c_index=evaluate_coxph(X_syn,T_syn,E_syn,dur_val,evt_val,X_val) 
        print(f"Step {step:04d} | Total Loss: {LOSS/eval_itr:.3f} | Surv Loss: {LOSS1/eval_itr:.3f} | Grad Loss: {LOSS2/eval_itr:.3f} | C-index: {c_index:.4f}")
        LOSS = LOSS1 = LOSS2 = 0
        if c_index > best_cindex:
            best_cindex = c_index
            best_X = X_syn.detach().clone()
            best_step = step
            accountant.max_steps = best_step 
        
        if (step-best_step)>3000:  # Early stopping
            break    


In [None]:
# Computing privacy budget
eps_best, _ = accountant.get_privacy_spent(delta)
print(f"📊 Final Privacy Budget (up to best step {best_step}): ε = {eps_best:.3f}, δ = {delta}")

In [None]:
# ----------------------------- #
# Evaluate Condensed Dataset
# ----------------------------- #
train_df = pd.DataFrame(best_X)
train_df["duration"] = T_syn
train_df["event"] = E_syn

cph_syn = CoxPHFitter(penalizer=0.1)
cph_syn.fit(train_df, duration_col="duration", event_col="event")
val_risk = cph_syn.predict_partial_hazard(test_df)

mean_cindex, lower, upper = bootstrap_cindex_ci(test_df["duration"], test_df["event"], val_risk)
print(f"Condensed Model C-index: {mean_cindex:.4f} (95% CI: {lower:.4f} – {upper:.4f})")

In [None]:
# dataset_dict = {
#     'data': best_X,
#     'time': T_syn,
#     'event':E_syn
# }
# import pickle

# with open('./condensed/COX_SEER_'+str(int(num_syn/2))+'.pkl', 'wb') as f:
#     pickle.dump(dataset_dict, f)