# Variational Inference

https://ermongroup.github.io/cs228-notes/inference/variational/

https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf

https://arxiv.org/pdf/1601.00670.pdf

## Idea
Pick an approximation $q(x)$ to the distribution from a *tractable* family, and make the approximation $q(x)$ be as close as possible to the true distribution $p^*(x)=p(x|D)$.

Note: a cheap alternative is a Gaussian approximation; however, this is only valid when the true distribution $p(x)$ is well-modeled by a Gaussian distribution.


## VI
Throughout the derivation, remember that $p^*(x)$ is intractable to compute.

First thought: minimize the KL divergence of $p^*$ from q:
$$KL(p^*||q)=\sum_x p^*(x) log \frac {p^*(x)} {q(x)}$$

However, the first term is intractable, since it is an expectation of $p^*$. So let's use the reverse KL:
$$KL(q||p^*)=\sum_x q(x) log \frac {q(x)} {p^*(x)}$$

Now, the expectation is on $q(x)$, which is tractable to compute. However, note the following:

$$
\begin{align}
p^*(x) &= p(x|D) \\
       &= \frac {p(x,D)} {p(D)}
\end{align}
$$

Here, the normalization constant $Z=p(D)$ usually also intractable to compute. So we can't use this as-is. What we *can* usually reasonably compute is the un-normalized representation of $p^*(x)$, which is $p(x,D)=p^*(x)Z$.

We then have the following optimization objective:
$$
\begin{align}
J(q) &= KL(q||\tilde{p}) \\
     &= \sum_x p(x) log \frac {q(x)} {\tilde{p}} \\
     &= \sum_x p(x) log \frac {q(x)} {p^*(x)Z} \\
     &= \sum_x p(x) log \frac {q(x)} {p^*(x)} - logZ \\
     &= KL(q||p^*) - logZ
\end{align}
$$

Note that we are now optimizing our previous reverse KL divergence, but with an additional constant $logZ$.

Additionally, recall that $logZ=log\ p(D)$. Since KL divergence is always non-negative, $KL(q||p^*)-logZ \ge - log\ p(D)$. What this shows is that the optimization objective value that we achieve is an upper bound on the NLL of the data. Equivalently, it is a lower bound on the log likelihood of the data.
