In [1]:
import toml
import sys
import numpy as np
import pandas as pd

from scipy.stats import norm
from scipy.special import logsumexp

import os

try:
    chdir
    print("Changed")
except:
    os.chdir("./bartpy/")
    chdir = True

from bartpy.sklearnmodel import SklearnModel

In [2]:
def logmeanexp(x):
    return logsumexp(x)-np.log(x.size)

In [180]:
experiment = 2

In [181]:
from julia import Julia
jl = Julia(compiled_modules=False)
%load_ext julia.magic

%julia using JLD
%julia using TOML
%julia using Random
%julia include("../../data/synthetic.jl")
%julia using .Synthetic

%julia Random.seed!(1234)
%julia experiment = $experiment
%julia config_path = "../../data/synthetic/$(experiment).toml"
obj_size = %julia TOML.parsefile(config_path)["data"]["obj_size"]

%julia SigmaU, U_, T_, X_, Y_, epsY, ftxu = generate_synthetic_confounder(config_path)
T, X, Y = %julia T_, X_, Y_
nObjects = int(len(T_)/obj_size)
n = len(T_)
object_ids = np.zeros((n, nObjects))

for i in range(nObjects):
    object_ids[i*obj_size:(i+1)*obj_size, i] = 1


Z = np.concatenate([T.reshape(-1, 1), X, object_ids], axis=1)

The julia.magic extension is already loaded. To reload it, use:
  %reload_ext julia.magic


In [182]:
if T_.max() == 1.:
    binary = True
    doTs = [True, False]
else:
    binary = False
    doTnSteps = 20
    lower = min(T) * 1.05
    upper = max(T) * 0.95
    doTs = np.linspace(lower, upper, doTnSteps)

In [183]:
model = SklearnModel(n_samples=10,
                     n_chains=10,
                     n_burn=50,
                     n_trees=10,
                     thin=1,
                     store_in_sample_predictions=False)
model.fit(Z, Y)

Starting burn


  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  2%|▏         | 1/50 [00:00<00:05,  9.17it/s]

Starting burn


  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)


Starting burn


  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  2%|▏         | 1/50 [00:00<00:05,  8.83it/s]

Starting burn


  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)


Starting burn


  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
 10%|█         | 5/50 [00:00<00:04,  9.79it/s]]

Starting burn


  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  2%|▏         | 1/50 [00:00<00:06,  7.58it/s]

Starting burn


  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  2%|▏         | 1/50 [00:00<00:05,  8.41it/s]]

Starting burn


  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
 64%|██████▍   | 32/50 [00:02<00:01, 14.05it/s]


Starting sampling


 84%|████████▍ | 42/50 [00:03<00:00, 14.42it/s]


Starting sampling


100%|██████████| 50/50 [00:03<00:00, 13.25it/s]


Starting sampling


100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
 80%|████████  | 40/50 [00:03<00:00, 13.41it/s]

Starting sampling


100%|██████████| 10/10 [00:00<00:00, 12.49it/s]
100%|██████████| 50/50 [00:03<00:00, 13.43it/s]


Starting sampling


100%|██████████| 10/10 [00:00<00:00, 12.26it/s]
100%|██████████| 50/50 [00:03<00:00, 12.76it/s]


Starting sampling


100%|██████████| 50/50 [00:03<00:00, 12.82it/s]


Starting sampling


100%|██████████| 10/10 [00:00<00:00, 12.22it/s]
100%|██████████| 50/50 [00:03<00:00, 12.98it/s]


Starting sampling


100%|██████████| 10/10 [00:00<00:00, 12.62it/s]
100%|██████████| 10/10 [00:00<00:00, 12.89it/s]
 80%|████████  | 8/10 [00:00<00:00, 15.02it/s]]
100%|██████████| 10/10 [00:00<00:00, 15.36it/s]
100%|██████████| 10/10 [00:00<00:00, 16.81it/s]


Starting burn


  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)


Starting burn


  prob_value_selected_within_variable = np.log(mutation.existing_node.data.proportion_of_value_in_variable(splitting_variable, splitting_value))
  return self.log_transition_ratio(tree, mutation) + self.log_likihood_ratio(model, tree, mutation) + self.log_tree_ratio(model, tree, mutation)
  prob_splitting_variable_selected = - np.log(mutation.existing_node.data.n_splittable_variables)
100%|██████████| 50/50 [00:01<00:00, 30.55it/s]


Starting sampling


100%|██████████| 50/50 [00:01<00:00, 30.58it/s]


Starting sampling


100%|██████████| 10/10 [00:00<00:00, 31.17it/s]
100%|██████████| 10/10 [00:00<00:00, 31.65it/s]


SklearnModel(alpha=0.95, beta=2.0,
       initializer=<bartpy.initializers.sklearntreeinitializer.SklearnTreeInitializer object at 0x123456208>,
       n_burn=50, n_chains=10, n_jobs=-1, n_samples=10, n_trees=10,
       sigma_a=0.001, sigma_b=0.001, store_acceptance_trace=True,
       store_in_sample_predictions=False, thin=1,
       tree_sampler=<bartpy.samplers.unconstrainedtree.treemutation.UnconstrainedTreeMutationSampler object at 0x1234561d0>)

In [184]:
Zcf = Z.copy()

def PEHE(effect, effect_pred):
    return np.sqrt(((effect - effect_pred)**2).mean())

PEHEs = np.zeros_like(doTs)


for i, doT in enumerate(doTs):
    mask = %julia mask = T_ .!= $doT
    Ycf = %julia ftxu(fill($doT, sum(mask)), X_[mask, :], U_[mask, :], epsY[mask])
    mask = T != doT
    
    Zcf[:, 1] = doT
    Y_pred = model.predict(Z[mask])
    Ycf_pred = model.predict(Zcf[mask])
    effect = Y[mask] - Ycf
    effect_pred = Y_pred - Ycf_pred
    PEHEs[i] = PEHE(effect, effect_pred)

print(PEHEs.mean())



1.5390677818564713
