In [11]:
import pandas as pd
import numpy as np

from helper_functions import *
from sklearn.model_selection import ParameterGrid
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from scipy.stats import expon, uniform

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

In [12]:
# Sample size
n = 1000

# Number of covariates
n_linear = 3
n_non_linear = 3
n_categorical = 3



num_covariates = (n_linear, n_non_linear, n_categorical)

num_latent_group = 3

num_simulation = 10

save_dict = {}

In [13]:
CoxPH_estimated = []
CoxPH_IPTW_estimated = []
PSM_CoxPH_estimated = []
DCM_latent_estimated = []
DCM_total_estimated = []
DSM_estimated = []

In [14]:
from statsmodels.duration.hazard_regression import PHReg

# Function to fit Cox PH model
def fit_cox_model(data):
    exog = data.drop(["time", "event"], axis=1)
    model = PHReg(endog=data["time"], exog=exog, status=data["event"])
    result = model.fit()
    return result

In [15]:
import statsmodels.api as sm

# Function to fit IPTW with Cox PH model
def fit_iptw_cox_model(data):
    # Step 1: Fit logistic regression for propensity scores (P(treatment=1 | X))
    exog_logit = sm.add_constant(data.drop(["treatment", "time", "event"], axis=1))  # Add intercept
    logit_model = sm.Logit(data["treatment"], exog_logit)
    logit_result = logit_model.fit(disp=0)  # disp=0 to suppress output

    # Predict propensity scores
    ps = logit_result.predict(exog_logit)

    # Step 2: Compute IPTW weights
    # Stabilized weights: weight = treatment / ps + (1 - treatment) / (1 - ps)
    weights = data["treatment"] / ps + (1 - data["treatment"]) / (1 - ps)

    # Step 3: Fit weighted Cox PH model (only include treatment as covariate, since confounding is handled by weights)
    exog_cox = data[["treatment"]]
    cox_model = PHReg(endog=data["time"], exog=exog_cox, status=data["event"], weights=weights)
    cox_result = cox_model.fit()

    return cox_result

In [16]:
from scipy.spatial import KDTree

# Function to fit PSM with Cox PH model (1:1 nearest neighbor matching with caliper)
def fit_psm_cox_model(data, caliper=0.05):
    # Step 1: Fit logistic regression for propensity scores (P(treatment=1 | X))
    exog_logit = sm.add_constant(data.drop(["treatment", "time", "event"], axis=1))  # Add intercept
    logit_model = sm.Logit(data["treatment"], exog_logit)
    logit_result = logit_model.fit(disp=0)  # disp=0 to suppress output

    # Predict propensity scores
    data["ps"] = logit_result.predict(exog_logit)

    # Separate treated and controls
    treated = data[data["treatment"] == 1].copy()
    controls = data[data["treatment"] == 0].copy()

    # Build KDTree for controls' PS (reshape to 2D)
    controls_ps = controls["ps"].values.reshape(-1, 1)
    tree = KDTree(controls_ps)

    # Match each treated to nearest control within caliper
    matched_indices = []
    used_controls = set()

    for idx_t, row_t in treated.iterrows():
        ps_t = row_t["ps"].reshape(1, -1)
        dist, idx_c = tree.query(ps_t)
        idx_c = idx_c[0]  # query returns array

        if dist <= caliper and idx_c not in used_controls:
            matched_indices.append(controls.iloc[idx_c].name)  # Original index
            used_controls.add(idx_c)

    # Get matched controls
    matched_controls = controls.loc[matched_indices]

    # Combine matched treated and controls
    matched_data = pd.concat([treated, matched_controls])

    # Drop 'ps' column if not needed
    matched_data = matched_data.drop(columns=["ps"])

    # Step 3: Fit Cox PH model on matched data (only treatment as covariate)
    exog_cox = matched_data[["treatment"]]
    cox_model = PHReg(endog=matched_data["time"], exog=exog_cox, status=matched_data["event"])
    cox_result = cox_model.fit()

    return cox_result

In [17]:
import numpy as np
from scipy.interpolate import interp1d

