In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.mcmc.transformed_kernel import make_transformed_log_prob

tfb = tfp.bijectors
tfd = tfp.distributions
dtype = tf.float32

In [2]:
mu = 12
sigma = 2.2
data = np.random.normal(mu, sigma, size=200)

In [3]:
model = tfd.JointDistributionSequential([
    tfd.Exponential(0.1, name='e'),  # sigma
    tfd.Normal(loc=0, scale=10, name='n'),  # mu
    lambda n, e: tfd.Normal(loc=n, scale=e)
])

In [4]:
joint_log_prob = lambda *x: model.log_prob(x + (data,))

unconstraining_bijectors = [
  tfb.Exp(),
  tfb.Identity()
]

target_log_prob = make_transformed_log_prob(
    joint_log_prob,
    unconstraining_bijectors,
    direction='forward',
    enable_bijector_caching=False
)

In [5]:
parameters = model.sample(1)
parameters.pop()
dists = []
for i, parameter in enumerate(parameters):
    shape = parameter[0].shape
    loc = tf.Variable(
        tf.random.normal(shape, dtype=dtype),
        name='meanfield_%s_loc' % i,
        dtype=dtype)
    scale = tfp.util.TransformedVariable(
        tf.fill(shape, value=tf.constant(0.02, dtype)),
        tfb.Softplus(),
        name='meanfield_%s_scale' % i,
    )

    approx_parameter = tfd.Independent(tfd.Normal(loc=loc, scale=scale))
    dists.append(approx_parameter)

meanfield_advi = tfd.JointDistributionSequential(dists)

In [6]:
num_steps = 10_000
num_cols = 20
it_break = num_steps // num_cols

def trace_fn(traceable_quantities):
    tf.cond(
        tf.math.mod(traceable_quantities.step + 1, it_break) == 0,
        lambda: tf.print(
            tf.strings.reduce_join(
                [
                    "\r|",
                    tf.strings.reduce_join(
                        tf.repeat(">", (traceable_quantities.step + 1) // it_break, axis=0)
                    ),
                    tf.strings.reduce_join(
                        tf.repeat(
                            ".",
                            num_cols - (traceable_quantities.step + 1) // it_break,
                            axis=0,
                        )
                    ),
                    "|",
                ]
            ),
            end="",
        ),
        lambda: tf.no_op(),
    )
    return traceable_quantities.loss

opt = tf.optimizers.Adam(learning_rate=.5)

def run_approximation():
    loss_ = tfp.vi.fit_surrogate_posterior(
        target_log_prob,
        surrogate_posterior=meanfield_advi,
        optimizer=opt,
        num_steps=num_steps,
        trace_fn=trace_fn
    )
    return loss_

loss_ = run_approximation()

|>>>>>>>>>>>>>>>>>>>>|