In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import seaborn as sns
import pickle
import json
import sys
import argparse
import torch
import random
import os

from pathlib import Path
# Add the parent directory to the system path
sys.path.append(str(Path().resolve().parent))

from causal_meta_learners.causal_inference_modeling import *
from causal_meta_learners.experiment_setup import *
from causal_meta_learners.survival_models import *

# Load the results

In [2]:
all_results = []

# import all pickle files in the results directory
for file in os.listdir("results_full"):
    if file.endswith(".pickle"):
        with open(f"results_full/{file}", "rb") as f:
            all_results.append(pickle.load(f))


results_per_time_snapshot = {}

for result in all_results:
    results_per_time_snapshot[result['hyper_params']['minimum_num_time_steps']-1] = result

del all_results

In [3]:
all_results_ablations = []

# import all pickle files in the results directory
for file in os.listdir("results_ablation"):
    if file.endswith(".pickle"):
        with open(f"results_ablation/{file}", "rb") as f:
            all_results_ablations.append(pickle.load(f))


results_per_time_snapshot_ablation = {}

for result in all_results_ablations:
    results_per_time_snapshot_ablation[result['hyper_params']['minimum_num_time_steps']-1] = result

del all_results_ablations

In [4]:
all_causal_survival_forest_results = []

# import all pickle files in the results directory
for file in os.listdir("results_causal_survival_forest"):
    if file.endswith(".pickle"):
        with open(f"results_causal_survival_forest/{file}", "rb") as f:
            all_causal_survival_forest_results.append(pickle.load(f))


results_causal_survival_forest_per_time_snapshot = {}

for result in all_causal_survival_forest_results:
    results_causal_survival_forest_per_time_snapshot[result['hyper_params']['minimum_num_time_steps']-1] = result

del all_causal_survival_forest_results

In [5]:
all_causal_survival_forest_results_ablation = []

# import all pickle files in the results directory
for file in os.listdir("results_causal_survival_forest_ablation"):
    if file.endswith(".pickle"):
        with open(f"results_causal_survival_forest_ablation/{file}", "rb") as f:
            all_causal_survival_forest_results_ablation.append(pickle.load(f))


results_causal_survival_forest_per_time_snapshot_ablation = {}

for result in all_causal_survival_forest_results_ablation:
    results_causal_survival_forest_per_time_snapshot_ablation[result['hyper_params']['minimum_num_time_steps']-1] = result

del all_causal_survival_forest_results_ablation

In [6]:
survival_analysis_model_names = ['CoxPH', 'RandomSurvivalForest', 'DeepSurv', 'DeepHit']
causal_inference_model_names = {'T-learner':'t-learner', 'S-learner':'s-learner', 'Matching':'matching'}

In [7]:
survival_analysis_metrics = {'C-index':'concordance_td', 'IBS':'integrated_brier_score', 'Time-Dependent AUC':'td_auc'}
survival_analysis_metrics_direction = {'C-index':1, 'IBS':-1, 'Time-Dependent AUC':1}

# Causal Inference ATE from Different Survival Analysis Models at Different Time Snapshots

In [8]:
print("Number of Matches in the Matching Model:")
print(list(results_per_time_snapshot[3]['CoxPH']['2732']['matching'].keys()))
print()

num_matches_displayed = [1, 5, 20]

print("Number of Matches Displayed:")
print(num_matches_displayed)

Number of Matches in the Matching Model:
[1, 2, 5, 10, 20, 50, 100]

Number of Matches Displayed:
[1, 5, 20]


In [9]:
causal_inference_ate_table_per_time_snapshot = {}

for time_snapshot, result in sorted(results_per_time_snapshot.items()):
    causal_inference_ate_table = {}
    for model_name in survival_analysis_model_names:
        current_model_table = {}
        for causal_model_name, causal_model_saved in causal_inference_model_names.items():
            if causal_model_saved != 'matching':
                if model_name not in result.keys():
                    current_model_table[causal_model_name] = None
                else:
                    ate_repeats = [result[model_name][random_seed][causal_model_saved]['ATE'] for random_seed in result[model_name].keys()]
                    current_model_table[causal_model_name] = (np.mean(ate_repeats), np.std(ate_repeats))
            else:
                for num_matches_ in num_matches_displayed:
                    matching_str = causal_model_name + f" ({num_matches_})"
                    if model_name not in result.keys():
                        current_model_table[matching_str] = None
                    else:
                        ate_repeats = [result[model_name][random_seed][causal_model_saved][num_matches_]['ATE'] for random_seed in result[model_name].keys()]
                        current_model_table[matching_str] = (np.mean(ate_repeats), np.std(ate_repeats))
        causal_inference_ate_table[model_name] = current_model_table
    causal_inference_ate_table_per_time_snapshot[time_snapshot] = causal_inference_ate_table

