https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed

In [2]:
import torch

ELBO Loss: $$ \min \mathbb{E}_q \left[ \log q\left(z\middle|x\right) - \log p\left(z\right) \right] - \mathbb{E}_q \log p\left(x\middle|z\right) $$

The first expectation is the KL divergence.

Think: the first term wants to minimize the (expected) difference between the probs given by $p$ and $q$. We're controlling $q$ here.

In [13]:
p = torch.distributions.Normal(loc=0, scale=1)
q = torch.distributions.Normal(loc=2, scale=4)

# take a sample
z = q.rsample()
print(z)

tensor(3.2009)


Given a sample from $q$, we can take the logs in the KL divergence:

In [17]:
log_pz = p.log_prob(z)
log_qzx = q.log_prob(z)

print(f"prob pz:  {torch.exp(log_pz):.4f}")
print(f"prob qzx: {torch.exp(log_qzx):.4f}")

kld = log_qzx - log_pz
print(f"kld (single point): {kld:.4f}")

prob pz:  0.0024
prob qzx: 0.0953
kld (single point): 3.6914


These are the probs of observing $z$ from the two distributions. If these two probs are vastly different for a given $x$, then if we do this over all $x$ and take the expected value, this difference will be large. Thus the KL divergence will be large.

We can move the distributions closer together and try again.

In [19]:
p = torch.distributions.Normal(loc=0, scale=1)
q = torch.distributions.Normal(loc=1, scale=2)

# plot
# ...

# take a sample
z = q.rsample()
kld = q.log_prob(z) - p.log_prob(z)
print(f"kld (single point): {kld:.4f}")

kld: 7.2369


Issue: this is technically only difference in the log probabilities for a single point $z \sim q$; what we *really* want is the **expected difference** under all possible values for $z$ (as drawn from $q$)!

In [20]:
def kl_divergence(z, mu, std):
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

    kld = log_qzx - log_pz

    return kld