# To do:
## - adaptive prior

In [1]:
from __future__ import absolute_import

import shutil
from tqdm import tqdm
from datetime import datetime

from six import iteritems

import os
from os.path import join as path_join

import numpy as np
from numpy import argmin, savez, asscalar, repeat, prod
from numpy import save as save_array
from numpy import pi as pi_const

from scipy.stats import norm as standard_gaussian

from keras.backend import clear_session, shape, random_normal, sqrt
from keras.models import Model, Input, load_model
from keras.layers import Layer, Lambda, Dense, Multiply, Add, Concatenate, Dot, Reshape
from keras.layers.advanced_activations import PReLU
from keras.losses import binary_crossentropy
from keras.optimizers import Adam

import keras.backend as kernel
from keras.backend.tensorflow_backend import set_session

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

sess = tf.Session()
set_session(sess)

seed = 1234
np.random.seed(seed)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# generating data for experiment

sampling_size = 10000
test_size = 100

data_dim = 3
regressor_dim = 2
latent_dim = 2

latent_mean = 0 * np.ones(latent_dim)
latent_cov = 1 * np.eye(latent_dim)
latent_vars = np.random.multivariate_normal(mean=latent_mean, cov=latent_cov, size=sampling_size)

regressor_mean = 5 * np.ones(regressor_dim)
regressor_cov = 2 * np.eye(regressor_dim)
regressor_vars = np.random.multivariate_normal(mean=regressor_mean, cov=regressor_cov, size=sampling_size)
squared_regressors = np.square(regressor_vars)

mults_over_squared_regressor = np.array([[1,0,1],
                                         [1,1,0]])
mults_over_latent = np.array([[0,1,1],
                             [1,0,1]])

noise = np.random.multivariate_normal(mean=np.zeros(data_dim), cov=np.eye(data_dim), size=sampling_size)

data_y = np.dot(squared_regressors, mults_over_squared_regressor) + np.dot(latent_vars, mults_over_latent) + noise
data = {'X': regressor_vars, 'y': data_y}

train_X, train_y = data['X'][:sampling_size-test_size], data['y'][:sampling_size-test_size]
test_X, test_y = data['X'][sampling_size-test_size:], data['y'][sampling_size-test_size:]

### Adversarial Variational Bayes

In [32]:
# parameters of the model
model_name = 'avb'

X_dim = data['X'].shape[1]
y_dim = data['y'].shape[1]

latent_dim = 2
noise_dim = y_dim

# training params
net_depth = 2
net_width = 250

batch_size = 50
epochs = 10

schedule = {'iter_discr': 1, 'iter_encdec': 1}
optimiser_params = {'encdec': {'lr': 0.001}, 'disc': {'lr': 0.001}}

# preparation of output folders
temp_dir = os.path.join('output', 'temp')
experiment_dir = os.path.join('output', model_name)
models_dir = os.path.join(experiment_dir, 'models')
if not os.path.exists(experiment_dir):
    os.makedirs(experiment_dir)
if os.path.exists(temp_dir):
    shutil.rmtree(temp_dir)
os.makedirs(temp_dir)

In [33]:
def data_prior_sampler(inputs, noise_dim, **kwargs):
    # standard normal
    seed = kwargs.get('seed')
    samples_isotropic = random_normal(shape=(shape(inputs)[0], noise_dim), mean=0, stddev=1, seed=seed)
    return samples_isotropic


def latent_prior_sampler(inputs, latent_dim, **kwargs):
    # standard normal
    seed = kwargs.get('seed')
    samples_isotropic = random_normal(shape=(shape(inputs)[0], latent_dim), mean=0, stddev=1, seed=seed)
    return samples_isotropic
    
    
def normal_log_probs(args):
    mu, std, x = args 
    MultiNorm = tf.contrib.distributions.MultivariateNormalDiag(loc=mu, scale_diag=std, name='dec_normal')
    log_px = MultiNorm.log_prob(x, name='dec_normal_log_px')
    
    return log_px


