In [1]:
import time
import warnings
warnings.filterwarnings("ignore")
warnings.simplefilter('ignore')

from collections import OrderedDict
import functools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
# use ggplot styles for graphs
plt.style.use('ggplot')

import pickle
import arviz as az

In [2]:
import ot
import tensorflow as tf
import tensorflow_probability as tfp

# set tf logger to log level ERROR to avoid warnings
tf.get_logger().setLevel('ERROR')

tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras

In [3]:
# import probabilistic models
from bayes_vi.models import Model

# import utils
from bayes_vi.utils.datasets import make_dataset_from_df

In [4]:
# mcmc imports
from bayes_vi.inference.mcmc import MCMC
from bayes_vi.inference.mcmc.transition_kernels import HamiltonianMonteCarlo, NoUTurnSampler, RandomWalkMetropolis
from bayes_vi.inference.mcmc.stepsize_adaptation_kernels import SimpleStepSizeAdaptation, DualAveragingStepSizeAdaptation

In [5]:
# vi imports 
from bayes_vi.inference.vi import VI

from bayes_vi.inference.vi.surrogate_posteriors import ADVI, NormalizingFlow
from bayes_vi.utils import to_ordered_dict
from bayes_vi.inference.vi.flow_bijectors import HamiltonianFlow, AffineFlow, make_energy_fn, make_scale_fn, make_shift_fn
from bayes_vi.utils.leapfrog_integrator import LeapfrogIntegrator

# Experiment 1 - Univariate Gaussian Model

## HMC Config

In [6]:
# sampling params
NUM_CHAINS = 5
NUM_SAMPLES = 1000
NUM_BURNIN_STEPS = 10000

In [7]:
# define step size adaptation
stepsize_adaptation_kernel = DualAveragingStepSizeAdaptation(num_adaptation_steps=int(NUM_BURNIN_STEPS*0.8))

kernel = NoUTurnSampler(
    step_size=0.01, 
    max_tree_depth=5,
    stepsize_adaptation_kernel=stepsize_adaptation_kernel
)

## VI Config

In [8]:
NUM_STEPS = 10000
SAMPLE_SIZE = 50

In [9]:
optimizer = tf.optimizers.Adam
init_lr = {
    'meanfield advi': 5e-3, 
    'affine flow': 5e-3,
    'maf': 5e-3,
    'cnf': 1e-4,
    'hnf(1)': 1e-3,
    'hnf(2)': 1e-3,
    'hnf(3)': 1e-3,
    'hnf(5)': 1e-3,
}

In [10]:
func_dict = {
    "mean": np.mean,
    "sd": np.std,
    "min": lambda x: np.percentile(x, 0),
    "hdi 3%": lambda x: np.percentile(x, 3),
    "mode": lambda x: np.percentile(x, 50),
    "hdi 97%": lambda x: np.percentile(x, 97),
    "max": lambda x: np.percentile(x, 100)
}

In [11]:
def get_continuous_flow_bijector(unconstrained_event_dims):
    state_fn = tfk.Sequential()
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(unconstrained_event_dims))
    state_fn.build((None, unconstrained_event_dims+1))
    state_time_derivative_fn = lambda t, state: state_fn(tf.concat([tf.fill((state.shape[0],1), t), state], axis=-1))
    return tfb.FFJORD(state_time_derivative_fn, 
                      ode_solve_fn=tfp.math.ode.DormandPrince(first_step_size=0.1).solve, 
                      trace_augmentation_fn=tfb.ffjord.trace_jacobian_hutchinson)
    
def get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows):
    return tfb.Chain([
        HamiltonianFlow(
            event_dims=unconstrained_event_dims,
            symplectic_integrator=LeapfrogIntegrator(), 
            step_sizes=tf.Variable(0.1), 
            num_integration_steps=2,
            hidden_layers=[128, 128]
        ) for _ in range(num_flows)
    ])

def get_masked_autoregressive_flow_bijector(unconstrained_event_dims):
    return tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(params=2, hidden_units=[128, 128], activation='tanh')
    )

def get_affine_flow_bijector(unconstrained_event_dims):
    return AffineFlow(unconstrained_event_dims)


