#  Other datasets

For the other dataset: IHDP B and ACIC 2, the obtained models are TARNKAAM and DragonKAM, which yield more complex functions, and therefore
more advanced visualization tools should be developed

# IHDP B

In [1]:
#imports
import numpy as np
import pandas as pd
import torch
import time

from utils import load_data
from sklearn.model_selection import train_test_split
from kan import ex_round
from kan_model import kan_net
from mlp_model import mlp_net
from utils import get_width, get_dims_mlp
from representation import *
from tueplots import bundles, figsizes, axes
import copy
import matplotlib.pyplot as plt
from experiment_symbolic import *

In [2]:
# seed
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1f9a2b5d1d0>

In [3]:
#load the data
i = 2
split = 0.2

data_train, data_test = load_data('IHDP_B', i)
col_names = data_train.columns
x_cols = [col for col in data_train.columns if 'x' in col]

x_train, y_train, t_train = data_train[x_cols].values, data_train['y_factual'].values[:, None], data_train['treatment'].values[:, None]
x_test, y_test, t_test = data_test[[col for col in data_test.columns if 'x' in col]].values, data_test['y_factual'].values[:, None], data_test['treatment'].values[:, None]

x_train, x_val, y_train, y_val, t_train, t_val = train_test_split(x_train, y_train, t_train, test_size=split, shuffle=False)

real_ite_train = data_train['mu1'].values - data_train['mu0'].values
real_ite_test = data_test['mu1'].values - data_test['mu0'].values

real_ite_train, real_ite_val = real_ite_train[:x_train.shape[0]], real_ite_train[x_train.shape[0]:]
print(col_names)
print(x_cols)

data_test_t0 = data_test.copy()
data_test_t0['treatment'] = 0
data_test_t1 = data_test.copy()
data_test_t1['treatment'] = 1

Index(['treatment', 'y_factual', 'y_cfactual', 'mu0', 'mu1', 'x1', 'x2', 'x3',
       'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14',
       'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24',
       'x25'],
      dtype='object')
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25']


In [4]:
# preprocess the data
print(x_train.shape, y_train.shape, t_train.shape, x_val.shape, y_val.shape, t_val.shape, x_test.shape, y_test.shape, t_test.shape, real_ite_train.shape, real_ite_test.shape, real_ite_val.shape)

(418, 25) (418, 1) (418, 1) (105, 25) (105, 1) (105, 1) (224, 25) (224, 1) (224, 1) (418,) (224,) (105,)


In [5]:
# instanciate the KAN model
# T-KAAM
width = get_width('dragonnet', x_train.shape[1], hidden_dims = [5], mult_kan=True)
dragon_kaam_ihdp = kan_net('dragonnet', width, grid=3, k=3, sparse_init=False, try_gpu=False, real_ite_train=real_ite_train, real_ite_test=real_ite_val, save_folder='visualization/checkpoint')

checkpoint directory created: C:\Users\Alex\CODE\causalkans\visualization/checkpoint\kan_345353\
saving model version 0.0


In [6]:
#train the model
tic = time.time()
results = dragon_kaam_ihdp.fit(x_train, y_train, t_train, x_val, y_val, t_val, early_stop=True, patience=30, batch=1000,
                           steps=10000, lamb=0.01, lamb_entropy=0.1,  lr=0.001, verbose=0)
tac = time.time()
print(results.keys())
print('Training time: ', tac - tic)

Early stopping at step 1192
saving model version 0.1
dict_keys(['train_loss', 'test_loss', 'reg', 'train_metrics', 'test_metrics'])
Training time:  49.44813585281372


In [7]:
prediction = dragon_kaam.predict(x_test, t_test)
print(prediction.keys())
prediction['ite']= prediction['y_pred_1'] - prediction['y_pred_0']

dict_keys(['y_pred_0', 'y_pred_1', 'y_pred_f', 'y_pred_cf', 'ps_pred', 'pred_best_treatment'])


In [8]:
# Compute prediction metrics in MSE of factual outcome and
mse = np.mean((y_test - prediction['y_pred_f'])**2)
pehe = np.sqrt(np.mean((real_ite_test - prediction['ite'])**2))
print(f'MSE original KAN: {mse}, PEHE original KAN: {pehe}')
print(f'Predicted ATE: {prediction["ite"].mean()}, Real ATE: {real_ite_test.mean()}')

MSE original KAN: 23.81933951374525, PEHE original KAN: 2.7240276773958003
Predicted ATE: 3.559431314468384, Real ATE: 3.826578597130422


In [9]:
# MLP params = hidden_dims= [100, 100, 100], dropout=0.0, lr=1e-05, activation=leaky_relu
dims_mlp = get_dims_mlp('dragonnet', x_train.shape[1], hidden_dims=[[200, 200, 200], [100, 100], []])
dragon_nn = mlp_net('dragonnet', dims_mlp, dropout=0.0, activation='elu', try_gpu=False, save_folder='visualization/checkpoint', real_ite_train = real_ite_train, real_ite_test = real_ite_val)
tic = time.time()
results_mlp = dragon_nn.fit(x_train, y_train, t_train, x_val, y_val, t_val, early_stop=True, patience=30, batch=1000,
                           steps=10000, lr=1e-5, verbose=0)
tac = time.time()
print(results_mlp.keys())
print('Training time MLP: ', tac - tic)