class AVBDataIterator(object):
    def __init__(self, X_dim, y_dim, latent_dim):
        self.X_dim = X_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim
    
    def iter(self, var1, var2, **kwargs):
        mode = kwargs.get('mode')
        batch_size = kwargs.get('batch_size')
        
        var1_size, var2_size = var1.shape[0], var2.shape[0]
        if (var1_size != var2_size):
            raise AttributeError("Data inputs are not same size!")
            
        n_batches = var1_size / batch_size
        if n_batches - int(n_batches) > 0:
            raise AttributeError("Data input should be divisible by batch size!")
        
        iterator = getattr(self, 'iter_data_{}'.format(mode))
        return iterator(var1, var2, int(n_batches), **kwargs), int(n_batches)
    
    def iter_data_training(self, X, y, n_batches, **kwargs):
        shuffle = kwargs.get('shuffle', True)
        data_size = X.shape[0]
        while True:
            indices_new_order = np.arange(data_size)
            if shuffle:
                np.random.shuffle(indices_new_order)
            
            for batch_indices in np.split(indices_new_order, n_batches):
                yield [X[batch_indices].astype(np.float32), y[batch_indices].astype(np.float32)]

    def iter_data_inference(self, X, y, n_batches, **kwargs):
        data_size = X.shape[0]
        while True:
            for batch_indices in np.split(np.arange(data_size), n_batches):
                yield [X[batch_indices].astype(np.float32), y[batch_indices].astype(np.float32)]

    def iter_data_generation(self, X, latent_samples, n_batches, **kwargs):
        data_size = X.shape[0]
        while True:
            for batch_indices in np.split(np.arange(data_size), n_batches):
                yield [X[batch_indices].astype(np.float32), latent_samples[batch_indices].astype(np.float32)]

                
class FreezableModel(Model):
    def __init__(self, inputs, outputs, name='freezable'):
        super(FreezableModel, self).__init__(inputs=inputs, outputs=outputs, name=name)
        self._trainable_layers = None

    def _crawl_trainable_layers(self, freezable_layers_prefix, deep_freeze=True):
        if not deep_freeze:
            trainable_layers = [layer for layer in self.layers if layer.trainable]
            
        else:
            def recursive_model_crawl(current_layer):
                deeper_layers = []
                if isinstance(current_layer, Model):
                    for l in current_layer.layers:
                        if l.trainable:
                            deeper_layers += recursive_model_crawl(l)
                if current_layer.trainable and (current_layer.name.split('_')[0] in freezable_layers_prefix):
                    deeper_layers.append(current_layer)
                return deeper_layers
            trainable_layers = sum([recursive_model_crawl(layer) for layer in self.layers if layer.trainable], [])

        return trainable_layers

    def get_trainable_layers(self, freezable_layers_prefix=None, deep_crawl=True):
        if self._trainable_layers is None:
            self._trainable_layers = self._crawl_trainable_layers(freezable_layers_prefix, deep_crawl)
            return self._trainable_layers
        else:
            return self._trainable_layers

    def freeze(self, freezable_layers_prefix=None, deep_freeze=True):
        if self._trainable_layers is None:
            self._trainable_layers = self._crawl_trainable_layers(freezable_layers_prefix, deep_freeze)
            
        for layer in self._trainable_layers:
            layer.trainable = False

    def unfreeze(self, unfreezable_layers_prefix=None, deep_unfreeze=True):
        if self._trainable_layers is None:
            self._trainable_layers = self._crawl_trainable_layers(unfreezable_layers_prefix, deep_unfreeze)
            
        for layer in self._trainable_layers:
            layer.trainable = True

