In [1]:
import time
import warnings
warnings.filterwarnings("ignore")
warnings.simplefilter('ignore')

from collections import OrderedDict
import functools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
# use ggplot styles for graphs
plt.style.use('ggplot')

import pickle
import arviz as az

In [2]:
import ot
import tensorflow as tf
import tensorflow_probability as tfp

# set tf logger to log level ERROR to avoid warnings
tf.get_logger().setLevel('ERROR')

tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras

In [3]:
# import probabilistic models
from bayes_vi.models import Model

# import utils
from bayes_vi.utils.datasets import make_dataset_from_df

In [4]:
# mcmc imports
from bayes_vi.inference.mcmc import MCMC
from bayes_vi.inference.mcmc.transition_kernels import HamiltonianMonteCarlo, NoUTurnSampler, RandomWalkMetropolis
from bayes_vi.inference.mcmc.stepsize_adaptation_kernels import SimpleStepSizeAdaptation, DualAveragingStepSizeAdaptation

In [5]:
# vi imports 
from bayes_vi.inference.vi import VI

from bayes_vi.inference.vi.surrogate_posteriors import ADVI, NormalizingFlow
from bayes_vi.utils import to_ordered_dict
from bayes_vi.inference.vi.flow_bijectors import HamiltonianFlow, AffineFlow, make_energy_fn, make_scale_fn, make_shift_fn
from bayes_vi.utils.leapfrog_integrator import LeapfrogIntegrator

# Experiment 2 - Simple Linear Regression

## MCMC Config

In [6]:
# sampling params
NUM_CHAINS = 5
NUM_SAMPLES = 1000
NUM_BURNIN_STEPS = 10000

In [7]:
# define step size adaptation
stepsize_adaptation_kernel = DualAveragingStepSizeAdaptation(num_adaptation_steps=int(NUM_BURNIN_STEPS*0.8))

kernel = NoUTurnSampler(
    step_size=0.01, 
    max_tree_depth=5,
    stepsize_adaptation_kernel=stepsize_adaptation_kernel
)

## VI Config

In [8]:
NUM_STEPS = 10000
SAMPLE_SIZE = 50

In [9]:
optimizer = tf.optimizers.Adam
init_lr = {
    'meanfield advi': 5e-3, 
    'affine flow': 5e-3,
    'maf': 5e-3,
    'cnf': 1e-4,
    'hnf(1)': 1e-3,
    'hnf(2)': 1e-3,
    'hnf(3)': 1e-3
}

In [10]:
func_dict = {
    "mean": np.mean,
    "sd": np.std,
    "min": lambda x: np.percentile(x, 0),
    "hdi 3%": lambda x: np.percentile(x, 3),
    "mode": lambda x: np.percentile(x, 50),
    "hdi 97%": lambda x: np.percentile(x, 97),
    "max": lambda x: np.percentile(x, 100)
}

In [11]:
def get_continuous_flow_bijector(unconstrained_event_dims):
    state_fn = tfk.Sequential()
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(128, activation=tfk.activations.tanh))
    state_fn.add(tfk.layers.Dense(unconstrained_event_dims))
    state_fn.build((None, unconstrained_event_dims+1))
    state_time_derivative_fn = lambda t, state: state_fn(tf.concat([tf.fill((state.shape[0],1), t), state], axis=-1))
    return tfb.FFJORD(state_time_derivative_fn, 
                      ode_solve_fn=tfp.math.ode.DormandPrince(first_step_size=0.1).solve, 
                      trace_augmentation_fn=tfb.ffjord.trace_jacobian_hutchinson)
    
def get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows):
    return tfb.Chain([
        HamiltonianFlow(
            event_dims=unconstrained_event_dims,
            symplectic_integrator=LeapfrogIntegrator(), 
            step_sizes=tf.Variable(0.1), 
            num_integration_steps=2,
            hidden_layers=[128, 128]
        ) for _ in range(num_flows)
    ])

def get_masked_autoregressive_flow_bijector(unconstrained_event_dims):
    return tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(params=2, hidden_units=[128, 128], activation='tanh')
    )

