In [1]:
#!/usr/bin/env python
"""CEVAE model on IHDP
"""

import edward2 as ed
import tensorflow as tf
import tf_slim as slim
# import tensorflow_probability as tfp
# tfd = tf

from progressbar import ETA, Bar, Percentage, ProgressBar

from datasets import IHDP
from evaluation import Evaluator
import numpy as np
import time
from scipy.stats import sem

from utils import fc_net, get_y0_y1
from argparse import ArgumentParser
# from earlystopping import EarlyStopping

class Args:
    reps = 1 #10
    earl = 10 #10
    lr = 0.001 #0.001
    opt = "adam"
    epochs = 100 #100
    print_every = 10 #10
    true_post = True
    
args = Args()

args.true_post = True

dataset = IHDP(replications=args.reps)
dimx = 25
scores = np.zeros((args.reps, 3))
scores_test = np.zeros((args.reps, 3))

M = None  # batch size during training
d = 20  # latent dimension
lamba = 1e-4  # weight decay
nh, h = 3, 200  # number and size of hidden layers

In [25]:
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
    print('\nReplication {}/{}'.format(i + 1, args.reps))
    (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
    (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
    (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
    evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)

    # reorder features with binary first and continuous after
    perm = binfeats + contfeats
    xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

    xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
    evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
                                mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

    # zero mean, unit variance for y during training
    ym, ys = np.mean(ytr), np.std(ytr)
    ytr, yva = (ytr - ym) / ys, (yva - ym) / ys
    best_logpvalid = - np.inf
    best_avg_loss = - np.inf

    with tf.Graph().as_default():
        sess = tf.compat.v1.Session()
        # sess = tf.Session()

        # ed.set_seed(1)
        initializer = tf.keras.initializers.GlorotNormal(seed = 0)
        np.random.seed(1)
        tf.compat.v1.set_random_seed(1)
        
        # x_ph_bin = tf.Variable([0,0], dtype=float, shape=[M, len(binfeats)], name='x_bin') # binary inputs

        x_ph_bin = tf.compat.v1.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
        x_ph_cont = tf.compat.v1.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
        t_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])
        y_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])

        x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
        activation = tf.nn.elu

        # CEVAE model (decoder)
        # p(z)
        z = ed.Normal(loc=tf.zeros([tf.shape(input=x_ph)[0], d]), scale=tf.ones([tf.shape(input=x_ph)[0], d]))

        # p(x|z)
        hx = fc_net(z, (nh - 1) * [h], [], 'px_z_shared', lamba=lamba, activation=activation)
        logits = fc_net(hx, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba, activation=activation)
        x1 = ed.Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')

        mu, sigma = fc_net(hx, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont', lamba=lamba,
                           activation=activation)
        x2 = ed.Normal(loc=mu, scale=sigma, name='gaussian_px_z')

        # p(t|z)
        logits = fc_net(z, [h], [[1, None]], 'pt_z', lamba=lamba, activation=activation)
        t = ed.Bernoulli(logits=logits, dtype=tf.float32)

        # p(y|t,z)
        mu2_t0 = fc_net(z, nh * [h], [[1, None]], 'py_t0z', lamba=lamba, activation=activation)
        mu2_t1 = fc_net(z, nh * [h], [[1, None]], 'py_t1z', lamba=lamba, activation=activation)
        y = ed.Normal(loc=t * mu2_t1 + (1. - t) * mu2_t0, scale=tf.ones_like(mu2_t0))

        # CEVAE variational approximation (encoder)
        # q(t|x)
        logits_t = fc_net(x_ph, [d], [[1, None]], 'qt', lamba=lamba, activation=activation)
        qt = ed.Bernoulli(logits=logits_t, dtype=tf.float32)
        # q(y|x,t)
        hqy = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared', lamba=lamba, activation=activation)
        mu_qy_t0 = fc_net(hqy, [h], [[1, None]], 'qy_xt0', lamba=lamba, activation=activation)
        mu_qy_t1 = fc_net(hqy, [h], [[1, None]], 'qy_xt1', lamba=lamba, activation=activation)
        qy = ed.Normal(loc=qt * mu_qy_t1 + (1. - qt) * mu_qy_t0, scale=tf.ones_like(mu_qy_t0))
        # q(z|x,t,y)
        inpt2 = tf.concat([x_ph, qy], 1)
        hqz = fc_net(inpt2, (nh - 1) * [h], [], 'qz_xty_shared', lamba=lamba, activation=activation)
        muq_t0, sigmaq_t0 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0', lamba=lamba,
                                   activation=activation)
        muq_t1, sigmaq_t1 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1', lamba=lamba,
                                   activation=activation)
        qz = ed.Normal(loc=qt * muq_t1 + (1. - qt) * muq_t0, scale=qt * sigmaq_t1 + (1. - qt) * sigmaq_t0)
        
        # Create data dictionary for edward
        data = {x1: x_ph_bin, x2: x_ph_cont, y: y_ph, qt: t_ph, t: t_ph, qy: y_ph}
        
        # Compute expected log-likelihood. First, sample from the variational distribution; second, compute the log-likelihood given the sample.
        
        # sample posterior predictive for p(y|z,t)
        # y_post = ed.copy(y, {z: qz, t: t_ph}, scope='y_post')
        mu2_t0_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t0z_y_post', lamba=lamba, activation=activation)
        mu2_t1_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t1z_y_post', lamba=lamba, activation=activation)
        y_post_dist = ed.Normal(loc=t_ph * mu2_t1_y_post + (1. - t_ph) * mu2_t0_y_post, scale=tf.ones_like(mu2_t0_y_post))
        # y_post = y_post_dist.distribution.sample(seed = 0)
        
        # crude approximation of the above
        # y_post_mean = ed.copy(y, {z: qz.mean(), t: t_ph}, scope='y_post_mean')
        mu2_t0_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_y_post_mean', lamba=lamba, activation=activation)
        mu2_t1_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_y_post_mean', lamba=lamba, activation=activation)
        y_post_mean_dist = ed.Normal(loc=t_ph * mu2_t1_y_post_mean + (1. - t_ph) * mu2_t0_y_post_mean, scale=tf.ones_like(mu2_t0_y_post_mean))
        y_post_mean = y_post_mean_dist.distribution.sample()
        
        # construct a deterministic version (i.e. use the mean of the approximate posterior) of the lower bound
        # for early stopping according to a validation set
        #  y_post_eval = ed.copy(y, {z: qz.mean(), qt: t_ph, qy: y_ph, t: t_ph}, scope='y_post_eval')
        inpt2_y_post_eval = tf.concat([x_ph, y_ph], 1)
        hqz_y_post_eval = fc_net(inpt2_y_post_eval, (nh - 1) * [h], [], 'qz_xty_shared_y_post_eval', lamba=lamba, activation=activation)
        muq_t0_y_post_eval, sigmaq_t0_y_post_eval = fc_net(hqz_y_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_y_post_eval', lamba=lamba,
                                   activation=activation)
        muq_t1_y_post_eval, sigmaq_t1_y_post_eval = fc_net(hqz_y_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_y_post_eval', lamba=lamba,
                                   activation=activation)
        qz_y_post_eval = ed.Normal(loc=t_ph * muq_t1_y_post_eval + (1. - t_ph) * muq_t0_y_post_eval, scale=t_ph * sigmaq_t1_y_post_eval + (1. - t_ph) * sigmaq_t0_y_post_eval)
        mu2_t0_y_post_eval = fc_net(qz_y_post_eval.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_y_post_eval', lamba=lamba, activation=activation)
        mu2_t1_y_post_eval = fc_net(qz_y_post_eval.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_y_post_eval', lamba=lamba, activation=activation)
        y_post_eval_dist = ed.Normal(loc=t_ph * mu2_t1_y_post_eval + (1. - t_ph) * mu2_t0_y_post_eval, scale=tf.ones_like(mu2_t0))
        # y_post_eval = y_post_eval_dist.distribution.sample()
        
        # x1_post_eval = x1, {z: qz.mean(), qt: t_ph, qy: y_ph}
        inpt2_x_post_eval = tf.concat([x_ph, y_ph], 1)
        hqz_x_post_eval = fc_net(inpt2_x_post_eval, (nh - 1) * [h], [], 'qz_xty_shared_x_post_eval', lamba=lamba, activation=activation)
        muq_t0_x_post_eval, sigmaq_t0_x_post_eval = fc_net(hqz_x_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_x_post_eval', lamba=lamba,
                                   activation=activation)
        muq_t1_x_post_eval, sigmaq_t1_x_post_eval = fc_net(hqz_x_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_x_post_eval', lamba=lamba,
                                   activation=activation)
        qz_x_post_eval = ed.Normal(loc=t_ph * muq_t1_x_post_eval + (1. - t_ph) * muq_t0_x_post_eval, scale=t_ph * sigmaq_t1_x_post_eval + (1. - t_ph) * sigmaq_t0_x_post_eval)
        hx_x_post_eval = fc_net(qz_x_post_eval.distribution.mean(), (nh - 1) * [h], [], 'px_z_shared_x_post_eval', lamba=lamba, activation=activation)
        logits_post_eval = fc_net(hx_x_post_eval, [h], [[len(binfeats), None]], 'px_z_bin_x_post_eval'.format(i + 1), lamba=lamba, activation=activation)
        x1_post_eval_dist = ed.Bernoulli(logits=logits_post_eval, dtype=tf.float32, name='bernoulli_px_z_post_eval')
        # x1_post_eval = x1_post_eval_dist.distribution.sample()
        
        # x2_post_eval = x2, {z: qz.mean(), qt: t_ph, qy: y_ph}
        mu_x2_post_eval, sigma_x2_post_eval = fc_net(hx_x_post_eval, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_x2_post_eval', lamba=lamba, activation=activation)
        x2_post_eval_dist = ed.Normal(loc=mu_x2_post_eval, scale=sigma_x2_post_eval, name='gaussian_px_z_x2_post_eval')
        # x2_post_eval = x2_post_eval_dist.distribution.sample()
        
        # t_post_eval = ed.copy(t, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='t_post_eval')
        # logits_t_post_eval = fc_net(t_ph * muq_t1_t_post_eval + (1. - t_ph) * muq_t0_t_post_eval, [h], [[1, None]], 'pt_z_t_post_eval', lamba=lamba, activation=activation)
        inpt2_t_post_eval = tf.concat([x_ph, y_ph], 1)
        hqz_t_post_eval = fc_net(inpt2_t_post_eval, (nh - 1) * [h], [], 'qz_xty_shared_t_post_eval', lamba=lamba, activation=activation)
        muq_t0_t_post_eval, sigmaq_t0_t_post_eval = fc_net(hqz_t_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_t_post_eval', lamba=lamba,
                                   activation=activation)
        muq_t1_t_post_eval, sigmaq_t1_t_post_eval = fc_net(hqz_t_post_eval, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_t_post_eval', lamba=lamba,
                                   activation=activation)
        qz_t_post_eval = ed.Normal(loc=t_ph * muq_t1_t_post_eval + (1. - t_ph) * muq_t0_t_post_eval, scale=t_ph * sigmaq_t1_t_post_eval + (1. - t_ph) * sigmaq_t0_t_post_eval)
        logits_t_post_eval = fc_net(qz_t_post_eval.distribution.mean(), [h], [[1, None]], 'pt_z_t_post_eval', lamba=lamba, activation=activation)
        t_post_eval_dist = ed.Bernoulli(logits=logits_t_post_eval, dtype=tf.float32)
        # t_post_eval = y_post_eval_dist.distribution.sample()
    
        logp_valid = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=y_post_eval_dist.distribution.log_prob(y_ph) + t_post_eval_dist.distribution.log_prob(t_ph), axis=1) +
                                    tf.reduce_sum(input_tensor=x1_post_eval_dist.distribution.log_prob(x_ph_bin), axis=1) +
                                    tf.reduce_sum(input_tensor=x2_post_eval_dist.distribution.log_prob(x_ph_cont), axis=1) +
                                    tf.reduce_sum(input_tensor=z.distribution.log_prob(qz.distribution.mean()) - qz.distribution.log_prob(qz.distribution.mean()), axis=1)) # tf.reduce_sum(input_tensor=z.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0) - qz.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0), axis=1))
      
        # inference用の分布
        logits_t_inf = fc_net(x_ph, [d], [[1, None]], 'qt_inf', lamba=lamba, activation=activation)
        qt_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        hqy_inf = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared_inf', lamba=lamba, activation=activation)
        mu_qy_t0_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt0_inf', lamba=lamba, activation=activation)
        mu_qy_t1_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt1_inf', lamba=lamba, activation=activation)
        qy_inf = ed.Normal(loc=t_ph * mu_qy_t1_inf + (1. - t_ph) * mu_qy_t0_inf, scale=tf.ones_like(mu_qy_t0_inf))
        inpt2_inf = tf.concat([x_ph, y_ph], 1)
        hqz_inf = fc_net(inpt2_inf, (nh - 1) * [h], [], 'qz_xty_shared_inf', lamba=lamba, activation=activation)
        muq_t0_inf, sigmaq_t0_inf = fc_net(hqz_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_inf', lamba=lamba,
                                   activation=activation)
        muq_t1_inf, sigmaq_t1_inf = fc_net(hqz_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_inf', lamba=lamba,
                                   activation=activation)
        qz_inf = ed.Normal(loc=t_ph * muq_t1_inf + (1. - t_ph) * muq_t0_inf, scale=t_ph * sigmaq_t1_inf + (1. - t_ph) * sigmaq_t0_inf)
        
        hx_inf = fc_net(qz_inf.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_inf', lamba=lamba, activation=activation)
        logits_inf = fc_net(hx_inf, [h], [[len(binfeats), None]], 'px_z_bin_inf'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf = ed.Bernoulli(logits=logits_inf, dtype=tf.float32, name='bernoulli_px_z_inf')
        mu_inf, sigma_inf = fc_net(hx_inf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf', lamba=lamba,
                           activation=activation)
        x2_inf = ed.Normal(loc=mu_inf, scale=sigma_inf, name='gaussian_px_z_inf')
        logits_t_inf = fc_net(qz_inf.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf', lamba=lamba, activation=activation)
        t_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        mu2_t0_inf = fc_net(qz_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf', lamba=lamba, activation=activation)
        mu2_t1_inf = fc_net(qz_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf', lamba=lamba, activation=activation)
        y_inf = ed.Normal(loc=t_inf * mu2_t1_inf + (1. - t_inf) * mu2_t0_inf, scale=tf.ones_like(mu2_t0_inf))
        
        inpt2_inf2 = tf.concat([x_ph, qy_inf], 1)
        hqz_inf2 = fc_net(inpt2_inf2, (nh - 1) * [h], [], 'qz_xty_shared_inf2', lamba=lamba, activation=activation)
        muq_t0_inf2, sigmaq_t0_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_inf2', lamba=lamba,
                                   activation=activation)
        muq_t1_inf2, sigmaq_t1_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_inf2', lamba=lamba,
                                   activation=activation)
        qz_inf2 = ed.Normal(loc=t_ph * muq_t1_inf2 + (1. - t_ph) * muq_t0_inf2, scale=t_ph * sigmaq_t1_inf2 + (1. - t_ph) * sigmaq_t0_inf2)
        hx_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_inf2', lamba=lamba, activation=activation)
        logits_inf2 = fc_net(hx_inf2, [h], [[len(binfeats), None]], 'px_z_bin_inf2'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf2 = ed.Bernoulli(logits=logits_inf2, dtype=tf.float32, name='bernoulli_px_z_inf2')
        mu_inf2, sigma_inf2 = fc_net(hx_inf2, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf2', lamba=lamba,
                           activation=activation)
        x2_inf2 = ed.Normal(loc=mu_inf2, scale=sigma_inf2, name='gaussian_px_z_inf2')
        logits_t_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf2', lamba=lamba, activation=activation)
        t_inf2 = ed.Bernoulli(logits=logits_t_inf2, dtype=tf.float32)
        mu2_t0_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf2', lamba=lamba, activation=activation)
        mu2_t1_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf2', lamba=lamba, activation=activation)
        y_inf2 = ed.Normal(loc=t_inf2 * mu2_t1_inf2 + (1. - t_inf2) * mu2_t0_inf2, scale=tf.ones_like(mu2_t0_inf2))
        
        inference = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=x1_inf.distribution.log_prob(x_ph_bin),axis=1) + tf.reduce_sum(input_tensor=x2_inf.distribution.log_prob(x_ph_cont),axis=1) + tf.reduce_sum(input_tensor=y_inf2.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=qt_inf.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=t_inf.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=qy_inf.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=z.distribution.log_prob(qz_inf.distribution.sample(seed=0))) - tf.reduce_sum(input_tensor=qz_inf.distribution.log_prob(qz_inf.distribution.sample(seed=0)),axis=1))
        global_step = tf.compat.v1.train.get_or_create_global_step()
        train_op = tf.compat.v1.train.AdamOptimizer(args.lr).minimize(-inference,global_step=global_step)
        
        """
        inference = ed.KLqp({z: qz}, data)
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.lr)
        inference.initialize(optimizer=optimizer)
        """
       
        sess.run(tf.compat.v1.global_variables_initializer())
        # tf.compat.v1.global_variables_initializer().run(session=sess)
        saver = tf.compat.v1.train.Saver()
        # saver = tf.compat.v1.train.Saver(slim.get_variables())
        # kernel_initializer=initializers.glorot_uniform(seed=0)))

        n_epoch, n_iter_per_epoch, idx = args.epochs, 10 * int(xtr.shape[0] / 100), np.arange(xtr.shape[0])

        # dictionaries needed for evaluation
        tr0, tr1 = np.zeros((xalltr.shape[0], 1)), np.ones((xalltr.shape[0], 1))
        tr0t, tr1t = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
        f1 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr1}
        f0 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr0}
        f1t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr1t}
        f0t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr0t}
        
        # early_stopping = EarlyStopping(patience=10, verbose=1)

        # loss = np.zeros(n_epoch*n_iter_per_epoch)
        logpvalid_list = np.zeros(n_iter_per_epoch*n_epoch)
        
        for epoch in range(n_epoch):
            avg_loss_list = np.zeros(n_iter_per_epoch)
            # avg_loss = 0.0
            t0 = time.time()
            widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
            pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
            pbar.start()
            np.random.shuffle(idx)
            
            for j in range(n_iter_per_epoch):
              # info_dict = 0.0
              pbar.update(j)
              batch = np.random.choice(idx, 100)
              x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch] 
              """
              sess.run(train_op,feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train})
              """
              _, info_dict = sess.run([train_op, inference], feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train}) 
              step = sess.run(global_step)
              # print(info_dict)
              logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):],
                                                            t_ph: tva, y_ph: yva})
              # print(logpvalid)
              # avg_loss += info_dict['loss']
              avg_loss_list[j] = info_dict
              logpvalid_list[epoch*n_iter_per_epoch+j]=logpvalid
              # avg_loss += info_dict
              # print(avg_loss_list[j])
              # loss[epoch*n_iter_per_epoch+j] = avg_loss_list[j]
              # print(loss[epoch*n_iter_per_epoch+j])
            # avg_loss = avg_loss / n_iter_per_epoch
            avg_loss = np.mean(avg_loss_list) / n_iter_per_epoch
            avg_loss = avg_loss / 100
            """
            y0, y1 = get_y0_y1(sess, y_post_dist, f0, f1, shape=yalltr.shape, L=1)
            y0, y1 = y0 * ys + ym, y1 * ys + ym
            score_train = evaluator_train.calc_stats(y1, y0)
            rmses_train = evaluator_train.y_errors(y0, y1)
            print("Epoch: {}/{}, avg_loss: {:0.3f}, logpvalid: {:0.3f}, ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f},"
                      "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, ".format(epoch + 1, n_epoch, avg_loss, logpvalid, score_train[0], score_train[1], score_train[2],
                                           rmses_train[0], rmses_train[1]))
            """
            # print(avg_loss)
            # saver.save(sess, 'ckpt/model.ckpt', step)
            
            """
            if avg_loss >= best_avg_loss:
                  print('Improved bound, old: {:0.3f}, new: {:0.3f}'.format(best_avg_loss, avg_loss))
                  best_avg_loss = avg_loss
                  saver.save(sess, "models/m6-ihdp")


            """
            if epoch % args.earl == 0 or epoch == (n_epoch - 1):
                # logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: tva, y_ph: yva})
                # print(logpvalid)
                saver.save(sess, 'ckpt/model.ckpt', step)
                """
                if early_stopping.validate(logpvalid):
                      break
                print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                best_logpvalid = logpvalid
                saver.save(sess, "models/m6-ihdp")
                """
                if logpvalid >= best_logpvalid:
                  print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                  best_logpvalid = logpvalid
                  # saver.save(sess, "models/m6-ihdp")
                  # saver.save(sess, 'ckpt/model.ckpt', step)
                  # checkpoint = tf.train.Checkpoint(sess)
                  # manager = tf.train.CheckpointManager(checkpoint)
                  # tf.keras.set_session(sess)
                  # tf.keras.Model.save_weights(sess, 'models/m6-ihdp')
            # """
            
            if epoch % args.print_every == 0:
                # y_post_train = sess.run(y_post)
                # , feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)], x_ph_cont: x_train[:, len(binfeats):], y_ph: y_train}
                y0, y1 = get_y0_y1(sess, y_inf2, f0, f1, shape=yalltr.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                # print(np.mean(y1-y0))
                score_train = evaluator_train.calc_stats(y1, y0)
                rmses_train = evaluator_train.y_errors(y0, y1)

                # y_post_test = sess.run(y_post, feed_dict={x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):],  y_ph: yte})
                y0, y1 = get_y0_y1(sess, y_inf2, f0t, f1t, shape=yte.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_test = evaluator_test.calc_stats(y1, y0)

                print("Epoch: {}/{}, log p(x) >= {:0.3f}, ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, " \
                      "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, " \
                      "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss, score_train[0], score_train[1], score_train[2],
                                           rmses_train[0], rmses_train[1], score_test[0], score_test[1], score_test[2],
                                           time.time() - t0))

        ckpt_path = tf.train.latest_checkpoint('ckpt/')
        # saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)
        # checkpoint = tf.train.Checkpoint(sess)
        # manager = tf.train.CheckpointManager(checkpoint)
        # status = checkpoint.restore(manager.latest_checkpoint)
        # tf.keras.Model.load_weights
        # saver.restore(sess, "models/m6-ihdp")
       
        # y_post_train2 = sess.run(y_post, feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)], x_ph_cont: x_train[:, len(binfeats):], y_ph: y_train})
        y0, y1 = get_y0_y1(sess, y_inf2, f0, f1, shape=yalltr.shape, L=100)
        y0, y1 = y0 * ys + ym, y1 * ys + ym
        score = evaluator_train.calc_stats(y1, y0)
        scores[i, :] = score

        # y_post_test2 = sess.run(y_post, feed_dict={x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], y_ph: yte})
        y0t, y1t = get_y0_y1(sess, y_inf2, f0t, f1t, shape=yte.shape, L=100)
        y0t, y1t = y0t * ys + ym, y1t * ys + ym
        score_test = evaluator_test.calc_stats(y1t, y0t)
        scores_test[i, :] = score_test

        print('Replication: {}/{}, tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}' \
              ', te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}'.format(i + 1, args.reps,
                                                                            score[0], score[1], score[2],
                                                                            score_test[0], score_test[1], score_test[2]))
        sess.close()

print('CEVAE model total scores')
means, stds = np.mean(scores, axis=0), sem(scores, axis=0)
print('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

means, stds = np.mean(scores_test, axis=0), sem(scores_test, axis=0)
print('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))



Replication 1/1


epoch #0| 97%|#################################################  |ETA:  0:00:00

Improved validation bound, old: -inf, new: -34.603


epoch #1|  5%|##                                                 |ETA:  0:00:02

Epoch: 1/100, log p(x) >= -0.499, ite_tr: 1.317, ate_tr: 0.472, pehe_tr: 0.972, rmse_f_tr: 1.045, rmse_cf_tr: 1.336, ite_te: 1.322, ate_te: 0.673, pehe_te: 1.064, dt: 5.661


epoch #10| 97%|################################################  |ETA:  0:00:00

Improved validation bound, old: -34.603, new: -34.593


epoch #11|  7%|###                                               |ETA:  0:00:02

Epoch: 11/100, log p(x) >= -0.472, ite_tr: 1.485, ate_tr: 0.025, pehe_tr: 1.312, rmse_f_tr: 0.927, rmse_cf_tr: 1.537, ite_te: 1.516, ate_te: 0.027, pehe_te: 1.226, dt: 3.732


epoch #21|  5%|##                                                |ETA:  0:00:02

Epoch: 21/100, log p(x) >= -0.471, ite_tr: 1.462, ate_tr: 0.144, pehe_tr: 1.357, rmse_f_tr: 0.835, rmse_cf_tr: 1.517, ite_te: 1.495, ate_te: 0.223, pehe_te: 1.215, dt: 3.920


epoch #31|  0%|                                                 |ETA:  --:--:--

Epoch: 31/100, log p(x) >= -0.471, ite_tr: 1.460, ate_tr: 0.089, pehe_tr: 1.421, rmse_f_tr: 0.758, rmse_cf_tr: 1.512, ite_te: 1.431, ate_te: 0.107, pehe_te: 1.193, dt: 6.332


epoch #41|  5%|##                                                |ETA:  0:00:03

Epoch: 41/100, log p(x) >= -0.471, ite_tr: 1.429, ate_tr: 0.071, pehe_tr: 1.423, rmse_f_tr: 0.751, rmse_cf_tr: 1.482, ite_te: 1.402, ate_te: 0.119, pehe_te: 1.217, dt: 3.910


epoch #51|  5%|##                                                |ETA:  0:00:02

Epoch: 51/100, log p(x) >= -0.471, ite_tr: 1.418, ate_tr: 0.019, pehe_tr: 1.410, rmse_f_tr: 0.763, rmse_cf_tr: 1.474, ite_te: 1.375, ate_te: 0.050, pehe_te: 1.191, dt: 4.565


epoch #60| 97%|################################################  |ETA:  0:00:00

Improved validation bound, old: -34.593, new: -34.577


epoch #61|  5%|##                                                |ETA:  0:00:02

Epoch: 61/100, log p(x) >= -0.471, ite_tr: 1.389, ate_tr: 0.003, pehe_tr: 1.358, rmse_f_tr: 0.744, rmse_cf_tr: 1.450, ite_te: 1.357, ate_te: 0.035, pehe_te: 1.120, dt: 4.376


epoch #71|  5%|##                                                |ETA:  0:00:03

Epoch: 71/100, log p(x) >= -0.471, ite_tr: 1.406, ate_tr: 0.060, pehe_tr: 1.377, rmse_f_tr: 0.737, rmse_cf_tr: 1.470, ite_te: 1.398, ate_te: 0.006, pehe_te: 1.130, dt: 7.244


epoch #81|  5%|##                                                |ETA:  0:00:03

Epoch: 81/100, log p(x) >= -0.471, ite_tr: 1.366, ate_tr: 0.048, pehe_tr: 1.335, rmse_f_tr: 0.758, rmse_cf_tr: 1.428, ite_te: 1.342, ate_te: 0.088, pehe_te: 1.071, dt: 5.388


epoch #91|  2%|#                                                 |ETA:  0:00:04

Epoch: 91/100, log p(x) >= -0.471, ite_tr: 1.354, ate_tr: 0.049, pehe_tr: 1.307, rmse_f_tr: 0.733, rmse_cf_tr: 1.419, ite_te: 1.326, ate_te: 0.034, pehe_te: 1.060, dt: 5.105


epoch #99| 97%|################################################  |ETA:  0:00:00

INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-4000
 Sample 100/100
 Sample 100/100
Replication: 1/1, tr_ite: 1.346, tr_ate: 0.059, tr_pehe: 1.313, te_ite: 1.338, te_ate: 0.099, te_pehe: 1.078
CEVAE model total scores
train ITE: 1.346+-nan, train ATE: 0.059+-nan, train PEHE: 1.313+-nan
test ITE: 1.338+-nan, test ATE: 0.099+-nan, test PEHE: 1.078+-nan


  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = um.true_divide(


In [2]:
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
    print('\nReplication {}/{}'.format(i + 1, args.reps))
    (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
    (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
    (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
    evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)

    # reorder features with binary first and continuous after
    perm = binfeats + contfeats
    xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

    xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
    evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
                                mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

    # zero mean, unit variance for y during training
    ym, ys = np.mean(ytr), np.std(ytr)
    ytr, yva = (ytr - ym) / ys, (yva - ym) / ys
    best_logpvalid = - np.inf
    best_avg_loss = - np.inf

    with tf.Graph().as_default():
        sess = tf.compat.v1.Session()
        # sess = tf.Session()

        # ed.set_seed(1)
        initializer = tf.keras.initializers.GlorotNormal(seed = 0)
        np.random.seed(1)
        tf.compat.v1.set_random_seed(1)
        
        # x_ph_bin = tf.Variable([0,0], dtype=float, shape=[M, len(binfeats)], name='x_bin') # binary inputs

        x_ph_bin = tf.compat.v1.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
        x_ph_cont = tf.compat.v1.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
        t_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])
        y_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])

        x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
        activation = tf.nn.elu

        # もとの分布
        # CEVAE model (decoder)
        # p(z)
        z = ed.Normal(loc=tf.zeros([tf.shape(input=x_ph)[0], d]), scale=tf.ones([tf.shape(input=x_ph)[0], d]))

        # p(x|z)
        hx = fc_net(z, (nh - 1) * [h], [], 'px_z_shared', lamba=lamba, activation=activation)
        logits = fc_net(hx, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba, activation=activation)
        x1 = ed.Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')

        mu, sigma = fc_net(hx, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont', lamba=lamba,
                           activation=activation)
        x2 = ed.Normal(loc=mu, scale=sigma, name='gaussian_px_z')

        # p(t|z)
        logitst = fc_net(z, [h], [[1, None]], 'pt_z', lamba=lamba, activation=activation)
        t = ed.Bernoulli(logits=logitst, dtype=tf.float32)

        # p(y|t,z)
        mu2_t0 = fc_net(z, nh * [h], [[1, None]], 'py_t0z', lamba=lamba, activation=activation)
        mu2_t1 = fc_net(z, nh * [h], [[1, None]], 'py_t1z', lamba=lamba, activation=activation)
        y = ed.Normal(loc=t * mu2_t1 + (1. - t) * mu2_t0, scale=tf.ones_like(mu2_t0))

        # CEVAE variational approximation (encoder)
        # q(t|x)
        logits_t = fc_net(x_ph, [d], [[1, None]], 'qt', lamba=lamba, activation=activation)
        qt = ed.Bernoulli(logits=logits_t, dtype=tf.float32)
        # q(y|x,t)
        hqy = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared', lamba=lamba, activation=activation)
        mu_qy_t0 = fc_net(hqy, [h], [[1, None]], 'qy_xt0', lamba=lamba, activation=activation)
        mu_qy_t1 = fc_net(hqy, [h], [[1, None]], 'qy_xt1', lamba=lamba, activation=activation)
        qy = ed.Normal(loc=qt * mu_qy_t1 + (1. - qt) * mu_qy_t0, scale=tf.ones_like(mu_qy_t0))
        # q(z|x,t,y)
        inpt2 = tf.concat([x_ph, qy], 1)
        hqz = fc_net(inpt2, (nh - 1) * [h], [], 'qz_xty_shared', lamba=lamba, activation=activation)
        muq_t0, sigmaq_t0 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0', lamba=lamba,
                                   activation=activation)
        muq_t1, sigmaq_t1 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1', lamba=lamba,
                                   activation=activation)
        qz = ed.Normal(loc=qt * muq_t1 + (1. - qt) * muq_t0, scale=qt * sigmaq_t1 + (1. - qt) * sigmaq_t0)
        
        # Create data dictionary for edward
        data = {x1: x_ph_bin, x2: x_ph_cont, y: y_ph, qt: t_ph, t: t_ph, qy: y_ph}
        
        # Compute expected log-likelihood. First, sample from the variational distribution; second, compute the log-likelihood given the sample.
        
        # sample posterior predictive for p(y|z,t)
        # y_post = ed.copy(y, {z: qz, t: t_ph}, scope='y_post')
        mu2_t0_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t0z_y_post', lamba=lamba, activation=activation)
        mu2_t1_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t1z_y_post', lamba=lamba, activation=activation)
        y_post_dist = ed.Normal(loc=t_ph * mu2_t1_y_post + (1. - t_ph) * mu2_t0_y_post, scale=tf.ones_like(mu2_t0_y_post))
        # y_post = y_post_dist.distribution.sample(seed = 0)
        
        # crude approximation of the above
        # y_post_mean = ed.copy(y, {z: qz.mean(), t: t_ph}, scope='y_post_mean')
        mu2_t0_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_y_post_mean', lamba=lamba, activation=activation)
        mu2_t1_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_y_post_mean', lamba=lamba, activation=activation)
        y_post_mean_dist = ed.Normal(loc=t_ph * mu2_t1_y_post_mean + (1. - t_ph) * mu2_t0_y_post_mean, scale=tf.ones_like(mu2_t0_y_post_mean))
        y_post_mean = y_post_mean_dist.distribution.sample()
        
        # logpvalid計算用の分布
        # construct a deterministic version (i.e. use the mean of the approximate posterior) of the lower bound
        # for early stopping according to a validation set
        # qz
        inpt2_post_eval_and_inf = tf.concat([x_ph, y_ph], 1)
        hqz_post_eval_and_inf = fc_net(inpt2_post_eval_and_inf, (nh - 1) * [h], [], 'qz_xty_shared_post_eval_and_inf', lamba=lamba, activation=activation)
        muq_t0_post_eval_and_inf, sigmaq_t0_post_eval_and_inf = fc_net(hqz_post_eval_and_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_post_eval_and_inf', lamba=lamba,
                                   activation=activation)
        muq_t1_post_eval_and_inf, sigmaq_t1_post_eval_and_inf = fc_net(hqz_post_eval_and_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_post_eval_and_inf', lamba=lamba,
                                   activation=activation)
        qz_post_eval_and_inf = ed.Normal(loc=t_ph * muq_t1_post_eval_and_inf + (1. - t_ph) * muq_t0_post_eval_and_inf, scale=t_ph * sigmaq_t1_post_eval_and_inf + (1. - t_ph) * sigmaq_t0_post_eval_and_inf)
        
        # y_post_eval = ed.copy(y, {z: qz.mean(), qt: t_ph, qy: y_ph, t: t_ph}, scope='y_post_eval')
        mu2_t0_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_post_eval', lamba=lamba, activation=activation)
        mu2_t1_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_post_eval', lamba=lamba, activation=activation)
        y_post_eval_dist = ed.Normal(loc=t_ph * mu2_t1_post_eval + (1. - t_ph) * mu2_t0_post_eval, scale=tf.ones_like(mu2_t0_post_eval))
        # y_post_eval = y_post_eval_dist.distribution.sample()
        
        # x1_post_eval = x1, {z: qz.mean(), qt: t_ph, qy: y_ph}
        hx_x_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), (nh - 1) * [h], [], 'px_z_shared_post_eval', lamba=lamba, activation=activation)
        logits_post_eval = fc_net(hx_x_post_eval, [h], [[len(binfeats), None]], 'px_z_bin_x_post_eval'.format(i + 1), lamba=lamba, activation=activation)
        x1_post_eval_dist = ed.Bernoulli(logits=logits_post_eval, dtype=tf.float32, name='bernoulli_px_z_post_eval')
        # x1_post_eval = x1_post_eval_dist.distribution.sample()
        
        # x2_post_eval = x2, {z: qz.mean(), qt: t_ph, qy: y_ph}
        mu_x2_post_eval, sigma_x2_post_eval = fc_net(hx_x_post_eval, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_post_eval', lamba=lamba, activation=activation)
        x2_post_eval_dist = ed.Normal(loc=mu_x2_post_eval, scale=sigma_x2_post_eval, name='gaussian_px_z_post_eval')
        # x2_post_eval = x2_post_eval_dist.distribution.sample()
        
        # t_post_eval = ed.copy(t, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='t_post_eval')
        # logits_t_post_eval = fc_net(t_ph * muq_t1_t_post_eval + (1. - t_ph) * muq_t0_t_post_eval, [h], [[1, None]], 'pt_z_t_post_eval', lamba=lamba, activation=activation)
        logitst_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), [h], [[1, None]], 'pt_z_post_eval', lamba=lamba, activation=activation)
        t_post_eval_dist = ed.Bernoulli(logits=logitst_post_eval, dtype=tf.float32)
        # t_post_eval = y_post_eval_dist.distribution.sample()
    
        logp_valid = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=y_post_eval_dist.distribution.log_prob(y_ph) + t_post_eval_dist.distribution.log_prob(t_ph), axis=1) +
                                    tf.reduce_sum(input_tensor=x1_post_eval_dist.distribution.log_prob(x_ph_bin), axis=1) +
                                    tf.reduce_sum(input_tensor=x2_post_eval_dist.distribution.log_prob(x_ph_cont), axis=1) +
                                    tf.reduce_sum(input_tensor=z.distribution.log_prob(qz_post_eval_and_inf.distribution.mean()) - qz_post_eval_and_inf.distribution.log_prob(qz_post_eval_and_inf.distribution.mean()), axis=1)) # tf.reduce_sum(input_tensor=z.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0) - qz.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0), axis=1))
      
        # inference用の分布
        # qz_data = qz.distribution.sample()
        
        # 補助分布
        logits_t_inf = fc_net(x_ph, [d], [[1, None]], 'qt_inf', lamba=lamba, activation=activation)
        qt_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        hqy_inf = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared_inf', lamba=lamba, activation=activation)
        mu_qy_t0_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt0_inf', lamba=lamba, activation=activation)
        mu_qy_t1_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt1_inf', lamba=lamba, activation=activation)
        qy_inf = ed.Normal(loc=qt_inf * mu_qy_t1_inf + (1. - qt_inf) * mu_qy_t0_inf, scale=tf.ones_like(mu_qy_t0_inf))
        
        # 推論ネットワーク
        hx_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_inf', lamba=lamba, activation=activation)
        logits_inf = fc_net(hx_inf, [h], [[len(binfeats), None]], 'px_z_bin_inf'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf = ed.Bernoulli(logits=logits_inf, dtype=tf.float32, name='bernoulli_px_z_inf')
        mu_inf, sigma_inf = fc_net(hx_inf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf', lamba=lamba,
                           activation=activation)
        x2_inf = ed.Normal(loc=mu_inf, scale=sigma_inf, name='gaussian_px_z_inf')
        logits_t_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf', lamba=lamba, activation=activation)
        t_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        mu2_t0_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf', lamba=lamba, activation=activation)
        mu2_t1_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf', lamba=lamba, activation=activation)
        y_inf = ed.Normal(loc=t_inf * mu2_t1_inf + (1. - t_inf) * mu2_t0_inf, scale=tf.ones_like(mu2_t0_inf))
        
        # 推論ネットワーク評価指標計算用
        inpt2_inf2 = tf.concat([x_ph, qy_inf], 1)
        hqz_inf2 = fc_net(inpt2_inf2, (nh - 1) * [h], [], 'qz_xty_shared_inf2', lamba=lamba, activation=activation)
        muq_t0_inf2, sigmaq_t0_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_inf2', lamba=lamba,
                                   activation=activation)
        muq_t1_inf2, sigmaq_t1_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_inf2', lamba=lamba,
                                   activation=activation)
        qz_inf2 = ed.Normal(loc=t_ph * muq_t1_inf2 + (1. - t_ph) * muq_t0_inf2, scale=t_ph * sigmaq_t1_inf2 + (1. - t_ph) * sigmaq_t0_inf2)
        hx_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_in2f', lamba=lamba, activation=activation)
        logits_inf2 = fc_net(hx_inf2, [h], [[len(binfeats), None]], 'px_z_bin_inf2'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf2 = ed.Bernoulli(logits=logits_inf2, dtype=tf.float32, name='bernoulli_px_z_inf2')
        mu_inf2, sigma_inf2 = fc_net(hx_inf2, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf2', lamba=lamba,
                           activation=activation)
        x2_inf2 = ed.Normal(loc=mu_inf2, scale=sigma_inf2, name='gaussian_px_z_inf2')
        logits_t_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf2', lamba=lamba, activation=activation)
        t_inf2 = ed.Bernoulli(logits=logits_t_inf2, dtype=tf.float32)
        mu2_t0_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf2', lamba=lamba, activation=activation)
        mu2_t1_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf2', lamba=lamba, activation=activation)
        y_inf2 = ed.Normal(loc=t_inf2 * mu2_t1_inf2 + (1. - t_inf2) * mu2_t0_inf2, scale=tf.ones_like(mu2_t0_inf2))
        
        inference = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=x1_inf2.distribution.log_prob(x_ph_bin),axis=1) + tf.reduce_sum(input_tensor=x2_inf2.distribution.log_prob(x_ph_cont),axis=1) + tf.reduce_sum(input_tensor=y_inf2.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=qt_inf.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=t_inf2.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=qy_inf.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=z.distribution.log_prob(qz_inf2.distribution.sample(seed=0))) - tf.reduce_sum(input_tensor=qz_inf2.distribution.log_prob(qz_inf2.distribution.sample(seed=0)),axis=1))
        global_step = tf.compat.v1.train.get_or_create_global_step()
        train_op = tf.compat.v1.train.AdamOptimizer(args.lr).minimize(-inference,global_step=global_step)
        
        """
        inference = ed.KLqp({z: qz}, data)
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.lr)
        inference.initialize(optimizer=optimizer)
        """
       
        sess.run(tf.compat.v1.global_variables_initializer())
        # tf.compat.v1.global_variables_initializer().run(session=sess)
        saver = tf.compat.v1.train.Saver()
        # saver = tf.compat.v1.train.Saver(slim.get_variables())
        # kernel_initializer=initializers.glorot_uniform(seed=0)))

        n_epoch, n_iter_per_epoch, idx = args.epochs, 10 * int(xtr.shape[0] / 100), np.arange(xtr.shape[0])

        # dictionaries needed for evaluation
        tr0, tr1 = np.zeros((xalltr.shape[0], 1)), np.ones((xalltr.shape[0], 1))
        tr0t, tr1t = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
        f1 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr1}
        f0 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr0}
        f1t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr1t}
        f0t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr0t}
        
        # early_stopping = EarlyStopping(patience=10, verbose=1)

        # loss = np.zeros(n_epoch*n_iter_per_epoch)
        logpvalid_list = np.zeros(n_iter_per_epoch*n_epoch)
        
        for epoch in range(n_epoch):
            avg_loss_list = np.zeros(n_iter_per_epoch)
            # avg_loss = 0.0
            t0 = time.time()
            widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
            pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
            pbar.start()
            np.random.shuffle(idx)
            
            for j in range(n_iter_per_epoch):
              # info_dict = 0.0
              pbar.update(j)
              batch = np.random.choice(idx, 100)
              x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch] 
              """
              sess.run(train_op,feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train})
              """
              _, info_dict = sess.run([train_op, inference], feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train}) 
              step = sess.run(global_step)
              # print(info_dict)
              logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):],
                                                            t_ph: tva, y_ph: yva})
              # print(logpvalid)
              # avg_loss += info_dict['loss']
              avg_loss_list[j] = info_dict
              logpvalid_list[epoch*n_iter_per_epoch+j]=logpvalid
              # avg_loss += info_dict
              # print(avg_loss_list[j])
              # loss[epoch*n_iter_per_epoch+j] = avg_loss_list[j]
              # print(loss[epoch*n_iter_per_epoch+j])
            # avg_loss = avg_loss / n_iter_per_epoch
            avg_loss = np.mean(avg_loss_list) / n_iter_per_epoch
            avg_loss = avg_loss / 100
            """
            y0, y1 = get_y0_y1(sess, y_post_dist, f0, f1, shape=yalltr.shape, L=1)
            y0, y1 = y0 * ys + ym, y1 * ys + ym
            score_train = evaluator_train.calc_stats(y1, y0)
            rmses_train = evaluator_train.y_errors(y0, y1)
            print("Epoch: {}/{}, avg_loss: {:0.3f}, logpvalid: {:0.3f}, ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f},"
                      "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, ".format(epoch + 1, n_epoch, avg_loss, logpvalid, score_train[0], score_train[1], score_train[2],
                                           rmses_train[0], rmses_train[1]))
            """
            # print(avg_loss)
            # saver.save(sess, 'ckpt/model.ckpt', step)
            
            """
            if avg_loss >= best_avg_loss:
                  print('Improved bound, old: {:0.3f}, new: {:0.3f}'.format(best_avg_loss, avg_loss))
                  best_avg_loss = avg_loss
                  saver.save(sess, "models/m6-ihdp")


            """
            if epoch % args.earl == 0 or epoch == (n_epoch - 1):
                # logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: tva, y_ph: yva})
                # print(logpvalid)
                saver.save(sess, 'ckpt/model.ckpt', step)
                """
                if early_stopping.validate(logpvalid):
                      break
                print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                best_logpvalid = logpvalid
                saver.save(sess, "models/m6-ihdp")
                """
                if logpvalid >= best_logpvalid:
                  print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                  best_logpvalid = logpvalid
                  # saver.save(sess, "models/m6-ihdp")
                  # saver.save(sess, 'ckpt/model.ckpt', step)
                  # checkpoint = tf.train.Checkpoint(sess)
                  # manager = tf.train.CheckpointManager(checkpoint)
                  # tf.keras.set_session(sess)
                  # tf.keras.Model.save_weights(sess, 'models/m6-ihdp')
            # """
            
            if epoch % args.print_every == 0:
                # y_post_train = sess.run(y_post)
                # , feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)], x_ph_cont: x_train[:, len(binfeats):], y_ph: y_train}
                y0, y1 = get_y0_y1(sess, qy_inf, y_inf2, f0, f1, shape=yalltr.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                # print(np.mean(y1-y0))
                score_train = evaluator_train.calc_stats(y1, y0)
                rmses_train = evaluator_train.y_errors(y0, y1)

                # y_post_test = sess.run(y_post, feed_dict={x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):],  y_ph: yte})
                y0, y1 = get_y0_y1(sess, qy_inf, y_inf2, f0t, f1t, shape=yte.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_test = evaluator_test.calc_stats(y1, y0)

                print("Epoch: {}/{}, log p(x) >= {:0.3f}, ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, " \
                      "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, " \
                      "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss, score_train[0], score_train[1], score_train[2],
                                           rmses_train[0], rmses_train[1], score_test[0], score_test[1], score_test[2],
                                           time.time() - t0))

        ckpt_path = tf.train.latest_checkpoint('ckpt/')
        # saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)
        # checkpoint = tf.train.Checkpoint(sess)
        # manager = tf.train.CheckpointManager(checkpoint)
        # status = checkpoint.restore(manager.latest_checkpoint)
        # tf.keras.Model.load_weights
        # saver.restore(sess, "models/m6-ihdp")
       
        # y_post_train2 = sess.run(y_post, feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)], x_ph_cont: x_train[:, len(binfeats):], y_ph: y_train})
        y0, y1 = get_y0_y1(sess, qy_inf, y_inf2, f0, f1, shape=yalltr.shape, L=100)
        y0, y1 = y0 * ys + ym, y1 * ys + ym
        score = evaluator_train.calc_stats(y1, y0)
        scores[i, :] = score

        # y_post_test2 = sess.run(y_post, feed_dict={x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], y_ph: yte})
        y0t, y1t = get_y0_y1(sess, qy_inf, y_inf2, f0t, f1t, shape=yte.shape, L=100)
        y0t, y1t = y0t * ys + ym, y1t * ys + ym
        score_test = evaluator_test.calc_stats(y1t, y0t)
        scores_test[i, :] = score_test

        print('Replication: {}/{}, tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}' \
              ', te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}'.format(i + 1, args.reps,
                                                                            score[0], score[1], score[2],
                                                                            score_test[0], score_test[1], score_test[2]))
        sess.close()

print('CEVAE model total scores')
means, stds = np.mean(scores, axis=0), sem(scores, axis=0)
print('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

means, stds = np.mean(scores_test, axis=0), sem(scores_test, axis=0)
print('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))



Replication 1/1


epoch #0| 97%|#################################################  |ETA:  0:00:00

Improved validation bound, old: -inf, new: -34.756


AttributeError: 'Tensor' object has no attribute 'numpy'

In [54]:
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
    print('\nReplication {}/{}'.format(i + 1, args.reps))
    (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
    (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
    (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
    evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)

    # reorder features with binary first and continuous after
    perm = binfeats + contfeats
    xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

    xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
    evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
                                mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

    # zero mean, unit variance for y during training
    ym, ys = np.mean(ytr), np.std(ytr)
    ytr, yva = (ytr - ym) / ys, (yva - ym) / ys
    best_logpvalid = - np.inf
    best_avg_loss = - np.inf

    with tf.Graph().as_default():
        sess = tf.compat.v1.Session()
        # sess = tf.Session()

        # ed.set_seed(1)
        initializer = tf.keras.initializers.GlorotNormal(seed = 0)
        np.random.seed(1)
        tf.compat.v1.set_random_seed(1)
        
        # x_ph_bin = tf.Variable([0,0], dtype=float, shape=[M, len(binfeats)], name='x_bin') # binary inputs

        x_ph_bin = tf.compat.v1.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
        x_ph_cont = tf.compat.v1.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
        t_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])
        y_ph = tf.compat.v1.placeholder(tf.float32, [M, 1])

        x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
        activation = tf.nn.elu

        # もとの分布
        # CEVAE model (decoder)
        # p(z)
        z = ed.Normal(loc=tf.zeros([tf.shape(input=x_ph)[0], d]), scale=tf.ones([tf.shape(input=x_ph)[0], d]))

        # p(x|z)
        hx = fc_net(z, (nh - 1) * [h], [], 'px_z_shared', lamba=lamba, activation=activation)
        logits = fc_net(hx, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba, activation=activation)
        x1 = ed.Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')

        mu, sigma = fc_net(hx, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont', lamba=lamba,
                           activation=activation)
        x2 = ed.Normal(loc=mu, scale=sigma, name='gaussian_px_z')

        # p(t|z)
        logitst = fc_net(z, [h], [[1, None]], 'pt_z', lamba=lamba, activation=activation)
        t = ed.Bernoulli(logits=logitst, dtype=tf.float32)

        # p(y|t,z)
        mu2_t0 = fc_net(z, nh * [h], [[1, None]], 'py_t0z', lamba=lamba, activation=activation)
        mu2_t1 = fc_net(z, nh * [h], [[1, None]], 'py_t1z', lamba=lamba, activation=activation)
        y = ed.Normal(loc=t * mu2_t1 + (1. - t) * mu2_t0, scale=tf.ones_like(mu2_t0))

        # CEVAE variational approximation (encoder)
        # q(t|x)
        logits_t = fc_net(x_ph, [d], [[1, None]], 'qt', lamba=lamba, activation=activation)
        qt = ed.Bernoulli(logits=logits_t, dtype=tf.float32)
        # q(y|x,t)
        hqy = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared', lamba=lamba, activation=activation)
        mu_qy_t0 = fc_net(hqy, [h], [[1, None]], 'qy_xt0', lamba=lamba, activation=activation)
        mu_qy_t1 = fc_net(hqy, [h], [[1, None]], 'qy_xt1', lamba=lamba, activation=activation)
        qy = ed.Normal(loc=qt * mu_qy_t1 + (1. - qt) * mu_qy_t0, scale=tf.ones_like(mu_qy_t0))
        # q(z|x,t,y)
        inpt2 = tf.concat([x_ph, qy], 1)
        hqz = fc_net(inpt2, (nh - 1) * [h], [], 'qz_xty_shared', lamba=lamba, activation=activation)
        muq_t0, sigmaq_t0 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0', lamba=lamba,
                                   activation=activation)
        muq_t1, sigmaq_t1 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1', lamba=lamba,
                                   activation=activation)
        qz = ed.Normal(loc=qt * muq_t1 + (1. - qt) * muq_t0, scale=qt * sigmaq_t1 + (1. - qt) * sigmaq_t0)
        
        # Create data dictionary for edward
        data = {x1: x_ph_bin, x2: x_ph_cont, y: y_ph, qt: t_ph, t: t_ph, qy: y_ph}
        
        # Compute expected log-likelihood. First, sample from the variational distribution; second, compute the log-likelihood given the sample.
        
        # sample posterior predictive for p(y|z,t)
        # y_post = ed.copy(y, {z: qz, t: t_ph}, scope='y_post')
        mu2_t0_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t0z_y_post', lamba=lamba, activation=activation)
        mu2_t1_y_post = fc_net(qz, nh * [h], [[1, None]], 'py_t1z_y_post', lamba=lamba, activation=activation)
        y_post_dist = ed.Normal(loc=t_ph * mu2_t1_y_post + (1. - t_ph) * mu2_t0_y_post, scale=tf.ones_like(mu2_t0_y_post))
        # y_post = y_post_dist.distribution.sample(seed = 0)
        
        # crude approximation of the above
        # y_post_mean = ed.copy(y, {z: qz.mean(), t: t_ph}, scope='y_post_mean')
        mu2_t0_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_y_post_mean', lamba=lamba, activation=activation)
        mu2_t1_y_post_mean = fc_net(qz.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_y_post_mean', lamba=lamba, activation=activation)
        y_post_mean_dist = ed.Normal(loc=t_ph * mu2_t1_y_post_mean + (1. - t_ph) * mu2_t0_y_post_mean, scale=tf.ones_like(mu2_t0_y_post_mean))
        y_post_mean = y_post_mean_dist.distribution.sample()
        
        # logpvalid計算用の分布
        # construct a deterministic version (i.e. use the mean of the approximate posterior) of the lower bound
        # for early stopping according to a validation set
        # qz
        inpt2_post_eval_and_inf = tf.concat([x_ph, y_ph], 1)
        hqz_post_eval_and_inf = fc_net(inpt2_post_eval_and_inf, (nh - 1) * [h], [], 'qz_xty_shared_post_eval_and_inf', lamba=lamba, activation=activation)
        muq_t0_post_eval_and_inf, sigmaq_t0_post_eval_and_inf = fc_net(hqz_post_eval_and_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_post_eval_and_inf', lamba=lamba,
                                   activation=activation)
        muq_t1_post_eval_and_inf, sigmaq_t1_post_eval_and_inf = fc_net(hqz_post_eval_and_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_post_eval_and_inf', lamba=lamba,
                                   activation=activation)
        qz_post_eval_and_inf = ed.Normal(loc=t_ph * muq_t1_post_eval_and_inf + (1. - t_ph) * muq_t0_post_eval_and_inf, scale=t_ph * sigmaq_t1_post_eval_and_inf + (1. - t_ph) * sigmaq_t0_post_eval_and_inf)
        
        # y_post_eval = ed.copy(y, {z: qz.mean(), qt: t_ph, qy: y_ph, t: t_ph}, scope='y_post_eval')
        mu2_t0_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), nh * [h], [[1, None]], 'py_t0z_post_eval', lamba=lamba, activation=activation)
        mu2_t1_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), nh * [h], [[1, None]], 'py_t1z_post_eval', lamba=lamba, activation=activation)
        y_post_eval_dist = ed.Normal(loc=t_ph * mu2_t1_post_eval + (1. - t_ph) * mu2_t0_post_eval, scale=tf.ones_like(mu2_t0_post_eval))
        # y_post_eval = y_post_eval_dist.distribution.sample()
        
        # x1_post_eval = x1, {z: qz.mean(), qt: t_ph, qy: y_ph}
        hx_x_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), (nh - 1) * [h], [], 'px_z_shared_post_eval', lamba=lamba, activation=activation)
        logits_post_eval = fc_net(hx_x_post_eval, [h], [[len(binfeats), None]], 'px_z_bin_x_post_eval'.format(i + 1), lamba=lamba, activation=activation)
        x1_post_eval_dist = ed.Bernoulli(logits=logits_post_eval, dtype=tf.float32, name='bernoulli_px_z_post_eval')
        # x1_post_eval = x1_post_eval_dist.distribution.sample()
        
        # x2_post_eval = x2, {z: qz.mean(), qt: t_ph, qy: y_ph}
        mu_x2_post_eval, sigma_x2_post_eval = fc_net(hx_x_post_eval, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_post_eval', lamba=lamba, activation=activation)
        x2_post_eval_dist = ed.Normal(loc=mu_x2_post_eval, scale=sigma_x2_post_eval, name='gaussian_px_z_post_eval')
        # x2_post_eval = x2_post_eval_dist.distribution.sample()
        
        # t_post_eval = ed.copy(t, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='t_post_eval')
        # logits_t_post_eval = fc_net(t_ph * muq_t1_t_post_eval + (1. - t_ph) * muq_t0_t_post_eval, [h], [[1, None]], 'pt_z_t_post_eval', lamba=lamba, activation=activation)
        logitst_post_eval = fc_net(qz_post_eval_and_inf.distribution.mean(), [h], [[1, None]], 'pt_z_post_eval', lamba=lamba, activation=activation)
        t_post_eval_dist = ed.Bernoulli(logits=logitst_post_eval, dtype=tf.float32)
        # t_post_eval = y_post_eval_dist.distribution.sample()
    
        logp_valid = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=y_post_eval_dist.distribution.log_prob(y_ph) + t_post_eval_dist.distribution.log_prob(t_ph), axis=1) +
                                    tf.reduce_sum(input_tensor=x1_post_eval_dist.distribution.log_prob(x_ph_bin), axis=1) +
                                    tf.reduce_sum(input_tensor=x2_post_eval_dist.distribution.log_prob(x_ph_cont), axis=1) +
                                    tf.reduce_sum(input_tensor=z.distribution.log_prob(qz_post_eval_and_inf.distribution.mean()) - qz_post_eval_and_inf.distribution.log_prob(qz_post_eval_and_inf.distribution.mean()), axis=1)) # tf.reduce_sum(input_tensor=z.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0) - qz.distribution.log_prob(qt * muq_t1 + (1. - qt) * muq_t0), axis=1))
      
        # inference用の分布
        # qz_data = qz.distribution.sample()
        
        # 補助分布
        logits_t_inf = fc_net(x_ph, [d], [[1, None]], 'qt_inf', lamba=lamba, activation=activation)
        qt_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        hqy_inf = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared_inf', lamba=lamba, activation=activation)
        mu_qy_t0_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt0_inf', lamba=lamba, activation=activation)
        mu_qy_t1_inf = fc_net(hqy_inf, [h], [[1, None]], 'qy_xt1_inf', lamba=lamba, activation=activation)
        qy_inf = ed.Normal(loc=t_ph * mu_qy_t1_inf + (1. - t_ph) * mu_qy_t0_inf, scale=tf.ones_like(mu_qy_t0_inf))
        
        # 推論ネットワーク
        hx_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_inf', lamba=lamba, activation=activation)
        logits_inf = fc_net(hx_inf, [h], [[len(binfeats), None]], 'px_z_bin_inf'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf = ed.Bernoulli(logits=logits_inf, dtype=tf.float32, name='bernoulli_px_z_inf')
        mu_inf, sigma_inf = fc_net(hx_inf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf', lamba=lamba,
                           activation=activation)
        x2_inf = ed.Normal(loc=mu_inf, scale=sigma_inf, name='gaussian_px_z_inf')
        logits_t_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf', lamba=lamba, activation=activation)
        t_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        mu2_t0_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf', lamba=lamba, activation=activation)
        mu2_t1_inf = fc_net(qz_post_eval_and_inf.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf', lamba=lamba, activation=activation)
        y_inf = ed.Normal(loc=t_inf * mu2_t1_inf + (1. - t_inf) * mu2_t0_inf, scale=tf.ones_like(mu2_t0_inf))
        
        inference = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=x1_inf.distribution.log_prob(x_ph_bin),axis=1) + tf.reduce_sum(input_tensor=x2_inf.distribution.log_prob(x_ph_cont),axis=1) + tf.reduce_sum(input_tensor=y_inf.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=qt_inf.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=t_inf.distribution.log_prob(t_ph),axis=1) + tf.reduce_sum(input_tensor=qy_inf.distribution.log_prob(y_ph),axis=1) + tf.reduce_sum(input_tensor=z.distribution.log_prob(qz_post_eval_and_inf.distribution.sample(seed=0))) - tf.reduce_sum(input_tensor=qz_post_eval_and_inf.distribution.log_prob(qz_post_eval_and_inf.distribution.sample(seed=0)),axis=1))
        global_step = tf.compat.v1.train.get_or_create_global_step()
        train_op = tf.compat.v1.train.AdamOptimizer(args.lr).minimize(-inference,global_step=global_step)
        
        # 推論ネットワーク評価指標計算用
        inpt2_inf2 = tf.concat([x_ph, qy_inf], 1)
        hqz_inf2 = fc_net(inpt2_inf2, (nh - 1) * [h], [], 'qz_xty_shared_inf2', lamba=lamba, activation=activation)
        muq_t0_inf2, sigmaq_t0_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_inf2', lamba=lamba,
                                   activation=activation)
        muq_t1_inf2, sigmaq_t1_inf2 = fc_net(hqz_inf2, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_inf2', lamba=lamba,
                                   activation=activation)
        qz_inf2 = ed.Normal(loc=t_ph * muq_t1_inf2 + (1. - t_ph) * muq_t0_inf2, scale=t_ph * sigmaq_t1_inf2 + (1. - t_ph) * sigmaq_t0_inf2)
        hx_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), (nh - 1) * [h], [], 'px_z_shared_in2f', lamba=lamba, activation=activation)
        logits_inf2 = fc_net(hx_inf2, [h], [[len(binfeats), None]], 'px_z_bin_inf2'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf2 = ed.Bernoulli(logits=logits_inf2, dtype=tf.float32, name='bernoulli_px_z_inf2')
        mu_inf2, sigma_inf2 = fc_net(hx_inf2, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf2', lamba=lamba,
                           activation=activation)
        x2_inf2 = ed.Normal(loc=mu_inf2, scale=sigma_inf2, name='gaussian_px_z_inf2')
        logits_t_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), [h], [[1, None]], 'pt_z_inf2', lamba=lamba, activation=activation)
        t_inf2 = ed.Bernoulli(logits=logits_t_inf2, dtype=tf.float32)
        mu2_t0_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t0z_inf2', lamba=lamba, activation=activation)
        mu2_t1_inf2 = fc_net(qz_inf2.distribution.sample(seed=0), nh * [h], [[1, None]], 'py_t1z_inf2', lamba=lamba, activation=activation)
        y_inf2 = ed.Normal(loc=t_inf2 * mu2_t1_inf2 + (1. - t_inf2) * mu2_t0_inf2, scale=tf.ones_like(mu2_t0_inf2))
        
        """
        inference = ed.KLqp({z: qz}, data)
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.lr)
        inference.initialize(optimizer=optimizer)
        """
       
        sess.run(tf.compat.v1.global_variables_initializer())
        # tf.compat.v1.global_variables_initializer().run(session=sess)
        saver = tf.compat.v1.train.Saver()
        # saver = tf.compat.v1.train.Saver(slim.get_variables())
        # kernel_initializer=initializers.glorot_uniform(seed=0)))

        n_epoch, n_iter_per_epoch, idx = args.epochs, 10 * int(xtr.shape[0] / 100), np.arange(xtr.shape[0])

        # dictionaries needed for evaluation
        tr0, tr1 = np.zeros((xalltr.shape[0], 1)), np.ones((xalltr.shape[0], 1))
        tr0t, tr1t = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
        qy_list = y_inf.distribution.sample(len(xalltr),seed=0)
        """
        f1 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr1}
        f0 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr0}
        f1t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr1t}
        f0t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr0t}
        """
        f1 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr1, y_ph: qy_list}
        f0 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr0, y_ph: qy_list}
        f1t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr1t, y_ph: qy_list}
        f0t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr0t, y_ph: qy_list}
        
        # loss = np.zeros(n_epoch*n_iter_per_epoch)
        # logpvalid_list = np.zeros(n_iter_per_epoch*n_epoch)
        
        for epoch in range(n_epoch):
            # avg_loss_list = np.zeros(n_iter_per_epoch)
            avg_loss = 0.0
            t0 = time.time()
            widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
            pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
            pbar.start()
            np.random.shuffle(idx)
            
            for j in range(n_iter_per_epoch):
              # info_dict = 0.0
              pbar.update(j)
              batch = np.random.choice(idx, 100)
              x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch] 
              """
              sess.run(train_op,feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train})
              """
              _, info_dict = sess.run([train_op, inference], feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                        x_ph_cont: x_train[:, len(binfeats):],
                                                        t_ph: t_train, y_ph: y_train}) 
              step = sess.run(global_step)
              avg_loss += info_dict
            
            avg_loss = avg_loss / n_iter_per_epoch
            # avg_loss = np.mean(avg_loss_list) / n_iter_per_epoch
            avg_loss = avg_loss / 100
            
            if epoch % args.earl == 0 or epoch == (n_epoch - 1):
                logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: tva, y_ph: yva})
                if logpvalid >= best_logpvalid:
                  print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                  best_logpvalid = logpvalid
                  saver.save(sess, "models/m6-ihdp")
            
            if epoch % args.print_every == 0:
                y0, y1 = get_y0_y1(sess, y_inf, f0, f1, shape=yalltr.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_train = evaluator_train.calc_stats(y1, y0)
                rmses_train = evaluator_train.y_errors(y0, y1)
                
                y0, y1 = get_y0_y1(sess, y_inf, f0t, f1t, shape=yte.shape, L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_test = evaluator_test.calc_stats(y1, y0)

                print("Epoch: {}/{}, log p(x) >= {:0.3f}, ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, " \
                      "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, " \
                      "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss, score_train[0], score_train[1], score_train[2],
                                           rmses_train[0], rmses_train[1], score_test[0], score_test[1], score_test[2],
                                           time.time() - t0))

        saver.restore(sess, "models/m6-ihdp")
       
        y0, y1 = get_y0_y1(sess, y_inf, f0, f1, shape=yalltr.shape, L=100)
        y0, y1 = y0 * ys + ym, y1 * ys + ym
        score = evaluator_train.calc_stats(y1, y0)
        scores[i, :] = score

        y0t, y1t = get_y0_y1(sess, y_inf, f0t, f1t, shape=yte.shape, L=100)
        y0t, y1t = y0t * ys + ym, y1t * ys + ym
        score_test = evaluator_test.calc_stats(y1t, y0t)
        scores_test[i, :] = score_test

        print('Replication: {}/{}, tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}' \
              ', te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}'.format(i + 1, args.reps,
                                                                            score[0], score[1], score[2],
                                                                            score_test[0], score_test[1], score_test[2]))
        sess.close()

print('CEVAE model total scores')
means, stds = np.mean(scores, axis=0), sem(scores, axis=0)
print('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

means, stds = np.mean(scores_test, axis=0), sem(scores_test, axis=0)
print('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
      ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))



Replication 1/1


epoch #0| 97%|#################################################  |ETA:  0:00:00

Improved validation bound, old: -inf, new: -84.743


TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles. For reference, the tensor object was Tensor("Normal_9/sample_1/Reshape:0", shape=(672, None, 1), dtype=float32) which was passed to the argument `feed_dict` with key Tensor("Placeholder_1:0", shape=(None, 1), dtype=float32).

In [53]:
print(type(qy_inf))

<class 'edward2.tensorflow.random_variable.RandomVariable'>


class EarlyStopping(object):
    def __init__(self, patience=0, verbose=0):
        self._step= 0
        self._loss=float('inf')
        self._patience=patience
        self.verbose=verbose

    def validate(self,loss):
        if self._loss < loss:
            self._step += 1
            if self._step > self._patience:
                if self.verbose:
                    print('early stopping')
                return True
        else:
            self.step = 0
            self.loss = loss
       
        return False


inpt2_inf = tf.concat([x_ph, qy], 1)
        hqz_inf = fc_net(inpt2_inf, (nh - 1) * [h], [], 'qz_xty_shared_inf', lamba=lamba, activation=activation)
        muq_t0_inf, sigmaq_t0_inf = fc_net(hqz_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0_inf', lamba=lamba,
                                   activation=activation)
        muq_t1_inf, sigmaq_t1_inf = fc_net(hqz_inf, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1_inf', lamba=lamba,
                                   activation=activation)
        qz_inf = ed.Normal(loc=qt * muq_t1_inf + (1. - qt) * muq_t0_inf, scale=qt * sigmaq_t1_inf + (1. - qt) * sigmaq_t0_inf)
        
        hx_inf = fc_net(qz, (nh - 1) * [h], [], 'px_z_shared_inf', lamba=lamba, activation=activation)
        logits_inf = fc_net(hx_inf, [h], [[len(binfeats), None]], 'px_z_bin_inf'.format(i + 1), lamba=lamba, activation=activation)
        x1_inf = ed.Bernoulli(logits=logits_inf, dtype=tf.float32, name='bernoulli_px_z_inf')
        mu_inf, sigma_inf = fc_net(hx_inf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont_inf', lamba=lamba,
                           activation=activation)
        x2_inf = ed.Normal(loc=mu_inf, scale=sigma_inf, name='gaussian_px_z_inf')
        logits_t_inf = fc_net(qz, [h], [[1, None]], 'pt_z_inf', lamba=lamba, activation=activation)
        t_inf = ed.Bernoulli(logits=logits_t_inf, dtype=tf.float32)
        mu2_t0_inf = fc_net(qz, nh * [h], [[1, None]], 'py_t0z_inf', lamba=lamba, activation=activation)
        mu2_t1_inf = fc_net(qz, nh * [h], [[1, None]], 'py_t1z_inf', lamba=lamba, activation=activation)
        y_inf = ed.Normal(loc=t_inf * mu2_t1_inf + (1. - t_inf) * mu2_t0_inf, scale=tf.ones_like(mu2_t0_inf))