In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import warnings
import math
from dataclasses import dataclass
from typing import Tuple, Optional
warnings.filterwarnings('ignore')


## proxy strength (more or less informative about U): sigma_z, sigma_v

| Corr(Z, U) | σ (sigma) |
|------------|-----------|
| 0.9        | 0.73      |
| 0.8        | 1.125     |
| 0.7        | 1.53      |
| 0.6        | 2.0       |
| 0.5        | 2.6       |
| 0.4        | 3.44      |
| 0.3        | 4.77      |
| 0.2        | 7.35      |
| 0.1        | 14.93     |

* we fix $a_z, a_v = 1.5$ for all cases.
* The table for sigma_v are the same as sigma_z.

## strength of confounding (how much U affects treatment assignment W and outcome T): gamma_u_in_w, beta_u_in_t


| level    | gamma_u_in_w | beta_u_in_t |
|----------|---------|--------|
| none     | 0.0     | 0.0    |
| weak     | 0.2     | 0.2    |
| moderate | 0.5     | 0.5    |
| strong   | 1.0     | 1.0    |
| extreme  | 2.0     | 2.0    |

## strength of treatment effect: tau_log_hr

| tau_log_hr value | Interpretation                          |
|------------------|------------------------------------------|
| 0                | No treatment effect                      |
| < 0              | Beneficial (lower hazard)      |
| > 0              | Harmful (higher hazard)        |

**possible range**

| Effect size category        | tau_log_hr range | Hazard ratio range (exp(tau)) | Interpretation |
|----------------------------|------------------|-------------------------------|----------------|
| Null / negligible          | [-0.1, 0.1]      | [0.90, 1.11]                  | Little to no effect |
| Small to moderate (typical)| [-0.3, 0.3]      | [0.74, 1.35]                  | Plausible clinical effects |
| Moderate to large          | [-0.7, 0.7]      | [0.50, 2.01]                  | Strong but still realistic |
| Extreme (use with caution) | < -1 or > 1      | < 0.37 or > 2.7               | Often unrealistic / unstable |

## code

In [None]:
def sigmoid(x: np.ndarray) -> np.ndarray:
    """
    Sigmoid function.
    """
    return 1.0 / (1.0 + np.exp(-x))

def calibrate_intercept_for_prevalence(
    linpred_no_intercept: np.ndarray,
    target_prevalence: float,
    max_iter: int = 60,
) -> float:
    """
    Find an intercept b0 such that the average treatment probability 
    equals a desired prevalence.
    """
    lo, hi = -20.0, 20.0
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        p = sigmoid(mid + linpred_no_intercept).mean()
        if p < target_prevalence:
            lo = mid
        else:
            hi = mid
    return 0.5 * (lo + hi)

def weibull_ph_time_paper(
    u01: np.ndarray, 
    k: float, 
    lam: float,
    eta: np.ndarray) -> np.ndarray:
    """
    Generates survival times from a Weibull proportional hazards model, 
    using inverse-CDF sampling.
    """
    u01 = np.clip(u01, 1e-12, 1 - 1e-12)
    scale = lam * np.exp(-eta / k)
    return scale * (-np.log(u01)) ** (1.0 / k)

@dataclass
class SynthConfig:
    n: int = 5000                               # sample size
    p_x: int = 10                               # number of covariates
    seed: int = 123

    # Treatment (W_i) assignment parameters
    w_prevalence: float = 0.5                   # target treatment prevalence
    gamma_u_in_w: float = 1.0                   # strength of unmeasured confounding in treatment (gamma_U)

    # Survival (T_i(w)) outcome parameters
    k_t: float = 1.5                            # Weibull shape
    lam_t: float = 0.4                          # Weibull scale
    tau_log_hr: float = -0.6                    # log hazard ratio for treatment effect (tau)
    beta_u_in_t: float = 0.8                    # strength of unmeasured confounding in outcome (beta_U)
    
    # Censoring parameters
    k_c: float = 1.2                            # Weibull shape
    lam_c: Optional[float] = None               # Weibull scale (if None, will be calibrated)
    beta_u_in_c: float = 0.3                    # strength of unmeasured confounding in censoring (beta_U)
    target_censor_rate: float = 0.35
    max_censor_calib_iter: int = 60             # iteration control for binary search when calibrating censoring
    censor_lam_lo: float = 1e-8
    censor_lam_hi: float = 1e6
    admin_censor_time: Optional[float] = None   # a fixed administrative censoring cutoff time

    # Negative control variables parameters
    az: float = 1.5                             # coefficient for U in Z
    av: float = 1.5                             # coefficient for U in V
    sigma_z: float = 0.8                        # std dev of measurement error in Z
    sigma_v: float = 0.8                        # std dev of measurement error in V

