In [2]:
import tensorflow_probability as tfp
import tensorflow as tf
from silence_tensorflow import auto

In [3]:
mu_true = 4.0
sigma_true = 2.0
X = tfp.distributions.Normal(loc = mu_true, scale = sigma_true)
N = 100
dataset = X.sample(N)
dataset

<tf.Tensor: shape=(100,), dtype=float32, numpy=
array([ 4.126232  ,  3.7774131 ,  3.550356  ,  2.6924682 ,  7.3140125 ,
        5.5419564 ,  2.3298979 ,  4.757136  ,  2.0871878 ,  5.2646217 ,
        3.7661717 ,  7.1011915 ,  4.757197  ,  4.145867  ,  2.9169364 ,
        4.0962176 ,  1.2646744 ,  4.0840535 ,  4.1072116 ,  2.629142  ,
        6.6190796 ,  2.1823378 ,  5.4498873 ,  4.037551  ,  5.0410676 ,
        3.8512619 ,  5.6660995 ,  6.9404554 ,  3.9612198 ,  0.02752471,
        5.0204678 ,  3.2048178 ,  1.7205889 ,  3.0492158 ,  0.37218952,
        6.520125  , 10.294703  ,  7.8183346 ,  5.5860443 ,  3.3717322 ,
        1.7206495 ,  1.1965175 ,  2.277965  ,  4.252978  ,  1.4373977 ,
        2.5857434 ,  5.701472  ,  0.7333822 ,  5.0215836 ,  4.241683  ,
        7.153178  ,  2.9367514 ,  3.8723366 ,  3.5990307 ,  5.051783  ,
        2.5304666 ,  4.1300244 ,  3.3802428 ,  4.3550277 ,  5.344529  ,
        3.836696  ,  5.878412  ,  2.747353  ,  3.7089791 ,  4.982298  ,
        4.690199

In [8]:
mu_0 = 4.2
sigma_0=0.3
sigma_fix = sigma_true
mu_N = (sigma_fix**2 * mu_0 + sigma_0**2*tf.reduce_sum(dataset))/(sigma_fix**2+N*sigma_0**2)
sigma_N = (sigma_0*sigma_fix)/tf.sqrt(sigma_fix**2+N*sigma_0**2)



In [10]:
sigma_N

<tf.Tensor: shape=(), dtype=float32, numpy=0.16641007363796234>

In [9]:
mu_N

<tf.Tensor: shape=(), dtype=float32, numpy=4.034511566162109>

In [13]:
def generative_model(mu_0, sigma_0, sigma_fix, n_samples):
    mu = yield tfp.distributions.JointDistributionCoroutine.Root(
        tfp.distributions.Normal(loc=mu_0, scale=sigma_0, name="mu")
    )
    X = yield tfp.distributions.Normal(loc=mu*tf.ones(n_samples), scale = sigma_fix, name = "X")

In [14]:
model_joint = tfp.distributions.JointDistributionCoroutineAutoBatched(lambda : generative_model(mu_0,sigma_0,sigma_fix, N))
model_joint

<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  mu=[],
  X=[100]
) dtype=StructTuple(
  mu=float32,
  X=float32
)>

In [17]:
model_joint_log_prob_fixed_data = lambda mu:model_joint.log_prob(mu, dataset)
model_joint_log_prob_fixed_data(4.0)

<tf.Tensor: shape=(), dtype=float32, numpy=-202.1978759765625>

In [19]:
mu_S = tf.Variable(mu_0, name="mu_Surrogate")
sigma_S = tfp.util.TransformedVariable(sigma_0, bijector=tfp.bijectors.Softplus(), name = "Sigma_Surrogate")
surrogate_posterior = tfp.distributions.Normal(loc = mu_S, scale=sigma_S, name = "surrogate_posterior")
surrogate_posterior

<tfp.distributions.Normal 'surrogate_posterior' batch_shape=[] event_shape=[] dtype=float32>

In [20]:
with tf.GradientTape() as g:
    samples = surrogate_posterior.sample(3)
    neg_elbo = -tf.reduce_mean(model_joint_log_prob_fixed_data(samples) - surrogate_posterior.log_prob(samples))
neg_elbo

<tf.Tensor: shape=(), dtype=float32, numpy=205.0410614013672>

In [21]:
g.gradient(neg_elbo, surrogate_posterior.trainable_variables)

(<tf.Tensor: shape=(), dtype=float32, numpy=11.262495994567871>,
 <tf.Tensor: shape=(), dtype=float32, numpy=2.9126365184783936>)

In [23]:
tfp.vi.fit_surrogate_posterior(target_log_prob_fn=model_joint_log_prob_fixed_data,
                               surrogate_posterior=surrogate_posterior,
                               optimizer=tf.optimizers.Adam(0.1),
                               num_steps = 1000,
                               sample_size = 100)

<tf.Tensor: shape=(1000,), dtype=float32, numpy=
array([204.11148, 203.45496, 203.37569, 203.3321 , 203.3441 , 203.39719,
       203.17738, 203.0641 , 203.06255, 203.08652, 203.1278 , 203.11322,
       203.1931 , 203.16518, 203.14934, 203.15816, 203.06592, 203.05722,
       203.11058, 203.12425, 203.14084, 203.09404, 203.06355, 203.07483,
       203.06967, 203.06438, 203.04323, 203.06677, 203.0772 , 203.05719,
       203.05994, 203.05266, 203.04123, 203.0645 , 203.08426, 203.083  ,
       203.11449, 203.06516, 203.04086, 203.06572, 203.05261, 203.0411 ,
       203.05774, 203.0621 , 203.06117, 203.06676, 203.04826, 203.06194,
       203.06258, 203.08759, 203.03883, 203.07043, 203.03848, 203.05812,
       203.05258, 203.05551, 203.04004, 203.06403, 203.10812, 203.07129,
       203.05295, 203.04898, 203.04665, 203.0457 , 203.07489, 203.04865,
       203.08685, 203.07578, 203.04266, 203.05125, 203.05548, 203.05266,
       203.05553, 203.04918, 203.04355, 203.06255, 203.06085, 203.08504,
  

In [24]:
mu_S

<tf.Variable 'mu_Surrogate:0' shape=() dtype=float32, numpy=4.057673454284668>

In [25]:
sigma_S

<TransformedVariable: name=Sigma_Surrogate, dtype=float32, shape=[], fn="softplus", numpy=0.1695268>

In [26]:
mu_N

<tf.Tensor: shape=(), dtype=float32, numpy=4.034511566162109>

In [27]:
sigma_N

<tf.Tensor: shape=(), dtype=float32, numpy=0.16641007363796234>