In [18]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from auton_survival.estimators import SurvivalModel
from auton_survival.metrics import survival_regression_metric
from auton_survival.preprocessing import Preprocessor
from scipy.stats import expon, uniform
from helper_functions import *

# Define causal metrics functions
def pehe(pred_risk_a1, pred_risk_a0, true_risk_a1, true_risk_a0):
    cate_pred = pred_risk_a1 - pred_risk_a0
    cate_true = true_risk_a1 - true_risk_a0
    return np.sqrt(np.mean((cate_pred - cate_true)**2, axis=0))  # Per time, then can average

def ate_error(pred_risk_a1, pred_risk_a0, true_risk_a1, true_risk_a0):
    ate_pred = np.mean(pred_risk_a1 - pred_risk_a0, axis=0)
    ate_true = np.mean(true_risk_a1 - true_risk_a0, axis=0)
    return np.abs(ate_pred - ate_true)

def policy_risk(pred_risk_a1, pred_risk_a0, true_risk_a1, true_risk_a0, threshold=0.0):
    cate_pred = pred_risk_a1 - pred_risk_a0
    # Assuming lower risk is better; treat if cate_pred < 0 (treatment reduces risk)
    treat_decision = (cate_pred < 0).astype(int)
    risk_under_policy = treat_decision * true_risk_a1 + (1 - treat_decision) * true_risk_a0
    return np.mean(risk_under_policy, axis=0)

# Main simulation
times = np.array([1, 2, 5, 10])  # Evaluation time points

# Generate data using the new function
df, save_dict, lambda0, lambda1 = generate_simulated_data(num_sample=200, num_group=3)

true_S0 = np.exp(-np.outer(lambda0, times))
true_S1 = np.exp(-np.outer(lambda1, times))
true_risk0 = 1 - true_S0
true_risk1 = 1 - true_S1

# Extract features, treatment, outcomes
covariate_cols = [col for col in df.columns if col.startswith('X')]
X = df[covariate_cols]  # Keep as DataFrame for Preprocessor
A = df['treatment'].values
outcomes = df[['time', 'event']]

# Define cat_feats and num_feats based on column names
cat_feats = [col for col in covariate_cols if col.startswith('X_c')]
num_feats = [col for col in covariate_cols if not col.startswith('X_c')]

# Train/test split
X_train_raw, X_test_raw, A_train, A_test, outcomes_train, outcomes_test, true_risk0_train, true_risk0_test, true_risk1_train, true_risk1_test = train_test_split(
    X, A, outcomes, true_risk0, true_risk1, test_size=0.2, random_state=42
)

# Preprocess features (scaling) - Pass cat_feats and num_feats to Preprocessor
preprocessor = Preprocessor().fit(X_train_raw, cat_feats=cat_feats, num_feats=num_feats)
X_train = preprocessor.transform(X_train_raw)
X_test = preprocessor.transform(X_test_raw)

# Models to compare: DSM and DCM
models = ['dsm', 'dcm']
results = {}

for model_name in models:
    print(f"Evaluating {model_name.upper()}...")

    # Manually implement 'direct' method: fit separate models for each arm
    # Split training data by treatment
    idx_control = A_train == 0
    idx_treated = A_train == 1

    x_train_control = X_train[idx_control]
    outcomes_train_control = outcomes_train[idx_control]

    x_train_treated = X_train[idx_treated]
    outcomes_train_treated = outcomes_train[idx_treated]

    # Model hyperparameters for DSM and DCM
    model_params = {'k': 3, 'layers': [100, 100]}

    # Fit model for control (A=0)
    model0 = SurvivalModel(model_name, **model_params)
    model0.fit(x_train_control, outcomes_train_control)

    # Fit model for treated (A=1)
    model1 = SurvivalModel(model_name, **model_params)
    model1.fit(x_train_treated, outcomes_train_treated)

    # Predict counterfactual survivals on full test set
    cf_S0 = model0.predict_survival(X_test, times)
    cf_S1 = model1.predict_survival(X_test, times)
    cf_risk0 = 1 - cf_S0
    cf_risk1 = 1 - cf_S1

    # Causal metrics
    pehe_score = pehe(cf_risk1, cf_risk0, true_risk1_test, true_risk0_test)
    ate_err = ate_error(cf_risk1, cf_risk0, true_risk1_test, true_risk0_test)
    pol_risk = policy_risk(cf_risk1, cf_risk0, true_risk1_test, true_risk0_test)

    # Factual predictions: based on actual A_test
    factual_S = np.zeros_like(cf_S0)
    for i, a in enumerate(A_test):
        factual_S[i] = cf_S1[i] if a == 1 else cf_S0[i]

    # Standard survival metrics (using full outcomes_train for IPCW)
    brs = survival_regression_metric('brs', outcomes_test, factual_S, times, outcomes_train)
    ibs = survival_regression_metric('ibs', outcomes_test, factual_S, times, outcomes_train)
    auc = survival_regression_metric('auc', outcomes_test, factual_S, times, outcomes_train)
    ctd = survival_regression_metric('ctd', outcomes_test, factual_S, times, outcomes_train)

    results[model_name] = {
        'PEHE (mean)': np.mean(pehe_score),
        'ATE Error (mean)': np.mean(ate_err),
        'Policy Risk (mean)': np.mean(pol_risk),
        'IBS': ibs,
        'AUC (mean)': np.mean(auc),
        'CTD': ctd,
        'BRS (mean)': np.mean(brs)
    }

# Display results
results_df = pd.DataFrame(results).T
print("Simulation Results Comparison:")
print(results_df)

Evaluating DSM...


  5%|▍         | 454/10000 [00:00<00:09, 976.01it/s] 
100%|██████████| 50/50 [00:00<00:00, 175.65it/s]
  4%|▍         | 395/10000 [00:00<00:08, 1185.26it/s]
  6%|▌         | 3/50 [00:00<00:02, 18.87it/s]


ValueError: Shape of passed values is (40, 1), indices imply (40, 4)