In [1]:
import time
import warnings
warnings.filterwarnings("ignore")
warnings.simplefilter('ignore')

from collections import OrderedDict
import functools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
# use ggplot styles for graphs
plt.style.use('ggplot')

import pickle
import arviz as az

In [2]:
import ot
import tensorflow as tf
import tensorflow_probability as tfp

# set tf logger to log level ERROR to avoid warnings
tf.get_logger().setLevel('ERROR')

tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras

In [3]:
# import probabilistic models
from bayes_vi.models import Model

# import utils
from bayes_vi.utils.datasets import make_dataset_from_df

In [4]:
# mcmc imports
from bayes_vi.inference.mcmc import MCMC
from bayes_vi.inference.mcmc.transition_kernels import HamiltonianMonteCarlo, NoUTurnSampler, RandomWalkMetropolis
from bayes_vi.inference.mcmc.stepsize_adaptation_kernels import SimpleStepSizeAdaptation, DualAveragingStepSizeAdaptation

In [5]:
# vi imports 
from bayes_vi.inference.vi import VI

from bayes_vi.inference.vi.surrogate_posteriors import ADVI, NormalizingFlow
from bayes_vi.utils import to_ordered_dict
from bayes_vi.inference.vi.flow_bijectors import HamiltonianFlow, AffineFlow, make_energy_fn, make_scale_fn, make_shift_fn
from bayes_vi.utils.leapfrog_integrator import LeapfrogIntegrator

# Experiment 1 - Univariate Gaussian Model

## HMC Config

In [6]:
# sampling params
NUM_CHAINS = 4
NUM_SAMPLES = 2500
NUM_BURNIN_STEPS = 10000

In [7]:
# define step size adaptation
stepsize_adaptation_kernel = DualAveragingStepSizeAdaptation(num_adaptation_steps=int(NUM_BURNIN_STEPS*0.8))

kernel = NoUTurnSampler(
    step_size=0.01, 
    max_tree_depth=5,
    stepsize_adaptation_kernel=stepsize_adaptation_kernel
)

## VI Config

In [8]:
NUM_STEPS = 5000
SAMPLE_SIZE = 10

In [9]:
optimizer = tf.optimizers.Adam
init_lr = {
    'meanfield_advi': 1e-2, 
    'affine_flow': 1e-2,
    'maf': 1e-2,
    'cnf': 1e-3,
    'hnf-1': 1e-3,
    'hnf-2': 1e-3,
    'hnf-5': 1e-3,
}

In [10]:
func_dict = {
    "mean": np.mean,
    "stddev": np.std,
    "min": lambda x: np.percentile(x, 0),
    "hdi_3%": lambda x: np.percentile(x, 3),
    "mode": lambda x: np.percentile(x, 50),
    "hdi_97%": lambda x: np.percentile(x, 97),
    "max": lambda x: np.percentile(x, 100)
}

In [11]:
def get_continuous_flow_bijector(unconstrained_event_dims):
    state_fn = tfk.Sequential()
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(unconstrained_event_dims))
    state_fn.build((None, unconstrained_event_dims+1))
    state_time_derivative_fn = lambda t, state: state_fn(tf.concat([tf.fill((state.shape[0],1), t), state], axis=-1))
    return tfb.FFJORD(state_time_derivative_fn, 
                      ode_solve_fn=tfp.math.ode.DormandPrince(first_step_size=0.1).solve, 
                      trace_augmentation_fn=tfb.ffjord.trace_jacobian_hutchinson)
    
def get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows):
    return tfb.Chain([
        HamiltonianFlow(
            event_dims=unconstrained_event_dims,
            symplectic_integrator=LeapfrogIntegrator(), 
            step_sizes=0.1, 
            num_integration_steps=5,
            hidden_layers=[128, 128]
        ) for _ in range(num_flows)
    ])

def get_masked_autoregressive_flow_bijector(unconstrained_event_dims):
    return tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(params=2, hidden_units=[32, 32], activation='relu')
    )