Early stopping at step 8994
dict_keys(['train_loss', 'test_loss', 'reg', 'train_metrics', 'test_metrics'])
Training time MLP:  198.2119972705841


In [10]:
prediction_mlp = dragon_nn.predict(x_test, t_test)
prediction_mlp['ite']= prediction_mlp['y_pred_1'] - prediction_mlp['y_pred_0']
mse_mlp = np.mean((y_test - prediction_mlp['y_pred_f'])**2)
pehe_mlp = np.sqrt(np.mean((real_ite_test - prediction_mlp['ite'])**2))
print(f'MSE MLP: {mse_mlp}, PEHE MLP: {pehe_mlp}')
print(f'ATE MLP: {prediction_mlp["ite"].mean()}, Real ATE: {real_ite_test.mean()}')

MSE MLP: 25.93584607590749, PEHE MLP: 2.455876843100015
ATE MLP: 3.401909351348877, Real ATE: 3.826578597130422


In [11]:
print(results.keys())

dict_keys(['train_loss', 'test_loss', 'reg', 'train_metrics', 'test_metrics'])


In [12]:
# formula without pruning
# symb_s_kaam_orig = symbolic_kan_regressor(x_names = x_cols, y_names=['mu_0', 'mu_1', 'propensity'])

In [13]:
# with plt.rc_context({**bundles.iclr2024()}):
#     tic = time.time()
#     xt_train = np.concatenate((x_train, t_train), axis=1)
#     xt_val = np.concatenate((x_val, t_val), axis=1)
#     # out_0 = symb_s_kaam_orig.fit(dragon_kaam.model, xt_train, y_train, xt_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     out_0 = symb_s_kaam_orig.fit(dragon_kaam.model, x_train, y_train, x_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     tac = time.time()
#     print('Formula extraction time (control): ', tac - tic)
#
# formulas_orig = symb_s_kaam_orig.get_formula()

In [14]:
# print(formulas_orig[0])

In [15]:
# print(formulas_orig[0])

 # PRUNING AND FORMULA EXTRACTION

In [16]:
# pruning and evaluation
# t_kaam_original = copy.deepcopy(t_kaam)
tic = time.time()
dragon_kaam_ihdp_pruned = dragon_kaam_ihdp.prune(node_th=1e-2, edge_th=1e-2)
tac = time.time()
print('Pruning time: ', tac - tic)

saving model version 0.2
Pruning time:  0.05620169639587402


In [17]:
#evaluate
predictions_pruned = dragon_kaam_ihdp.predict(x_test, t_test)
predictions_pruned['ite']= predictions_pruned['y_pred_1'] - predictions_pruned['y_pred_0']
mse_pruned = np.mean((y_test - predictions_pruned['y_pred_f'])**2)
pehe_pruned = np.sqrt(np.mean((real_ite_test - predictions_pruned['ite'])**2))
print(f'MSE pruned KAN: {mse_pruned}, PEHE pruned KAN: {pehe_pruned}')
print(f'Predicted ATE: {predictions_pruned["ite"].mean()}, Real ATE: {real_ite_test.mean()}')

MSE pruned KAN: 27.874434246228944, PEHE pruned KAN: 2.6958143339238805
Predicted ATE: 3.3271918296813965, Real ATE: 3.826578597130422


In [18]:
# # formula extraction
# symb_s_kaam = symbolic_kan_regressor(x_names = x_cols, y_names=['mu_0', 'mu_1', 'propensity'])
# with plt.rc_context({**bundles.iclr2024()}):
#     tic = time.time()
#     xt_train = np.concatenate((x_train, t_train), axis=1)
#     xt_val = np.concatenate((x_val, t_val), axis=1)
#     # out_0 = symb_s_kaam.fit(dragon_kaam.model, xt_train, y_train, xt_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     out_0 = symb_s_kaam.fit(dragon_kaam.model, x_train, y_train, x_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     tac = time.time()
#     print('Formula extraction time (control): ', tac - tic)
# formulas = symb_s_kaam.get_formula()


In [19]:
formulas = dragon_kaam_ihdp.interprete('all', n_digit=4)

Using library all
fixing (0,0,0) with sin, r2=0.9976866245269775, c=2
fixing (0,0,1) with sin, r2=0.9992488026618958, c=2
fixing (0,0,2) with sin, r2=0.9837432503700256, c=2
fixing (0,0,3) with x, r2=0.988549530506134, c=1
fixing (0,0,4) with sin, r2=0.9991015791893005, c=2
fixing (0,0,5) with x, r2=0.9428504705429077, c=1
fixing (0,0,6) with x^2, r2=0.9985201954841614, c=2
fixing (0,0,7) with sin, r2=0.9947016835212708, c=2
fixing (0,0,8) with 0
fixing (0,0,9) with 0
fixing (0,0,10) with tanh, r2=0.999727189540863, c=3
fixing (0,0,11) with 0
fixing (0,0,12) with sin, r2=0.9998430609703064, c=2
fixing (0,0,13) with sin, r2=0.9996123313903809, c=2
fixing (0,1,0) with 0
fixing (0,1,1) with x, r2=0.9902762770652771, c=1
fixing (0,1,2) with sin, r2=0.9884856343269348, c=2
fixing (0,1,3) with sin, r2=0.9909999370574951, c=2
fixing (0,1,4) with sin, r2=0.9946965575218201, c=2
fixing (0,1,5) with sin, r2=0.994387149810791, c=2
fixing (0,1,6) with sin, r2=0.9951779246330261, c=2
fixing (0,1,7)

