In [1]:
import numpy as np
import pandas as pd
import ast
import torch
from main import *
from eval_cf_MISE import *

In [2]:
#%% Data 
# Data train
df = pd.read_csv('data/Sim_data_train.csv')
data_train = df.to_numpy()

# Data test counterfactual
data_cf = np.load('data/data_cf.npz')
A_cf = data_cf['A_cf']
Y_cf = data_cf['Y_cf']
X_cf = data_cf['X_cf']

# Hyperpar list
hyper_opt_list = open("hyperpars/hyperpars_opt_real_large.txt", "r")
hyper_opt_list = hyper_opt_list.read()
hyper_opt = ast.literal_eval(hyper_opt_list)

# Convert hyperpar_opt_list so that its values are iterable
for i in range(len(hyper_opt)):
    for key in hyper_opt[i].keys():
        hyper_opt[i][key] = [hyper_opt[i][key]]

In [3]:
#%% Counterfactual point estimate and variance (averaged over 10 runs)

models = ['lm', 'nn', 'gps', 'dr', 'sci', 'cgct_gps', 'rf', 'cgct_rf']
# Set all seeds
np.random.seed(123)
torch.manual_seed(123)

# Get results
res_table = np.empty(shape=(8,10))
for l in range(10):
    test_loss = []
    for i, model in enumerate(models):
        cv_results = eval_MISE(data_train, X_cf, A_cf, Y_cf, model, hyper_opt[i])
        test_loss.append(cv_results[0]['loss'])
    res_table[:,l] = np.array(test_loss)

In [4]:
#Get results into format for export

df = pd.DataFrame(np.transpose(res_table), columns=models)
df.insert(0, "measure", [f"run {i+1}" for i in range(len(res_table.T))])

stats = {
    "measure": ["mean", "median", "sd"],
    **{model: [df[model].mean(), df[model].median(), df[model].std()] for model in df.columns if model != "measure"}
}

stats_df = pd.DataFrame(stats)
result_df = pd.concat([df, stats_df], ignore_index=True)

result_df.to_csv("outputs/model_perf_sim.csv")

print(result_df)

   measure       lm        nn       gps        dr        sci  cgct_gps  \
0    run 1  0.17318  0.600140  0.210394  0.806490   4.428555  0.332491   
1    run 2  0.17318  0.378943  0.210394  0.696156   3.802102  0.337377   
2    run 3  0.17318  0.253374  0.210394  0.552711   1.726807  0.340213   
3    run 4  0.17318  0.200030  0.210394  0.449337   5.768789  0.338559   
4    run 5  0.17318  0.524151  0.210394  0.742749   5.610894  0.340803   
5    run 6  0.17318  0.268141  0.210394  0.452694   2.277125  0.333515   
6    run 7  0.17318  0.608851  0.210394  0.478910   8.484173  0.338309   
7    run 8  0.17318  0.230941  0.210394  0.635279   7.702509  0.333403   
8    run 9  0.17318  0.222312  0.210394  0.268998  11.097772  0.333201   
9   run 10  0.17318  0.474266  0.210394  0.338369  15.791056  0.326564   
10    mean  0.17318  0.376115  0.210394  0.542169   6.668978  0.335444   
11  median  0.17318  0.323542  0.210394  0.515810   5.689841  0.335446   
12      sd  0.00000  0.162744  0.00000