def compute_rmst_difference(result, df):
    """
    Compute RMST difference (treated - control) from fitted statsmodels PHReg model.
    - result: Fitted PHRegResults object from statsmodels.
    - df: DataFrame with covariates (including 'treatment').
    - times: Array of time points for integration (e.g., np.linspace(0, tau, 1000)).
    - tau: Restriction time.
    """
    times = np.linspace(0, np.max(df["time"]), 1000)
    # Get baseline cumulative hazard (list of triples; use first for unstratified)
    baseline_list = result.baseline_cumulative_hazard
    if len(baseline_list) == 0:
        raise ValueError("No baseline cumulative hazard available.")
    # Assume unstratified (single stratum); unpack times and cumhaz from the triple
    baseline_times, baseline_cumhaz, _ = baseline_list[0]  # Ignore survival

    # Interpolate cumulative hazard for desired times
    interp_cumhaz = interp1d(baseline_times, baseline_cumhaz, kind='linear', fill_value='extrapolate')
    cumhaz_values = interp_cumhaz(times)

    # Exog data (covariates including treatment) - use result.model.exog_names
    exog = df[result.model.exog_names]
    beta = result.params
    treatment_idx = result.model.exog_names.index('treatment')  # Assumes 'treatment' column

    def average_survival(a):
        # Set treatment to a for all observations
        exog_copy = exog.copy()
        exog_copy['treatment'] = a

        # Linear predictors
        lp = np.dot(exog_copy, beta)

        # Average S(t | A=a) at each time
        avg_s = np.mean(np.exp(-cumhaz_values[:, np.newaxis] * np.exp(lp)), axis=1)
        return avg_s

    # Compute average survival curves
    avg_s_treated = average_survival(1)
    avg_s_control = average_survival(0)

    # RMST via trapezoidal integration
    rmst_treated = np.trapz(avg_s_treated, times)
    rmst_control = np.trapz(avg_s_control, times)

    # Causal effect: RMST difference
    return rmst_treated - rmst_control