In [20]:
y_pred = dragon_kaam_ihdp.predict(x_test, t_test)
y_pred_0 = y_pred['y_pred_0']
y_pred_1 = y_pred['y_pred_1']
y_pred_f = y_pred['y_pred_f']
ite_pred = y_pred_1 - y_pred_0
mse_kaam_formula = np.mean((y_test - y_pred_f)**2)
pehe_kaam_formula = np.sqrt(np.mean((real_ite_test - y_pred_f)**2))
print(f'MSE KAN formula: {mse_kaam_formula}, PEHE KAN formula: {pehe_kaam_formula}')
print(f'Predicted ATE formula: {ite_pred.mean()}, Real ATE: {real_ite_test.mean()}')

MSE KAN formula: 26.79343985906936, PEHE KAN formula: 4.824187335295197
Predicted ATE formula: 3.1782314777374268, Real ATE: 3.826578597130422


In [21]:
# evaluate formula performance

# y_pred = symb_s_kaam.predict(x_test)
# y_pred_0 = y_pred[:, 0]
# y_pred_1 = y_pred[:, 1]
# y_pred_f = t_test * y_pred_1 + (1 - t_test) * y_pred_0
# ite_pred = y_pred_1 - y_pred_0
# mse_formula = np.mean((y_test - y_pred_f)**2)
# pehe_formula = np.sqrt(np.mean((real_ite_test - ite_pred)**2))
# print(f'MSE formula: {mse_formula}, PEHE formula: {pehe_formula}')


In [22]:
print(formulas)