def get_affine_flow_bijector(unconstrained_event_dims):
    return AffineFlow(unconstrained_event_dims)

def get_posterior_lift_distribution(unconstrained_event_dims):
    scale_fn = make_scale_fn(unconstrained_event_dims, hidden_layers=[128,128])
    shift_fn = make_shift_fn(unconstrained_event_dims, hidden_layers=[128,128])
    return lambda q: tfd.MultivariateNormalDiag(loc=shift_fn(q), scale_diag=scale_fn(q))

## Define Model

In [12]:
# define priors
priors = OrderedDict(
    beta_0 = tfd.Normal(loc=0., scale=10.),
    beta_1 = tfd.Normal(loc=0., scale=10.),
    scale = tfd.HalfNormal(scale=10.),
)

# define likelihood
def likelihood(beta_0, beta_1, scale, features):
    linear_response = beta_0 + beta_1*features['x']
    return tfd.Normal(loc=linear_response, scale=scale)

In [13]:
model = Model(priors=priors, likelihood=likelihood)

## Generate Datasets

In [14]:
num_datasets = 10
num_datapoints = 100

In [15]:
true_params = {}
datasets = {}

for i in range(1, num_datasets+1):
    x = tf.random.normal(shape=(num_datapoints, ), mean=0., stddev=10.0, dtype=tf.float32)
    dist = model.get_joint_distribution(features={'x': x})
    beta_0, beta_1, scale, y = dist.sample().values()
    
    true_params['dataset {}'.format(i)] = to_ordered_dict(
        model.param_names, [beta_0.numpy(), beta_1.numpy(), scale.numpy()]
    ) 
    datasets['dataset {}'.format(i)] = make_dataset_from_df(
        pd.DataFrame({'y': y, 'x': x}), target_names=['y'], feature_names=['x'], format_features_as='dict'
    ) 

## Run Inferences and Collect Results

In [17]:
results = {}

