In [1]:
import datetime
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow_probability as tfp
import tensorflow.keras.layers as tfkl
tfd,tfpl = tfp.distributions,tfp.layers
import tensorflow.keras.backend as tfkb
from tensorflow.keras.callbacks import Callback
from sklearn.preprocessing import StandardScaler
from networks import p_x_z, p_t_z, p_y_tz, q_t_x, q_y_xt, q_z_xyt
from evaluation import Evaluator, pdist2sq, Full_Metrics
#################################IHDP Data
# data information 
t_bin_dim = 1
y_dim, default_y_scale = 1,tf.exp(0.)
M = None        # batch size during training
z_dim = 20          # latent z dimension
lamba = 1e-4    # weight decay
nh, h = 3, 200  # number and size of hidden layers
binfeats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
numfeats = [i for i in range(25) if i not in binfeats]
x_bin_dim = len(binfeats)
x_num_dim = len(numfeats)

2022-03-04 18:29:57.764308: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i=7):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')
        ycf=np.concatenate((train_data['ycf'][:,i],  test_data['ycf'][:,i])).astype('float32')

        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        data['ycf'] = ycf.reshape(-1,1)
        
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])

    return data

data=load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz')

文件 “ihdp_npci_1-100.train.npz” 已经存在；不获取。

文件 “ihdp_npci_1-100.test.npz” 已经存在；不获取。



In [3]:
#Colab command to allow us to run Colab in TF2
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
%reload_ext tensorboard 

In [8]:
class CEVAE(tf.keras.Model):
    def __init__(self):
        super(CEVAE, self).__init__()
        ########################################
        # networks
        self.activation = 'elu'
        # CEVAE Model (decoder)
        # p(z)
        self.z = tfd.Normal(loc = [0]*z_dim, scale = [1]*z_dim)
        # p(x|z)
        self.p_x_z = p_x_z(x_bin_dim, x_num_dim, z_dim, nh, h)
        # p(t|z)
        self.p_t_z = p_t_z(t_bin_dim, z_dim,h)
        # p(y|t,z)
        self.p_y_tz = p_y_tz(y_dim, t_bin_dim, z_dim,default_y_scale, nh, h )
        # CEVAE Model (encoder)
        # q(t|x)
        self.q_t_x = q_t_x(x_bin_dim, x_num_dim, t_bin_dim, z_dim, nh, h)
        # q(y|x,t)
        self.q_y_xt = q_y_xt(x_bin_dim, x_num_dim, y_dim, t_bin_dim, default_y_scale, nh, h)
        # q(z|x,t,y)
        self.q_z_xyt = q_z_xyt(x_bin_dim, x_num_dim, y_dim, t_bin_dim, z_dim, nh, h) 

    def call(self, data, training=False):
        if training:
            # when training need x,y,t
            x,y,t = data

            # Dataset_inp
            # CEVAE variational approximation (encoder)
            ## inferred distribution over z
            xyt_input = tf.concat([x, y, t], axis=-1)
            z_infer = self.q_z_xyt(xyt_input)
            z_infer_sample = z_infer.sample()
            ## q(t|x)
            qt = self.q_t_x(x)
            ## q(y|x,t)
            xt_inputs = tf.concat([x,qt],-1)
            qy = self.q_y_xt(xt_inputs)
            
        
            # CEVAE model (decoder)
            ## p(x|z)
            [x_bin,x_con] = self.p_x_z(z_infer_sample)
            ## p(t|z)
            t_infer = self.p_t_z(z_infer_sample)
            ## p(y|t,z)
            tz = tf.concat([t_infer,z_infer_sample],-1)
            y_infer = self.p_y_tz()
            output = [x_infer_bin,x_infer_num,t_infer,y_infer, qz,qt,qy]
            return output
        else:
            return data

    def get_y0_y1(self, x_train, t_train, L=1):
        y_infer = self.q_y_xt(tf.concat([x_train, t_train],-1))
        # use inferred y
        xyt = tf.concat([x_train, y_infer, t_train], -1)  # TODO take mean?
        z_infer = self.q_z_xyt(xyt)
        # Manually input zeros and ones
        y0 = self.p_y_tz(tf.concat([t_train,tf.zeros_like(z_infer)],-1)).mean()  # TODO take mean?
        y1 = self.p_y_tz(tf.concat([t_train,tf.ones_like(z_infer)],-1)).mean()  # TODO take mean?
        return tf.concat([y0,y1],-1)

    def cevae_loss(self, data, pred):
        # read labels
        x_bin_real,x_num_real = data[0][...,:x_bin_dim],data[0][...,x_bin_dim: (x_bin_dim+x_num_dim)]
        t_real, y_real = data[1],data[2]
        # get preds
        [x_bin,x_num,t,y,qz,qt,qy] = pred
        # Reconstruction loss
        ## p(x|z)
        loss = {}
        loss['loss_p_x_z_bin'] = tfkb.mean(x_bin.log_prob(x_bin_real))
        loss['loss_p_x_z_num'] = tfkb.mean(x_num.log_prob(x_num_real))
        ## p(t|z)
        loss['loss_p_t_z_bin'] = tfkb.mean(t.log_prob(t_real))
        ## p(y|t,z)
        loss['loss_p_y_tz'] = tfkb.mean(y.log_prob(y_real))
        # REGULARIZATION LOSS
        # AUXILIARY LOSS
        ## q(t|x)
        loss['aux_loss_q_t_x'] = tfkb.mean(qt.log_prob(t_real))
        ## q(y|x,t)
        loss['aux_loss_q_y_xt'] = tfkb.mean(qy.log_prob(y_real))
        loss_all = loss['loss_p_x_z_bin'] + loss['loss_p_x_z_num'] + loss['loss_p_t_z_bin'] + loss['loss_p_y_tz'] + loss['aux_loss_q_t_x'] + loss['aux_loss_q_y_xt']
        return -loss_all

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        # 这里data[0]因为会自动在外面拼接一层
        data = data[0]
        x,t,y = data
        with tf.GradientTape() as tape:
            pred = self(data, training=True)  # Forward pass
            loss = self.cevae_loss(data,pred)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        metrics = {"loss": loss}
        return metrics

    # def test_step(self, data):
    #     # Unpack the data. Its structure depends on your model and
    #     # on what you pass to `fit()`.
    #     x,_ = data
    #     with tf.GradientTape() as tape:
    #         pred = self(x, training=True)  # Forward pass
    #         loss = self.cevae_loss(data,pred)
    #     metrics = {"loss": loss}
    #     return metrics