@dataclass
class SynthParams:

    # Coefficient vectors of observed covariates X
    b_z: np.ndarray
    b_v: np.ndarray
    beta_t: np.ndarray

def generate_synthetic_nc_cox(cfg: SynthConfig) -> Tuple[pd.DataFrame, pd.DataFrame, SynthParams]:
    """
    Generate synthetic survival data with:
      - unmeasured confounding (latent U)
      - treatment assignment confounded by U
      - Weibull proportional hazards outcome model
      - Weibull censoring model (possibly informative)
      - two negative control variables Z and V

    Returns:
      observed_df: pd.DataFrame
          What an analyst observes (time, event, W, X, Z, V)
      truth_df: pd.DataFrame
          Latent variables and counterfactual outcomes (for evaluation only)
      params: SynthParams
          True coefficient vectors used in the DGP
    """
    rng = np.random.default_rng(cfg.seed)
    n, p = cfg.n, cfg.p_x

    X = rng.normal(size=(n, p))
    U = rng.normal(size=n)

    # --------------------------------------------------------------------------
    # Generate negative control variables Z and V
    b_z = rng.normal(scale=0.3, size=p)
    b_v = rng.normal(scale=0.3, size=p)

    Z = cfg.az * U + X @ b_z + rng.normal(scale=cfg.sigma_z, size=n)
    V = cfg.av * U + X @ b_v + rng.normal(scale=cfg.sigma_v, size=n)
    
    # --------------------------------------------------------------------------
    # Generate treatment W
    alpha = rng.normal(scale=0.5, size=p)
    linpred = X @ alpha + cfg.gamma_u_in_w * U
    b0 = calibrate_intercept_for_prevalence(linpred, cfg.w_prevalence)

    p_w = sigmoid(b0 + linpred)
    W = rng.binomial(1, p_w, size=n).astype(int)

    # --------------------------------------------------------------------------
    # Generate survival times T
    beta_t = rng.normal(scale=0.4, size=p)
    u_t = rng.random(n)

    # ηi(w) = β_t^T * Xi +β_u * U_i +τw
    eta_t0 = X @ beta_t + cfg.beta_u_in_t * U + cfg.tau_log_hr * 0.0
    eta_t1 = X @ beta_t + cfg.beta_u_in_t * U + cfg.tau_log_hr * 1.0

    T0 = weibull_ph_time_paper(u_t, k=cfg.k_t, lam=cfg.lam_t, eta=eta_t0)
    T1 = weibull_ph_time_paper(u_t, k=cfg.k_t, lam=cfg.lam_t, eta=eta_t1)

    # --------------------------------------------------------------------------
    # Generate censoring times C
    beta_c = rng.normal(scale=0.3, size=p)
    u_c = rng.random(n)
    
    eta_c0 = X @ beta_c + cfg.beta_u_in_c * U
    eta_c1 = X @ beta_c + cfg.beta_u_in_c * U

    T_obs_for_calib = np.where(W == 1, T1, T0)
    lam_c_used = cfg.lam_c
    if lam_c_used is None:
        lo, hi = float(cfg.censor_lam_lo), float(cfg.censor_lam_hi)
        for _ in range(cfg.max_censor_calib_iter):
            mid = 0.5 * (lo + hi)
            C0_mid = weibull_ph_time_paper(u_c, k=cfg.k_c, lam=mid, eta=eta_c0)
            C1_mid = weibull_ph_time_paper(u_c, k=cfg.k_c, lam=mid, eta=eta_c1)
            C_obs_mid = np.where(W == 1, C1_mid, C0_mid)
            censor_rate_mid = (C_obs_mid < T_obs_for_calib).mean()
            if censor_rate_mid < cfg.target_censor_rate:
                hi = mid
            else:
                lo = mid
        lam_c_used = 0.5 * (lo + hi)

    C0 = weibull_ph_time_paper(u_c, k=cfg.k_c, lam=lam_c_used, eta=eta_c0)
    C1 = weibull_ph_time_paper(u_c, k=cfg.k_c, lam=lam_c_used, eta=eta_c1)

    T = np.where(W == 1, T1, T0)
    C = np.where(W == 1, C1, C0)
    time = np.minimum(T, C)
    event = (T <= C).astype(int)

    if cfg.admin_censor_time is not None:
        admin = float(cfg.admin_censor_time)
        cens_by_admin = admin < time
        time = np.where(cens_by_admin, admin, time)
        event = np.where(cens_by_admin, 0, event).astype(int)

    # --------------------------------------------------------------------------
    # Assemble observed dataframe
    X_cols = {f"X{j}": X[:, j] for j in range(p)}

    observed_df = pd.DataFrame({"time": time, "event": event, "W": W, "A": W, "Z": Z, "V": V, **X_cols})

    # --------------------------------------------------------------------------
    # Assemble truth dataframe
    truth_df = observed_df.copy()
    truth_df.insert(0, "U", U)
    truth_df["T0"] = T0
    truth_df["T1"] = T1
    truth_df["C0"] = C0
    truth_df["C1"] = C1
    truth_df["T"] = T
    truth_df["C"] = C
    truth_df.attrs["lam_c_used"] = lam_c_used

    params = SynthParams(b_z=b_z, b_v=b_v, beta_t=beta_t)
    return observed_df, truth_df, params

