In [None]:
import datetime
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow_probability as tfp
import tensorflow.keras.layers as tfkl
tfd,tfpl = tfp.distributions,tfp.layers
import tensorflow.keras.backend as tfkb
from tensorflow.keras.callbacks import Callback
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD
from evaluation import *
from cevae_networks import *
################################################
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--scale_penalize',    type = float, default = 0.001,  help = '')
parser.add_argument('--learning_rate',     type = float, default = 0.001,  help = '')
parser.add_argument('--default_y_scale',   type = float, default = 1.,  help = '')
parser.add_argument('--t_dim',     type = int, default = 1,  help = '')
parser.add_argument('--y_dim',     type = int, default = 1,  help = '')
parser.add_argument('--x_dim',     type = int, default = 25, help = '')
parser.add_argument('--z_dim',     type = int, default = 20, help = '')
parser.add_argument('--x_num_dim', type = int, default = 6,  help = '')
parser.add_argument('--x_bin_dim', type = int, default = 19, help = '')
parser.add_argument('--val_split', type = float, default = 0.2, help = '')
parser.add_argument('--batch_size', type = int, default = 256, help = '')
parser.add_argument('--nh', type = int, default = 3, help = 'number of hidden layers')
parser.add_argument('--h',  type = int, default = 200, help = 'number of hidden units')
args = parser.parse_args([])
################################################
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')
        ycf=np.concatenate((train_data['ycf'][:,i],  test_data['ycf'][:,i])).astype('float32')
        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        data['ycf'] = ycf.reshape(-1,1)
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])
    return data

ind = 7
# rep = 5
# rep = 1
# data = load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz',i = ind)
# for key in data:
#     if key != 'y_scaler':
#         data[key] = np.repeat(data[key],repeats = rep, axis = 0)
# np.shape(data['x'])
data_train = load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.train.npz',i = ind)
data_valid = load_IHDP_data(training_data='./ihdp_npci_1-100.test.npz',testing_data='./ihdp_npci_1-100.test.npz',i = ind)
np.shape(data_train['x'])

In [None]:
class EpsilonLayer(tfkl.Layer):
    def __init__(self):
        super(EpsilonLayer, self).__init__()
    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.epsilon = self.add_weight(name='epsilon',
                                       shape=[1, 1],
                                       initializer='RandomNormal',
                                       #  initializer='ones',
                                       trainable=True)
        super(EpsilonLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, inputs, **kwargs):
        #note there is only one epsilon were just duplicating it for conformability
        return self.epsilon * tf.ones_like(inputs)[:, 0:1]

