In [37]:
# import packages
import numpy as np
from numpy import random
from scipy import stats
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split


In [2]:
data = pd.read_csv('/Users/arberimbibaj/dataset_example_100.csv', header=None, index_col=[0])
data = data.to_numpy()
data

array([[ 0.718867,  0.648899,  0.724323, ...,  0.195425,  0.179164,
         0.      ],
       [ 0.118287,  0.265923,  0.685177, ...,  0.604447,  0.117696,
         1.      ],
       [ 4.359214,  0.005596,  0.94837 , ...,  0.647872,  0.101803,
         1.      ],
       ...,
       [ 0.972673,  0.022176,  0.441456, ...,  0.452264,  0.887922,
         0.      ],
       [-0.937569,  0.302652,  0.525748, ...,  0.838342,  0.133374,
         0.      ],
       [ 0.861966,  0.721366,  0.108411, ...,  0.181184,  0.202419,
         0.      ]])

In [25]:
# train test split
random.shuffle(data)
training, test = data[:70,:], data[70:,:]

array([[ 0.096022,  0.722591,  0.605033, ...,  0.994083,  0.540669,
         0.      ],
       [ 2.153544,  0.435643,  0.269808, ...,  0.062082,  0.461419,
         1.      ],
       [ 1.059121,  0.861333,  0.593707, ...,  0.284559,  0.14764 ,
         0.      ],
       ...,
       [ 0.537154,  0.963096,  0.31989 , ...,  0.017404,  0.421527,
         1.      ],
       [-0.515086,  0.402986,  0.752111, ...,  0.359433,  0.498578,
         1.      ],
       [ 0.199952,  0.403424,  0.998301, ...,  0.123431,  0.538963,
         1.      ]])

In [34]:
# slice dataset by treatment status
training_control = training[training[:,26]==0]
training_treatment = training[training[:,26]==1]

Y_train_control = training_control[:,0]
Y_train_treatment = training_treatment[:,0]

X_train_control = training_control[:,1:26]
X_train_treatment = training_treatment[:,1:26]

X_test = test[:,1:26]
Y_test = test[:,0]

In [27]:
X_train_control

