# DragonNet

__Reference__: [Claudia Shi et al, Adapting Neural Networks for the Estimation of Treatment Effects, NeurIPS 2019](https://arxiv.org/pdf/1906.02120v2.pdf)

__Implementation remarks__: our implementation is exactly the same of the original paper with the exception 
    of a _sklearn.preprocessing.StandardScaler_ which was originally used to scale predictions. 

## DragonNet on IHDP 

In [1]:
from causalforge.model import Model , PROBLEM_TYPE
from causalforge.data_loader import DataLoader 

# load IHDP dataset 
r = DataLoader.get_loader('IHDP').load()
X_tr, T_tr, YF_tr, YCF_tr, mu_0_tr, mu_1_tr, X_te, T_te, YF_te, YCF_te, mu_0_te, mu_1_te = r

# model 
params={}
params['input_dim'] = X_tr.shape[1] 
    
    
dragonnet = Model.create_model("dragonnet",
                               params,
                               problem_type=PROBLEM_TYPE.CAUSAL_TREATMENT_EFFECT_ESTIMATION, 
                               multiple_treatments=False)

dragonnet.model.summary()

2023-05-05 16:08:33.088174: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, 25)]         0           []                               
                                                                                                  
 dense (Dense)                  (None, 200)          5200        ['input[0][0]']                  
                                                                                                  
 dense_1 (Dense)                (None, 200)          40200       ['dense[0][0]']                  
                                                                                                  
 dense_2 (Dense)                (None, 200)          40200       ['dense_1[0][0]']                
                                                                                              

2023-05-05 16:08:36.229178: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


 dense_4 (Dense)                (None, 100)          20100       ['dense_2[0][0]']                
                                                                                                  
 dense_5 (Dense)                (None, 100)          20100       ['dense_2[0][0]']                
                                                                                                  
 dense_6 (Dense)                (None, 100)          10100       ['dense_4[0][0]']                
                                                                                                  
 dense_7 (Dense)                (None, 100)          10100       ['dense_5[0][0]']                
                                                                                                  
 dense_3 (Dense)                (None, 1)            201         ['dense_2[0][0]']                
                                                                                                  
 y0_predic

In [2]:
from causalforge.metrics import eps_ATE_diff, PEHE_with_ite
import numpy as np

experiment_ids = [1,10,400]

eps_ATE_tr, eps_ATE_te = [], []
eps_PEHE_tr, eps_PEHE_te = [] , [] 



for idx in experiment_ids:    
    t_tr, y_tr, x_tr, mu0tr, mu1tr = T_tr[:,idx] , YF_tr[:,idx], X_tr[:,:,idx], mu_0_tr[:,idx], mu_1_tr[:,idx] 
    t_te, y_te, x_te, mu0te, mu1te = T_te[:,idx] , YF_te[:,idx], X_te[:,:,idx], mu_0_te[:,idx], mu_1_te[:,idx]  
    
    
    # Train your causal method on train-set ...
    dragonnet.fit(x_tr,t_tr,y_tr)

    # Validate your method test-set ... 
    ATE_truth_tr = (mu1tr - mu0tr).mean()
    ATE_truth_te = (mu1te - mu0te).mean()
    
    ITE_truth_tr = (mu1tr - mu0tr)
    ITE_truth_te = (mu1te - mu0te)
    
    eps_ATE_tr.append( eps_ATE_diff( dragonnet.predict_ite(x_tr), ITE_truth_tr) )
    eps_ATE_te.append( eps_ATE_diff( dragonnet.predict_ite(x_te), ITE_truth_te) )
    
    eps_PEHE_tr.append( PEHE_with_ite( dragonnet.predict_ite(x_tr), ITE_truth_tr, sqrt=True))
    eps_PEHE_te.append( PEHE_with_ite(dragonnet.predict_ite(x_te), ITE_truth_te , sqrt=True))
        