[-0.5617*x_10 + 0.5772*x_11 + 1.1103*x_12 - 0.2464*x_13 - 0.3154*x_14 + 0.7025*x_15 + 0.6174*x_16 + 0.6059*x_17 + 0.7961*x_18 + 1.2552*x_19 - 0.5493*x_2 + 1.1921*x_21 - 0.0546*x_22 + 0.8908*x_24 - 0.4129*x_25 + 0.1991*x_6 - 0.7544*x_7 - 1.2144*x_8 + 0.3645*x_9 + 0.4992*(0.216*x_10 + 0.1*x_12 + 0.2278*x_14 + 0.5113*x_15 + 0.2005*x_16 + 0.178*x_17 - 0.0616*x_19 + 0.2092*x_21 + 0.2484*x_22 - 0.0734*x_24 + 0.1464*x_25 + 0.3171*x_7 + 0.1775*x_9 - 0.4904*sin(0.5807*x_1 + 2.5804) + 0.2358*sin(0.813*x_2 - 0.7861) + 0.3768*sin(0.6902*x_5 + 5.0122) + 1.6899*sin(0.2259*x_6 - 7.3415) - 0.0981*tan(2.579*x_4 + 6.0114) + 3.3531)*(0.0582*x_1 + 0.1759*x_10 + 0.1813*x_11 + 0.094*x_12 + 0.0844*x_13 + 0.0933*x_14 + 0.5246*x_15 + 0.3394*x_17 + 0.206*x_18 + 0.0636*x_19 + 0.0767*x_22 + 0.3497*x_25 + 0.0207*x_3 + 0.135*x_7 + 0.1345*x_8 + 0.0602*x_9 + 0.0025*(-3.3751*x_6 - 8.9225)**2 - 0.1165*sin(0.7342*x_2 + 8.5555) + 0.5807*sin(0.763*x_5 + 5.3909) - 0.1122*tan(8.605*x_4 - 4.3994) + 1.5381) - 0.0632*(0.1389*x

In [23]:
# extract the CATE
# in this case the CATE is the term that corresponds to the treatment 't'
# Break into additive terms

# Rebuild the expression
cate_formula = formulas[1] - formulas[0]
print(cate_formula)

0.032*x_1 + 0.8949*x_10 - 0.6442*x_11 - 0.781*x_12 + 0.5064*x_13 + 0.1099*x_14 - 0.8864*x_15 - 0.689*x_16 - 0.7059*x_17 - 0.8885*x_18 - 1.5474*x_19 + 0.613*x_2 + 0.072*x_20 - 1.3599*x_21 + 0.3073*x_22 - 1.502*x_24 + 0.8482*x_25 - 0.1752*x_6 + 1.142*x_7 + 1.5374*x_8 - 0.3751*x_9 - 0.9111*(0.0202*x_2 - 0.1481*x_21 + 0.0418*x_5 - 0.0267*x_6 - 0.1345*sin(0.678*x_3 - 7.5796) - 0.4471)*(0.3337*x_14 + 0.2149*x_21 - 0.6507*x_22 - 0.142*x_5 - 0.2025*x_7 - 0.2691*x_8 + 0.2016*sin(4.5803*x_4 - 2.4094) - 0.1806*tanh(1.0222*x_1 - 0.925) - 0.2402) - 0.2971*(0.3689*x_11 - 0.4167*x_20 + 0.2475*x_21 + 0.3691*x_24 - 0.3129*x_7 + 0.3479*tanh(0.7809*x_5 - 1.0274) + 0.5601)*(-0.1934*x_11 - 0.2308*x_14 - 0.2392*x_21 - 0.404*x_24 + 0.2046*x_7 - 0.1444*sin(0.9677*x_3 + 5.2002) - 0.4907*sin(6.4178*x_4 - 2.1933) - 0.3958*tanh(0.5055*x_5 - 0.7037) - 1.1516) - 0.1955*(0.216*x_10 + 0.1*x_12 + 0.2278*x_14 + 0.5113*x_15 + 0.2005*x_16 + 0.178*x_17 - 0.0616*x_19 + 0.2092*x_21 + 0.2484*x_22 - 0.0734*x_24 + 0.1464*x_25 

In [24]:
# ite_pred_formula = eval_expr_on_df(cate_formula, data_test)
# pehe_formula_cate = np.sqrt(np.mean((real_ite_test - ite_pred_formula)**2))
# print(f'PEHE formula cate: {pehe_formula_cate}')

In [25]:
# chage x1, x2, x3... by x_1, x_2,...
x_cols_ = ['x_' + col[1:] for col in x_cols]
test_df = pd.DataFrame(data_test.loc[:, x_cols].to_numpy(), columns=x_cols_)

In [26]:
ite_pred_formula = get_formula_values(cate_formula, test_df.copy())
print(ite_pred_formula)

100%|██████████| 224/224 [00:46<00:00,  4.77it/s]

[[ 9.72983867e-02]
 [-1.91365111e+00]
 [ 2.99473763e+00]
 [ 1.06877843e+00]
 [ 5.02938050e+00]
 [ 6.56110193e+00]
 [ 4.75688639e+00]
 [-1.23992864e+00]
 [ 2.42613228e+00]
 [ 4.13071561e+00]
 [ 4.81018821e+00]
 [-3.23527895e+00]
 [ 2.56129162e+00]
 [ 6.73980826e+00]
 [ 4.94533489e+00]
 [ 1.93772022e+00]
 [ 2.25781061e+00]
 [ 5.74183932e+00]
 [ 9.87932155e-01]
 [ 3.20204488e+00]
 [ 2.67093152e+00]
 [ 6.10581850e-01]
 [-3.03539825e+00]
 [ 5.06167721e+00]
 [ 7.02844022e-01]
 [ 7.00658813e+00]
 [ 2.70997050e+00]
 [ 4.80829376e+00]
 [ 5.58252341e+00]
 [ 1.80809406e+00]
 [ 5.01951759e+00]
 [ 4.11107823e+00]
 [ 2.51142856e+00]
 [ 1.27612918e+00]
 [-8.73479084e-01]
 [ 6.26881357e+00]
 [ 4.49327239e+00]
 [ 1.69859359e+00]
 [ 4.64637516e+00]
 [ 2.00401410e+00]
 [ 3.42353211e+00]
 [-1.42810966e-01]
 [ 6.73149789e+00]
 [ 4.06682599e+00]
 [ 7.29918762e+00]
 [ 4.61827679e+00]
 [ 7.99410950e+00]
 [ 5.10866758e+00]
 [ 3.32340916e+00]
 [ 5.99032871e+00]
 [ 3.89804280e+00]
 [ 7.61770243e+00]
 [ 6.5987401




In [27]:
pehe_formula_cate = np.sqrt(np.mean((real_ite_test - ite_pred_formula)**2))
print(f'PEHE formula cate: {pehe_formula_cate}')

PEHE formula cate: 4.296792070654834


In [28]:
# formula_0_rounded = sp.sympify(sp.nsimplify(formulas[0], tolerance=1e-2))
# y_pred_0_trunc = eval_expr_on_df(formula_0_rounded, data_test)
# formula_1_rounded = sp.sympify(sp.nsimplify(formulas[0], tolerance=1e-2))
# y_pred_1_trunc = eval_expr_on_df(formula_1_rounded, data_test)
# y_pred_f_trunc = y_pred_0_trunc * (1 - t_test[:,0]) + y_pred_1_trunc * t_test[:,0]
# ite_pred_trunc = y_pred_1_trunc - y_pred_0_trunc
# mse_formula_trunc = np.mean((y_test[:,0] - y_pred_f_trunc)**2)
# pehe_formula_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_trunc)**2))
# print(f'MSE formula truncated: {mse_formula_trunc}, PEHE formula truncated: {pehe_formula_trunc}')



In [29]:
# truncate
formulas_trun = [ex_round(f, 2) for f in formulas]
y_pred_0_trunc = get_formula_values(formulas_trun[0], test_df.copy())
y_pred_1_trunc = get_formula_values(formulas_trun[1],  test_df.copy())
y_pred_f_trunc = y_pred_0_trunc * (1 - t_test[:,0]) + y_pred_1_trunc * t_test[:,0]
ite_pred_trunc = y_pred_1_trunc - y_pred_0_trunc
mse_formula_trunc = np.mean((y_test[:,0] - y_pred_f_trunc)**2)
pehe_formula_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_trunc)**2))
print(f'MSE formula truncated: {mse_formula_trunc}, PEHE formula truncated: {pehe_formula_trunc}')
print(f'Predicted ATE formula truncated: {ite_pred_trunc.mean()}, Real ATE: {real_ite_test.mean()}')


100%|██████████| 224/224 [00:29<00:00,  7.56it/s]
100%|██████████| 224/224 [00:35<00:00,  6.37it/s]

MSE formula truncated: 27.404776595898653, PEHE formula truncated: 4.2694041498261885
Predicted ATE formula truncated: 3.2396191189480406, Real ATE: 3.826578597130422





In [30]:
# #print pehe with cate formula truncated
# cate_formula_trunc = formulas_1_rounded - formulas_0_rounded
# ite_pred_formula_trunc = eval_expr_on_df(cate_formula_trunc, data_test)
# pehe_formula_cate_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_formula_trunc)**2))
# print(f'PEHE formula cate truncated: {pehe_formula_cate_trunc}')


In [31]:
#print pehe with cate formula truncated
cate_formula_trunc = formulas_trun[1] - formulas_trun[0]
ite_pred_formula_trunc = get_formula_values(cate_formula_trunc, test_df.copy())
pehe_formula_cate_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_formula_trunc)**2))
print(f'PEHE formula cate truncated: {pehe_formula_cate_trunc}')
print(f'ATE formula cate truncated: {ite_pred_formula_trunc.mean()}, Real ATE: {real_ite_test.mean()}')

