In [2]:
import sys
import os

sys.path.append("../..")

if not os.path.exists('power_plot_results'):
    os.makedirs('power_plot_results')

from tqdm import tqdm
from datetime import datetime

import pandas as pd
import statsmodels.api as sm
import numpy as np
from sklearn.linear_model import (
    LogisticRegressionCV,
    LinearRegression,
)

from src.data.synthetic import SparseModifierSynthetic
from src.randomization_aware.learners import (
    DRLearner,
    QuasiOptimizedLearner,
)
from src.baselines.asaiee import AsaieeCATE


In [3]:
def trial_only_test(data, covariates, alpha=0.05, pool_data = False):

    if pool_data is False:
        data = data[data["S"] == 1]

    # Fit linear regression model
    X = sm.add_constant(data[["A", "X0", "interaction_term"] + covariates])  # Add intercept
    y = data["Y"]

    # Fit OLS regression model
    model = sm.OLS(y, X).fit()

    # Check if the coefficient for "interaction_term" is significantly different from 0
    conf_int = model.conf_int(alpha=alpha).loc["interaction_term"]
    return conf_int[0] > 0 or conf_int[1] < 0


def randomization_aware_test(data, covariates, cate_model):

    data = data.copy()

    X = data[covariates].values
    S = data["S"].values.reshape(-1, 1)
    A = data["A"].values.reshape(-1, 1)
    Y = data["Y"].values.reshape(-1, 1)
    cate_model.fit(X, S, A, Y)

    # Combine pseudo-outcomes from all folds
    computed_pseudo_outcomes = cate_model.get_computed_pseudo_outcomes()
    data["pseudo_outcome"] = computed_pseudo_outcomes[:, 0]
    data["fold"] = computed_pseudo_outcomes[:, 1]

    # Fit linear regression model for each fold
    param_value = []
    param_se = []
    for fold in data["fold"].unique():
        observed_data_S1 = data[(data["S"] == 1) & (data["fold"] == fold)]
        obs_X = sm.add_constant(observed_data_S1[["X0"]])  # Add intercept
        response = observed_data_S1["pseudo_outcome"]
        model = sm.OLS(response, obs_X).fit()
        param_value.append(model.params["X0"])
        param_se.append(model.bse["X0"])

    mean_param = np.mean(param_value)
    aggregated_se = np.sqrt(np.sum(np.array(param_se) ** 2) / 4)

    lcl = mean_param - 1.96 * aggregated_se
    ucl = mean_param + 1.96 * aggregated_se

    return lcl > 0 or ucl < 0

In [4]:


regressor = LinearRegression


methods = {
    "Trial-only covariate adjustment": lambda data, covariates: trial_only_test(
        data, ["X0"]
    ),
    "Pooled covariate adjustment": lambda data, covariates: trial_only_test(
        data, ["X0"], pool_data=True
    ),
    "DR-learner": lambda data, covariates: randomization_aware_test(
        data,
        covariates,
        DRLearner(
            propensity_score=1 / 2,
            regressor_cate=LinearRegression(),
            regressor_treated=regressor(),
            regressor_control=regressor(),
            crossfit_folds=2,
        ),
    ),
    "Quasi-optimized": lambda data, covariates: randomization_aware_test(
        data,
        covariates,
        QuasiOptimizedLearner(
            propensity_score=1 / 2,
            regressor_cate=LinearRegression(),
            regressor_treated=regressor(),
            regressor_control=regressor(),
            study_classifier=LogisticRegressionCV(max_iter=1000),
            crossfit_folds=2,
        ),
    ),
    "Asaiee": lambda data, covariates: randomization_aware_test(
        data,
        covariates,
        AsaieeCATE(
            propensity_score=1 / 2,
            regressor_cate=LinearRegression(),
            regressor_treated=regressor(),
            regressor_control=regressor(),
            crossfit_folds=2,
        ),
    ),
}