In [18]:
for i in range(num_simulation):
    df, save_dict, lambda0, lambda1 = generate_simulated_data(num_sample=n,
                                 num_group=num_latent_group,
                                 num_covariates=num_covariates,
                                 **save_dict)

    """
    Cox PH model
    """
    result = fit_cox_model(df)
    # CoxPH_estimated.append(result.params[0])
    CoxPH_estimated.append(compute_rmst_difference(result, df))

    """
    Cox PH model with IPTW
    """
    result = fit_iptw_cox_model(df)
    # CoxPH_IPTW_estimated.append(result.params[0])
    CoxPH_IPTW_estimated.append(compute_rmst_difference(result, df))

    """
    PSM with Cox PH model
    """
    result = fit_psm_cox_model(df)
    # PSM_CoxPH_estimated.append(result.params[0])
    PSM_CoxPH_estimated.append(compute_rmst_difference(result, df))

    """
    DCM model
    """
    categorical_covariates = list(df.columns[n_linear + n_non_linear: n_linear + n_non_linear + n_categorical])
    # Data preparation
    cleaned_df = df.copy()

    # Extract treatment
    treatment = cleaned_df["treatment"]

    # Names of all categorical features
    processed_list = categorical_covariates

    # Exclude treatment during the training
    cleaned_df.drop("treatment", axis=1, inplace=True)

    # Processed data for DCM
    X_data_DCM, t_data_DCM, e_data_DCM, categorical_features_list_DCM, numerical_features_list_DCM = (
        processing_data_2_DCM(df=cleaned_df,
                              categorical_features_list=processed_list,
                              train_test_val_size=(0.7, 0.2, 0.1),
                              random_seed=RANDOM_SEED,
                              clustering=False,  # No need to cluster
                              )
    )

    X_train_DCM, X_val_DCM, X_test_DCM = X_data_DCM
    t_train_DCM, t_val_DCM, t_test_DCM = t_data_DCM
    e_train_DCM, e_val_DCM, e_test_DCM = e_data_DCM

    whole_X_set_DCM = pd.concat([X_train_DCM, X_val_DCM, X_test_DCM], axis=0)
    whole_t_set_DCM = pd.concat([t_train_DCM, t_val_DCM, t_test_DCM], axis=0)
    whole_e_set_DCM = pd.concat([e_train_DCM, e_val_DCM, e_test_DCM], axis=0)

    covariates_DCM = list(whole_X_set_DCM.columns)

    # Resample data with IPTW
    df_ps_DCM = pd.concat([whole_X_set_DCM, treatment], axis=1)

    df_ps_DCM = compute_PS_and_IPTW(df=df_ps_DCM,
                                covariates=covariates_DCM,
                                treatment="treatment")

    prob_DCM = df_ps_DCM["iptw_weight"] / df_ps_DCM["iptw_weight"].sum()

    resampled_indices_DCM = np.random.choice(len(df_ps_DCM), size=len(df_ps_DCM), replace=True, p=prob_DCM)

    x_resampled_DCM = whole_X_set_DCM.values[resampled_indices_DCM]
    t_resampled_DCM = whole_t_set_DCM.values[resampled_indices_DCM]
    e_resampled_DCM = whole_e_set_DCM.values[resampled_indices_DCM]
    treatment_resampled_DCM = treatment.values[resampled_indices_DCM]

    x_resampled_df_DCM = pd.DataFrame(x_resampled_DCM, columns=whole_X_set_DCM.columns)
    t_resampled_df_DCM = pd.DataFrame(t_resampled_DCM, columns=["time"])
    e_resampled_df_DCM = pd.DataFrame(e_resampled_DCM, columns=["event"])
    treatment_resampled_df_DCM = pd.DataFrame(treatment_resampled_DCM, columns=["treatment"])

    df_resampled_DCM = pd.concat([x_resampled_df_DCM,
                                 t_resampled_df_DCM,
                                 e_resampled_df_DCM,
                                 treatment_resampled_df_DCM],
                                axis=1
                                )

    # DCM hyper-parameters
    DCM_param_grid = {"k" : [3],
                      "learning_rate" : [1e-3],
                      "layers": [[50, 50]],
                      "iters": [100]
                 }
    DCM_params = ParameterGrid(DCM_param_grid)

    # Define and train DCM model
    dcm_wrap = DCM_Wrapper(DCM_params)

    dcm_wrap.fit(train_set=[df_resampled_DCM[covariates_DCM],
                            df_resampled_DCM["time"],
                            df_resampled_DCM["event"]]
                 )

    dcm_model = dcm_wrap.model

    train_latent_DCM = dcm_model.predict_latent_z(df_resampled_DCM[covariates_DCM])

    train_group_DCM = np.argmax(train_latent_DCM, axis=1)

    dcm_causal_effects_latent = plot_avg_survival_curve(df=df_resampled_DCM,
                                                        group_index=train_group_DCM,
                                                        model_wrapper=dcm_wrap,
                                                        covariates=covariates_DCM,
                                                        treatment="treatment",
                                                        num_time=1000,
                                                        show_figure=False
                                                        )

    DCM_latent_estimated.append(dcm_causal_effects_latent)

    dcm_causal_effects_total = plot_avg_survival_curve(df=df_resampled_DCM,
                                                       group_index=np.zeros_like(train_group_DCM),
                                                       model_wrapper=dcm_wrap,
                                                       covariates=covariates_DCM,
                                                       treatment="treatment",
                                                       num_time=1000,
                                                       given_title="Total Survival Curve (Treated v.s. Untreated)",
                                                       show_figure=False
                                                       )

    DCM_total_estimated.append(dcm_causal_effects_total[0])

    """
    DSM model
    """
    df_resampled_DSM = df_resampled_DCM
    covariates_DSM = covariates_DCM

    # DSM hyper-parameters
    DSM_param_grid = {"distribution": ["Weibull"],
                      "k": [3],
                      "layers": [[50, 50]],
                      "learning_rate": [1e-3],
                      "iters": [100]
                 }
    DSM_params = ParameterGrid(DSM_param_grid)

    # Define and train DCM model
    dsm_wrap = DSM_Wrapper(DSM_params)

    dsm_wrap.fit(train_set=[df_resampled_DSM[covariates_DSM],
                            df_resampled_DCM["time"],
                            df_resampled_DSM["event"]]
                 )

    dsm_causal_effects = plot_avg_survival_curve(df=df_resampled_DSM,
                                                group_index=np.zeros_like(train_group_DCM),
                                                model_wrapper=dsm_wrap,
                                                covariates=covariates_DSM,
                                                treatment="treatment",
                                                num_time=1000,
                                                given_title="Survival Curve (Treated v.s. Untreated)",
                                                show_figure=False
                                                )

    DSM_estimated.append(dsm_causal_effects)

  probs = gates+np.log(event_probs)
  return spl(ts)**risks
  s0ts = (-risks)*(spl(ts)**(risks-1))
  probs = gates+np.log(event_probs)
 45%|████▌     | 45/100 [00:01<00:01, 28.82it/s]
  return spl(ts)**risks
  S1_avg = np.nanmean(S1, axis=0)
  S0_avg = np.nanmean(S0, axis=0)
  return spl(ts)**risks
  S1_avg = np.nanmean(S1, axis=0)
  S0_avg = np.nanmean(S0, axis=0)
  return spl(ts)**risks
  S1_avg = np.nanmean(S1, axis=0)
  S0_avg = np.nanmean(S0, axis=0)
  return spl(ts)**risks
  S1_avg = np.nanmean(S1, axis=0)
  S0_avg = np.nanmean(S0, axis=0)
  2%|▏         | 197/10000 [00:00<00:09, 1028.96it/s]
 26%|██▌       | 26/100 [00:01<00:05, 14.51it/s]
  probs = gates+np.log(event_probs)
  probs = gates+np.log(event_probs)
  return spl(ts)**risks
  s0ts = (-risks)*(spl(ts)**(risks-1))
 10%|█         | 10/100 [00:00<00:04, 19.88it/s]
  2%|▏         | 162/10000 [00:00<00:09, 1039.17it/s]
  5%|▌         | 5/100 [00:00<00:06, 14.19it/s]
  probs = gates+np.log(event_probs)
  probs = gates+np.log(

In [19]:
def create_estimates_dict(estimates_list,
                          num_groups=3,
                          total_estimates=None):
    """
    Create the estimates_dict for compute_bias from overall and per-group estimates lists.
    :param overall_estimates: List of estimated effects from replicates.
    :param num_groups: Number of groups to use.
    :param total_estimates: List of estimated effects for total.
    :return estimates_dict (dict): {'overall': overall_estimates, group_id: list_for_group, ...}
    """
    if total_estimates is None:
        estimates_dict = {"average": estimates_list}

        for i in range(num_groups):
            estimates_dict[f"group {i}"] = estimates_list
    else:
        estimates_dict = {"average": total_estimates}

        for i in range(num_groups):
            estimates_dict[f"group {i}"] = [row[i] for row in estimates_list]

    return estimates_dict


In [20]:
CoxPH_estimated_dict = create_estimates_dict(CoxPH_estimated)
CoxPH_IPTW_estimated_dict = create_estimates_dict(CoxPH_IPTW_estimated)
PSM_CoxPH_estimated_dict = create_estimates_dict(PSM_CoxPH_estimated)
DSM_estimated_dict = create_estimates_dict(DSM_estimated)
DCM_latent_estimated_dict = create_estimates_dict(DCM_latent_estimated, total_estimates=DCM_total_estimated)

In [21]:
true_value = save_dict["beta_A_k"]
avg_true_value = np.mean(true_value)
true_value_dict = {"average": avg_true_value,
                   "group 0": true_value[0],
                   "group 1": true_value[1],
                   "group 2": true_value[2]}

## Different metrics

In [23]:
def compute_bias(estimates_dict, true_values_dict):
    """
    Compute absolute and relative bias from a dict of estimates and true values
    :param estimates_dict: dict of estimated effects
    :param true_values_dict: dict of true values
    :return: biases
    """
    biases = {}
    for key in true_values_dict:

        true_value = true_values_dict[key]
        estimates = estimates_dict[key]

        mean_est = np.mean(estimates)
        abs_bias = np.abs(mean_est - true_value)

        if true_value == 0:
            rel_bias = None
        else:
            rel_bias = (mean_est - true_value) / true_value

        biases[key] = {"abs_bias": abs_bias, "rel_bias": rel_bias}

    return biases


In [24]:
def compute_mse(estimates_dict, true_values_dict):
    """
    Compute MSE from a dict of estimates and true values,
    to emphasize latent groups by computing per group and overall.
    :param estimates_dict: dict of estimated effects
    :param true_values_dict: dict of true values
    :return
    """
    mses = {}
    for key in estimates_dict:

        estimates = estimates_dict[key]
        true_value = true_values_dict[key]

        mse = np.mean([(e - true_value) ** 2 for e in estimates])

        mses[key] = mse

    return mses

## Get df

In [25]:
"""
compute_bias
"""
CoxPH_bias = compute_bias(CoxPH_estimated_dict, true_value_dict)
CoxPH_IPTW_bias = compute_bias(CoxPH_IPTW_estimated_dict, true_value_dict)
PSM_CoxPH_bias = compute_bias(PSM_CoxPH_estimated_dict, true_value_dict)
DSM_estimated_bias = compute_bias(DSM_estimated_dict, true_value_dict)
DCM_latent_bias = compute_bias(DCM_latent_estimated_dict, true_value_dict)

In [26]:
"""
compute_mse
"""
CoxPH_mse = compute_mse(CoxPH_estimated_dict, true_value_dict)
CoxPH_IPTW_mse = compute_mse(CoxPH_IPTW_estimated_dict, true_value_dict)
PSM_CoxPH_mse = compute_mse(PSM_CoxPH_estimated_dict, true_value_dict)
DSM_estimated_mse = compute_mse(DSM_estimated_dict, true_value_dict)
DCM_latent_mse = compute_mse(DCM_latent_estimated_dict, true_value_dict)

In [27]:
df_dict = {"Model Name": ["CoxPH", "CoxPH_IPTW", "PSM_CoxPH", "DSM", "DCM"]}
for key in true_value_dict:

    df_dict[key + "_abs_bias"] = [
        CoxPH_bias[key]["abs_bias"],
        CoxPH_IPTW_bias[key]["abs_bias"],
        PSM_CoxPH_bias[key]["abs_bias"],
        DSM_estimated_bias[key]["abs_bias"],
        DCM_latent_bias[key]["abs_bias"]
    ]

    df_dict[key + "_rel_bias"] = [
        CoxPH_bias[key]["rel_bias"],
        CoxPH_IPTW_bias[key]["rel_bias"],
        PSM_CoxPH_bias[key]["rel_bias"],
        DSM_estimated_bias[key]["rel_bias"],
        DCM_latent_bias[key]["rel_bias"]
    ]

    df_dict[key + "_mse"] = [
        CoxPH_mse[key],
        CoxPH_IPTW_mse[key],
        PSM_CoxPH_mse[key],
        DSM_estimated_mse[key],
        DCM_latent_mse[key]
    ]

In [28]:
pd.DataFrame(df_dict)

Unnamed: 0,Model Name,average_abs_bias,average_rel_bias,average_mse,group 0_abs_bias,group 0_rel_bias,group 0_mse,group 1_abs_bias,group 1_rel_bias,group 1_mse,group 2_abs_bias,group 2_rel_bias,group 2_mse
0,CoxPH,8.738996,-18.419941,76.792775,8.846798,-15.194596,78.688557,8.843208,-15.282667,78.625046,8.526983,-32.493932,73.132152
1,CoxPH_IPTW,6.159667,-12.983264,38.976746,6.267469,-10.764533,40.316415,6.263878,-10.825117,40.271425,5.947653,-22.664832,36.40983
2,PSM_CoxPH,6.965882,-14.682594,48.972078,7.073684,-12.149228,50.48557,7.070094,-12.218404,50.434791,6.753869,-25.737093,46.063304
3,DSM,0.488848,-1.030388,0.244314,0.59665,-1.024761,0.361333,0.59306,-1.024915,0.357061,0.276835,-1.054938,0.081979
4,DCM,0.329741,-0.695023,0.112533,0.538712,-0.925252,0.296198,0.526072,-0.909148,0.278432,1.74861,6.663461,3.326174