100%|██████████| 224/224 [00:33<00:00,  6.62it/s]

PEHE formula cate truncated: 4.269112252741745
ATE formula cate truncated: 3.240369020285241, Real ATE: 3.826578597130422





In [32]:
present = {s.name for s in formulas[0].free_symbols}
print(present)

{'x_21', 'x_4', 'x_14', 'x_22', 'x_15', 'x_20', 'x_6', 'x_3', 'x_23', 'x_17', 'x_12', 'x_13', 'x_16', 'x_10', 'x_9', 'x_8', 'x_2', 'x_11', 'x_19', 'x_1', 'x_18', 'x_24', 'x_7', 'x_5', 'x_25'}


In [33]:
ausent = set(x_cols) - present
print(ausent)

{'x6', 'x4', 'x7', 'x8', 'x15', 'x9', 'x20', 'x11', 'x21', 'x1', 'x10', 'x5', 'x17', 'x23', 'x25', 'x2', 'x13', 'x22', 'x19', 'x3', 'x18', 'x24', 'x14', 'x16', 'x12'}


In [34]:
n_feats = len(list(present))

# ACIC 2

In [35]:
#load the data
i = 3
split = 0.2

data_train, data_test = load_data('ACIC_2', i)
col_names = data_train.columns
x_cols = [col for col in data_train.columns if 'x' in col]

x_train, y_train, t_train = data_train[x_cols].values, data_train['y_factual'].values[:, None], data_train[
                                                                                                    'treatment'].values[
                                                                                                :, None]
x_test, y_test, t_test = data_test[[col for col in data_test.columns if 'x' in col]].values, data_test[
                                                                                                 'y_factual'].values[:,
                                                                                             None], data_test[
                                                                                                        'treatment'].values[
                                                                                                    :, None]

x_train, x_val, y_train, y_val, t_train, t_val = train_test_split(x_train, y_train, t_train, test_size=split,
                                                                  shuffle=False)

real_ite_train = data_train['mu1'].values - data_train['mu0'].values
real_ite_test = data_test['mu1'].values - data_test['mu0'].values

real_ite_train, real_ite_val = real_ite_train[:x_train.shape[0]], real_ite_train[x_train.shape[0]:]
print(col_names)
print(x_cols)

data_test_t0 = data_test.copy()
data_test_t0['treatment'] = 0
data_test_t1 = data_test.copy()
data_test_t1['treatment'] = 1
# preprocess the data
print(x_train.shape, y_train.shape, t_train.shape, x_val.shape, y_val.shape, t_val.shape, x_test.shape, y_test.shape,
      t_test.shape, real_ite_train.shape, real_ite_test.shape, real_ite_val.shape)
# instanciate the KAN model
# T-KAAM
width = get_width('dragonnet', x_train.shape[1], hidden_dims=[5], mult_kan=True)
dragon_kaam = kan_net('dragonnet', width, grid=3, k=3, sparse_init=False, try_gpu=False, real_ite_train=real_ite_train,
                      real_ite_test=real_ite_val, save_folder='visualization/checkpoint')
#train the model
tic = time.time()
results = dragon_kaam.fit(x_train, y_train, t_train, x_val, y_val, t_val, early_stop=True, patience=30, batch=1000,
                          steps=10000, lamb=0.01, lamb_entropy=0.1, lr=0.001, verbose=0)
tac = time.time()
print(results.keys())
print('Training time: ', tac - tic)
prediction = dragon_kaam.predict(x_test, t_test)
print(prediction.keys())
prediction['ite'] = prediction['y_pred_1'] - prediction['y_pred_0']
# Compute prediction metrics in MSE of factual outcome and
mse = np.mean((y_test - prediction['y_pred_f']) ** 2)
pehe = np.sqrt(np.mean((real_ite_test - prediction['ite']) ** 2))
print(f'MSE original KAN: {mse}, PEHE original KAN: {pehe}')
print(f'Predicted ATE: {prediction["ite"].mean()}, Real ATE: {real_ite_test.mean()}')
# MLP params = hidden_dims= [100, 100, 100], dropout=0.0, lr=1e-05, activation=leaky_relu
dims_mlp = get_dims_mlp('dragonnet', x_train.shape[1], hidden_dims=[[200, 200, 200], [100, 100], []])
dragon_nn = mlp_net('dragonnet', dims_mlp, dropout=0.0, activation='elu', try_gpu=False,
                    save_folder='visualization/checkpoint', real_ite_train=real_ite_train, real_ite_test=real_ite_val)
