In [1]:
import datetime
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow_probability as tfp
import tensorflow.keras.layers as tfkl
tfd,tfpl = tfp.distributions,tfp.layers
import tensorflow.keras.backend as tfkb
from tensorflow.keras.callbacks import Callback
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD
from evaluation import *
from networks import *
################################################
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--scale_penalize',    type = float, default = 0.001,  help = '')
parser.add_argument('--learning_rate',     type = float, default = 0.001,  help = '')
parser.add_argument('--default_y_scale',   type = float, default = 1.,  help = '')
parser.add_argument('--t_dim',     type = int, default = 1,  help = '')
parser.add_argument('--y_dim',     type = int, default = 1,  help = '')
parser.add_argument('--x_dim',     type = int, default = 25, help = '')
parser.add_argument('--z_dim',     type = int, default = 20, help = '')
parser.add_argument('--x_num_dim', type = int, default = 6,  help = '')
parser.add_argument('--x_bin_dim', type = int, default = 19, help = '')
parser.add_argument('--nh', type = int, default = 3, help = 'number of hidden layers')
parser.add_argument('--h',  type = int, default = 200, help = 'number of hidden units')
args = parser.parse_args([])
################################################
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')
        ycf=np.concatenate((train_data['ycf'][:,i],  test_data['ycf'][:,i])).astype('float32')
        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        data['ycf'] = ycf.reshape(-1,1)
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])
    return data

ind = 7
rep = 1
data = load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz',i = ind)
for key in data:
    if key != 'y_scaler':
        data[key] = np.repeat(data[key],repeats = rep, axis = 0)
data['y_scaler'].mean_, data['y_scaler'].scale_

In [3]:
class CEVAE(tf.keras.Model):
    def __init__(self):
        super(CEVAE, self).__init__()
        ########################################
        # networks
        self.activation = 'elu'
        # CEVAE Model 
        ## (encoder)
        self.q_y_tx = q_y_tx(args.x_bin_dim, args.x_num_dim, args.y_dim, args.t_dim, args.nh, args.h)
        self.q_t_x = q_t_x(args.x_bin_dim, args.x_num_dim, args.t_dim, args.nh, args.h)
        self.q_z_txy = q_z_txy(args.x_bin_dim, args.x_num_dim, args.y_dim, args.t_dim, args.z_dim, args.nh, args.h)
        ## (decoder)
        self.p_x_z = p_x_z(args.x_bin_dim, args.x_num_dim, args.z_dim, args.nh, args.h)
        self.p_t_z = p_t_z(args.t_dim, args.z_dim, args.nh, args.h)
        self.p_y_tz = p_y_tz(args.y_dim, args.t_dim, args.z_dim, args.nh, args.h)
        

    def call(self, data, training=False):
        if training:
            x_train,t_train = data
            # encoder
            t_infer = self.q_t_x(x_train)
            t_infer_sample = tf.cast(t_infer.sample(), tf.float32)
            
            y_infer = self.q_y_tx(x_train)
            y0_infer, y1_infer = y_infer
            y_infer_sample = y0_infer.sample() * (1-t_infer_sample) + y1_infer.sample() * t_infer_sample
            
            txy = tf.concat([tf.cast(t_infer_sample,tf.float32), y_infer_sample, x_train],-1)
            z_infer = self.q_z_txy(txy)
            z_infer_sample = z_infer.sample()
            # decoder
            ## p(x|z)
            x_num,x_bin = self.p_x_z(z_infer_sample)
            ## p(t|z)
            t = self.p_t_z(z_infer_sample)
            ## p(y|t,z)
            y = self.p_y_tz(tf.concat([t_train,z_infer_sample],-1) )
            
            return y_infer,t_infer,z_infer,y,t,x_num,x_bin
        else:
            x_train = data
            # encoder
            t_infer = self.q_t_x(x_train)
            t_infer_sample = tf.cast(t_infer.sample(), tf.float32)
        
            y_infer = self.q_y_tx(x_train)
            y0_infer, y1_infer = y_infer
            y_infer_sample = y0_infer.sample() * (1-t_infer_sample) + y1_infer.sample() * t_infer_sample
            
            txy = tf.concat([tf.cast(t_infer_sample,tf.float32), y_infer_sample, x_train],-1)
            z_infer = self.q_z_txy(txy)
            z_infer_sample = z_infer.sample()
            return y_infer,t_infer,z_infer


    def cevae_loss(self, data, pred, training = False):
        x_train, t_train, y_train = data[0],data[1],data[2]
        x_train_num, x_train_bin = x_train[:,:args.x_num_dim],x_train[:,args.x_num_dim:]
        y_infer,t_infer,z_infer,y,t,x_num,x_bin = pred
        y0,y1 = y_infer
        # reconstruct loss
        recon_x_num = tfkb.sum(x_num.log_prob(x_train_num), 1)
        recon_x_bin = tfkb.sum(x_bin.log_prob(x_train_bin), 1)
        recon_y = tfkb.sum(y.log_prob(y_train), 1)
        recon_t = tfkb.sum(t.log_prob(t_train), 1)
        # kl loss
        z_infer_sample = z_infer.sample()
        z = tfd.Normal(loc = [0] * 20, scale = [1]*20)
        kl_z = tfkb.sum((z.log_prob(z_infer_sample) - z_infer.log_prob(z_infer_sample)), -1)
        # aux loss
        aux_y = tfkb.sum(y0.log_prob(y_train)*(1-t_train) + y1.log_prob(y_train)* t_train, 1)
        aux_t = tfkb.sum(t_infer.log_prob(t_train), 1)
        loss = -tfkb.mean(recon_x_bin + recon_x_num + recon_y + recon_t + aux_y + aux_t + kl_z)
        return loss

    def train_step(self, data):
        data = data[0]
        x,t,y = data
        with tf.GradientTape() as tape:
            pred = self([x,t], training=True)  # Forward pass
            loss = self.cevae_loss(data = data, pred = pred, training = True)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        metrics = {"loss": loss}
        return metrics

    def test_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        data = data[0]
        x,t,y = data
        with tf.GradientTape() as tape:
            pred = self([x,t], training=True)  # Forward pass
        y_infer = pred[0]
        loss = self.cevae_loss(data = data, pred = pred, training = False)
        y0, y1 = y_infer[0].sample(),y_infer[1].sample()
        ate = tfkb.mean(y1) - tfkb.mean(y0)
        metrics = {"loss":loss,"y0": tfkb.mean(y0),"y1": tfkb.mean(y1),'ate_afte_scaled': ate}
        return metrics


