In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import seaborn as sns
import pickle
import argparse
import os

import sys
from pathlib import Path
# Add the parent directory to the system path
sys.path.append(str(Path().resolve().parent))

from causal_meta_learners.causal_inference_modeling import *
from causal_meta_learners.experiment_setup import *
from causal_meta_learners.survival_models import *

## Initialize the Arguments

In [2]:
from datetime import datetime
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")

# Simulating command-line arguments in Jupyter Notebook
sys.argv = [
    "notebook", 
    "--data_address", "../data_splits/mental-health-full/people_dict_unfiltered_expanded.pickle", 
    "--dataframe_address", "../data_generation/adherence_export_expanded.csv", 
    "--output_address", f"./results_causal_survival_forest_{current_datetime}",
    "--num_repeats", "5"
]

parser = argparse.ArgumentParser(description="Run experiments with causal survival forest model.")
parser.add_argument("--data_address", type=str, required=True, help="Path to the data pickle file.")
parser.add_argument("--dataframe_address", type=str, required=True, help="Path to the dataframe CSV file.")
parser.add_argument("--output_address", type=str, required=True, help="Path to save the output JSON file.")
parser.add_argument("--non_adherence_threshold", type=float, default=1./3, help="Non-adherence threshold.")
parser.add_argument("--minimum_num_time_steps", type=int, default=4, help="Minimum number of time steps.")
parser.add_argument("--low_occurrency_threshold", type=int, default=2, help="Low occurrence threshold.")
parser.add_argument("--experiment_task", type=str, default="survival", help="Experiment task type.")
parser.add_argument("--experiment_type", type=str, default="Composite Event", help="Experiment type.")
parser.add_argument("--experiment_num", type=str, default="SA", help="Experiment number.")
parser.add_argument("--handle_imbalance", type=bool, default=True, help="Handle imbalance in the data.")
parser.add_argument("--num_repeats", type=int, default=5, help="Number of random seeds to use.")

args = parser.parse_args()

print(args.output_address)

./results_causal_survival_forest_20250131224449


## Initial Randomization Setup

In [3]:
# Generate random seeds
np.random.seed(0)
random_seeds = np.random.randint(0, 10000, 10).tolist()
random_seeds = random_seeds[:args.num_repeats]
print(random_seeds)

[2732, 9845, 3264, 4859, 9225]