tic = time.time()
results_mlp = dragon_nn.fit(x_train, y_train, t_train, x_val, y_val, t_val, early_stop=True, patience=30, batch=1000,
                            steps=10000, lr=1e-5, verbose=0)
tac = time.time()
print(results_mlp.keys())
print('Training time MLP: ', tac - tic)

prediction_mlp = dragon_nn.predict(x_test, t_test)
prediction_mlp['ite'] = prediction_mlp['y_pred_1'] - prediction_mlp['y_pred_0']
mse_mlp = np.mean((y_test - prediction_mlp['y_pred_f']) ** 2)
pehe_mlp = np.sqrt(np.mean((real_ite_test - prediction_mlp['ite']) ** 2))
print(f'MSE MLP: {mse_mlp}, PEHE MLP: {pehe_mlp}')
print(f'ATE MLP: {prediction_mlp["ite"].mean()}, Real ATE: {real_ite_test.mean()}')
print(results.keys())
# formula without pruning
# symb_s_kaam_orig = symbolic_kan_regressor(x_names=x_cols, y_names=['mu_0', 'mu_1', 'propensity'])
# with plt.rc_context({**bundles.iclr2024()}):
#     tic = time.time()
#     xt_train = np.concatenate((x_train, t_train), axis=1)
#     xt_val = np.concatenate((x_val, t_val), axis=1)
#     # out_0 = symb_s_kaam_orig.fit(dragon_kaam.model, xt_train, y_train, xt_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     out_0 = symb_s_kaam_orig.fit(dragon_kaam.model, x_train, y_train, x_val, y_val, denorm_function=None,
#                                  stochastic=False, r2_threshold=0.98, show_results=False,
#                                  save_dir='visualization/formulas')
#     tac = time.time()
#     print('Formula extraction time (control): ', tac - tic)
#
# formulas_orig = symb_s_kaam_orig.get_formula()
# print(formulas_orig[0])
# print(formulas_orig[0])
# PRUNING AND FORMULA EXTRACTION
# pruning and evaluation
# t_kaam_original = copy.deepcopy(t_kaam)
tic = time.time()
dragon_kaam_pruned = dragon_kaam.prune(node_th=1e-2, edge_th=1e-2)
tac = time.time()
print('Pruning time: ', tac - tic)
#evaluate
predictions_pruned = dragon_kaam.predict(x_test, t_test)
predictions_pruned['ite'] = predictions_pruned['y_pred_1'] - predictions_pruned['y_pred_0']
mse_pruned = np.mean((y_test - predictions_pruned['y_pred_f']) ** 2)
pehe_pruned = np.sqrt(np.mean((real_ite_test - predictions_pruned['ite']) ** 2))
print(f'MSE pruned KAN: {mse_pruned}, PEHE pruned KAN: {pehe_pruned}')
print(f'Predicted ATE: {predictions_pruned["ite"].mean()}, Real ATE: {real_ite_test.mean()}')
# # formula extraction
# symb_s_kaam = symbolic_kan_regressor(x_names = x_cols, y_names=['mu_0', 'mu_1', 'propensity'])
# with plt.rc_context({**bundles.iclr2024()}):
#     tic = time.time()
#     xt_train = np.concatenate((x_train, t_train), axis=1)
#     xt_val = np.concatenate((x_val, t_val), axis=1)
#     # out_0 = symb_s_kaam.fit(dragon_kaam.model, xt_train, y_train, xt_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     out_0 = symb_s_kaam.fit(dragon_kaam.model, x_train, y_train, x_val, y_val, denorm_function=None, stochastic=False, r2_threshold=0.98, show_results=False, save_dir = 'visualization/formulas')
#     tac = time.time()
#     print('Formula extraction time (control): ', tac - tic)
# formulas = symb_s_kaam.get_formula()