In [34]:
class StandardEncoder(object):
    """
    An Encoder model is trained to parametrise an arbitrary posterior approximate distribution given some 
    input x, i.e. q(z|X,y). The model takes as input concatenated data samples and arbitrary noise and produces
    a latent encoding:
    
      X, y                              Input
     - - - - - - - - -   
       |       Noise: N(0,I)                      
       |         |                        
       ----------- <-- concatenation    
            |                           Encoder model
       -----------
       | Encoder |                      
       -----------
            |
        Latent space                    Output
    
    """
    def __init__(self, X_dim, y_dim, noise_dim, latent_dim, name='Encoder'):
        print("Initialising {} with {}-dim data input, {}-dim noise input and {}-dim latent output".format(name, data_dim, noise_dim, latent_dim))
        
        self.name = name
        self.X_dim = X_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim
        self.noise_dim = noise_dim
                
        self.input_X = Input(shape=(self.X_dim,), name='enc_data_input_X')
        self.input_y = Input(shape=(self.y_dim,), name='enc_data_input_y')
        
        self.data_prior_sampler = Lambda(data_prior_sampler, name='enc_data_prior_sampler')
        self.data_prior_sampler.arguments = {'noise_dim': self.noise_dim, 'seed': seed}
        
        # internal model
        input_X = Input(shape=(self.X_dim,), name='enc_internal_data_input_X')
        input_y = Input(shape=(self.y_dim,), name='enc_internal_data_input_y')       
        noise_input = Input(shape=(self.noise_dim,), name='enc_internal_noise_input')
        input_concat = Concatenate(axis=1, name='enc_internal_input_concat')([input_X, input_y, noise_input])
        
        encoder_body = Dense(net_width, name='enc_body' + '_0')(input_concat)
        encoder_body_a = PReLU()(encoder_body)
        for i in range(1, net_depth):
            encoder_body = Dense(net_width, name='enc_body' + '_{}'.format(i))(encoder_body_a)
            encoder_body_a = PReLU()(encoder_body)
        
        latent_factors = Dense(self.latent_dim, name='enc_latent')(encoder_body_a)
        
        encoder_body_model = Model(inputs=[input_X, input_y, noise_input],
                                   outputs=latent_factors,
                                   name='enc_internal_model')
        # ---
        
        data_input_concat = Concatenate(axis=1, name='enc_data_input')([self.input_X, self.input_y])
        sampled_noise = self.data_prior_sampler(data_input_concat)
        self.encoder_model = Model(inputs=[self.input_X, self.input_y],
                                   outputs=encoder_body_model([self.input_X, self.input_y, sampled_noise]),
                                   name='enc_model')

    def __call__(self, *args, **kwargs):
        is_learninig = kwargs.get('is_learning', True)
        if is_learninig:
            return self.encoder_model(args[0])
        else:
            return self.encoder_model(args[0])

class Discriminator(object):
    """
    Discriminator model is adversarially trained against the encoder in order to account 
    for a D_KL(q(z|X,y) || p(z)) term in the variational loss. The discriminator
    architecture takes as input samples from the joint probability distribution of the data `X,y` and a approximate
    posterior `z` and from the joint of the data and the prior over `z`:
    
             -----------
       ----> | Encoder |
       |     -----------
       |         |
       |    Approx. posterior --> | 
       |                          |---> (X,y,z') --|
       -------------------------> |                |
       |                                           |     -----------------
      X,y                                          | --> | Discriminator | --> T(X,y,z) regression output
       |                                           |     -----------------
       -------------------------> |                |
                                  |---> (X,y,z)  --|
       Prior p(z): N(0,I) ------> |
       
    """
    def __init__(self, X_dim, y_dim, latent_dim, name='Discriminator'):
        print("Initialising {} with {}-dim data input and {}-dim prior/latent input.".format(name, X_dim, y_dim, latent_dim))
        
        self.X_dim = X_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim

        self.discriminator_from_prior_model = None
        self.discriminator_from_posterior_model = None
        
        self.input_X = Input(shape=(self.X_dim,), name='disc_data_input_X')
        self.input_y = Input(shape=(self.y_dim,), name='disc_data_input_y')
        
        self.latent_input = Input(shape=(self.latent_dim,), name='disc_latent_input')
        
        self.latent_prior_sampler = Lambda(latent_prior_sampler, name='disc_latent_prior_sampler')
        self.latent_prior_sampler.arguments = {'latent_dim': self.latent_dim, 'seed': seed}
        
        # internal model
        input_X = Input(shape=(self.X_dim,), name='disc_internal_data_input_X')
        input_y = Input(shape=(self.y_dim,), name='disc_internal_data_input_y')       
        latent_input = Input(shape=(self.latent_dim,), name='disc_internal_latent_input')
        
        data_input = Concatenate(axis=1, name='disc_internal_data_input')([input_X, input_y])
        discriminator_body_data = Dense(net_width, name='disc_body_data' + '_0')(data_input)
        discriminator_body_data_a = PReLU()(discriminator_body_data)
        for i in range(1, net_depth):
            discriminator_body_data = Dense(net_width, name='disc_body_data' + '_{}'.format(i))(discriminator_body_data_a)
            discriminator_body_data_a = PReLU()(discriminator_body_data)

        discriminator_body_latent = Dense(net_width, name='disc_body_latent' + '_0')(latent_input)
        discriminator_body_latent_a = PReLU()(discriminator_body_latent)
        for i in range(1, net_depth):
            discriminator_body_latent = Dense(net_width, name='disc_body_latent' + '_{}'.format(i))(discriminator_body_latent_a)
            discriminator_body_latent_a = PReLU()(discriminator_body_latent)
            
        discriminator_output = Dot(axes=1, name='disc_output_dot')([discriminator_body_data_a, discriminator_body_latent_a])
        discriminator_model = Model(inputs=[input_X, input_y, latent_input], 
                                    outputs=discriminator_output,
                                    name='disc_internal_model')
        # ---
        
        data_input_concat = Concatenate(axis=1, name='disc_data_input')([self.input_X, self.input_y])
        prior_distribution = self.latent_prior_sampler(data_input_concat)
        from_prior_output = discriminator_model([self.input_X, self.input_y, prior_distribution])
        self.discriminator_from_prior_model = Model(inputs=[self.input_X, self.input_y],
                                                    outputs=from_prior_output,
                                                    name='disc_model_from_prior')
        
        from_posterior_output = discriminator_model([self.input_X, self.input_y, self.latent_input])
        self.discriminator_from_posterior_model = Model(inputs=[self.input_X, self.input_y, self.latent_input],
                                                        outputs=from_posterior_output,
                                                        name='disc_model_from_posterior')

    def __call__(self, *args, **kwargs):
        from_posterior = kwargs.get('from_posterior', False)
        if from_posterior:
            return self.discriminator_from_posterior_model(args[0])
        else:
            return self.discriminator_from_prior_model(args[0])