In [18]:
for i, (name, dataset) in enumerate(datasets.items()):
    print('Starting with {}\n'.format(name))
    features, targets = list(dataset.batch(dataset.cardinality()).take(1))[0]
    
    results[name]={
        'data': {
            'features': {k: v.numpy() for k,v in features.items()},
            'targets': targets.numpy(),
            'true params': true_params[name]
        }}

    ################################################################################################################################
    
    # MCMC
    print('Start MCMC...')

    mcmc = MCMC(model=model, dataset=dataset, transition_kernel=kernel)
    
    print('Run NUTS:')
    start = time.time()
    mcmc_result = mcmc.fit(
        num_chains=NUM_CHAINS, 
        num_samples=NUM_SAMPLES, 
        num_burnin_steps=NUM_BURNIN_STEPS,
        progress_bar=True,
    )
    run_time = round(time.time() - start, 2)
    print('Finished in {}s\n'.format(run_time))

    post_pred_dist = model.get_posterior_predictive_distribution(
        posterior_samples=mcmc_result.samples, features=features
    )
    posterior_predictive_samples = post_pred_dist.sample(100)['y'].numpy()

    posterior_samples = to_ordered_dict(model.param_names, tf.nest.map_structure(lambda v: np.swapaxes(v.numpy(), 0, 1), mcmc_result.samples))
    map_trace = lambda v: np.swapaxes(v.numpy(), 0, 1) if tf.rank(v) >= 2 else v
    traces = {k: map_trace(v) if not tf.nest.is_nested(v) else tf.nest.map_structure(map_trace, v) for k,v in mcmc_result.trace.items()}
    summary = az.summary(posterior_samples, round_to=2, kind='diagnostics', stat_funcs=func_dict, extend=True)
    cols = summary.columns.tolist()
    summary = summary[cols[-len(func_dict):] + cols[:-len(func_dict)]]

        
    mcmc_results = {
        'posterior samples': posterior_samples,
        'traces': traces,
        'posterior predictive samples': posterior_predictive_samples,
        'acceptance ratios per chain': mcmc_result.accept_ratios.numpy(),
        'summary': summary,
        'run-time': run_time
    }
    
    results[name]['mcmc'] = mcmc_results
    
    print('Finished MCMC! \n')
    ################################################################################################################################
    
    # VI
    unconstrained_event_dims = model.flat_unconstrained_param_event_ndims    

    surrogate_posteriors = {
        'meanfield advi': ADVI(model, mean_field=True), 
        'affine flow': NormalizingFlow(model, flow_bijector=get_affine_flow_bijector(unconstrained_event_dims)),
        'maf': NormalizingFlow(model, flow_bijector=get_masked_autoregressive_flow_bijector(unconstrained_event_dims)),
        'cnf': NormalizingFlow(model, flow_bijector=get_continuous_flow_bijector(unconstrained_event_dims)),
        'hnf(1)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=1), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
        'hnf(2)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=2), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims)),
        'hnf(3)': NormalizingFlow(
            model, flow_bijector=get_hamiltonian_flow_bijector(unconstrained_event_dims, num_flows=3), 
            extra_ndims=unconstrained_event_dims, posterior_lift_distribution=get_posterior_lift_distribution(unconstrained_event_dims))
    }
    
    vi_results = {}
    
    print('Start VI...')
    for surrogate_name, surrogate_posterior in surrogate_posteriors.items():
    
        vi = VI(model, dataset, surrogate_posterior)
        
        print('Run {}:'.format(surrogate_name))
        start = time.time()
        approx_posterior, losses = vi.fit(optimizer=optimizer(init_lr[surrogate_name]), num_steps=NUM_STEPS, sample_size=SAMPLE_SIZE, progress_bar=True)
        run_time = round(time.time() - start, 2)
        print('Finished in {}s\n'.format(run_time))

        print('computing error metrics...')
        start = time.time()

        post_pred_dist = model.get_posterior_predictive_distribution(
            posterior_distribution=approx_posterior, features=features
        )
        posterior_predictive_samples = post_pred_dist.sample(100)['y'].numpy()
        posterior_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.numpy(),0), approx_posterior.sample(5000))
        
        summary = az.summary(posterior_samples, stat_funcs=func_dict, extend=False)
        
        # compute Wasserstein distance
        flat_mcmc_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.flatten(), 1), mcmc_results['posterior samples'])
        flat_vi_samples = tf.nest.map_structure(lambda v: np.expand_dims(v.flatten(), 1), posterior_samples)
        distances = tf.nest.map_structure(ot.dist, flat_mcmc_samples, flat_vi_samples)
        W2 = tf.nest.map_structure(
            lambda x,y, dist: np.sqrt(ot.emd2(ot.unif(x.size), ot.unif(y.size), dist)), 
            flat_mcmc_samples, flat_vi_samples, distances
        )
        
        # compute absolute posterior mean difference to mcmc result
        posterior_mean_error = tf.nest.map_structure(lambda x,y: np.abs(x.mean() - y.mean()), flat_vi_samples, flat_mcmc_samples)
        posterior_sd_error = tf.nest.map_structure(lambda x,y: np.abs(x.std() - y.std()), flat_vi_samples, flat_mcmc_samples)
        comp_time = round(time.time() - start, 2)
        print('Finished computing error metrics in {}s\n\n'.format(comp_time))
        
        vi_result = {
            'losses': losses,
            'final loss': losses[-100:].mean(),
            'learning rate': init_lr[surrogate_name],
            'posterior samples': posterior_samples,
            'posterior predictive samples': posterior_predictive_samples,
            'summary': summary,
            'W2': W2,
            'mean error':  posterior_mean_error,
            'sd error': posterior_sd_error,
            'run-time': run_time,
        }
        vi_results[surrogate_name] = vi_result
    print('Finished VI! \n\n\n\n')
    results[name]['vi'] = vi_results
    

Starting with dataset 1

Start MCMC...
Run NUTS:


Finished in 203.48s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 14.86s

computing error metrics...




Finished computing error metrics in 18.44s


Run affine flow:


Finished in 14.44s

computing error metrics...




Finished computing error metrics in 18.22s


Run maf:


Finished in 29.82s

computing error metrics...




Finished computing error metrics in 17.15s


Run cnf:


Finished in 1171.02s

computing error metrics...




Finished computing error metrics in 12.16s


Run hnf(1):


Finished in 64.18s

computing error metrics...




Finished computing error metrics in 12.97s


Run hnf(2):


Finished in 104.54s

computing error metrics...




Finished computing error metrics in 12.61s


Run hnf(3):


Finished in 149.42s

computing error metrics...




Finished computing error metrics in 15.13s


Finished VI! 




Starting with dataset 2

Start MCMC...
Run NUTS:


Finished in 265.99s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 11.37s

computing error metrics...




Finished computing error metrics in 14.82s


Run affine flow:


Finished in 13.35s

computing error metrics...




Finished computing error metrics in 11.85s


Run maf:


Finished in 28.3s

computing error metrics...




Finished computing error metrics in 12.6s


Run cnf:


Finished in 1287.74s

computing error metrics...




Finished computing error metrics in 13.84s


Run hnf(1):


Finished in 64.88s

computing error metrics...




Finished computing error metrics in 15.64s


Run hnf(2):


Finished in 106.98s

computing error metrics...




Finished computing error metrics in 12.88s


Run hnf(3):


Finished in 153.98s

computing error metrics...




Finished computing error metrics in 14.93s


Finished VI! 




Starting with dataset 3

Start MCMC...
Run NUTS:


Finished in 244.64s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 10.19s

computing error metrics...




Finished computing error metrics in 14.14s


Run affine flow:


Finished in 14.61s

computing error metrics...




Finished computing error metrics in 11.83s


Run maf:


Finished in 27.88s

computing error metrics...




Finished computing error metrics in 15.26s


Run cnf:


Finished in 1172.1s

computing error metrics...




Finished computing error metrics in 18.52s


Run hnf(1):


Finished in 77.05s

computing error metrics...




Finished computing error metrics in 16.55s


Run hnf(2):


Finished in 123.25s

computing error metrics...




Finished computing error metrics in 17.91s


Run hnf(3):


Finished in 155.64s

computing error metrics...




Finished computing error metrics in 15.59s


Finished VI! 




Starting with dataset 4

Start MCMC...
Run NUTS:


Finished in 255.26s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 15.16s

computing error metrics...




Finished computing error metrics in 20.23s


Run affine flow:


Finished in 22.74s

computing error metrics...




Finished computing error metrics in 14.62s


Run maf:


Finished in 35.09s

computing error metrics...




Finished computing error metrics in 13.63s


Run cnf:


Finished in 1234.25s

computing error metrics...




Finished computing error metrics in 14.59s


Run hnf(1):


Finished in 63.49s

computing error metrics...




Finished computing error metrics in 12.4s


Run hnf(2):


Finished in 108.54s

computing error metrics...




Finished computing error metrics in 16.48s


Run hnf(3):


Finished in 153.2s

computing error metrics...




Finished computing error metrics in 17.11s


Finished VI! 




Starting with dataset 5

Start MCMC...
Run NUTS:


Finished in 184.29s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 20.35s

computing error metrics...




Finished computing error metrics in 15.46s


Run affine flow:


Finished in 15.86s

computing error metrics...




Finished computing error metrics in 17.93s


Run maf:


Finished in 26.99s

computing error metrics...




Finished computing error metrics in 14.75s


Run cnf:


Finished in 1368.77s

computing error metrics...




Finished computing error metrics in 13.79s


Run hnf(1):


Finished in 77.05s

computing error metrics...




Finished computing error metrics in 15.13s


Run hnf(2):


Finished in 114.17s

computing error metrics...




Finished computing error metrics in 21.96s


Run hnf(3):


Finished in 164.52s

computing error metrics...




Finished computing error metrics in 16.67s


Finished VI! 




Starting with dataset 6

Start MCMC...
Run NUTS:


Finished in 283.88s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 17.19s

computing error metrics...




Finished computing error metrics in 14.93s


Run affine flow:


Finished in 21.5s

computing error metrics...