Index(['treatment', 'y_factual', 'y_cfactual', 'mu0', 'mu1', 'x1', 'x2', 'x3',
       'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14',
       'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24',
       'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x31', 'x32', 'x33', 'x34',
       'x35', 'x36', 'x37', 'x38', 'x39', 'x40', 'x41', 'x42', 'x43', 'x44',
       'x45', 'x46', 'x47', 'x48', 'x49', 'x50', 'x51', 'x52', 'x53', 'x54',
       'x55', 'x56', 'x57', 'x58'],
      dtype='object')
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x31', 'x32', 'x33', 'x34', 'x35', 'x36', 'x37', 'x38', 'x39', 'x40', 'x41', 'x42', 'x43', 'x44', 'x45', 'x46', 'x47', 'x48', 'x49', 'x50', 'x51', 'x52', 'x53', 'x54', 'x55', 'x56', 'x57', 'x58']
(800, 58) (800, 1) (800, 1) (200, 58) (200, 1) (200, 1) (960, 58) (960, 1) (960

In [36]:
formulas = dragon_kaam.interprete('all', n_digit=4)

Using library all
fixing (0,0,0) with x, r2=0.06568443775177002, c=1
fixing (0,0,1) with sin, r2=0.9997977018356323, c=2
fixing (0,0,2) with 0
fixing (0,0,3) with 0
fixing (0,0,4) with x^2, r2=0.9998084902763367, c=2
fixing (0,0,5) with 0
fixing (0,0,6) with 0
fixing (0,0,7) with x, r2=0.19280271232128143, c=1
fixing (0,0,8) with tanh, r2=0.9998217225074768, c=3
fixing (0,0,9) with x, r2=0.7966095805168152, c=1
fixing (0,0,10) with 0
fixing (0,0,11) with x^2, r2=0.9751867055892944, c=2
fixing (0,0,12) with x, r2=0.37401631474494934, c=1
fixing (0,0,13) with x, r2=0.07438771426677704, c=1
fixing (0,0,14) with x, r2=0.4498281478881836, c=1
fixing (0,1,0) with 0
fixing (0,1,1) with 0
fixing (0,1,2) with 0
fixing (0,1,3) with 0
fixing (0,1,4) with 0
fixing (0,1,5) with 0
fixing (0,1,6) with x, r2=1.000001072883606, c=1
fixing (0,1,7) with 0
fixing (0,1,8) with 0
fixing (0,1,9) with 0
fixing (0,1,10) with x, r2=1.000001072883606, c=1
fixing (0,1,11) with x, r2=1.0000020265579224, c=1
fixing

In [37]:
y_pred = dragon_kaam.predict(x_test, t_test)
y_pred_0 = y_pred['y_pred_0']
y_pred_1 = y_pred['y_pred_1']
y_pred_f = y_pred['y_pred_f']
ite_pred = y_pred_1 - y_pred_0
mse_kaam_formula = np.mean((y_test - y_pred_f) ** 2)
pehe_kaam_formula = np.sqrt(np.mean((real_ite_test - y_pred_f) ** 2))
print(f'MSE KAN formula: {mse_kaam_formula}, PEHE KAN formula: {pehe_kaam_formula}')
print(f'Predicted ATE formula: {ite_pred.mean()}, Real ATE: {real_ite_test.mean()}')
# evaluate formula performance

# y_pred = symb_s_kaam.predict(x_test)
# y_pred_0 = y_pred[:, 0]
# y_pred_1 = y_pred[:, 1]
# y_pred_f = t_test * y_pred_1 + (1 - t_test) * y_pred_0
# ite_pred = y_pred_1 - y_pred_0
# mse_formula = np.mean((y_test - y_pred_f)**2)
# pehe_formula = np.sqrt(np.mean((real_ite_test - ite_pred)**2))
# print(f'MSE formula: {mse_formula}, PEHE formula: {pehe_formula}')

print(formulas)
# extract the CATE
# in this case the CATE is the term that corresponds to the treatment 't'
# Break into additive terms

# Rebuild the expression
cate_formula = formulas[1] - formulas[0]

x_cols_ = ['x_' + col[1:] for col in x_cols]
test_df = pd.DataFrame(data_test.loc[:, x_cols].to_numpy(), columns=x_cols_)

# ite_pred_formula = eval_expr_on_df(cate_formula, data_test)
# pehe_formula_cate = np.sqrt(np.mean((real_ite_test - ite_pred_formula)**2))
# print(f'PEHE formula cate: {pehe_formula_cate}')
ite_pred_formula = get_formula_values(cate_formula,  test_df.copy())
pehe_formula_cate = np.sqrt(np.mean((real_ite_test - ite_pred_formula) ** 2))
print(f'PEHE formula cate: {pehe_formula_cate}')
print(f'ATE formula cate: {ite_pred_formula.mean()}, Real ATE: {real_ite_test.mean()}')
# formula_0_rounded = sp.sympify(sp.nsimplify(formulas[0], tolerance=1e-2))
# y_pred_0_trunc = eval_expr_on_df(formula_0_rounded, data_test)
# formula_1_rounded = sp.sympify(sp.nsimplify(formulas[0], tolerance=1e-2))
# y_pred_1_trunc = eval_expr_on_df(formula_1_rounded, data_test)
# y_pred_f_trunc = y_pred_0_trunc * (1 - t_test[:,0]) + y_pred_1_trunc * t_test[:,0]
# ite_pred_trunc = y_pred_1_trunc - y_pred_0_trunc
# mse_formula_trunc = np.mean((y_test[:,0] - y_pred_f_trunc)**2)
# pehe_formula_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_trunc)**2))
# print(f'MSE formula truncated: {mse_formula_trunc}, PEHE formula truncated: {pehe_formula_trunc}')


# truncate
formulas_trun = [ex_round(f, 2) for f in formulas]
y_pred_0_trunc = get_formula_values(formulas_trun[0],  test_df.copy())
y_pred_1_trunc = get_formula_values(formulas_trun[1],  test_df.copy())
y_pred_f_trunc = y_pred_0_trunc * (1 - t_test[:, 0]) + y_pred_1_trunc * t_test[:, 0]
ite_pred_trunc = y_pred_1_trunc - y_pred_0_trunc
mse_formula_trunc = np.mean((y_test[:, 0] - y_pred_f_trunc) ** 2)
pehe_formula_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_trunc) ** 2))
print(f'MSE formula truncated: {mse_formula_trunc}, PEHE formula truncated: {pehe_formula_trunc}')
print(f'Predicted ATE formula truncated: {ite_pred_trunc.mean()}, Real ATE: {real_ite_test.mean()}')