array([[0.722591, 0.605033, 0.148307, 0.294102, 0.99058 , 0.988063,
        0.559625, 0.347653, 0.590329, 0.640423, 0.843867, 0.294543,
        0.664212, 0.72776 , 0.957712, 0.424071, 0.185662, 0.080284,
        0.845604, 0.43544 , 0.392399, 0.056314, 0.264511, 0.994083,
        0.540669],
       [0.861333, 0.593707, 0.808938, 0.723056, 0.053708, 0.70119 ,
        0.353835, 0.488408, 0.235718, 0.418138, 0.332183, 0.33083 ,
        0.953591, 0.885608, 0.498231, 0.383914, 0.787306, 0.158138,
        0.335693, 0.806131, 0.483416, 0.137373, 0.674744, 0.284559,
        0.14764 ],
       [0.145459, 0.413449, 0.712938, 0.352809, 0.885702, 0.981943,
        0.057673, 0.347297, 0.905632, 0.090792, 0.968038, 0.588586,
        0.677071, 0.720137, 0.648047, 0.205447, 0.588443, 0.525901,
        0.392362, 0.04025 , 0.639459, 0.993659, 0.938   , 0.194606,
        0.667146],
       [0.721366, 0.108411, 0.142024, 0.954121, 0.835531, 0.47677 ,
        0.149934, 0.727651, 0.261154, 0.951585, 0.424136, 0

In [28]:
X_train_treatment

array([[0.435643, 0.269808, 0.74557 , 0.630237, 0.501181, 0.812563,
        0.104974, 0.221112, 0.771683, 0.881971, 0.726509, 0.762375,
        0.683061, 0.351708, 0.87476 , 0.222745, 0.646945, 0.295711,
        0.678499, 0.045482, 0.380957, 0.352887, 0.379375, 0.062082,
        0.461419],
       [0.265923, 0.685177, 0.319284, 0.998351, 0.568025, 0.481667,
        0.866343, 0.395225, 0.74766 , 0.195265, 0.623567, 0.848222,
        0.690514, 0.719151, 0.245627, 0.881458, 0.592953, 0.170541,
        0.022372, 0.033906, 0.983374, 0.212436, 0.787805, 0.604447,
        0.117696],
       [0.260888, 0.207341, 0.985302, 0.812787, 0.64505 , 0.570579,
        0.907508, 0.822555, 0.157392, 0.840672, 0.58867 , 0.818344,
        0.92251 , 0.558281, 0.534803, 0.120838, 0.31206 , 0.06278 ,
        0.145352, 0.818076, 0.312634, 0.447392, 0.817622, 0.269084,
        0.76826 ],
       [0.390691, 0.814292, 0.632031, 0.809357, 0.333434, 0.472062,
        0.084298, 0.032439, 0.181092, 0.933073, 0.420252, 0

In [29]:
Y_train_control

array([ 0.096022,  1.059121,  1.443843,  0.861966,  1.411428, -0.105323,
       -0.792367,  0.399431, -1.588222,  1.26759 ,  0.162921,  0.685683,
       -1.692155,  2.660242, -1.988907, -0.64844 , -0.225472, -0.076656,
        0.718867,  2.339844, -0.110206,  0.021245,  1.199751, -1.091402,
        1.676444, -0.997404, -1.391332,  0.436795,  2.207135,  2.078643,
        0.823328,  1.488837,  1.021495,  0.972673,  1.094586,  0.343313,
        0.651447])

In [30]:
Y_train_treatment

array([ 2.153544,  0.118287, -0.063521, -1.719049,  2.731482,  0.020138,
       -1.131438, -0.718712, -1.354425,  4.008102,  4.359214,  0.759627,
       -0.194273,  2.295531, -0.899656, -1.659883, -0.887771,  1.494098,
        1.978757, -0.342572,  2.056347, -0.198546,  1.066908,  0.649351,
       -1.827348, -1.050771,  3.651505,  2.179798,  0.506924, -0.058619,
        0.537154, -0.515086,  0.199952])

In [32]:
X_test

array([[3.02652e-01, 5.25748e-01, 5.37146e-01, 1.91750e-01, 4.51150e-02,
        3.82480e-02, 4.76103e-01, 7.98595e-01, 8.23879e-01, 5.81834e-01,
        7.30635e-01, 7.67481e-01, 7.18300e-01, 2.52827e-01, 7.35375e-01,
        8.77869e-01, 2.62073e-01, 7.36757e-01, 9.80805e-01, 2.42059e-01,
        1.31260e-02, 5.55227e-01, 1.87711e-01, 8.38342e-01, 1.33374e-01],
       [8.17221e-01, 6.50119e-01, 5.84933e-01, 6.41077e-01, 1.99870e-02,
        7.55700e-03, 3.93681e-01, 5.14840e-02, 6.69315e-01, 3.48739e-01,
        8.73741e-01, 5.62906e-01, 2.81863e-01, 9.23037e-01, 9.47241e-01,
        7.39086e-01, 4.27607e-01, 7.07159e-01, 3.72451e-01, 5.74878e-01,
        8.15337e-01, 4.09357e-01, 9.25086e-01, 1.53485e-01, 9.82997e-01],
       [2.95626e-01, 4.54680e-02, 2.52849e-01, 2.96321e-01, 5.41435e-01,
        1.12139e-01, 6.40980e-02, 7.87582e-01, 3.79925e-01, 7.71478e-01,
        8.25936e-01, 8.16994e-01, 3.56166e-01, 6.96463e-01, 5.87616e-01,
        4.81193e-01, 5.67686e-01, 6.83591e-01, 7.

In [35]:
Y_test

array([-0.937569,  0.39078 ,  0.610402,  1.653261, -0.342207,  0.292602,
       -0.82958 ,  3.55141 , -2.901602,  2.354513, -1.56435 , -1.16988 ,
       -0.892039,  0.422965, -0.150666, -0.520459, -2.302113,  0.709951,
       -0.624427,  2.776141, -0.842584, -1.638599,  0.218614,  1.736505,
        0.246882,  0.216607,  0.302452,  0.563452,  0.633437, -0.572658])

# T-Learner

In [96]:
# T-Learner (example with Random Forest)

# mu_0
t_learner_mu0 = RandomForestRegressor(max_depth=100, random_state=0)
t_learner_mu0.fit(X_train_control,Y_train_control)
mu_0_hat = t_learner_mu0.predict(X_test)
print(mu_0_hat)

# mu_1
t_learner_mu1 = RandomForestRegressor(max_depth=100, random_state=0)
t_learner_mu1.fit(X_train_treatment,Y_train_treatment)
mu_1_hat = t_learner_mu1.predict(X_test)
print(mu_1_hat)
# Prediction = mu_1 - mu_0
tau_hat = mu_1_hat - mu_0_hat
print(tau_hat)

[-0.0249546   0.72419391  0.85978762  0.35031942  0.10363848  0.47603894
  0.70959731 -0.01555471  0.23261559  0.35254889  0.5349197   0.25480334
  0.38294181  0.54537079  0.24786227  0.78324506  0.25738765  0.71397781
  0.51648798  0.6916348   0.26005858 -0.19097463  0.19875135  0.55794635
  0.41444289  0.67843834  0.52330465  0.50363954  0.4532404   0.89167669]
[ 1.33669887  0.77153334  0.59200429  0.42061397  1.03830598  0.403372
  0.5536999   1.41196115  1.65462477  0.2502382   0.22707517  0.44730257
  1.04115177  0.82383456 -0.03616677  0.56881277  0.4270187   0.70489063
  0.69681991  0.40105748  1.120523    1.7622537   1.44370243  1.59588783
  1.08073792  0.27019903  0.95505191  0.6806814  -0.01834505  0.41374704]
[ 1.36165347  0.04733943 -0.26778333  0.07029455  0.9346675  -0.07266694
 -0.15589741  1.42751586  1.42200918 -0.10231069 -0.30784453  0.19249923
  0.65820996  0.27846377 -0.28402904 -0.21443229  0.16963105 -0.00908718
  0.18033193 -0.29057732  0.86046442  1.95322833  1

TODO: Create an example csv dataset with the true values for CATE aka tau!

In [None]:
# S-Learner

In [None]:
# X-Learner

In [None]:
# R-Learner

In [None]:
# DR-Learner

In [None]:
# RA-Learner

In [None]:
# PW-Learner

In [None]:
# F-Learner

In [None]:
# U-Learner