In [1]:
# The more recent Tensorflow versions can directly compute KL terms with a built-in function.
# Let's compare this computation with the analytical form, as well as an approximation using samples. 
# We can see how many samples we need to get a good estimate as well.
# We'll just use Normal distributions for now.
import tensorflow as tf
import numpy as np

In [14]:
p_loc = 0.0
p_scale = 0.4
q_loc = 2.0
q_scale = 1.2

p = tf.distributions.Normal(loc=p_loc, scale=p_scale, name="distribution_a")
q = tf.distributions.Normal(loc=q_loc, scale=q_scale, name="distribution_b")

n_samples = tf.placeholder(dtype=tf.int32)
p_samps = p.sample(n_samples)
q_samps = q.sample(n_samples)

kld_tf = tf.distributions.kl_divergence(p, q, name="kld")
kld_approx = tf.reduce_mean(p.log_prob(p_samps) - q.log_prob(p_samps))

# TF built-in function and approximation.

In [15]:
with tf.Session() as sess:
        
    print("KL term as implemented in TF:")
    kld_tf_ = kld_tf.eval()
    print(kld_tf_)
    print("")

KL term as implemented in TF:
2.0430567



In [16]:
with tf.Session() as sess:
    
    samps_space = [1, 2, 4, 8, 16, 32, 64, 128]
    for s in samps_space:
        print("KL term as approximated with %d samples:" % s)
        for i in range(10):
            kld_approx_ = kld_approx.eval(feed_dict={n_samples:s})
            print("%d  % .4f  %.4f" % (i, kld_approx_, abs(kld_approx_ - kld_tf_) / kld_tf_))

KL term as approximated with 1 samples:
0   2.5594  0.2527
1  -0.4264  1.2087
2   1.9823  0.0297
3   1.8560  0.0915
4   2.0138  0.0143
5   2.6317  0.2881
6   2.4318  0.1903
7   2.0815  0.0188
8   0.9530  0.5335
9   2.6330  0.2887
KL term as approximated with 2 samples:
0   2.2073  0.0804
1   2.5494  0.2478
2   1.6407  0.1970
3   1.9204  0.0600
4   2.6041  0.2746
5   2.0809  0.0185
6   1.0975  0.4628
7   2.4207  0.1848
8   1.9826  0.0296
9   2.3401  0.1454
KL term as approximated with 4 samples:
0   2.3396  0.1452
1   1.1545  0.4349
2   2.4077  0.1785
3   1.8378  0.1005
4   2.4826  0.2152
5   2.3430  0.1468
6   1.4351  0.2976
7   2.0910  0.0235
8   2.0377  0.0026
9   2.1341  0.0446
KL term as approximated with 8 samples:
0   2.2290  0.0910
1   2.1102  0.0329
2   1.6686  0.1833
3   2.2177  0.0855
4   1.7791  0.1292
5   2.4837  0.2157
6   2.1225  0.0389
7   1.9360  0.0524
8   1.7870  0.1253
9   2.2139  0.0836
KL term as approximated with 16 samples:
0   2.1552  0.0549
1   2.3257  0.1384
2

# Analytical computation.

Now, we could also compute this term analytically in this case.

$p(x) = \frac{1}{\sqrt{2\pi\sigma_p^2}} e^{-\frac{(x-\mu_p)^2}{2\sigma_p^2}}$
$q(x) = \frac{1}{\sqrt{2\pi\sigma_q^2}} e^{-\frac{(x-\mu_q)^2}{2\sigma_q^2}}$

$KL(p||q) = E_p [\log (\frac{p}{q})] = \int_{-\infty}^{\infty} p(x) \log (\frac{p(x)}{q(x)}) dx = \int_{-\infty}^{\infty} p(x) (\log p(x) - \log q(x)) dx = \int_{-\infty}^{\infty} p(x) \log p(x) dx - \int_{-\infty}^{\infty} p(x) \log q(x) dx$

$= -\frac{1}{2} (1 + \log (2\pi \sigma_p^2)) - \int_{-\infty}^{\infty} p(x) \log q(x) dx$

$- \int_{-\infty}^{\infty} p(x) \log q(x) dx = - \int_{-\infty}^{\infty} p(x) ( \log \frac{1}{\sqrt{2\pi\sigma_q}} - \frac{(x-\mu_q)^2}{2\sigma_q^2}) dx = \frac{1}{2} \log (2\pi\sigma_q^2) + \int_{-\infty}^{\infty} p(x) \frac{(x-\mu_q)^2}{2\sigma_q^2}) dx$

$\int_{-\infty}^{\infty} p(x) \frac{(x-\mu_q)^2}{2\sigma_q^2} dx = \frac{1}{2\sigma_q^2} \int_{-\infty}^{\infty} p(x) (x^2 - 2\mu_qx + \mu_q^2) dx = \frac{(\sigma_p^2 + \mu_p^2) - 2\mu_q\mu_p + \mu_q^2}{2\sigma_q^2}  = \frac{\sigma_p^2 + (\mu_p - \mu_q)^2)}{2\sigma_q^2}$

Collecting up all the terms we get

$KL(p||q) = -\frac{1}{2} (1 + \log (2\pi\sigma_p^2)) + \frac{1}{2} \log (2\pi\sigma_q^2) + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2)}{2\sigma_q^2} $

$= -\frac{1}{2} (1 + \log (2\pi\sigma_p^2) - \log (2\pi\sigma_q^2)) + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2)}{2\sigma_q^2} = -\frac{1}{2} (1 + \log (\sigma_p^2) - \log (\sigma_q^2)) + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2)}{2\sigma_q^2} = \log \frac{\sigma_q}{\sigma_p} + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2)}{2\sigma_q^2} - \frac{1}{2}$

In [17]:
kld_analytical = np.log(q_scale/p_scale) + (p_scale**2 + (p_loc - q_loc)**2) / (2*q_scale**2) - 0.5

print("KL term computed analytically:")
print(kld_analytical)

KL term computed analytically:
2.043056733112554