class CEWAE(tf.keras.Model):
    def __init__(self, kernel = "IMQ"):
        super(CEWAE, self).__init__()
        ########################################
        # networks
        self.activation = 'elu'
        self.kernel = kernel
        # CEVAE Model 
        ## (encoder)
        self.q_y_tx = q_y_tx(args.x_bin_dim, args.x_num_dim, args.y_dim, args.t_dim, args.nh, args.h)
        self.q_t_x = q_t_x(args.x_bin_dim, args.x_num_dim, args.t_dim, args.nh, args.h)
        self.q_z_txy = q_z_txy(args.x_bin_dim, args.x_num_dim, args.y_dim, args.t_dim, args.z_dim, args.nh, args.h)
        ## (decoder)
        self.p_x_z = p_x_z(args.x_bin_dim, args.x_num_dim, args.z_dim, args.nh, args.h)
        self.p_t_z = p_t_z(args.t_dim, args.z_dim, args.nh, args.h)
        self.p_y_tz = p_y_tz(args.y_dim, args.t_dim, args.z_dim, args.nh, args.h)
        self.epsilon_layer = EpsilonLayer()
        self.beta = 1
        self.lmbda = 1

    def call(self, data, training=False):
        if training:
            x_train,t_train = data
            # encoder
            t_infer = self.q_t_x(x_train)
            t_infer_sample = tf.cast(t_infer.sample(), tf.float32)
            
            y_infer = self.q_y_tx(x_train)
            y0_infer, y1_infer = y_infer
            y_infer_sample = y0_infer.sample() * (1-t_infer_sample) + y1_infer.sample() * t_infer_sample
            
            txy = tf.concat([tf.cast(t_infer_sample,tf.float32), y_infer_sample, x_train],-1)
            z_infer = self.q_z_txy(txy)
            z_infer_sample = z_infer.sample()
            # decoder
            ## p(x|z)
            x_num,x_bin = self.p_x_z(z_infer_sample)
            ## p(t|z)
            t = self.p_t_z(z_infer_sample)
            ## p(y|t,z)
            t0z = tf.concat([tf.zeros_like(t_train),z_infer_sample],-1)
            t1z = tf.concat([tf.ones_like(t_train),z_infer_sample],-1)
            y0 = self.p_y_tz(t0z)
            y1 = self.p_y_tz(t1z)
            y = [y0,y1]
            epsilon = self.epsilon_layer(t_infer_sample)
            
            return y_infer,t_infer,z_infer,y,t,x_num,x_bin,epsilon
        else:
            x_train = data
            # encoder
            t_infer = self.q_t_x(x_train)
            t_infer_sample = tf.cast(t_infer.sample(), tf.float32)
            y_infer = self.q_y_tx(x_train)
            y0_infer, y1_infer = y_infer
            y_infer_sample = y0_infer.sample() * (1-t_infer_sample) + y1_infer.sample() * t_infer_sample
            txy = tf.concat([tf.cast(t_infer_sample,tf.float32), y_infer_sample, x_train],-1)
            z_infer = self.q_z_txy(txy)
            z_infer_sample = z_infer.loc

            t1z = tf.concat([tf.ones_like(t_infer_sample),z_infer_sample],-1)
            t0z = tf.concat([tf.zeros_like(t_infer_sample),z_infer_sample],-1)
            y0 = self.p_y_tz(t0z)
            y1 = self.p_y_tz(t1z)
            y = [y0,y1]
            return y,t_infer,z_infer

    def mmd_penalty(self, sample_qz, sample_pz, batch_size = args.batch_size):
        opts = {'kernel': self.kernel, 'verbose':True, "zdim":20, "pz":"normal"} 
        sigma2_p = 1 ** 2
        kernel = opts['kernel']
        n = tf.shape(sample_qz)[0]
        n = tf.cast(n, tf.int32)
        nf = tf.cast(n, tf.float32)
        half_size = (n * n - n) / 2

        norms_pz = tf.reduce_sum(tf.square(sample_pz), axis=1, keepdims=True)
        dotprods_pz = tf.matmul(sample_pz, sample_pz, transpose_b=True)
        distances_pz = norms_pz + tf.transpose(norms_pz) - 2. * dotprods_pz

        norms_qz = tf.reduce_sum(tf.square(sample_qz), axis=1, keepdims=True)
        dotprods_qz = tf.matmul(sample_qz, sample_qz, transpose_b=True)
        distances_qz = norms_qz + tf.transpose(norms_qz) - 2. * dotprods_qz

        dotprods = tf.matmul(sample_qz, sample_pz, transpose_b=True)
        distances = norms_qz + tf.transpose(norms_pz) - 2. * dotprods

        if kernel == 'RBF':
            # Median heuristic for the sigma^2 of Gaussian kernel
            sigma2_k = tf.compat.v1.nn.top_k(
                tf.reshape(distances, [-1]), half_size).values[half_size - 1]
            sigma2_k += tf.compat.v1.top_k(
                tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1]
            # Maximal heuristic for the sigma^2 of Gaussian kernel
            # sigma2_k = tf.nn.top_k(tf.reshape(distances_qz, [-1]), 1).values[0]
            # sigma2_k += tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]
            # sigma2_k = opts['latent_space_dim'] * sigma2_p
            if opts['verbose']:
                sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:')
            res1 = tf.exp( - distances_qz / 2. / sigma2_k)
            res1 += tf.exp( - distances_pz / 2. / sigma2_k)
            res1 = tf.multiply(res1, 1. - tf.eye(n))
            res1 = tf.reduce_sum(res1) / (nf * nf - nf)
            res2 = tf.exp( - distances / 2. / sigma2_k)
            res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
            stat = res1 - res2
        elif kernel == 'IMQ':
            # k(x, y) = C / (C + ||x - y||^2)
            # C = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1]
            # C += tf.nn.top_k(tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1]
            if opts['pz'] == 'normal':
                Cbase = 2. * opts['zdim'] * sigma2_p
            elif opts['pz'] == 'sphere':
                Cbase = 2.
            elif opts['pz'] == 'uniform':
                # E ||x - y||^2 = E[sum (xi - yi)^2]
                #               = zdim E[(xi - yi)^2]
                #               = const * zdim
                Cbase = opts['zdim']
            stat = 0.
            for scale in [.1, .2, .5, 1., 2., 5., 10.]:
                C = Cbase * scale
                res1 = C / (C + distances_qz)
                res1 += C / (C + distances_pz)
                res1 = tf.multiply(res1, 1. - tf.eye(n))
                res1 = tf.reduce_sum(res1) / (nf * nf - nf)
                res2 = C / (C + distances)
                res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
                stat += res1 - res2
        return stat

    def cewae_loss(self, data, pred, training = False):
        x_train, t_train, y_train = data[0],data[1],data[2]
        x_train_num, x_train_bin = x_train[:,:args.x_num_dim],x_train[:,args.x_num_dim:]
        y_infer,t_infer,z_infer,y,t,x_num,x_bin,epsilon = pred
        # y0_infer,y1_infer = y_infer
        y0,y1 = y
        # reconstruct loss
        rec_x_num = tfkb.mean(tf.math.square(x_train_num - x_num.sample()))
        rec_x_bin = tf.reduce_sum(
            tfk.losses.binary_crossentropy(
                x_train_bin,
                tf.cast(x_bin.sample(),tf.float32),
                from_logits=False))
        rec_t_bin = tf.reduce_sum(
            tfk.losses.binary_crossentropy(
                t_train,
                tf.cast(t_infer.sample(),tf.float32),
                from_logits=False))
        rec_y0 = tf.math.square(y0.sample() - y_train)
        rec_y1 = tf.math.square(y1.sample() - y_train)
        rec_y = tfkb.mean(t_train * rec_y1 + (1-t_train)* rec_y0)
        # regularization
        # mmd penalty
        pz = tfd.Normal(loc = tf.zeros_like(z_infer.sample()), scale = tf.ones_like(z_infer.sample()))
        reg_mmd = self.mmd_penalty(z_infer.sample(), pz.sample())

        loss = rec_x_num + rec_x_bin + rec_t_bin + rec_y + reg_mmd * self.lmbda
        loss = rec_x_num + rec_x_bin + rec_t_bin + rec_y + reg_mmd
        # loss = rec_x_num + rec_x_bin + rec_t_bin + rec_y 
        return loss

    def train_step(self, data):
        data = data[0]
        x,t,y = data
        with tf.GradientTape() as tape:
            pred = self([x,t], training=True)  # Forward pass
            loss = self.cewae_loss(data = data, pred = pred, training = True)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        metrics = {"loss": loss}
        return metrics

    def test_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        data = data[0]
        x,t,y = data
        with tf.GradientTape() as tape:
            pred = self([x,t], training=True)  # Forward pass
        y_infer = pred[0]
        loss = self.cewae_loss(data = data, pred = pred, training = False)
        y0, y1 = y_infer[0].sample(),y_infer[1].sample()
        ate = tfkb.mean(y1) - tfkb.mean(y0)
        metrics = {"loss":loss,"y0": tfkb.mean(y0),"y1": tfkb.mean(y1),'ate_afte_scaled': ate}
        return metrics


In [None]:
#Colab command to allow us to run Colab in TF2
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
%reload_ext tensorboard 

model = CEWAE(kernel = 'IMQ')
### MAIN CODE ####
verbose=True
i = 0
tf.random.set_seed(i)
np.random.seed(i)
 
callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='var_loss', patience=40, min_delta=0), 
        #40 is Shi's recommendation patience for this dataset, but you should tune for your data 
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0, cooldown=0, min_lr=0),
        #This learning rate scheduling is quite agressive which seems good for this dataset
        metrics_for_cevae(data_train,'train',verbose),
        metrics_for_cevae(data_valid,'valid',verbose),
        tensorboard_callback
    ]
    
#optimizer hyperparameters
learning_rate = 5e-4
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = learning_rate))

model.fit(
    [data_train['x'],data_train['t'],data_train['ys']],
    callbacks=callbacks,
    validation_data=[[data_valid['x'],data_valid['t'],data_valid['ys']]],
    epochs=300,
    batch_size=args.batch_size,
    verbose=verbose
    )
print("Done!")

In [None]:
%tensorboard --logdir logs/fit