In [None]:
import matplotlib.pyplot as pp
import tensorflow as tf
import tensorflow_probability as tp

from functools import partial

In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
pp.style.use('ggplot')

# A/B testing

In [None]:
def unnormalized_log_probability(success_count, total_count, success_probability):
    rv_success_probability = tp.distributions.Uniform()
    rv_success = tp.distributions.Bernoulli(probs=success_probability)
    return (
        rv_success_probability.log_prob(success_probability)
        + tf.to_float(success_count) * rv_success.log_prob(1)
        + tf.to_float(total_count - success_count) * rv_success.log_prob(0)
    )

probability = 0.05
total_count = 100000
success_count = tf.reduce_sum(tp.distributions.Bernoulli(probs=probability).sample(sample_shape=total_count))

initial_state = [
    tf.to_float(success_count) / tf.to_float(total_count),
]

unconstraining_bijectors = [
    tp.bijectors.Identity(),  
]

with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    step_size = tf.get_variable(
        name='step_size',
        initializer=tf.constant(0.5, dtype=tf.float32),
        trainable=False,
        use_resource=True,
    )

kernel = tp.mcmc.TransformedTransitionKernel(
    inner_kernel=tp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=partial(unnormalized_log_probability, success_count, total_count),
        num_leapfrog_steps=2,
        step_size=step_size,
        step_size_update_fn=tp.mcmc.make_simple_step_size_update_policy(),
        state_gradients_are_stopped=True,
    ),
    bijector=unconstraining_bijectors,
)

[posterior_probability], kernel = tp.mcmc.sample_chain(
    num_results=100000,
    num_burnin_steps=10000,
    current_state=initial_state,
    kernel=kernel,
)

In [None]:
session = tf.Session()

session.run([
    tf.global_variables_initializer(),
    tf.local_variables_initializer(),
])

[
    posterior_probability_,
    kernel_,
] = session.run([
    posterior_probability,
    kernel,
])

print('Acceptance rate: {}'.format(kernel_.inner_results.is_accepted.mean()))

In [None]:
pp.figure(figsize=(12, 6))
pp.axvline(x=probability, color='black', linestyle='--', lw=1)
pp.hist(posterior_probability_, bins=25, density=True);