class Decoder(object):
    """
    A Decoder model has inputs comprising of a latent encoding given by an Encoder model, a prior sampler 
    or other custom input and the raw Encoder data input, which is needed to estimate the reconstructed 
    data log likelihood. It can be visualised as:
     
       y      X, Latent
       |        |
       |    -----------
       |    | Decoder |
       |    -----------
       |        |
       |      Output
       |    probability    --->  Generated data
       |        |
       ---> Log Likelihood ---> -(reconstruction loss)
    
    Note that the reconstruction loss is not used when the model training ends. It serves only the purpose to 
    define a measure of loss which is optimised. 
    """
    def __init__(self, X_dim, y_dim, latent_dim, name='Decoder'):
        print("Initialising {} with {}-dim latent input and {}-dim data output.".format(name, latent_dim, data_dim))
        
        self.name = name
        self.X_dim = X_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim

        self.input_X = Input(shape=(self.X_dim,), name='dec_data_input_X')
        self.input_y = Input(shape=(self.y_dim,), name='dec_data_input_y')     
        self.latent_input = Input(shape=(self.latent_dim,), name='dec_latent_input')
        
        # internal model
        input_X = Input(shape=(self.X_dim,), name='dec_internal_data_input_X')
        input_y = Input(shape=(self.y_dim,), name='dec_internal_data_input_y')       
        latent_input = Input(shape=(self.latent_dim,), name='dec_internal_input')

        latent_X_concat = Concatenate(axis=1, name='dec_latent_X_concat')([input_X, latent_input])
        generator_body = Dense(net_width, name='dec_body' + '_0')(latent_X_concat)
        generator_body_a = PReLU()(generator_body)
        for i in range(1, net_depth):
            generator_body = Dense(net_width, name='dec_body' + '_{}'.format(i))(generator_body_a)
            generator_body_a = PReLU()(generator_body)

        mu_params = Dense(self.y_dim, name='dec_mu_params')(generator_body_a)
        std_params_raw = Dense(self.y_dim, name='dec_std_params_raw')(generator_body_a)      
        std_params = Lambda(lambda x: 1.001 + kernel.elu(x), name='dec_std_params')(std_params_raw)
        
        log_probs = Lambda(normal_log_probs, name='dec_normal_logprob')([mu_params, std_params, input_y])
        
        Generator_Model = Model(inputs=[input_X, latent_input], outputs=[mu_params, std_params], name='dec_internal_sampling')
        ll_estimator_Model = Model(inputs=[input_X, input_y, latent_input], outputs=log_probs, name='dec_internal_trainable')
        # ---
        
        self.generator = Model(inputs=[self.input_X, self.latent_input],
                               outputs=Generator_Model([self.input_X, self.latent_input]),
                               name='dec_sampling')
        
        self.ll_estimator = Model(inputs=[self.input_X, self.input_y, self.latent_input],
                                  outputs=ll_estimator_Model([self.input_X, self.input_y, self.latent_input]),
                                  name='dec_trainable')

    def __call__(self, *args, **kwargs):
        is_learninig = kwargs.get('is_learning', True)
        if is_learninig:
            return self.ll_estimator(args[0])
        else:
            return self.generator(args[0])

