In [8]:
import warnings
warnings.simplefilter(action="ignore")

In [9]:
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
matplotlib.use("Agg")
from matplotlib import figure  # pylint: disable=g-import-not-at-top
from matplotlib.backends import backend_agg
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
try:
  import seaborn as sns  # pylint: disable=g-import-not-at-top
  HAS_SEABORN = True
except ImportError:
  HAS_SEABORN = False

tfd = tfp.distributions

In [15]:
A =tfd.Normal( loc=[1., -1],scale=[1, 2.])
B =tfd.Normal( loc=[1., -1],scale=[1, 2.])
C = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(
      loc=[-1., 1],       # One for each component.
      scale=[0.1, 0.5]))  # And same here.
D = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(
      loc=[-1., 1],       # One for each component.
      scale=[0.1, 0.5]))  # And same here.

In [39]:
tf.distributions.kl_divergence(A, B)

<tf.Tensor 'KullbackLeibler_4/kl_normal_normal/add:0' shape=(2,) dtype=float32>

In [24]:
mu_p = 1.
sigma_p = 1.
mu_q = 0.
sigma_q = 2.
p = tf.distributions.Normal(loc=mu_p, scale=sigma_p)
q = tf.distributions.Normal(loc=mu_q, scale=sigma_q)

In [40]:
mc_samples = 2

In [41]:
sess = tf.InteractiveSession()

In [42]:
x_p = p.sample(sample_shape=(mc_samples,)).eval()
x_q = q.sample(sample_shape=(mc_samples,)).eval()

In [43]:
mixture_prob = lambda x, p, q: .5 * (p.prob(x) + q.prob(x))

In [44]:
def kl_mixture(x, p1, p2): 
    return p1.log_prob(x) - tf.log(mixture_prob(x, p1, p2))

In [45]:
kl_mixture(x_p, p ,q )

<tf.Tensor 'sub_1:0' shape=(2,) dtype=float32>

In [14]:
def make_mixture_posterior_fn(mixture_components):
  """Creates the mixture of Gaussians posterior distribution.
  Args:
    mixture_components: Number of elements of the mixture.
  Returns:
    random_posterior_fn: A callable like tfp.layers.default_multivariate_normal_fn which returns an instance of a posterior distribution.
  """
  if mixture_components == 1:
    # See the module docstring for why we don't learn the parameters here.
    # default_mvn_fn returns a standard non-trainable MVN
    return tfp.layers.default_multivariate_normal_fn

  def _random_posterior_fn(dtype, shape, name, trainable, add_variable_fn):
    # signature matches tfp.layers.default_multivariate_normal_fn
#    loc = add_variable_fn(
#       name=name + "/loc", shape=[mixture_components,] + list(shape),  trainable=trainable)
#    raw_scale_diag = add_variable_fn(
#       name=name+"/scale", shape=[mixture_components,] + list(shape), trainable=trainable)
    loc = add_variable_fn(
        name=name + "/loc", shape=list(shape) + [mixture_components,],  trainable=trainable)
    raw_scale_diag = add_variable_fn(
        name=name+"/scale", shape=list(shape) + [mixture_components,], trainable=trainable)
    mixture_logits = add_variable_fn(
        name=name+"/logit", shape=[mixture_components], trainable=trainable)

    return tfd.MixtureSameFamily(
        components_distribution=tfd.MultivariateNormalDiag(
            loc=loc,
            scale_diag=tf.nn.softplus(raw_scale_diag)),
        mixture_distribution=tfd.Categorical(logits=mixture_logits),
        name="posterior")
  return _random_posterior_fn