def run_experiment(n_trials, n_obs, iterations, dgp):
    results = []

    for n_trial in n_trials:
        print(f"n_trial = {n_trial}")
        for i in tqdm(range(iterations)):
            data = dgp.sample(n_trial=n_trial, n_obs=n_obs)
            covariates = dgp.get_covar()
            data["interaction_term"] = data["A"] * data["X0"]

            for method_name, method_func in methods.items():
                output = method_func(data, covariates)
                results.append(
                    {
                        "iteration": i,
                        "n_trial": n_trial,
                        "estimator": method_name,
                        "reject": output,
                    }
                )

    results_df = pd.DataFrame(results)
    return results_df

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"timestamp: {timestamp}")
n_trial = [25, 50, 75, 100, 125, 160, 175, 200] 
n_obs = 1000
iterations = 5
interaction_effect_size = [0.0, 0.5]

dfs = []
for effect_size in interaction_effect_size:
    print(f'effect_size = {effect_size}')
    dgp = SparseModifierSynthetic(
        population_shift=0.2, n_features=5, effect_modifiers=[effect_size]
    )
    results = run_experiment(n_trial, n_obs, iterations, dgp)
    results["interaction_effect_size"] = effect_size
    dfs.append(results)

result_df = pd.concat(dfs)

# Save the DataFrame as a CSV file

result_df.to_csv(f'power_plot_results/results_{timestamp}.csv', index=False)

timestamp: 20250519_104438
effect_size = 0.0
n_trial = 25


100%|██████████| 5/5 [00:02<00:00,  2.23it/s]


n_trial = 50


100%|██████████| 5/5 [00:02<00:00,  2.09it/s]


n_trial = 75


100%|██████████| 5/5 [00:02<00:00,  2.14it/s]


n_trial = 100


100%|██████████| 5/5 [00:02<00:00,  1.84it/s]


n_trial = 125


100%|██████████| 5/5 [00:01<00:00,  2.51it/s]


n_trial = 160


100%|██████████| 5/5 [00:02<00:00,  2.46it/s]


n_trial = 175


100%|██████████| 5/5 [00:01<00:00,  2.80it/s]


n_trial = 200


100%|██████████| 5/5 [00:01<00:00,  2.54it/s]


effect_size = 0.1
n_trial = 25


100%|██████████| 5/5 [00:01<00:00,  3.05it/s]


n_trial = 50


100%|██████████| 5/5 [00:01<00:00,  2.78it/s]


n_trial = 75


100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


n_trial = 100


100%|██████████| 5/5 [00:02<00:00,  2.14it/s]


n_trial = 125


100%|██████████| 5/5 [00:02<00:00,  2.07it/s]


n_trial = 160


100%|██████████| 5/5 [00:02<00:00,  2.04it/s]


n_trial = 175


100%|██████████| 5/5 [00:02<00:00,  1.86it/s]


n_trial = 200


100%|██████████| 5/5 [00:02<00:00,  1.93it/s]


effect_size = 0.25
n_trial = 25


100%|██████████| 5/5 [00:02<00:00,  1.96it/s]


n_trial = 50


100%|██████████| 5/5 [00:02<00:00,  2.28it/s]


n_trial = 75


100%|██████████| 5/5 [00:02<00:00,  2.11it/s]


n_trial = 100


100%|██████████| 5/5 [00:02<00:00,  2.26it/s]


n_trial = 125


100%|██████████| 5/5 [00:02<00:00,  2.07it/s]


n_trial = 160


100%|██████████| 5/5 [00:02<00:00,  2.33it/s]


n_trial = 175


100%|██████████| 5/5 [00:02<00:00,  2.45it/s]


n_trial = 200


100%|██████████| 5/5 [00:02<00:00,  2.14it/s]


effect_size = 0.5
n_trial = 25


100%|██████████| 5/5 [00:02<00:00,  1.97it/s]


n_trial = 50


100%|██████████| 5/5 [00:02<00:00,  2.17it/s]


n_trial = 75


100%|██████████| 5/5 [00:02<00:00,  2.29it/s]


n_trial = 100


100%|██████████| 5/5 [00:01<00:00,  2.53it/s]


n_trial = 125


100%|██████████| 5/5 [00:01<00:00,  3.04it/s]


n_trial = 160


100%|██████████| 5/5 [00:02<00:00,  2.31it/s]


n_trial = 175


100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


n_trial = 200


100%|██████████| 5/5 [00:02<00:00,  2.41it/s]