In [4]:
#Colab command to allow us to run Colab in TF2
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
%reload_ext tensorboard 

model = CEVAE()
### MAIN CODE ####
val_split=0.2
batch_size=64
verbose=True
i = 0
tf.random.set_seed(i)
np.random.seed(i)
 
callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='loss', patience=40, min_delta=0), 
        #40 is Shi's recommendation patience for this dataset, but you should tune for your data 
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0, cooldown=0, min_lr=0),
        #This learning rate scheduling is quite agressive which seems good for this dataset
        metrics_for_cevae(data,verbose),
        tensorboard_callback
    ]
    
#optimizer hyperparameters
learning_rate = 5e-5
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = learning_rate))

model.fit(
    [data['x'],data['t'],data['ys']],
    callbacks=callbacks,
    validation_split=val_split,
    epochs=300,
    batch_size=200,
    verbose=verbose
    )
print("Done!")

2022-03-08 22:46:47.769632: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Epoch 1/140
Epoch 2/140
Epoch 3/140
 — ite: 4.7613  — ate: 3.3621 — pehe: 5.3596 
Epoch 4/140
Epoch 5/140
Epoch 6/140
Epoch 7/140
Epoch 8/140
Epoch 9/140
Epoch 10/140
Epoch 11/140
Epoch 12/140
Epoch 13/140
Epoch 14/140
Epoch 15/140
Epoch 16/140
Epoch 17/140
Epoch 18/140
Epoch 19/140
Epoch 20/140
Epoch 21/140
Epoch 22/140
Epoch 23/140
Epoch 24/140
Epoch 25/140
Epoch 26/140
Epoch 27/140
Epoch 28/140
Epoch 29/140
Epoch 30/140
Epoch 31/140
Epoch 32/140
Epoch 33/140
Epoch 34/140
Epoch 35/140
Epoch 36/140
Epoch 37/140
Epoch 38/140
Epoch 39/140
Epoch 40/140
Epoch 41/140
Done!


In [5]:
%tensorboard --logdir logs/fit

Reusing TensorBoard on port 6006 (pid 11635), started 1 day, 6:38:38 ago. (Use '!kill 11635' to kill it.)