Finished computing error metrics in 19.16s


Run maf:


Finished in 30.43s

computing error metrics...




Finished computing error metrics in 15.67s


Run cnf:


Finished in 1146.43s

computing error metrics...




Finished computing error metrics in 14.07s


Run hnf(1):


Finished in 65.87s

computing error metrics...




Finished computing error metrics in 16.43s


Run hnf(2):


Finished in 111.16s

computing error metrics...




Finished computing error metrics in 15.38s


Run hnf(3):


Finished in 172.52s

computing error metrics...




Finished computing error metrics in 14.48s


Finished VI! 




Starting with dataset 7

Start MCMC...
Run NUTS:


Finished in 221.83s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 14.85s

computing error metrics...




Finished computing error metrics in 16.58s


Run affine flow:


Finished in 18.27s

computing error metrics...




Finished computing error metrics in 15.61s


Run maf:


Finished in 33.6s

computing error metrics...




Finished computing error metrics in 19.96s


Run cnf:


Finished in 1158.26s

computing error metrics...




Finished computing error metrics in 15.73s


Run hnf(1):


Finished in 65.47s

computing error metrics...




Finished computing error metrics in 13.94s


Run hnf(2):


Finished in 127.03s

computing error metrics...




Finished computing error metrics in 14.34s


Run hnf(3):


Finished in 152.58s

computing error metrics...




Finished computing error metrics in 15.9s


Finished VI! 




Starting with dataset 8

Start MCMC...
Run NUTS:


Finished in 235.78s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 14.01s

computing error metrics...




Finished computing error metrics in 13.72s


Run affine flow:


Finished in 20.18s

computing error metrics...




Finished computing error metrics in 14.85s


Run maf:


Finished in 33.14s

computing error metrics...




Finished computing error metrics in 18.07s


Run cnf:


Finished in 1175.23s

computing error metrics...




Finished computing error metrics in 13.95s


Run hnf(1):


Finished in 75.39s

computing error metrics...




Finished computing error metrics in 15.72s


Run hnf(2):


Finished in 118.27s

computing error metrics...




Finished computing error metrics in 15.34s


Run hnf(3):


Finished in 171.66s

computing error metrics...




Finished computing error metrics in 20.39s


Finished VI! 




Starting with dataset 9

Start MCMC...
Run NUTS:


Finished in 262.53s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 13.55s

computing error metrics...




Finished computing error metrics in 16.24s


Run affine flow:


Finished in 33.07s

computing error metrics...




Finished computing error metrics in 15.04s


Run maf:


Finished in 33.22s

computing error metrics...




Finished computing error metrics in 27.68s


Run cnf:


Finished in 1356.38s

computing error metrics...




Finished computing error metrics in 23.93s


Run hnf(1):


Finished in 85.4s

computing error metrics...




Finished computing error metrics in 20.61s


Run hnf(2):


Finished in 137.96s

computing error metrics...




Finished computing error metrics in 22.25s


Run hnf(3):


Finished in 168.45s

computing error metrics...




Finished computing error metrics in 17.46s


Finished VI! 




Starting with dataset 10

Start MCMC...
Run NUTS:


Finished in 222.94s

Finished MCMC! 

Start VI...
Run meanfield advi:


Finished in 13.37s

computing error metrics...




Finished computing error metrics in 14.68s


Run affine flow:


Finished in 19.2s

computing error metrics...




Finished computing error metrics in 13.47s


Run maf:


Finished in 30.76s

computing error metrics...




Finished computing error metrics in 15.52s


Run cnf:


Finished in 1114.07s

computing error metrics...




Finished computing error metrics in 12.99s


Run hnf(1):


Finished in 64.37s

computing error metrics...




Finished computing error metrics in 13.41s


Run hnf(2):


Finished in 117.57s

computing error metrics...




Finished computing error metrics in 22.64s


Run hnf(3):


Finished in 169.78s

computing error metrics...




Finished computing error metrics in 13.94s


Finished VI! 






In [19]:
with open('./linear_regression_experiment/linear_regression_exp_results.pickle', 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)