def get_posterior_lift_distribution(unconstrained_event_dims):
    scale_fn = make_scale_fn(unconstrained_event_dims, hidden_layers=[128,128])
    shift_fn = make_shift_fn(unconstrained_event_dims, hidden_layers=[128,128])
    return lambda q: tfd.MultivariateNormalDiag(loc=shift_fn(q), scale_diag=scale_fn(q))

## Define Model

In [12]:
priors = OrderedDict(
    loc = tfd.Normal(loc=0., scale=10.),
    scale = tfd.HalfNormal(scale=10.)
)

likelihood = lambda loc, scale: tfd.Normal(
    loc=loc, 
    scale=scale
)

In [13]:
model = Model(priors=priors, likelihood=likelihood)

## Generate Datasets

In [14]:
num_datasets = 10
num_datapoints = 100

In [15]:
true_params = {}
datasets = {}

for i in range(1, num_datasets+1):
    dist = model.get_joint_distribution(num_samples=num_datapoints)
    loc, scale, y = dist.sample().values()
    
    true_params['dataset {}'.format(i)] = to_ordered_dict(
        model.param_names, [loc.numpy(), scale.numpy()]
    ) 
    datasets['dataset {}'.format(i)] = make_dataset_from_df(
        pd.DataFrame({'y': y}), target_names=['y'], format_features_as='dict'
    ) 

## Run Inferences and Collect Results

In [16]:
results = {}