In [35]:
class AVBDiscriminatorLossLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(AVBDiscriminatorLossLayer, self).__init__(**kwargs)

    @staticmethod
    def discriminator_loss(discrim_output_prior, discrim_output_posterior, from_logits=False):
        if from_logits:
            discrim_output_posterior = kernel.sigmoid(discrim_output_posterior)
            discrim_output_prior = kernel.sigmoid(discrim_output_prior)
            
        discriminator_loss = kernel.mean(binary_crossentropy(y_pred=discrim_output_posterior,
                                                             y_true=kernel.ones_like(discrim_output_posterior))
                                       + binary_crossentropy(y_pred=discrim_output_prior,
                                                             y_true=kernel.zeros_like(discrim_output_prior)))
        return discriminator_loss

    def call(self, inputs, **kwargs):
        discrim_output_prior, discrim_output_posterior = inputs
        is_in_logits = kwargs.get('is_in_logits', True)
        loss = self.discriminator_loss(discrim_output_prior, discrim_output_posterior, from_logits=is_in_logits)
        self.add_loss(loss, inputs=inputs)
        return loss


class AVBEncoderDecoderLossLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(AVBEncoderDecoderLossLayer, self).__init__(**kwargs)

    @staticmethod
    def decoder_loss(data_log_probs, discrim_output_posterior):
        norm_factor = 1.0 / prod(kernel.int_shape(data_log_probs)[1:])
        kl_divergence = kernel.mean(discrim_output_posterior)
        reconstruction_error = -1.0 * kernel.mean(data_log_probs)
        return norm_factor * kernel.mean(kl_divergence + reconstruction_error)
        
    def call(self, inputs, **kwargs):
        decoder_output_log_probs, discrim_output_posterior = inputs
        loss = self.decoder_loss(decoder_output_log_probs, discrim_output_posterior)
        self.add_loss(loss, inputs=inputs)
        return loss