# #print pehe with cate formula truncated
# cate_formula_trunc = formulas_1_rounded - formulas_0_rounded
# ite_pred_formula_trunc = eval_expr_on_df(cate_formula_trunc, data_test)
# pehe_formula_cate_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_formula_trunc)**2))
# print(f'PEHE formula cate truncated: {pehe_formula_cate_trunc}')

#print pehe with cate formula truncated
cate_formula_trunc = formulas_trun[1] - formulas_trun[0]
ite_pred_formula_trunc = get_formula_values(cate_formula_trunc, test_df.copy())
pehe_formula_cate_trunc = np.sqrt(np.mean((real_ite_test - ite_pred_formula_trunc) ** 2))
print(f'PEHE formula cate truncated: {pehe_formula_cate_trunc}')
print(f'ATE formula cate truncated: {ite_pred_formula_trunc.mean()}, Real ATE: {real_ite_test.mean()}')
present = {s.name for s in formulas[0].free_symbols}
print(present)
ausent = set(x_cols) - present
print(ausent)
n_feats = len(list(present))

MSE KAN formula: 9.452967830202796, PEHE KAN formula: 5.077772417046369
Predicted ATE formula: 2.080573558807373, Real ATE: 4.229490757544409
[-0.0138*x_5 - 0.0161*x_53 - 0.0007*(-1.5942*x_1 - 8.7925)**2 + 0.0041*(-0.409*x_16 - 7.2934)**2 - 0.217*(-0.099*x_2 + 0.0607*x_26 + 0.102*x_32 - 0.0988*x_38 - 0.1456*x_53 + 0.009*(-1.0377*x_16 - 6.5972)**2 + 0.0016*(-4.4941*x_18 - 7.4151)**2 - 0.001*(-4.9898*x_39 - 8.5834)**2 + 0.0776*(-0.12*x_48 - 9.9959)**2 - 2.0393*exp(0.0789*x_52) + 0.4386*sin(5.0066*x_10 - 9.2191) + 0.6519*sin(0.4076*x_12 + 1.9907) + 0.3954*sin(0.3852*x_13 + 2.0046) + 0.135*sin(0.6656*x_20 + 8.4305) + 0.2448*sin(0.4446*x_27 - 7.39) + 0.5021*sin(0.3758*x_36 + 5.1904) + 0.3711*sin(0.5137*x_4 + 5.1971) + 0.522*sin(0.3226*x_45 - 1.2654) - 0.2976*sin(8.6072*x_46 - 9.0207) - 0.9974*sin(9.8146*x_49 - 5.4146) + 0.5101*sin(4.6011*x_9 - 8.5994) - 6.8547)*(-0.0205*x_21 + 0.0496*x_35 - 0.1237*x_49 + 0.0695*x_6 - 0.0754*x_7 + 0.0054*(-1.206*x_16 - 6.367)**2 + 0.0006*(-7.4*x_18 - 9.9998)

100%|██████████| 960/960 [44:17<00:00,  2.77s/it]


PEHE formula cate: 2.739394304776883
ATE formula cate: 2.0724885129080115, Real ATE: 4.229490757544409


100%|██████████| 960/960 [27:12<00:00,  1.70s/it]
100%|██████████| 960/960 [28:46<00:00,  1.80s/it]


MSE formula truncated: 10.527961262884023, PEHE formula truncated: 3.0660576165487687
Predicted ATE formula truncated: 1.394575236644213, Real ATE: 4.229490757544409


100%|██████████| 960/960 [36:44<00:00,  2.30s/it]

PEHE formula cate truncated: 3.066057616548771
ATE formula cate truncated: 1.3945752366442112, Real ATE: 4.229490757544409
{'x_21', 'x_45', 'x_35', 'x_4', 'x_14', 'x_38', 'x_43', 'x_15', 'x_27', 'x_56', 'x_20', 'x_6', 'x_3', 'x_29', 'x_23', 'x_34', 'x_17', 'x_12', 'x_48', 'x_44', 'x_50', 'x_13', 'x_30', 'x_55', 'x_16', 'x_10', 'x_9', 'x_8', 'x_37', 'x_2', 'x_11', 'x_19', 'x_1', 'x_32', 'x_18', 'x_53', 'x_31', 'x_33', 'x_24', 'x_57', 'x_47', 'x_7', 'x_54', 'x_5', 'x_26', 'x_41', 'x_40', 'x_58', 'x_52', 'x_25', 'x_39', 'x_42', 'x_36', 'x_46', 'x_49', 'x_51', 'x_28'}
{'x58', 'x38', 'x27', 'x34', 'x6', 'x4', 'x7', 'x8', 'x52', 'x56', 'x9', 'x15', 'x41', 'x54', 'x16', 'x29', 'x20', 'x28', 'x55', 'x47', 'x49', 'x35', 'x11', 'x21', 'x1', 'x10', 'x48', 'x5', 'x17', 'x51', 'x23', 'x25', 'x42', 'x2', 'x13', 'x22', 'x44', 'x40', 'x43', 'x32', 'x19', 'x37', 'x26', 'x39', 'x3', 'x18', 'x50', 'x24', 'x33', 'x14', 'x30', 'x36', 'x45', 'x31', 'x57', 'x46', 'x12', 'x53'}