In [17]:
for i, (name, dataset) in enumerate(datasets.items()):
    print('Starting with {}\n'.format(name))
    features, targets = list(dataset.batch(dataset.cardinality()).take(1))[0]
    
    results[name]={
        'data': {
            'targets': targets.numpy(),
            'true params': true_params[name],
        }}

    ################################################################################################################################
    
    # MCMC
    print('Start MCMC...')

    mcmc = MCMC(model=model, dataset=dataset, transition_kernel=kernel)
    
    print('Run NUTS:')
    start = time.time()
    mcmc_result = mcmc.fit(
        num_chains=NUM_CHAINS, 
        num_samples=NUM_SAMPLES, 
        num_burnin_steps=NUM_BURNIN_STEPS,
        progress_bar=True,
    )
    run_time = round(time.time() - start, 2)
    print('Finished in {}s\n'.format(run_time))

    post_pred_dist = model.get_posterior_predictive_distribution(
        posterior_samples=mcmc_result.samples,
    )
    posterior_predictive_samples = post_pred_dist.sample(1000)['y'].numpy()

    posterior_samples = to_ordered_dict(model.param_names, tf.nest.map_structure(lambda v: np.swapaxes(v.numpy(), 0, 1), mcmc_result.samples))
    map_trace = lambda v: np.swapaxes(v.numpy(), 0, 1) if tf.rank(v) >= 2 else v.numpy()
    traces = {k: map_trace(v) if not tf.nest.is_nested(v) else tf.nest.map_structure(map_trace, v) for k,v in mcmc_result.trace.items()}    
    summary = az.summary(posterior_samples, round_to=2, kind='diagnostics', stat_funcs=func_dict, extend=True)
    cols = summary.columns.tolist()
    summary = summary[cols[-len(func_dict):] + cols[:-len(func_dict)]]
    
    mcmc_results = {
        'posterior samples': posterior_samples,
        'traces': traces,
        'posterior predictive samples': posterior_predictive_samples,
        'acceptance ratios per chain': mcmc_result.accept_ratios.numpy(),
        'summary': summary,
        'run-time': run_time
    }
    
    results[name]['mcmc'] = mcmc_results
    
    print('Finished MCMC! \n')
    ################################################################################################################################
    
    # VI
    unconstrained_event_dims = model.flat_unconstrained_param_event_ndims    

    surrogate_posteriors = {
        'meanfield advi': ADVI(model, mean_field=True), 
        'affine flow': NormalizingFlow(model, flow_bijector=get_affine_flow_bijector(unconstrained_event_dims)),
        'maf': NormalizingFlow(model, flow_bijector=get_masked_autoregressive_flow_bijector(unconstrained_event_dims)),
        'cnf': NormalizingFlow(model, flow_bijector=get_continuous_flow_bijector(unconstrained_event_dims)),
        'hnf(1)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=1), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
        'hnf(2)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=2), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
        'hnf(3)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=3), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
        'hnf(5)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=5), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
    }
    
    vi_results = {}
    
    print('Start VI...')
    for surrogate_name, surrogate_posterior in surrogate_posteriors.items():
    
        vi = VI(model, dataset, surrogate_posterior)
        
        print('Run {}:'.format(surrogate_name))
        start = time.time()
        approx_posterior, losses = vi.fit(optimizer=optimizer(init_lr[surrogate_name]), num_steps=NUM_STEPS, sample_size=SAMPLE_SIZE, progress_bar=True)
        run_time = round(time.time() - start, 2)
        print('Finished in {}s\n'.format(run_time))

        print('computing error metrics...')
        start = time.time()

        post_pred_dist = model.get_posterior_predictive_distribution(posterior_distribution=approx_posterior)
        posterior_predictive_samples = post_pred_dist.sample(1000)['y'].numpy()
        
        posterior_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.numpy(),0), approx_posterior.sample(5000))
        summary = az.summary(posterior_samples, stat_funcs=func_dict, extend=False)
        
        # compute Wasserstein distance
        flat_mcmc_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.flatten(), 1), mcmc_results['posterior samples'])
        flat_vi_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.flatten(), 1), posterior_samples)
        distances = tf.nest.map_structure(ot.dist, flat_mcmc_samples, flat_vi_samples)
        W2 = tf.nest.map_structure(
            lambda x,y, dist: np.sqrt(ot.emd2(ot.unif(x.size), ot.unif(y.size), dist)), 
            flat_mcmc_samples, flat_vi_samples, distances
        )

        # compute absolute posterior mean difference to mcmc result
        posterior_mean_error = tf.nest.map_structure(lambda x,y: np.abs(x.mean() - y.mean()), flat_vi_samples, flat_mcmc_samples)
        posterior_sd_error = tf.nest.map_structure(lambda x,y: np.abs(x.std() - y.std()), flat_vi_samples, flat_mcmc_samples)
        comp_time = round(time.time() - start, 2)
        print('Finished computing error metrics in {}s\n\n'.format(comp_time))
        
        vi_result = {
            'losses': losses,
            'final loss': losses[-100:].mean(),
            'learning rate': init_lr[surrogate_name],
            'posterior samples': posterior_samples,
            'posterior predictive samples': posterior_predictive_samples,
            'summary': summary,
            'W2': W2,
            'mean error': posterior_mean_error,
            'sd error': posterior_sd_error,
            'run-time': run_time,
        }
        vi_results[surrogate_name] = vi_result
    print('Finished VI! \n\n\n\n')
    results[name]['vi'] = vi_results
    

Starting with dataset 1

Start MCMC...
Run NUTS:


Finished in 36.4s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 9.82s

computing error metrics...




Finished computing error metrics in 7.15s


Run affine flow:


Finished in 12.17s

computing error metrics...




Finished computing error metrics in 7.13s


Run maf:


Finished in 19.72s

computing error metrics...




Finished computing error metrics in 8.48s


Run cnf:


Finished in 791.74s

computing error metrics...




Finished computing error metrics in 7.25s


Run hnf(1):


Finished in 49.26s

computing error metrics...




Finished computing error metrics in 8.2s


Run hnf(2):


Finished in 88.9s

computing error metrics...




Finished computing error metrics in 8.29s


Run hnf(3):


Finished in 119.32s

computing error metrics...




Finished computing error metrics in 7.65s


Run hnf(5):


Finished in 197.03s

computing error metrics...




Finished computing error metrics in 7.24s


Finished VI! 




Starting with dataset 2

Start MCMC...
Run NUTS:


Finished in 34.7s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 11.1s

computing error metrics...




Finished computing error metrics in 7.28s


Run affine flow:


Finished in 10.63s

computing error metrics...




Finished computing error metrics in 7.2s