In [10]:
causal_survival_forest_ate_table_per_time_snapshot = {}

for time_snapshot, result in sorted(results_causal_survival_forest_per_time_snapshot.items()):
    ate_repeats = [np.mean(result['CausalSurvivalForest'][random_seed]['ITE']) for random_seed in result['CausalSurvivalForest'].keys()]
    causal_survival_forest_ate_table_per_time_snapshot[time_snapshot] = {'CausalSurvivalForest': {'ATE': (np.mean(ate_repeats), np.std(ate_repeats))}}

In [11]:
def format_value(value, decimal_places=4):
    """Format the mean and standard deviation with the given decimal precision."""
    if value is None:
        return "--"
    mean, std = value
    return f"{mean:.{decimal_places}f} ± {std:.{decimal_places}f}"


def generate_causal_inference_table(time_snapshot, causal_inference_ate_table_per_time_snapshot,
                                    survival_analysis_model_names, causal_inference_meta_learners,
                                    causal_survival_forest_ate_table_per_time_snapshot=None,
                                    table_font_size="small", decimal_places=4):
    """Generate LaTeX table for a single snapshot time."""
    table = f"""
    \\begin{{table*}}[hbtp]
    \\setlength{{\\tabcolsep}}{{4pt}}
    \\{table_font_size}
    \\floatconts
      {{tab:table_time_{time_snapshot}}}
      {{\caption{{Causal Inference ATE Estimates at {time_snapshot} Months}}}}
      {{\\begin{{tabular}}{{l {'c ' * len(causal_inference_meta_learners)}}}
    \\toprule"""
    
    # Define column headers
    table += "\n    &" + " & ".join(causal_inference_meta_learners) + " \\\\"
    table += "\n    \midrule"
    
    # Fill in table values
    for model_name in survival_analysis_model_names:
        row = f"{model_name}"
        result = causal_inference_ate_table_per_time_snapshot.get(time_snapshot, {}).get(model_name, {})
        row += " & " + " & ".join([format_value(result.get(causal_model, None), decimal_places) for causal_model in causal_inference_meta_learners])
        table += "\n    " + row + " \\\\"

    if causal_survival_forest_ate_table_per_time_snapshot is not None:
        table += "\n    \midrule"
        row = f"CausalSurvivalForest"
        ate_value = causal_survival_forest_ate_table_per_time_snapshot.get(time_snapshot, {}).get('CausalSurvivalForest', None).get('ATE', None)
        row += f" &  \\multicolumn{{2}}{{c}}{{{format_value(ate_value, decimal_places)}}} "
        table += "\n    " + row + " \\\\"
    
    table += "\n    \\bottomrule"
    table += "\n    \end{tabular}}"
    table += "\n    \end{table*}"
    
    return table

In [12]:
# Define the decimal precision variable
DECIMAL_PLACES = 3
TABLE_FONT_SIZE = "footnotesize"  # Options: tiny, scriptsize, footnotesize, small, normalsize, large, etc.