In [36]:
class AdversarialVariationalBayes(object):

    def __init__(self, X_dim, y_dim, latent_dim, noise_dim, model_name, optimiser_params):
        
        self.name = model_name
        self.X_dim = X_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim
        self.noise_dim = noise_dim
        
        # define the AVB model
        self.encoder = StandardEncoder(X_dim=X_dim, y_dim=y_dim, noise_dim=noise_dim, latent_dim=latent_dim)
        self.discriminator = Discriminator(X_dim=X_dim, y_dim=y_dim, latent_dim=latent_dim)
        self.decoder = Decoder(X_dim=X_dim, y_dim=y_dim, latent_dim=latent_dim)
        
        # define model inputs
        self.input_X = Input(shape=(self.X_dim,), name='{}_data_input_X'.format(self.name))
        self.input_y = Input(shape=(self.y_dim,), name='{}_data_input_y'.format(self.name))
        self.latent_input = Input(shape=(self.latent_dim,), name='{}_latent_prior_input'.format(self.name))
        
        # define intermediate params of the model
        posterior_approximation = self.encoder([self.input_X, self.input_y],
                                               is_learning=True)
        discriminator_output_prior = self.discriminator([self.input_X, self.input_y],
                                                        from_posterior=False)
        discriminator_output_posterior = self.discriminator([self.input_X, self.input_y, posterior_approximation],
                                                            from_posterior=True)
        reconstruction_log_likelihood = self.decoder([self.input_X, self.input_y, posterior_approximation],
                                                     is_learning=True)

        # define loss functions
        discriminator_loss = AVBDiscriminatorLossLayer(name='disc_loss')([discriminator_output_prior,
                                                                          discriminator_output_posterior],
                                                                          is_in_logits=True)
        decoder_loss = AVBEncoderDecoderLossLayer(name='dec_loss')([reconstruction_log_likelihood,
                                                                    discriminator_output_posterior])
        
        # define the trainable models
        self.avb_trainable_discriminator = FreezableModel(inputs=[self.input_X, self.input_y],
                                                          outputs=discriminator_loss,
                                                          name='freezable_discriminator')
        self.avb_trainable_encoder_decoder = FreezableModel(inputs=[self.input_X, self.input_y],
                                                            outputs=decoder_loss,
                                                            name='freezable_encoder_decoder')

        self.avb_trainable_discriminator.freeze(freezable_layers_prefix=['disc'], deep_freeze=True)
        self.avb_trainable_encoder_decoder.unfreeze(unfreezable_layers_prefix=['dec', 'enc'], deep_unfreeze=True)
        optimiser_params_encdec = optimiser_params['encdec']
        self.avb_trainable_encoder_decoder.compile(optimizer=Adam(**optimiser_params_encdec), loss=None)

        self.avb_trainable_discriminator.unfreeze()
        self.avb_trainable_encoder_decoder.freeze()
        optimiser_params_disc = optimiser_params['disc']
        self.avb_trainable_discriminator.compile(optimizer=Adam(**optimiser_params_disc), loss=None)

        # define testing models
        self.inference_model = Model(inputs=[self.input_X, self.input_y],
                                     outputs=self.encoder([self.input_X, self.input_y], is_learning=False))
        self.generative_model = Model(inputs=[self.input_X, self.latent_input],
                                      outputs=self.decoder([self.input_X, self.latent_input], is_learning=False))
        
        # make collection of models
        self.models_dict = {'avb_trainable_discriminator': None,
                            'avb_trainable_encoder_decoder': None}
        self.models_dict['avb_trainable_encoder_decoder'] = self.avb_trainable_encoder_decoder
        self.models_dict['avb_trainable_discriminator'] = self.avb_trainable_discriminator

        # define data iterator
        self.data_iterator = AVBDataIterator(X_dim=X_dim, y_dim=y_dim, latent_dim=latent_dim)
        
    def fit(self, X, y, batch_size, epochs, **kwargs):
        discriminator_repetitions = kwargs.get('discriminator_repetitions')

        data_iterator, iters_per_epoch = self.data_iterator.iter(X, y, batch_size=batch_size, mode='training', shuffle=True)
        history = {'encoderdecoder_loss': [], 'discriminator_loss': []}
        epoch_loss = np.inf

        for ep in tqdm(range(epochs)):
            epoch_loss_history_encdec = []
            epoch_loss_history_disc = []
            
            for it in range(iters_per_epoch):
                training_batch = next(data_iterator)
                loss_autoencoder = self.avb_trainable_encoder_decoder.train_on_batch(training_batch, None)
                epoch_loss_history_encdec.append(loss_autoencoder)
                
                for _ in range(discriminator_repetitions):
                    loss_discriminator = self.avb_trainable_discriminator.train_on_batch(training_batch, None)
                    epoch_loss_history_disc.append(loss_discriminator)
                    
            history['encoderdecoder_loss'].append(epoch_loss_history_encdec)
            history['discriminator_loss'].append(epoch_loss_history_disc)
            
        return history

    def infer(self, X, y, **kwargs):
        if not hasattr(self, 'data_iterator'):
            raise AttributeError("Initialise the data iterator in the child classes first!")
        sampling_size = kwargs.get('sampling_size')
        batch_size = kwargs.get('batch_size')
        
        X, y = np.repeat(X, sampling_size, axis=0), np.repeat(y, sampling_size, axis=0)
        data_iterator, n_iters = self.data_iterator.iter(X, y, batch_size=batch_size, mode='inference')
        latent_samples = self.inference_model.predict_generator(data_iterator, steps=n_iters)
        return X, latent_samples

    def generate(self, X, **kwargs):
        latent_samples = kwargs.get('latent_samples', None)
        if latent_samples is None:
            n_samples = kwargs.get('n_samples')
            # prior sampler
            latent_samples = np.random.standard_normal(size=(n_samples, self.latent_dim))
        
        batch_size = kwargs.get('batch_size')
        data_iterator, n_iters = self.data_iterator.iter(X, latent_samples, batch_size=batch_size, mode='generation')
        mu, std = self.generative_model.predict_generator(data_iterator, steps=n_iters)
        
        return_params = kwargs.get('return_params')
        if return_params:
            return mu, std
        else:
            # likelihood sampler
            if np.min(std) > 0:
                print (np.mean(mu, axis=1), np.mean(std, axis=1))
                sampled_data = np.random.normal(loc=mu, scale=std)
            else:
                print (mu, std)
            return sampled_data
            
    def reconstruct(self, X, y, batch_size, **kwargs):
        sampling_size = kwargs.get('sampling_size', 1)
        X_rep, latent_samples = self.infer(X, y, sampling_size=sampling_size, batch_size=batch_size)
        reconstructed_samples = self.generate(X_rep, latent_samples=latent_samples, batch_size=batch_size, return_params=False)
        return reconstructed_samples