Epoch 1/30


  super().__init__(name, **kwargs)


Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 1/100


  super().__init__(name, **kwargs)


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100


Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
1/9 [==>...........................] - ETA: 0s - loss: 151.6492 - regression_loss: 59.9757 - binary_classification_loss: 27.0230 - treatment_accuracy: 0.8125 - track_epsilon: 0.0037
Epoch 42: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
Epoch 43/100
Epoch 44/100


Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
1/9 [==>...........................] - ETA: 0s - loss: 192.3869 - regression_loss: 78.4169 - binary_classification_loss: 30.3007 - treatment_accuracy: 0.8125 - track_epsilon: 0.0048
Epoch 51: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
1/9 [==>...........................] - ETA: 0s - loss: 146.3898 - regression_loss: 55.3150 - binary_classification_loss: 30.9340 - treatment_accuracy: 0.7812 - track_epsilon: 0.0024
Epoch 58: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06.
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
1/9 [==>...........................] - ETA: 0s - loss: 198.6789 - regression_loss: 80.6731 - binary_classification_loss: 32.6109 - treatment_accuracy: 0.7812 - track_epsilon: 0.0027
Epoch 63: ReduceLROnPlateau reducing learning rate to

Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
1/9 [==>...........................] - ETA: 0s - loss: 145.4354 - regression_loss: 57.7047 - binary_classification_loss: 25.1966 - treatment_accuracy: 0.8594 - track_epsilon: 0.0025
Epoch 70: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07.
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 1/100
Epoch 2/100


Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
1/9 [==>...........................] - ETA: 0s - loss: 149.7285 - regression_loss: 56.0357 - binary_classification_loss: 32.9105 - treatment_accuracy: 0.7656 - track_epsilon: 0.0060
Epoch 14: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
1/9 [==>...........................] - ETA: 0s - loss: 149.1883 - regression_loss: 57.4775 - binary_classification_loss: 29.5373 - treatment_accuracy: 0.7969 - track_epsilon: 0.0037
Epoch 20: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.
Epoch 21/100
Epoch 22/100
Epoch 23/100


Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
1/9 [==>...........................] - ETA: 0s - loss: 130.1864 - regression_loss: 44.9679 - binary_classification_loss: 35.5117 - treatment_accuracy: 0.7188 - track_epsilon: 0.0030
Epoch 38: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06.
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100


Epoch 46/100
Epoch 47/100
1/9 [==>...........................] - ETA: 0s - loss: 110.4949 - regression_loss: 38.3722 - binary_classification_loss: 28.8350 - treatment_accuracy: 0.8125 - track_epsilon: 0.0028
Epoch 47: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07.
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
1/9 [==>...........................] - ETA: 0s - loss: 138.6945 - regression_loss: 53.9106 - binary_classification_loss: 26.0309 - treatment_accuracy: 0.8438 - track_epsilon: 0.0032
Epoch 52: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07.
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
1/9 [==>...........................] - ETA: 0s - loss: 152.7404 - regression_loss: 58.0996 - binary_classification_loss: 31.7209 - treatment_accuracy: 0.7812 - track_epsilon: 0.0028
Epoch 57: ReduceLROnPlateau reducing learning rate to 1.56249996052793e-07.
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch

Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
1/9 [==>...........................] - ETA: 0s - loss: 152.2102 - regression_loss: 58.5109 - binary_classification_loss: 30.3595 - treatment_accuracy: 0.7969 - track_epsilon: 0.0028
Epoch 68: ReduceLROnPlateau reducing learning rate to 3.906249901319825e-08.
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30


Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100


Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100


1/9 [==>...........................] - ETA: 0s - loss: 167.4697 - regression_loss: 67.0101 - binary_classification_loss: 28.9756 - treatment_accuracy: 0.8281 - track_epsilon: 0.0149
Epoch 34: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
1/9 [==>...........................] - ETA: 0s - loss: 131.5742 - regression_loss: 49.9464 - binary_classification_loss: 26.5770 - treatment_accuracy: 0.7812 - track_epsilon: 0.0155
Epoch 43: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
1/9 [==>...........................] - ETA: 0s - loss: 160.3855 - regression_loss: 65.7391 - binary_classification_loss: 23.7671 - treatment_accuracy: 0.8750 - track_epsilon: 0.0145
Epoch 50: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06.
Epoch 51/100
Ep

Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
1/9 [==>...........................] - ETA: 0s - loss: 152.0521 - regression_loss: 58.7837 - binary_classification_loss: 31.3957 - treatment_accuracy: 0.7656 - track_epsilon: 0.0142
Epoch 61: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07.
Epoch 62/100
Epoch 63/100


## Results 

In [3]:
import pandas as pd 

pd.DataFrame([[np.mean(eps_ATE_tr),np.mean(eps_ATE_te),np.mean(eps_PEHE_tr),np.mean(eps_PEHE_te)]],
             columns=['eps_ATE_tr','eps_ATE_te','eps_PEHE_tr','eps_PEHE_te'], 
             index=['DragonNet'])

Unnamed: 0,eps_ATE_tr,eps_ATE_te,eps_PEHE_tr,eps_PEHE_te
DragonNet,0.091358,0.080693,0.654876,0.649264