def get_affine_flow_bijector(unconstrained_event_dims):
    return AffineFlow(unconstrained_event_dims)


## Define Model

In [12]:
priors = OrderedDict(
    loc = tfd.Normal(loc=0.1, scale=10.),
    scale = tfd.HalfNormal(scale=10.)
)

likelihood = lambda loc, scale: tfd.Normal(
    loc=loc, 
    scale=scale
)

In [13]:
model = Model(priors=priors, likelihood=likelihood)

## Generate Datasets

In [14]:
num_datasets = 10
num_datapoints = 100

In [15]:
true_params = {}
datasets = {}

for i in range(1, num_datasets+1):
    dist = model.get_joint_distribution(targets=tf.ones(shape=(num_datapoints,)))
    loc, scale, y = dist.sample().values()
    
    true_params['dataset_{}'.format(i)] = to_ordered_dict(
        model.param_names, [loc.numpy(), scale.numpy()]
    ) 
    datasets['dataset_{}'.format(i)] = make_dataset_from_df(
        pd.DataFrame({'y': y}), target_names=['y'], format_features_as='dict'
    ) 

## Run Inferences and Collect Results

In [16]:
results = {}

In [17]:
for i, (name, dataset) in enumerate(datasets.items()):
    print('Starting with {}\n'.format(name))
    features, targets = list(dataset.batch(dataset.cardinality()).take(1))[0]
    
    results[name]={
        'data': {
            'targets': targets.numpy(),
            'true_params': true_params[name],
        }}

    ################################################################################################################################
    
    # MCMC
    print('Start MCMC...')

    mcmc = MCMC(model=model, dataset=dataset, transition_kernel=kernel)
    
    print('Run NUTS:')
    start = time.time()
    mcmc_result = mcmc.fit(
        num_chains=NUM_CHAINS, 
        num_samples=NUM_SAMPLES, 
        num_burnin_steps=NUM_BURNIN_STEPS,
        progress_bar=True,
    )
    run_time = round(time.time() - start, 2)
    print('Finished in {}s\n'.format(run_time))

    post_pred_dist = model.get_posterior_predictive_distribution(
        posterior_samples=mcmc_result.samples,
    )
    posterior_predictive_samples = post_pred_dist.sample(1000)['y'].numpy()

    posterior_samples = OrderedDict([(k, np.swapaxes(v.numpy(), 0, 1)) for k, v in zip(model.param_names, mcmc_result.samples)])
    traces = {k: np.swapaxes(v.numpy(), 0, 1) if tf.rank(v) >= 2 else v for k,v in mcmc_result.trace.items()}
    summary = az.summary(posterior_samples, round_to=2)
    
    mcmc_results = {
        'posterior_samples': posterior_samples,
        'traces': traces,
        'posterior_predictive_samples': posterior_predictive_samples,
        'acceptance_ratios_per_chain': mcmc_result.accept_ratios.numpy(),
        'summary': summary,
        'run_time': run_time
    }
    
    results[name]['mcmc'] = mcmc_results
    
    print('Finished MCMC! \n')
    ################################################################################################################################
    
    # VI
    unconstrained_event_dims = model.flat_unconstrained_param_event_ndims    
    
    surrogate_posteriors = {
        'meanfield_advi': ADVI(model, mean_field=True), 
        'affine_flow': NormalizingFlow(model, flow_bijector=get_affine_flow_bijector(unconstrained_event_dims)),
        'maf': NormalizingFlow(model, flow_bijector=get_masked_autoregressive_flow_bijector(unconstrained_event_dims)),
        'cnf': NormalizingFlow(model, flow_bijector=get_continuous_flow_bijector(unconstrained_event_dims)),
        'hnf-1': NormalizingFlow(model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=1), extra_ndims=unconstrained_event_dims),
        'hnf-2': NormalizingFlow(model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=2), extra_ndims=unconstrained_event_dims),
        'hnf-5': NormalizingFlow(model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=5), extra_ndims=unconstrained_event_dims),
    }
    
    vi_results = {}
    
    print('Start VI...')
    for surrogate_name, surrogate_posterior in surrogate_posteriors.items():
    
        vi = VI(model, dataset, surrogate_posterior)
        
        print('Run {}:'.format(surrogate_name))
        start = time.time()
        approx_posterior, losses = vi.fit(optimizer=optimizer(init_lr[surrogate_name]), num_steps=NUM_STEPS, sample_size=SAMPLE_SIZE, progress_bar=True)
        run_time = round(time.time() - start, 2)
        print('Finished in {}s\n'.format(run_time))

        print('computing error metrics...')
        start = time.time()

        post_pred_dist = model.get_posterior_predictive_distribution(
            posterior_distribution=approx_posterior
        )
        posterior_predictive_samples = post_pred_dist.sample(1000)['y'].numpy()
        posterior_samples = OrderedDict([(k, np.expand_dims(v.numpy(),0)) for k, v in approx_posterior.sample(10000).items()])
        
        summary = az.summary(posterior_samples, stat_funcs=func_dict, extend=False)
        
        # compute Wasserstein distance
        flat_mcmc_samples = OrderedDict([(k,np.expand_dims(v.flatten(), 1)) for k, v in mcmc_results['posterior_samples'].items()])
        flat_vi_samples = OrderedDict([(k,np.expand_dims(v.flatten(), 1)) for k, v in posterior_samples.items()])
        distances = OrderedDict([(k, ot.dist(flat_mcmc_samples[k], flat_vi_samples[k])) for k in flat_vi_samples.keys()])
        W2 = OrderedDict([(k, np.sqrt(
            ot.sinkhorn2(
                ot.unif(flat_mcmc_samples[k].size), 
                ot.unif(flat_vi_samples[k].size), 
                dist/dist.max(), 
                reg=1e-3
            ))) for k, dist in distances.items()]) 
        
        # compute absolute posterior mean difference to mcmc result
        posterior_mean_error = OrderedDict([(k, np.abs(v.mean() - flat_mcmc_samples[k].mean())) for k, v in flat_vi_samples.items()])
        comp_time = round(time.time() - start, 2)
        print('Finished computing error metrics in {}s\n\n'.format(comp_time))
        
        vi_result = {
            'losses': losses,
            'final_loss': losses[-100:].mean(),
            'posterior_samples': posterior_samples,
            'posterior_predictive_samples': posterior_predictive_samples,
            'summary': summary,
            'W2': W2,
            'posterior_mean_error':  posterior_mean_error,
            'run_time': run_time,
        }
        vi_results[surrogate_name] = vi_result
    print('Finished VI! \n\n\n\n')
    results[name]['vi'] = vi_results
    

Starting with dataset_1

Start MCMC...
Run NUTS:


Finished in 37.55s

Finished MCMC! 

Start VI...
Run meanfield_advi:


Finished in 5.09s

computing error metrics...




Finished computing error metrics in 68.51s


Run affine_flow:


Finished in 5.37s

computing error metrics...




Finished computing error metrics in 61.71s


Run maf:


Finished in 7.93s

computing error metrics...




Finished computing error metrics in 78.02s


Run cnf:


Finished in 447.23s

computing error metrics...




Finished computing error metrics in 73.2s


Run hnf-1:


Finished in 31.08s

computing error metrics...




Finished computing error metrics in 36.56s


Run hnf-2:


KeyError: 'hnf-2'

loc_W2_mean = np.mean([v['vi']['maf']['W2']['loc'] for v in results.values()])
scale_W2_mean = np.mean([v['vi']['maf']['W2']['scale'] for v in results.values()])    

In [None]:
with open('./univariate_gaussian_exp_results.pickle', 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('./filename.pickle', 'rb') as handle:
    b = pickle.load(handle)