In [37]:
# initialization of AVB model
model = AdversarialVariationalBayes(X_dim=X_dim, y_dim=y_dim, latent_dim=latent_dim, noise_dim=noise_dim,
                                    model_name=model_name, optimiser_params=optimiser_params)

# training of AVB model
training_starttime = datetime.now().isoformat()
loss_history = model.fit(train_X, train_y, batch_size=batch_size, epochs=epochs,
                         discriminator_repetitions=schedule['iter_discr'])

print('Training finished.')

Initialising Encoder with 3-dim data input, 3-dim noise input and 2-dim latent output
Initialising Discriminator with 2-dim data input and 3-dim prior/latent input.
Initialising Decoder with 2-dim latent input and 3-dim data output.


100%|██████████████████████████████████████████| 10/10 [00:29<00:00,  2.92s/it]


Training finished.


In [38]:
# Reconstruction test
y_preds = []

for X_true_i, y_true_i in zip(test_X, test_y):
    x_t_i = np.array([X_true_i])
    y_t_i = np.array([y_true_i])
    
    y_recs_i = model.reconstruct(x_t_i, y_t_i, sampling_size=25, batch_size=25)
    y_pred_i = np.mean(y_recs_i[0], axis=0)
    y_preds.append(y_pred_i)
    
    """
    print (y_true_i)
    print (y_pred_i)
    print ()
    """
    
y_preds = np.array(y_preds)
rmse = np.std(y_preds - test_y, axis = 0)
std = np.std(test_y, axis=0)

print('rmse of reconstruction: ', rmse/std)

[[16.773062   5.2034454 10.893135 ]] [[1.5828562 1.7967187 1.3891052]]
[[43.298256 20.942232 22.184954]] [[1.329269   1.4779662  0.68416023]]
[[35.662838 12.452222 23.08468 ]] [[1.2868456 1.4458755 0.7948687]]
[[63.136806 46.07037  16.709578]] [[1.4535375  1.930746   0.90792745]]
[[38.403015   7.0115986 31.40671  ]] [[1.2590497 1.4696846 1.0757662]]
[[99.24041  38.764847 62.030304]] [[3.1708844 1.7633792 1.0079141]]
[[91.95512  44.41244  48.301132]] [[2.6019354 1.7126775 0.7864224]]
[[32.507835 15.250232 16.90889 ]] [[1.1695487  1.4736991  0.77822477]]
[[102.11871   60.45593   42.039185]] [[2.7344873 2.0219631 0.9993904]]
[[66.99834  25.377342 42.112823]] [[2.0470407 1.4037398 0.7468687]]
[[59.912613 31.777058 28.250654]] [[1.7002709  1.5387889  0.68181306]]
[[48.037163 26.484219 21.28293 ]] [[1.2998362 1.5420837 0.6567341]]
[[66.373024 38.46926  27.964777]] [[1.7809824 1.6495948 0.7161645]]
[[31.443792 20.906303 10.155079]] [[1.0467572 1.8149883 0.864958 ]]
[[13.846912  8.188354  5.11

In [None]:
clear_session()