Run maf:


Finished in 20.63s

computing error metrics...




Finished computing error metrics in 7.91s


Run cnf:


Finished in 750.33s

computing error metrics...




Finished computing error metrics in 7.35s


Run hnf(1):


Finished in 52.33s

computing error metrics...




Finished computing error metrics in 8.02s


Run hnf(2):


Finished in 90.57s

computing error metrics...




Finished computing error metrics in 7.48s


Run hnf(3):


Finished in 126.47s

computing error metrics...




Finished computing error metrics in 7.46s


Run hnf(5):


Finished in 195.06s

computing error metrics...




Finished computing error metrics in 7.93s


Finished VI! 




Starting with dataset 3

Start MCMC...
Run NUTS:


Finished in 41.37s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 9.26s

computing error metrics...




Finished computing error metrics in 7.54s


Run affine flow:


Finished in 10.74s

computing error metrics...




Finished computing error metrics in 8.08s


Run maf:


Finished in 21.1s

computing error metrics...




Finished computing error metrics in 7.5s


Run cnf:


Finished in 977.3s

computing error metrics...




Finished computing error metrics in 8.91s


Run hnf(1):


Finished in 61.82s

computing error metrics...




Finished computing error metrics in 9.18s


Run hnf(2):


Finished in 101.92s

computing error metrics...




Finished computing error metrics in 9.25s


Run hnf(3):


Finished in 143.62s

computing error metrics...




Finished computing error metrics in 10.02s


Run hnf(5):


Finished in 223.55s

computing error metrics...




Finished computing error metrics in 8.5s


Finished VI! 




Starting with dataset 4

Start MCMC...
Run NUTS:


Finished in 42.81s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 11.89s

computing error metrics...




Finished computing error metrics in 9.37s


Run affine flow:


Finished in 13.34s

computing error metrics...




Finished computing error metrics in 9.27s


Run maf:


Finished in 24.2s

computing error metrics...




Finished computing error metrics in 8.49s


Run cnf:


Finished in 859.33s

computing error metrics...




Finished computing error metrics in 9.79s


Run hnf(1):


Finished in 60.83s

computing error metrics...




Finished computing error metrics in 9.43s


Run hnf(2):


Finished in 83.34s

computing error metrics...




Finished computing error metrics in 7.13s


Run hnf(3):


Finished in 117.38s

computing error metrics...




Finished computing error metrics in 7.93s


Run hnf(5):


Finished in 213.02s

computing error metrics...




Finished computing error metrics in 9.77s


Finished VI! 




Starting with dataset 5

Start MCMC...
Run NUTS:


Finished in 65.74s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 10.74s

computing error metrics...




Finished computing error metrics in 9.22s


Run affine flow:


Finished in 14.01s

computing error metrics...




Finished computing error metrics in 9.09s


Run maf:


Finished in 23.87s

computing error metrics...




Finished computing error metrics in 8.71s


Run cnf:


Finished in 747.06s

computing error metrics...




Finished computing error metrics in 9.99s


Run hnf(1):


Finished in 60.97s

computing error metrics...




Finished computing error metrics in 9.88s


Run hnf(2):


Finished in 102.04s

computing error metrics...




Finished computing error metrics in 9.71s


Run hnf(3):


Finished in 141.51s

computing error metrics...




Finished computing error metrics in 8.69s


Run hnf(5):


Finished in 222.08s

computing error metrics...




Finished computing error metrics in 9.12s


Finished VI! 




Starting with dataset 6

Start MCMC...
Run NUTS:


Finished in 46.62s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 11.48s

computing error metrics...




Finished computing error metrics in 8.94s


Run affine flow:


Finished in 13.26s

computing error metrics...




Finished computing error metrics in 9.89s


Run maf:


Finished in 24.15s

computing error metrics...




Finished computing error metrics in 9.31s


Run cnf:


Finished in 939.81s

computing error metrics...




Finished computing error metrics in 9.17s


Run hnf(1):


Finished in 57.2s

computing error metrics...




Finished computing error metrics in 7.63s


Run hnf(2):


Finished in 94.89s

computing error metrics...