# Print LaTeX tables
for time_snapshot in causal_inference_ate_table_per_time_snapshot.keys():
    print(generate_causal_inference_table(time_snapshot, causal_inference_ate_table_per_time_snapshot,
                                          survival_analysis_model_names, causal_inference_ate_table_per_time_snapshot[time_snapshot]['CoxPH'].keys(),
                                          causal_survival_forest_ate_table_per_time_snapshot,
                                          table_font_size=TABLE_FONT_SIZE, decimal_places=DECIMAL_PLACES))
    print("\n\n")


    \begin{table*}[hbtp]
    \setlength{\tabcolsep}{4pt}
    \footnotesize
    \floatconts
      {tab:table_time_3}
      {\caption{Causal Inference ATE Estimates at 3 Months}}
      {\begin{tabular}{l c c c c c }
    \toprule
    &T-learner & S-learner & Matching (1) & Matching (5) & Matching (20) \\
    \midrule
    CoxPH & -3.524 ± 0.274 & -3.644 ± 0.393 & -4.245 ± 0.005 & -4.280 ± 0.003 & -4.406 ± 0.002 \\
    RandomSurvivalForest & -2.317 ± 0.602 & -1.183 ± 0.363 & -1.615 ± 0.073 & -1.538 ± 0.054 & -1.978 ± 0.050 \\
    DeepSurv & -2.366 ± 0.603 & -1.986 ± 0.382 & -2.160 ± 1.014 & -2.101 ± 1.051 & -2.312 ± 1.223 \\
    DeepHit & -2.956 ± 6.663 & 0.417 ± 0.967 & 0.215 ± 0.481 & 0.241 ± 0.476 & 0.200 ± 0.501 \\
    \midrule
    CausalSurvivalForest &  \multicolumn{2}{c}{-3.045 ± 0.128}  \\
    \bottomrule
    \end{tabular}}
    \end{table*}




    \begin{table*}[hbtp]
    \setlength{\tabcolsep}{4pt}
    \footnotesize
    \floatconts
      {tab:table_time_6}
      {\caption{Causal 

# Ablation Data Tables

In [13]:
causal_inference_ate_table_per_time_snapshot_ablation = {}

for time_snapshot, result in sorted(results_per_time_snapshot_ablation.items()):
    causal_inference_ate_table = {}
    for model_name in survival_analysis_model_names:
        current_model_table = {}
        for causal_model_name, causal_model_saved in causal_inference_model_names.items():
            if causal_model_saved != 'matching':
                if model_name not in result.keys():
                    current_model_table[causal_model_name] = None
                else:
                    ate_repeats = [result[model_name][random_seed][causal_model_saved]['ATE'] for random_seed in result[model_name].keys()]
                    current_model_table[causal_model_name] = (np.mean(ate_repeats), np.std(ate_repeats))
            else:
                for num_matches_ in num_matches_displayed:
                    matching_str = causal_model_name + f" ({num_matches_})"
                    if model_name not in result.keys():
                        current_model_table[matching_str] = None
                    else:
                        ate_repeats = [result[model_name][random_seed][causal_model_saved][num_matches_]['ATE'] for random_seed in result[model_name].keys()]
                        current_model_table[matching_str] = (np.mean(ate_repeats), np.std(ate_repeats))
        causal_inference_ate_table[model_name] = current_model_table
    causal_inference_ate_table_per_time_snapshot_ablation[time_snapshot] = causal_inference_ate_table

In [14]:
causal_survival_forest_ate_table_per_time_snapshot_ablation = {}

for time_snapshot, result in sorted(results_causal_survival_forest_per_time_snapshot_ablation.items()):
    ate_repeats = [np.mean(result['CausalSurvivalForest'][random_seed]['ITE']) for random_seed in result['CausalSurvivalForest'].keys()]
    causal_survival_forest_ate_table_per_time_snapshot_ablation[time_snapshot] = {'CausalSurvivalForest': {'ATE': (np.mean(ate_repeats), np.std(ate_repeats))}}

In [15]:
def generate_causal_inference_ablation_table(model_name, 
                                             causal_inference_ate_table_per_time_snapshot,
                                             causal_inference_ate_table_per_time_snapshot_ablation,
                                             causal_inference_meta_learners,
                                             table_font_size="small", decimal_places=4):
    """Generate LaTeX table for a single snapshot time."""
    table = f"""
    \\begin{{table*}}[hbtp]
    \\setlength{{\\tabcolsep}}{{4pt}}
    \\{table_font_size}
    \\floatconts
      {{tab:table_ablation_{model_name}}}
      {{\caption{{Causal Inference Ablation ATE Estimates for {model_name}}}}}
      {{\\begin{{tabular}}{{l {'c ' * len(causal_inference_meta_learners)}}}
    \\toprule"""
    
    # Define column headers
    table += "\n    &" + " & ".join(causal_inference_meta_learners) + " \\\\"
    table += "\n    \midrule"
    
    # Fill in table values
    for snapshot_time in causal_inference_ate_table_per_time_snapshot.keys():
        time_row = f"\\textit{{({snapshot_time} months)}}\n" + " & " * (len(causal_inference_meta_learners)-1)
        result = causal_inference_ate_table_per_time_snapshot.get(snapshot_time, {}).get(model_name, {})
        result_ablation = causal_inference_ate_table_per_time_snapshot_ablation.get(snapshot_time, {}).get(model_name, {})
        model_row = f"{model_name} (Full)"
        model_row += " & " + " & ".join([format_value(result.get(causal_model, None), decimal_places) for causal_model in causal_inference_meta_learners])
        ablated_row = f"{model_name} (Ablation)"
        ablated_row += " & " + " & ".join([format_value(result_ablation.get(causal_model, None), decimal_places) for causal_model in causal_inference_meta_learners])
        table += "\n    " + time_row + " \\\\"
        table += "\n    " + model_row + " \\\\"
        table += "\n    " + ablated_row + " \\\\\\\\"
    
    table += "\n    \\bottomrule"
    table += "\n    \end{tabular}}"
    table += "\n    \end{table*}"
    
    return table

In [16]:
# Define the decimal precision variable
DECIMAL_PLACES = 3
TABLE_FONT_SIZE = "footnotesize"  # Options: tiny, scriptsize, footnotesize, small, normalsize, large, etc.

# Print LaTeX tables
for model_name in survival_analysis_model_names:
    print(generate_causal_inference_ablation_table(model_name, 
                                                   causal_inference_ate_table_per_time_snapshot, 
                                                   causal_inference_ate_table_per_time_snapshot_ablation,
                                                   causal_inference_ate_table_per_time_snapshot[3][model_name].keys(),
                                                   TABLE_FONT_SIZE, DECIMAL_PLACES))
    print("\n\n")
    # break


    \begin{table*}[hbtp]
    \setlength{\tabcolsep}{4pt}
    \footnotesize
    \floatconts
      {tab:table_ablation_CoxPH}
      {\caption{Causal Inference Ablation ATE Estimates for CoxPH}}
      {\begin{tabular}{l c c c c c }
    \toprule
    &T-learner & S-learner & Matching (1) & Matching (5) & Matching (20) \\
    \midrule
    \textit{(3 months)}
 &  &  &  &  \\
    CoxPH (Full) & -3.524 ± 0.274 & -3.644 ± 0.393 & -4.245 ± 0.005 & -4.280 ± 0.003 & -4.406 ± 0.002 \\
    CoxPH (Ablation) & -5.432 ± 0.344 & -5.460 ± 0.429 & -6.020 ± 0.004 & -6.105 ± 0.004 & -6.184 ± 0.002 \\\\
    \textit{(6 months)}
 &  &  &  &  \\
    CoxPH (Full) & -2.910 ± 0.791 & -2.118 ± 0.751 & -2.578 ± 0.009 & -2.639 ± 0.006 & -2.900 ± 0.007 \\
    CoxPH (Ablation) & -4.933 ± 0.815 & -4.191 ± 0.720 & -4.220 ± 0.007 & -4.287 ± 0.003 & -4.453 ± 0.003 \\\\
    \textit{(9 months)}
 &  &  &  &  \\
    CoxPH (Full) & -2.285 ± 0.578 & -1.384 ± 0.512 & -1.658 ± 0.006 & -1.790 ± 0.002 & -2.032 ± 0.003 \\
    CoxPH (

In [17]:
# Define the decimal precision variable
DECIMAL_PLACES = 3
TABLE_FONT_SIZE = "footnotesize"  # Options: tiny, scriptsize, footnotesize, small, normalsize, large, etc.

# Print LaTeX tables
for model_name in ['CausalSurvivalForest']:
    print(generate_causal_inference_ablation_table(model_name, 
                                                   causal_survival_forest_ate_table_per_time_snapshot, 
                                                   causal_survival_forest_ate_table_per_time_snapshot_ablation,
                                                   causal_survival_forest_ate_table_per_time_snapshot[3][model_name].keys(),
                                                   TABLE_FONT_SIZE, DECIMAL_PLACES))
    print("\n\n")
    # break


    \begin{table*}[hbtp]
    \setlength{\tabcolsep}{4pt}
    \footnotesize
    \floatconts
      {tab:table_ablation_CausalSurvivalForest}
      {\caption{Causal Inference Ablation ATE Estimates for CausalSurvivalForest}}
      {\begin{tabular}{l c }
    \toprule
    &ATE \\
    \midrule
    \textit{(3 months)}
 \\
    CausalSurvivalForest (Full) & -3.045 ± 0.128 \\
    CausalSurvivalForest (Ablation) & -7.816 ± 0.064 \\\\
    \textit{(6 months)}
 \\
    CausalSurvivalForest (Full) & -2.831 ± 0.101 \\
    CausalSurvivalForest (Ablation) & -5.189 ± 0.047 \\\\
    \textit{(9 months)}
 \\
    CausalSurvivalForest (Full) & -1.610 ± 0.057 \\
    CausalSurvivalForest (Ablation) & -2.833 ± 0.077 \\\\
    \textit{(12 months)}
 \\
    CausalSurvivalForest (Full) & 0.345 ± 0.097 \\
    CausalSurvivalForest (Ablation) & -1.561 ± 0.104 \\\\
    \textit{(18 months)}
 \\
    CausalSurvivalForest (Full) & 0.358 ± 0.115 \\
    CausalSurvivalForest (Ablation) & -0.019 ± 0.126 \\\\
    \textit{(24 mont