In [4]:
def set_all_seeds(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

### Experiment Setup "SA":  $((X, f(\bar{A}_{T-1})), A_T, Y)$ where $f(\bar{A}_{T})=[A_{1},..., A_{T-1}]$

In [5]:
import pickle
import pandas as pd
import numpy as np
import os
import rpy2.robjects as ro
from rpy2.robjects.packages import importr, PackageNotInstalledError
from rpy2.robjects import numpy2ri, FloatVector, r


def run_experiment(
    data_address, 
    dataframe_address, 
    output_address,
    non_adherence_threshold=1./3, 
    minimum_num_time_steps=4, 
    low_occurrency_threshold=2, 
    experiment_task="survival", 
    experiment_type="Composite Event", 
    experiment_num="SA", 
    handle_imbalance=True,
    continuous_covariates_lst=None, 
    post_hoc_covariates_lst=None,
    random_seeds=[42]
):
    # Set default covariates if not provided
    if continuous_covariates_lst is None:
        continuous_covariates_lst = [
            'age', 'predicted_PRO_MORTALITY_12MO', 'predicted_PRO_JAILSTAY_12MO', 
            'predicted_PRO_OVERDOSE_12MO', 'predicted_PRO_302_12MO', 'predicted_PRO_SHELTER_STAY_12MO'
        ]
    if post_hoc_covariates_lst is None:
        post_hoc_covariates_lst = ['covered_by', 'covered_by_injectable']
    
    # Load data
    with open(data_address, 'rb') as handle:
        people_dict = pickle.load(handle)
    data_df = pd.read_csv(dataframe_address)
    
    # Initialize or load existing results
    results_file = output_address + '.pickle'
    if os.path.exists(results_file):
        with open(results_file, 'rb') as f:
            results = pickle.load(f)
    else:
        results = {}
    
    # Save hyperparameters
    hyper_params = {
        "data_address": data_address,
        "dataframe_address": dataframe_address,
        "output_address": output_address,
        "non_adherence_threshold": non_adherence_threshold,
        "minimum_num_time_steps": minimum_num_time_steps,
        "low_occurrency_threshold": low_occurrency_threshold,
        "experiment_task": experiment_task,
        "experiment_type": experiment_type,
        "experiment_num": experiment_num,
        "handle_imbalance": handle_imbalance,
        "continuous_covariates_lst": continuous_covariates_lst,
        "post_hoc_covariates_lst": post_hoc_covariates_lst,
        "random_seeds": random_seeds
    }
    results["hyper_params"] = hyper_params


    # Activate automatic data frame and numpy array conversion
    numpy2ri.activate()
    stats = importr("stats")  # stats package provides the generic predict function
    # Try importing grf package
    try:
        grf = importr('grf')
    except PackageNotInstalledError:
        print("The 'grf' package is not installed in R. Please install it by running: install.packages('grf') in R.")


    # Run experiments for each model and seed
    model_name = 'CausalSurvivalForest'
    if model_name not in results:
        results[model_name] = {}

    for random_seed in random_seeds:
        if str(random_seed) in results[model_name]:
            print(f"Skipping {model_name} with random seed {random_seed} as it already exists.")
            continue

        set_all_seeds(random_seed)

        # Initialize patient data
        patient_data = PatientData(
            people_dict, data_df, 
            experiment_type=experiment_type,
            task=experiment_task, 
            non_adherence_threshold=non_adherence_threshold, 
            minimum_num_time_steps=minimum_num_time_steps, 
            low_occurrency_threshold=low_occurrency_threshold,
            continuous_covariates_lst=continuous_covariates_lst,
            post_hoc_covariates_lst=post_hoc_covariates_lst,
            random_seed=random_seed
        )
        causal_data_dict = patient_data.get_causal_data_setup_for_each_experiment(experiment_num, random_state=random_seed)

        # Run the causal survival forest model
        print(f"Running {model_name} with random seed {random_seed}...")

        # Convert Python arrays to R objects
        X_r = r.matrix(causal_data_dict['X']['total'], nrow=causal_data_dict['X']['total'].shape[0], ncol=causal_data_dict['X']['total'].shape[1])
        Y_r = FloatVector(causal_data_dict['Y']['total'][:, 0])
        W_r = FloatVector(causal_data_dict['A']['total'])
        D_r = FloatVector(causal_data_dict['Y']['total'][:, 1])

        # Call causal_survival_forest from grf
        cs_forest = grf.causal_survival_forest(
            X_r,
            Y_r,
            W_r,
            D_r,
            target = "RMST",    # specify target as RMST (Restricted Mean Survival Time)
            horizon = max(causal_data_dict['Y']['total'][:, 0])   # set horizon to maximum time point
        )

        # Now use the 'stats' package's predict function, which will dispatch correctly:
        cs_pred = stats.predict(cs_forest)

        # The result is an R list with a "predictions" element. Extract it:
        predictions = cs_pred.rx2("predictions")  # Extract "predictions" from the R list
        predictions_np = np.array(predictions)

        results[model_name][str(random_seed)] = {'ITE': predictions_np, 'causal_data_dict': causal_data_dict}

        # Save the causal data dict for the run
        results[model_name][str(random_seed)]['causal_data_dict'] = causal_data_dict
        print("-" * 100)

        # Save results incrementally
        temp_file = results_file + '.tmp'
        with open(temp_file, 'wb') as f:
            pickle.dump(results, f)
        os.replace(temp_file, results_file)  # Atomically replace the old file with the new one
    
    print(f"Results saved to {results_file}")

In [6]:
# current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
# args.output_address = f"./results_{current_datetime}"
# args.minimum_num_time_steps = 4
# print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

Running CausalSurvivalForest with random seed 2732...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 9845...
----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 3264...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
[Random-Seed:4859] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 4859...
----------------------------------------------------------------------------------------------------
[Random-Seed:9225] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9225...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Results saved to ./results_causal_survival_forest_20250131224449.pickle


In [12]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 7
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131224944
[Random-Seed:2732] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 2732...
----------------------------------------------------------------------------------------------------
[Random-Seed:9845] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9845...
----------------------------------------------------------------------------------------------------
[Random-Seed:3264] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 3264...
----------------------------------------------------------------------------------------------------
[Random-Seed:4859] Standard deviation of columns in the total set is 0 (for one of

In [13]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 10
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131225024
Running CausalSurvivalForest with random seed 2732...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 9845...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 3264...
----------------------------------------------------------------------------------------------------
[Random-Seed:4859] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 4859...
----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 9225...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Results saved to ./results_causal_survival_forest_20250131225024.pickle


In [14]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 13
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131225046
[Random-Seed:2732] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 2732...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 9845...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
[Random-Seed:3264] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 3264...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 4859...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
[Random-Seed:9225] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9225...


R[write to console]: 
 



----------------------------------------------------------------------------------------------------
Results saved to ./results_causal_survival_forest_20250131225046.pickle


In [15]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 19
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131225108
[Random-Seed:2732] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 2732...
----------------------------------------------------------------------------------------------------
[Random-Seed:9845] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9845...
----------------------------------------------------------------------------------------------------
[Random-Seed:3264] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 3264...
----------------------------------------------------------------------------------------------------
[Random-Seed:4859] Standard deviation of columns in the total set is 0 (for one of

In [16]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 25
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131225115
[Random-Seed:2732] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 2732...
----------------------------------------------------------------------------------------------------
[Random-Seed:9845] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9845...
----------------------------------------------------------------------------------------------------
[Random-Seed:3264] Standard deviation of columns in the total set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 3264...
----------------------------------------------------------------------------------------------------
[Random-Seed:4859] Standard deviation of columns in the total set is 0 (for one of

In [17]:
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
args.output_address = f"./results_causal_survival_forest_{current_datetime}"
args.minimum_num_time_steps = 37
print(args.output_address)

# Call the function with parsed arguments
run_experiment(data_address=args.data_address, dataframe_address=args.dataframe_address,output_address=args.output_address,
               non_adherence_threshold=args.non_adherence_threshold, minimum_num_time_steps=args.minimum_num_time_steps,
               low_occurrency_threshold=args.low_occurrency_threshold,
               experiment_task=args.experiment_task, experiment_type=args.experiment_type, experiment_num=args.experiment_num,
               handle_imbalance=args.handle_imbalance, 
               random_seeds=random_seeds)

./results_causal_survival_forest_20250131225122
[Random-Seed:2732] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 2732...
----------------------------------------------------------------------------------------------------
[Random-Seed:9845] Standard deviation of columns in the training set is 0 (for one of the treatment assignments). Rearranging the data...
Running CausalSurvivalForest with random seed 9845...
----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 3264...
----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 4859...
----------------------------------------------------------------------------------------------------
Running CausalSurvivalForest with random seed 9225...
--