In [None]:
model = CEVAE()
### MAIN CODE ####
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD
from evaluation import *
 
val_split=0.2
batch_size=64
verbose=True
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1)
 
sgd_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=40, min_delta=0), 
        #40 is Shi's recommendation patience for this dataset, but you should tune for your data 
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0, cooldown=0, min_lr=0),
        #This learning rate scheduling is quite agressive which seems good for this dataset
        # Full_Metrics(data,verbose),
        # metrics_for_cevae(data,verbose),
        tensorboard_callback
    ]
    
#optimizer hyperparameters
sgd_lr = 1e-5
momentum = 0.9
model.compile(
    optimizer=SGD(
        learning_rate=sgd_lr, 
        momentum=momentum, 
        nesterov=True
        )
    )

model.fit(
    [data['x'],data['t'],data['y']],
    callbacks=sgd_callbacks,
    validation_split=val_split,
    epochs=300,
    batch_size=batch_size,
    verbose=verbose
    )
print("Done!")

Epoch 1/300
success
success
([[0.208507463 -0.202945948 1.12855399 ... 1 0 0]
 [0.759718955 0.996346235 -0.733261 ... 0 0 0]
 [-1.67209649 -1.80200219 2.2456429 ... 0 0 1]
 ...
 [0.748910904 0.596582174 -0.360898 ... 0 0 0]
 [-1.60724807 -1.40223813 0.756190956 ... 0 0 0]
 [-0.677754164 -1.00247407 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[3.48064399]
 [8.69245815]
 [2.48917842]
 ...
 [1.55222762]
 [-1.15384924]
 [7.43073893]])
 1/10 [==>...........................] - ETA: 31s - loss: 47.4693([[-0.872299433 -0.60271 0.756190956 ... 0 0 0]
 [-0.288663685 0.196818128 -0.733261 ... 0 0 1]
 [0.359820426 -0.202945948 -0.733261 ... 0 0 0]
 ...
 [1.48385954 1.79587436 -1.10562396 ... 0 0 0]
 [1.24608207 1.79587436 -0.733261 ... 0 1 0]
 [0.122042917 -0.202945948 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[4.46477032]
 [10.8867645]
 [7.01332903]
 ...
 [2.33838511]
 [6.2661972]
 [4.20486832]])
([[1.15961754 0.596582174 -0.733261 ... 0 1 0]
 [-1.5640158

 [0.424668849 0.596582174 -1.47798693 ... 0 0 0]
 [0.748910904 0.596582174 -0.733261 ... 0 1 0]
 ...
 [1.13800144 0.996346235 0.756190956 ... 0 0 0]
 [0.208507463 -0.202945948 1.12855399 ... 1 0 0]
 [0.813759327 0.996346235 -0.733261 ... 0 0 1]], [[0]
 [0]
 [0]
 ...
 [1]
 [0]
 [0]], [[3.45381379]
 [6.42257643]
 [3.3428185]
 ...
 [8.874547]
 [3.48064399]
 [8.97282124]])
([[0.467901111 0.996346235 0.756190956 ... 0 0 0]
 [-1.45593512 -1.40223813 2.2456429 ... 0 0 0]
 [0.370628506 0.596582174 0.0114649916 ... 0 0 0]
 ...
 [-0.699370325 -0.60271 -0.360898 ... 0 0 0]
 [-0.439976662 0.196818128 0.0114649916 ... 0 0 0]
 [-1.17492533 -1.28056991 0.756190956 ... 0 0 1]], [[0]
 [1]
 [0]
 ...
 [0]
 [0]
 [0]], [[4.03279829]
 [6.49890327]
 [3.24240303]
 ...
 [4.25085831]
 [8.03814888]
 [6.35638714]])
([[-0.310279846 -0.202945948 -0.360898 ... 0 1 0]
 [-1.97472239 -2.20176625 2.61800575 ... 0 0 0]
 [0.640830219 -0.202945948 -1.47798693 ... 0 0 0]
 ...
 [-0.245431423 -0.202945948 0.0114649916 ... 1 0

 [0.208507463 0.196818128 0.0114649916 ... 0 0 1]
 [0.792143166 0.596582174 -0.360898 ... 0 0 0]
 ...
 [-0.375128239 0.196818128 -0.360898 ... 0 0 1]
 [1.13800144 0.996346235 0.756190956 ... 0 0 0]
 [0.770527065 0.596582174 0.383827984 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [1]
 [1]], [[6.21752501]
 [1.2875129]
 [7.77413464]
 ...
 [5.0234046]
 [8.874547]
 [8.41891289]])
([[-1.21815765 -1.80200219 2.2456429 ... 0 0 0]
 [-0.699370325 0.196818128 0.756190956 ... 0 0 0]
 [1.48385954 0.996346235 -0.360898 ... 0 0 0]
 ...
 [-1.02361238 -0.202945948 1.12855399 ... 0 0 0]
 [-0.0725023225 -0.202945948 -0.360898 ... 0 0 0]
 [0.25173974 1.3961103 -0.360898 ... 0 0 1]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[3.58761024]
 [5.72716045]
 [2.38152099]
 ...
 [1.12154269]
 [4.61986494]
 [3.1196363]])
([[0.165275186 0.196818128 -1.10562396 ... 0 0 0]
 [0.338204294 0.196818128 -0.360898 ... 0 1 0]
 [0.813759327 0.996346235 -0.360898 ... 0 0 0]
 ...
 [0.403052717 0.196818128 0.0114649916 ... 0 0 0]
 [0.921

([[-2.2989645 -0.663642347 2.2456429 ... 0 0 0]
 [0.359820426 0.196818128 -0.360898 ... 0 0 0]
 [0.802951276 0.996346235 -1.47798693 ... 0 0 0]
 ...
 [1.18123364 0.196818128 -1.47798693 ... 0 0 0]
 [-0.569673479 0.196818128 -1.10562396 ... 0 0 1]
 [0.92184 1.3961103 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [1]], [[3.08380675]
 [4.81772947]
 [5.59601688]
 ...
 [1.78568709]
 [9.91706181]
 [8.28565311]])
([[0.511133373 -0.60271 -0.360898 ... 1 0 0]
 [0.87860775 0.996346235 -0.733261 ... 0 0 0]
 [1.28931439 1.79587436 -1.47798693 ... 0 0 0]
 ...
 [0.208507463 0.196818128 -1.10562396 ... 0 0 1]
 [0.965072274 1.79587436 -0.360898 ... 0 0 0]
 [-0.548057318 -1.00247407 0.0114649916 ... 0 0 1]], [[0]
 [1]
 [0]
 ...
 [0]
 [0]
 [1]], [[6.86179209]
 [8.10634327]
 [1.35772169]
 ...
 [4.12323475]
 [3.9785552]
 [9.915802]])
 [0.424668849 0.596582174 -1.10562396 ... 0 0 0]
 [-0.158966869 -0.60271 -0.360898 ... 0 0 0]
 ...
 [1.07315302 0.596582174 -0.733261 ... 0 0 0]
 [0.338204294 0.19

([[-0.137350738 -0.202945948 -1.10562396 ... 1 0 0]
 [1.07315302 0.596582174 -0.733261 ... 0 0 0]
 [0.381436557 0.196818128 -0.360898 ... 0 0 0]
 ...
 [-1.58563197 -1.40223813 1.12855399 ... 0 0 0]
 [0.87860775 0.996346235 -0.733261 ... 0 0 0]
 [1.46224344 0.996346235 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [1]
 [1]], [[5.43475819]
 [3.16475439]
 [8.10910892]
 ...
 [3.26952362]
 [8.10634327]
 [7.29342413]])
([[-0.39674437 -0.60271 0.383827984 ... 0 0 0]
 [-0.267047554 0.196818128 0.383827984 ... 0 0 0]
 [1.28931439 0.596582174 0.383827984 ... 0 0 0]
 ...
 [1.20284975 0.596582174 -1.10562396 ... 0 0 1]
 [1.44062734 0.196818128 -1.10562396 ... 0 0 0]
 [1.13800144 1.3961103 -1.47798693 ... 0 0 0]], [[0]
 [1]
 [0]
 ...
 [1]
 [1]
 [0]], [[1.87898326]
 [7.9127593]
 [-0.179098472]
 ...
 [7.49862719]
 [7.03739119]
 [4.91432953]])
 [0.813759327 -0.37386933 -0.733261 ... 0 1 0]
 [-0.569673479 0.196818128 -0.360898 ... 0 0 1]
 ...
 [-0.180583 0.596582174 -0.360898 ... 0 0 1]
 [0.42466

 1/10 [==>...........................] - ETA: 0s - loss: 46.5159([[0.145820662 0.596582174 -1.10562396 ... 0 0 0]
 [-0.223815277 0.196818128 -1.47798693 ... 0 0 0]
 [0.316588163 0.196818128 -0.360898 ... 0 0 0]
 ...
 [-0.656138062 -1.00247407 -0.360898 ... 1 0 0]
 [-0.180583 0.196818128 -0.360898 ... 0 0 0]
 [1.18123364 -0.60271 -0.360898 ... 0 0 0]], [[1]
 [0]
 [0]
 ...
 [0]
 [1]
 [0]], [[8.78146744]
 [4.5718627]
 [3.18248415]
 ...
 [4.86824799]
 [7.74038124]
 [3.33471823]])
([[-0.504825056 -0.378263652 -0.360898 ... 1 0 0]
 [0.835375428 1.3961103 0.0114649916 ... 0 0 0]
 [-0.331895977 0.596582174 -0.360898 ... 0 0 0]
 ...
 [-1.28300607 -1.00247407 1.12855399 ... 0 0 0]
 [1.48385954 0.996346235 -0.360898 ... 0 0 0]
 [0.792143166 0.596582174 -0.360898 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [1]
 [0]
 [0]], [[3.72619772]
 [2.92256117]
 [4.60567665]
 ...
 [7.39705229]
 [2.38152099]
 [7.77413464]])
([[0.424668849 0.996346235 0.0114649916 ... 0 0 0]
 [0.359820426 0.196818128 -0.360898 ... 0 0 0]

Epoch 12/300
([[0.424668849 0.596582174 -1.10562396 ... 0 0 0]
 [1.39739501 0.996346235 -1.10562396 ... 0 0 1]
 [-0.764218748 -1.00247407 0.756190956 ... 0 0 1]
 ...
 [-0.742602587 -1.00247407 1.12855399 ... 0 1 0]
 [1.20284975 0.996346235 -1.47798693 ... 1 0 0]
 [-0.0292700436 -0.60271 0.0114649916 ... 0 0 0]], [[0]
 [1]
 [0]
 ...
 [0]
 [0]
 [0]], [[11.454339]
 [8.35765934]
 [3.79539299]
 ...
 [7.36864376]
 [4.13379097]
 [3.26158786]])
 1/10 [==>...........................] - ETA: 0s - loss: 44.1738([[0.532749534 0.596582174 -1.47798693 ... 0 1 0]
 [0.489517272 0.996346235 -0.733261 ... 0 0 0]
 [-0.158966869 -0.60271 -0.360898 ... 0 0 0]
 ...
 [0.294972032 -0.202945948 -0.360898 ... 0 0 1]
 [0.467901111 0.996346235 0.756190956 ... 0 0 0]
 [0.87860775 0.596582174 -0.360898 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[11.8775101]
 [4.46803474]
 [7.74218464]
 ...
 [5.60639811]
 [4.03279829]
 [2.97440529]])
([[-1.80179334 -2.20176625 1.87327993 ... 0 0 0]
 [0.230123609 -0.202945948

 [0.792143166 1.3961103 -0.360898 ... 0 0 0]
 [-0.742602587 -0.978717 1.12855399 ... 0 0 0]
 ...
 [0.705678642 0.596582174 -1.10562396 ... 1 0 0]
 [-1.28300607 -1.00247407 1.12855399 ... 0 0 0]
 [0.208507463 0.196818128 0.0114649916 ... 0 0 1]], [[1]
 [0]
 [0]
 ...
 [0]
 [1]
 [0]], [[7.27787971]
 [4.24617147]
 [5.00886679]
 ...
 [7.1778059]
 [7.39705229]
 [1.2875129]])
Epoch 14/300
([[-0.267047554 -0.60271 0.383827984 ... 0 1 0]
 [0.424668849 0.996346235 0.383827984 ... 0 0 0]
 [-1.52078354 -0.202945948 0.383827984 ... 0 0 0]
 ...
 [0.359820426 0.596582174 -1.10562396 ... 0 0 0]
 [-0.353512108 -0.202945948 -0.360898 ... 0 0 0]
 [-1.95310628 -2.60153031 2.61800575 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [1]
 [0]
 [0]], [[1.666026]
 [2.6918993]
 [4.29586077]
 ...
 [7.25801]
 [6.02166462]
 [2.95969939]])
 1/10 [==>...........................] - ETA: 0s - loss: 39.1765([[0.66244638 0.996346235 0.383827984 ... 0 1 0]
 [0.25173974 1.3961103 -0.360898 ... 0 0 1]
 [0.770527065 0.596582174 -0.360898 

([[0.489517272 -0.60271 0.0114649916 ... 0 1 0]
 [-1.28300607 -1.00247407 1.12855399 ... 0 0 0]
 [0.186891332 0.596582174 0.383827984 ... 0 0 0]
 ...
 [-0.504825056 -0.378263652 -0.360898 ... 1 0 0]
 [-1.47755122 -1.58296883 1.50091696 ... 0 0 1]
 [-0.548057318 -0.202945948 -0.360898 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [1]
 [1]], [[7.10239124]
 [2.09498644]
 [4.51536369]
 ...
 [3.72619772]
 [9.67978668]
 [6.18218088]])
([[0.511133373 -0.202945948 -0.733261 ... 0 0 1]
 [-0.288663685 0.196818128 0.0114649916 ... 0 0 0]
 [1.22446597 0.196818128 -1.47798693 ... 0 0 1]
 ...
 [-1.56401587 -1.80200219 1.50091696 ... 0 0 0]
 [-1.04522848 -1.00247407 0.756190956 ... 0 0 0]
 [-0.00765390694 0.196818128 -1.47798693 ... 0 0 0]], [[0]
 [0]
 [1]
 ...
 [0]
 [0]
 [0]], [[6.19283056]
 [3.1262033]
 [7.40529203]
 ...
 [7.28419256]
 [4.48962498]
 [-1.70303845]])
([[0.87860775 0.596582174 -0.360898 ... 0 0 0]
 [1.15961754 0.996346235 -0.360898 ... 0 0 0]
 [-0.39674437 -0.60271 0.0114649916 ... 0 0 0]
 .

([[0.835375428 1.3961103 0.0114649916 ... 0 0 0]
 [-1.44296539 -1.80200219 1.12855399 ... 0 0 1]
 [-1.0884608 -0.60271 0.383827984 ... 0 0 0]
 ...
 [0.208507463 -0.202945948 1.12855399 ... 1 0 0]
 [0.0788106397 -0.202945948 0.756190956 ... 0 0 0]
 [-0.0725023225 0.596582174 -0.733261 ... 0 0 0]], [[0]
 [1]
 [0]
 ...
 [0]
 [0]
 [0]], [[2.92256117]
 [7.52597904]
 [6.15846348]
 ...
 [3.48064399]
 [4.93399858]
 [5.470644]])
 [-2.16926765 -2.60153031 2.61800575 ... 0 0 0]
 [0.316588163 0.596582174 0.383827984 ... 0 0 0]
 ...
 [-1.95310628 -2.60153031 2.61800575 ... 0 0 0]
 [0.965072274 1.3961103 0.0114649916 ... 0 0 0]
 [-0.569673479 -0.202945948 -0.360898 ... 0 0 0]], [[0]
 [0]
 [1]
 ...
 [0]
 [1]
 [0]], [[2.9724381]
 [4.8345089]
 [9.49594593]
 ...
 [2.95969939]
 [7.03447676]
 [3.34323]])
([[0.381436557 -0.202945948 -1.10562396 ... 1 0 0]
 [1.31093049 0.596582174 -0.360898 ... 0 0 0]
 [1.07315302 0.596582174 -0.733261 ... 0 0 0]
 ...
 [1.0515368 1.79587436 -1.10562396 ... 0 0 1]
 [-1.45593

([[-0.850683272 -0.60271 0.756190956 ... 0 0 0]
 [-0.528602839 -0.34345451 1.12855399 ... 0 0 0]
 [-0.742602587 -0.60271 0.756190956 ... 0 0 0]
 ...
 [-1.71532881 -1.40223813 1.87327993 ... 0 0 0]
 [-1.73694491 -1.80277169 2.2456429 ... 0 0 0]
 [0.597597957 0.996346235 -0.360898 ... 0 0 0]], [[0]
 [1]
 [0]
 ...
 [0]
 [0]
 [0]], [[3.26035333]
 [8.38237095]
 [3.79391146]
 ...
 [7.98158264]
 [5.06827402]
 [3.42344379]])
 [-0.310279846 0.196818128 0.756190956 ... 0 0 0]
 [0.230123609 0.996346235 0.0114649916 ... 0 0 0]
 ...
 [-2.12603545 -1.40223813 1.12855399 ... 0 0 0]
 [0.100426778 0.596582174 0.0114649916 ... 0 0 1]
 [1.48385954 1.3961103 -0.360898 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [1]], [[7.19330645]
 [3.54221869]
 [-1.87661171]
 ...
 [5.97124147]
 [1.80595028]
 [8.02861881]])
([[-0.245431423 -0.202945948 0.0114649916 ... 1 0 0]
 [-2.08280301 -2.20176625 1.50091696 ... 1 0 0]
 [0.230123609 0.596582174 -0.733261 ... 1 0 0]
 ...
 [-0.158966869 -3.8008225 0.383827984 ... 0 0 1]

([[-0.504825056 -0.378263652 -0.360898 ... 1 0 0]
 [1.37577891 1.3961103 -0.733261 ... 0 0 0]
 [-0.699370325 -0.202945948 0.0114649916 ... 0 1 0]
 ...
 [-2.45027757 -2.60153031 2.61800575 ... 0 0 0]
 [0.0788106397 0.196818128 -0.733261 ... 0 0 0]
 [0.66244638 0.996346235 0.383827984 ... 0 1 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[3.72619772]
 [3.15155458]
 [6.56607485]
 ...
 [2.44576168]
 [4.4059186]
 [1.92785335]])
([[0.467901111 0.196818128 0.383827984 ... 0 0 0]
 [-1.15330923 -1.40223813 1.12855399 ... 0 0 0]
 [-1.03442049 0.196818128 1.12855399 ... 0 0 0]
 ...
 [0.100426778 -0.0771776736 0.0114649916 ... 0 0 1]
 [0.66244638 0.996346235 -0.733261 ... 0 1 0]
 [0.0788106397 0.196818128 1.12855399 ... 0 0 0]], [[1]
 [0]
 [0]
 ...
 [1]
 [0]
 [0]], [[9.32467556]
 [3.17840075]
 [4.06758451]
 ...
 [7.3092351]
 [8.69907856]
 [2.28223896]])
([[1.06450653 0.996346235 0.0114649916 ... 0 0 0]
 [-0.569673479 0.196818128 -0.360898 ... 0 0 1]
 [1.13800144 0.996346235 -1.10562396 ... 0 1 0]
 ..

 1/10 [==>...........................] - ETA: 0s - loss: 39.2201([[-1.15330923 -1.40223813 1.12855399 ... 0 0 0]
 [0.792143166 1.3961103 -0.360898 ... 0 0 0]
 [-1.04522848 -1.3372761 1.12855399 ... 0 0 0]
 ...
 [0.856991589 1.06491697 -1.10562396 ... 0 0 0]
 [0.467901111 0.196818128 0.383827984 ... 0 1 0]
 [-0.677754164 -1.00247407 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [1]
 [0]], [[3.17840075]
 [4.24617147]
 [4.40149546]
 ...
 [2.52382302]
 [8.13000298]
 [7.43073893]])
([[-1.95310628 -1.80200219 2.2456429 ... 1 0 0]
 [-1.71532881 -1.00247407 -0.733261 ... 0 0 0]
 [-1.47755122 -0.60271 1.87327993 ... 0 0 1]
 ...
 [0.943456173 0.996346235 0.383827984 ... 0 0 0]
 [-1.93149018 2.59540248 0.756190956 ... 0 0 0]
 [1.46224344 0.996346235 0.0114649916 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [1]
 [0]
 [1]], [[4.1080761]
 [9.93134212]
 [3.00748587]
 ...
 [7.76731443]
 [7.16189766]
 [7.29342413]])
([[0.813759327 0.601994157 -1.10562396 ... 0 0 0]
 [-0.39674437 -0.60271 0.0114649916 ... 0 

Epoch 25/300
([[-0.158966869 0.596582174 -1.47798693 ... 0 0 0]
 [0.208507463 0.596582174 -0.360898 ... 0 0 0]
 [-0.288663685 0.196818128 -0.733261 ... 0 0 1]
 ...
 [-0.277855635 0.596582174 0.383827984 ... 0 0 0]
 [1.3433547 1.79587436 -1.10562396 ... 0 0 0]
 [0.965072274 0.196818128 0.0114649916 ... 0 0 1]], [[0]
 [0]
 [0]
 ...
 [0]
 [1]
 [0]], [[9.09922886]
 [5.14177036]
 [10.8867645]
 ...
 [4.12269402]
 [7.48435926]
 [4.25447369]])
 1/10 [==>...........................] - ETA: 0s - loss: 33.9000([[0.165275186 0.196818128 0.383827984 ... 0 0 0]
 [-1.86664176 -1.80200219 2.61800575 ... 0 0 0]
 [-1.17492533 -1.40223813 1.50091696 ... 0 0 0]
 ...
 [-0.00765390694 -0.202945948 0.756190956 ... 0 0 0]
 [1.15961754 0.996346235 -1.47798693 ... 0 1 0]
 [-0.353512108 -0.202945948 -0.360898 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[4.66521788]
 [1.94699252]
 [3.13030148]
 ...
 [5.47806168]
 [3.63773084]
 [6.02166462]])
([[1.0515368 0.596582174 -0.733261 ... 0 0 0]
 [0.403052717 0.596

([[0.316588163 0.196818128 -0.360898 ... 0 0 0]
 [0.338204294 0.596582174 -0.733261 ... 0 0 0]
 [0.165275186 0.196818128 0.383827984 ... 0 0 0]
 ...
 [-1.30462217 -1.40223813 1.12855399 ... 0 1 0]
 [0.186891332 0.596582174 0.383827984 ... 0 0 0]
 [1.22446597 0.996346235 -0.733261 ... 1 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[1.45547557]
 [3.69027686]
 [4.66521788]
 ...
 [8.77923679]
 [4.51536369]
 [4.25984478]])
Epoch 27/300
([[1.07315302 0.196818128 -0.360898 ... 0 0 1]
 [0.208507463 -0.202945948 1.12855399 ... 1 0 0]
 [-2.32058072 -3.00129437 1.87327993 ... 0 0 0]
 ...
 [0.87860775 0.996346235 -0.733261 ... 0 0 0]
 [-0.418360531 -0.60271 0.0114649916 ... 0 1 0]
 [0.186891332 0.596582174 0.383827984 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [1]
 [0]
 [0]], [[4.61736584]
 [3.48064399]
 [4.99427843]
 ...
 [8.10634327]
 [3.65822411]
 [4.51536369]])
 1/10 [==>...........................] - ETA: 0s - loss: 31.0987([[-2.08280301 -1.80200219 1.12855399 ... 0 0 0]
 [-0.980380118 -1.00247407 1.5

([[0.835375428 -0.139101893 0.383827984 ... 1 0 0]
 [1.13800144 0.196818128 0.0114649916 ... 1 0 0]
 [0.748910904 0.996346235 0.0114649916 ... 0 0 0]
 ...
 [-0.223815277 -0.60271 -0.733261 ... 1 0 0]
 [0.965072274 0.196818128 -0.733261 ... 0 1 0]
 [-0.0508861803 0.196818128 -0.360898 ... 0 0 1]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[0.777918279]
 [1.21202612]
 [2.25538254]
 ...
 [5.67937088]
 [5.89904881]
 [2.78665566]])
 [-1.00199628 -0.60271 1.12855399 ... 0 0 0]
 [0.813759327 0.596582174 0.0114649916 ... 0 1 0]
 ...
 [0.727294743 0.596582174 -0.360898 ... 0 0 1]
 [-0.0292700436 -0.60271 0.0114649916 ... 0 0 0]
 [1.33254659 1.3961103 -1.47798693 ... 0 0 0]], [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[1.80595028]
 [5.13891411]
 [4.85384178]
 ...
 [3.61776757]
 [3.26158786]
 [5.62686062]])
([[-1.56401587 -1.00247407 1.87327993 ... 0 0 0]
 [0.338204294 0.196818128 -0.733261 ... 0 1 0]
 [-0.0379165 0.596582174 0.383827984 ... 0 0 0]
 ...
 [-0.461592793 0.596582174 0.0114649916 ... 0 0 0]
 

([[-0.288663685 0.196818128 -0.360898 ... 0 0 1]
 [-0.267047554 0.196818128 0.383827984 ... 0 0 0]
 [0.770527065 0.596582174 0.383827984 ... 0 0 0]
 ...
 [-1.67209649 -1.80200219 2.2456429 ... 0 0 1]
 [-0.158966869 -3.8008225 0.383827984 ... 0 0 1]
 [0.0139622306 -0.60271 0.0114649916 ... 0 0 1]], [[1]
 [1]
 [1]
 ...
 [0]
 [0]
 [1]], [[7.51345396]
 [7.9127593]
 [8.41891289]
 ...
 [2.48917842]
 [2.69325566]
 [5.66872311]])
([[0.684062481 -0.0983648449 0.0114649916 ... 0 0 0]
 [0.597597957 0.596582174 -0.360898 ... 1 0 0]
 [-0.00765390694 0.196818128 -1.47798693 ... 0 0 0]
 ...
 [-0.0725023225 0.596582174 -0.733261 ... 0 0 0]
 [0.0139622306 0.196818128 0.0114649916 ... 0 0 0]
 [-0.00765390694 -0.202945948 -1.10562396 ... 0 1 0]], [[1]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]], [[7.80346632]
 [5.37603807]
 [-1.70303845]
 ...
 [5.470644]
 [4.51159811]
 [5.82504511]])
 [-1.86664176 -1.40223813 -0.360898 ... 0 0 0]
 [0.727294743 0.596582174 -0.360898 ... 0 0 0]
 ...
 [-0.202199146 -0.202945948 -0.73326

In [10]:
# %tensorboard --logdir logs/fit

In [7]:
def get_y0_y1(self, x_train, t_train, L=1):
    y_infer = self.q_y_xt(tf.concat([x_train, t_train],-1))
    print(y_infer)
    # use inferred y
    xyt = tf.concat([x_train, y_infer, t_train], -1)  # TODO take mean?
    z_infer = self.q_z_xyt(xyt)
    # Manually input zeros and ones
    y0 = self.p_y_tz(tf.concat([t_train,tf.zeros_like(z_infer)],-1)).mean()  # TODO take mean?
    y1 = self.p_y_tz(tf.concat([t_train,tf.ones_like(z_infer)],-1)).mean()  # TODO take mean?
    return y0,y1

x_test = np.loadtxt("datasets/x.txt")
t_test = np.loadtxt("datasets/t.txt")
get_y0_y1(
    model,
    tf.cast(x_test,tf.float32), 
    tf.cast(t_test.reshape(-1,1),tf.float32)
)

tfp.distributions._TensorCoercible("tensor_coercible", batch_shape=[672, 1], event_shape=[], dtype=float32)


(<tf.Tensor: shape=(672, 1), dtype=float32, numpy=
 array([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.

In [8]:
[key for key in data]

['x', 't', 'y', 'mu_0', 'mu_1', 'ycf', 'y_scaler', 'ys']

In [9]:
tfkl.Dense(tfpl.IndependentNormal.params_size(1))

<keras.layers.core.dense.Dense at 0x1445e76d0>