# Beta-Bernoulli model
We fit a simple beta-Bernoulli model.

In [69]:
import numpy as np
import tensorflow as tf
import edward as ed

Configuration

In [113]:
np.random.seed(37)

## Sample

In [114]:
N = 50
p = .8
draws = np.random.choice([0, 1], size=N, p=[1-p, p])
p_ml = sum(draws) / N
print(p_ml)

0.76


## Model

In [115]:
from edward.models import Beta, Bernoulli
theta = Beta(a=1., b=1.)
x = Bernoulli(p=tf.ones([N]) * theta)

## Variational distribution

In [116]:
qtheta_a = tf.nn.softplus(tf.Variable(tf.random_normal([])))
qtheta_b = tf.nn.softplus(tf.Variable(tf.random_normal([])))
qtheta = Beta(a=qtheta_a, b=qtheta_b)

## Inference

In [117]:
inference = ed.KLqp({theta: qtheta}, data={x: draws})
inference.run(n_iter=3000, n_samples=20)

Iteration    1 [  0%]: Loss = 43.153
Iteration  300 [ 10%]: Loss = 32.827
Iteration  600 [ 20%]: Loss = 30.702
Iteration  900 [ 30%]: Loss = 30.272
Iteration 1200 [ 40%]: Loss = 30.169
Iteration 1500 [ 50%]: Loss = 30.640
Iteration 1800 [ 60%]: Loss = 30.149
Iteration 2100 [ 70%]: Loss = 29.725
Iteration 2400 [ 80%]: Loss = 30.173
Iteration 2700 [ 90%]: Loss = 30.190
Iteration 3000 [100%]: Loss = 30.098


## Criticism

In [118]:
x_post = ed.copy(x, {theta: qtheta})

In [121]:
qtheta.mean().eval()

0.74353731

In [141]:
qtheta.std().eval()

0.10805982

In [143]:
print(ed.ppc(lambda xs, zs: tf.reduce_mean(tf.cast(xs[x_post], tf.float32)),
             data={x_post: draws}))

[array([ 0.63999999,  0.44      ,  0.94      ,  0.51999998,  0.51999998,
        0.56      ,  0.81999999,  0.80000001,  0.68000001,  0.81999999,
        0.57999998,  0.86000001,  0.72000003,  0.80000001,  0.57999998,
        0.75999999,  0.77999997,  0.92000002,  0.83999997,  0.75999999,
        0.60000002,  0.72000003,  0.66000003,  0.80000001,  0.83999997,
        0.63999999,  0.88      ,  0.83999997,  0.57999998,  0.92000002,
        0.89999998,  0.86000001,  0.81999999,  0.66000003,  0.44      ,
        0.72000003,  0.77999997,  0.66000003,  0.92000002,  0.68000001,
        0.63999999,  0.77999997,  0.69999999,  0.86000001,  0.92000002,
        0.5       ,  0.75999999,  0.86000001,  0.77999997,  0.80000001,
        0.72000003,  0.72000003,  0.75999999,  0.56      ,  0.54000002,
        0.77999997,  0.74000001,  0.51999998,  0.68000001,  0.54000002,
        0.89999998,  0.77999997,  0.72000003,  0.83999997,  0.68000001,
        0.95999998,  0.47999999,  0.75999999,  0.68000001,  0.9