Finished computing error metrics in 8.92s


Run hnf(3):


Finished in 134.98s

computing error metrics...




Finished computing error metrics in 8.96s


Run hnf(5):


Finished in 209.0s

computing error metrics...




Finished computing error metrics in 8.41s


Finished VI! 




Starting with dataset 7

Start MCMC...
Run NUTS:


Finished in 55.34s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 8.72s

computing error metrics...




Finished computing error metrics in 7.38s


Run affine flow:


Finished in 10.7s

computing error metrics...




Finished computing error metrics in 7.17s


Run maf:


Finished in 22.44s

computing error metrics...




Finished computing error metrics in 8.17s


Run cnf:


Finished in 544.83s

computing error metrics...




Finished computing error metrics in 9.94s


Run hnf(1):


Finished in 55.65s

computing error metrics...




Finished computing error metrics in 8.07s


Run hnf(2):


Finished in 96.28s

computing error metrics...




Finished computing error metrics in 9.53s


Run hnf(3):


Finished in 138.38s

computing error metrics...




Finished computing error metrics in 7.56s


Run hnf(5):


Finished in 212.9s

computing error metrics...




Finished computing error metrics in 9.01s


Finished VI! 




Starting with dataset 8

Start MCMC...
Run NUTS:


Finished in 39.9s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 9.52s

computing error metrics...




Finished computing error metrics in 7.84s


Run affine flow:


Finished in 14.12s

computing error metrics...




Finished computing error metrics in 7.3s


Run maf:


Finished in 22.51s

computing error metrics...




Finished computing error metrics in 8.29s


Run cnf:


Finished in 772.29s

computing error metrics...




Finished computing error metrics in 8.77s


Run hnf(1):


Finished in 56.41s

computing error metrics...




Finished computing error metrics in 8.64s


Run hnf(2):


Finished in 94.39s

computing error metrics...




Finished computing error metrics in 8.34s


Run hnf(3):


Finished in 138.77s

computing error metrics...




Finished computing error metrics in 10.08s


Run hnf(5):


Finished in 212.93s

computing error metrics...




Finished computing error metrics in 7.6s


Finished VI! 




Starting with dataset 9

Start MCMC...
Run NUTS:


Finished in 37.62s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 8.56s

computing error metrics...




Finished computing error metrics in 9.65s


Run affine flow:


Finished in 18.98s

computing error metrics...




Finished computing error metrics in 7.53s


Run maf:


Finished in 22.49s

computing error metrics...




Finished computing error metrics in 8.84s


Run cnf:


Finished in 1032.11s

computing error metrics...




Finished computing error metrics in 7.94s


Run hnf(1):


Finished in 58.41s

computing error metrics...




Finished computing error metrics in 7.89s


Run hnf(2):


Finished in 96.31s

computing error metrics...




Finished computing error metrics in 8.59s


Run hnf(3):


Finished in 136.03s

computing error metrics...




Finished computing error metrics in 10.22s


Run hnf(5):


Finished in 220.75s

computing error metrics...




Finished computing error metrics in 8.2s


Finished VI! 




Starting with dataset 10

Start MCMC...
Run NUTS:


Finished in 49.15s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 9.45s

computing error metrics...




Finished computing error metrics in 9.77s


Run affine flow:


Finished in 14.5s

computing error metrics...




Finished computing error metrics in 8.45s


Run maf:


Finished in 22.63s

computing error metrics...




Finished computing error metrics in 8.49s


Run cnf:


Finished in 842.84s

computing error metrics...




Finished computing error metrics in 7.51s


Run hnf(1):


Finished in 67.22s

computing error metrics...




Finished computing error metrics in 8.4s


Run hnf(2):


Finished in 97.75s

computing error metrics...




Finished computing error metrics in 9.05s


Run hnf(3):


Finished in 140.72s

computing error metrics...




Finished computing error metrics in 9.05s


Run hnf(5):


Finished in 218.07s

computing error metrics...




Finished computing error metrics in 7.75s


Finished VI! 






In [18]:
with open('./univariate_gaussian_experiment/univariate_gaussian_exp_results.pickle', 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)