def add_eq8_eq9_columns(
    observed_df: pd.DataFrame,
    truth_df: pd.DataFrame,
    cfg: SynthConfig,
    params: SynthParams,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    obs = observed_df.copy()
    tru = truth_df.copy()
    """
    Augment observed and truth dataframes with “oracle” / benchmarking quantities derived from the
    negative-control Gaussian measurement model and the Weibull PH event-time model.

    Returns:
      obs: pd.DataFrame
        Copy of observed_df with additional columns:
        - tildeZ, tildeV
        - mu_U_post, var_U_post
        - CATE_XZV_eq9
      tru: pd.DataFrame
        Copy of truth_df with additional columns:
        - CATE_XU_eq7
        - ITE_T1_minus_T0
    """
    x_cols = sorted([c for c in obs.columns if c.startswith("X")], key=lambda s: int(s[1:]))
    X = obs[x_cols].to_numpy()
    Z = obs["Z"].to_numpy()
    V = obs["V"].to_numpy()

    # Z̃ = Z - X b_z
    tildeZ = Z - X @ params.b_z
    tildeV = V - X @ params.b_v

    az, av = float(cfg.az), float(cfg.av)                           # coefficients for U in Z and V
    sz2, sv2 = float(cfg.sigma_z) ** 2, float(cfg.sigma_v) ** 2     # variances of measurement errors
    denom = (az**2) * sv2 + (av**2) * sz2 + sz2 * sv2               # common denominator

    mu_post = (az * sv2 * tildeZ + av * sz2 * tildeV) / denom
    var_post = (sz2 * sv2) / denom

    k = float(cfg.k_t)
    lam = float(cfg.lam_t)
    tau = float(cfg.tau_log_hr)
    beta_u = float(cfg.beta_u_in_t)

    G = math.gamma(1.0 + 1.0 / k)                                   # Γ(1 + 1/k)                           
    xb = X @ params.beta_t

    # --------------------------------------------------------------------------
    # oracle CATE formula conditioning on (X,Z,V) via the posterior of U
    cate_xzv = (
        lam * G * np.exp(-(1.0 / k) * xb -(beta_u / k) * mu_post + 0.5 * (beta_u**2) * var_post / (k**2))
        * (np.exp(-tau / k) - 1.0)
    )

    # --------------------------------------------------------------------------
    # oracle CATE formula conditioning on (X,U)
    U = tru["U"].to_numpy()
    cate_xu = (lam * G * np.exp(-(1.0 / k) * (xb + beta_u * U)) * (np.exp(-tau / k) - 1.0))

    # --------------------------------------------------------------------------
    # true individual treatment effect on event time: T(1) - T(0)
    ite = tru["T1"].to_numpy() - tru["T0"].to_numpy()

    obs["tildeZ"] = tildeZ
    obs["tildeV"] = tildeV
    obs["mu_U_post"] = mu_post
    obs["var_U_post"] = var_post
    obs["CATE_XZV_eq9"] = cate_xzv

    tru["CATE_XU_eq7"] = cate_xu
    tru["ITE_T1_minus_T0"] = ite

    return obs, tru

In [4]:
cfg = SynthConfig(
    n=10000,
    p_x=10,
    lam_c=1e6,
    seed=42,
)

obs_df, truth_df, params = generate_synthetic_nc_cox(cfg)
obs_df, truth_df = add_eq8_eq9_columns(obs_df, truth_df, cfg, params)

print(f"Dataset size: {len(obs_df)}")
print(f"Censoring rate: {1 - obs_df['event'].mean():.3f}")
print(f"Treatment prevalence: {obs_df['W'].mean():.3f}")
print(f"\nGround truth CATE (Eq.7, true U): mean={truth_df['CATE_XU_eq7'].mean():.4f}, std={truth_df['CATE_XU_eq7'].std():.4f}")
print(f"Proxy-based CATE (Eq.9): mean={obs_df['CATE_XZV_eq9'].mean():.4f}, std={obs_df['CATE_XZV_eq9'].std():.4f}")

Dataset size: 10000
Censoring rate: 0.000
Treatment prevalence: 0.500

Ground truth CATE (Eq.7, true U): mean=0.3266, std=0.4902
Proxy-based CATE (Eq.9): mean=0.3261, std=0.4618
