# Variational inference

## Overview

We have seen that computing the posterior $P(\pmb{Z}|\pmb{X})$ of latent variables $\pmb{Z}$ given observed data $\pmb{XP}$) for a model is an important task in EM (i.e. when computing the expectating of the posterior in the E step) and in all probabilistic model applications where the prior is not a conjugate for the likelihood.

It may happen that it is not feasible to evaluate the posterior or compute its expectation:

- High latent space dimensionality.
- Complex for of the posterior.
- Intractability of integrations in no closed form solutions for continuous variables.
- Exponentially large number of configurations for discrete settings.
- ...

In those cases, we need to rely on approximations of the posterior. There are two families of approaches:

- Stochastic approximations
- Deterministic approximations

|                              | Pros                              | Cons                          | Example                   |
| ---------------------------- | --------------------------------- | ----------------------------  | ------------------------  |   
| Stochastic approximations    | Exact results if run long enough  | Computationally expensive     | Markov chain Monte Carlo  |
| Deterministic approximations | Scale easily                      | Cannot generate exact results | Variational inference     |

We will be focusing on the family of **deterministic approximations**. Let's recall the E-step in the EM algorithm for i.i.d data $\pmb{X}$ and latent variables $\pmb{Z}$:

$$
\begin{align*}
\log P(\pmb{X} | \pmb{\theta}) = \mathcal{L}(q) + \mathcal{KL} \left(q(\pmb{Z}) || P(\pmb{Z}|\pmb{X})\right)
\end{align*}
$$

Note that the parameter $\pmb{\theta}$ is ommitted, since latent variables are considered now as random variables and are included in $\pmb{Z}$. Maximizing the lower wound is equivalent to minimize the KL divergence, which occur when $q$ equals to the posterior. As we assume that the posterior is not tractable, we must find an alternative: **restrict $q$ to a family of tractable distributions and then, minimize the KL divergence**.

The quality of the results is conditioned by the choice of $q$. For "wider" family of distributions, it is more complex to compute the variational inference (e.g. imagine we decide that $\mathcal{Q}$ is any distribution). However, if $\mathcal{Q}$ is too narrowed, there is a high chance that the posterior does not fall in that distribution (i.e. the smaller the distribution, the greater the approximation error). The chosen family must be flexible enough but still tractable.

Noe that there is **no overfiting** in this scenario: simply distributions that allows us to approach better the posterior.


## Summing up

The objective of variational inference algorithm is to estimate the posterior probability $P^{*}(\pmb{Z})$:

$$
P^{*}(\pmb{Z}) = P(\pmb{Z}|\pmb{X}) = \frac{P(\pmb{X}|\pmb{Z})P(\pmb{Z})}{P(\pmb{X})}
$$

Two steps are involved:

1. Select a family of distributions $\mathcal{Q}$ (i.e. variational family) where to obtain lower bound $q(z)$. Example:

$$
\mathcal{Q} \sim \mathcal{N}(\pmb{\mu}, \pmb{\sigma^2} \pmb{I})
$$

2. Find best approximation $q(\pmb{Z})$ of $P^{*}(\pmb{Z})$ through KL divergence:

$$
q(\pmb{Z}) = \underset{q^{\prime} \in \mathcal{Q}}{\mathrm{argmin}} \ \mathcal{KL} [ q^{\prime}(\pmb{Z}) || P^{*}(\pmb{Z}) ]
$$

Note that this second step is equivalent to maximizing the lower bound $\mathcal{L}(q)$.

## Mean field approximation

### Formula for lower bound maximization

The **mean field approximation** consists of selecting a family of distributions $\mathcal{Q}$ such that it factorizes with respect to disjoint groups of latent variables $\pmb{Z}_i$. Given $M$ groups:

$$
\mathcal{Q} = \left\{ q | q(\pmb{Z}) = \prod^{M}_{i=1} q_i(\pmb{Z}_i) \right\}
$$

Therefore, we assume independence between the factors groups. No additional restrictions are made. In order to optimize the lower bound, it must be optimized with respect to one of the factors at a time.

Let's replace the definition into the formula of the lower bound (we are using $q_i$ to denote $q_i(\pmb{Z}_i)$ for readability):

$$
\begin{align*}
\mathcal{L}(q) & = \int q \log \frac{P(\pmb{X}, \pmb{Z})}{q} d\pmb{Z} \\
               & = \int \prod_i^M q_i \left[ \log P(\pmb{Z}, \pmb{X}) - \log \prod_i^M q_i \right] d\pmb{Z} \\
               & = \int \prod_i^M q_i \log P(\pmb{X}, \pmb{Z}) - \int \prod_i^M q_i \sum_i^M \log q_i d\pmb{Z} \\
               & = \int \prod_i^M q_i \log P(\pmb{X}, \pmb{Z}) - \int \prod_i^M q_i \left[ \log q_1 + \ldots + q_k + \ldots + q_M \right] d\pmb{Z} \\
               & = \int \prod_i^M q_i \log P(\pmb{X}, \pmb{Z}) - \int \prod_i^M q_i \log q_k d\pmb{Z} - \int \prod_i^M q_i \sum_{i \neq k}^M \log q_i d\pmb{Z} \\
\end{align*}
$$

Then, as we want to optimize with respect to each factor $q_k$, let's write each of the terms with respect that. First:

$$
\begin{align*}
\int \prod_i^M q_i \log P(\pmb{X}, \pmb{Z}) & = \int q_1 \ldots \int q_k \ldots \int q_M \log P(\pmb{X}, \pmb{Z}) d\pmb{Z}_k \ldots d\pmb{Z}_k \ldots d\pmb{Z}_M \\
                                             & = \int q_k \left\{ \int \prod_{i \neq k}^M q_i \log P(\pmb{X}, \pmb{Z}) d\pmb{Z}_{i} \right\} d\pmb{Z}_k
\end{align*}
$$

Note that by definition:

$$
\mathbb{E}_p \left[ f(x) \right] = \int p(x) f(x) dx
$$

For distribution $P(q_1, \ldots, q_{k-1}, q_{k+1}, \ldots, q_M)$ and function $\log P(\pmb{X}, \pmb{Z})$, we have:

$$
\mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] = \int \prod_{i \neq k}^M q_i \log P(\pmb{X}, \pmb{Z}) d\pmb{Z}_{i}
$$

Then:

$$
\int \prod_i^M q_i \log P(\pmb{X}, \pmb{Z}) = \int q_k \mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] d\pmb{Z}_k
$$

For the second term, we have:

$$
\begin{align*}
- \int \prod_i^M q_i \log q_k d\pmb{Z} & = - \int q_k \log q_k d\pmb{Z}_k \int \prod_{i \neq k}^M q_i d\pmb{Z}_i  \\
                                       & = - \int q_k \log q_k d\pmb{Z}_k \\
\end{align*}
$$

Note that the second integral in the first line is equal to 1 by definition of probability distributions.

Finally, for the last term, we have:

$$
\begin{align*}
- \int \prod_i^M q_i \sum_{i \neq k}^M \log q_i d\pmb{Z} & = - \int q_k \prod_{i \neq k}^M q_i \sum_{i \neq k}^M \log q_i d\pmb{Z} \\
                                                         & = - \int q_k \left\{ \prod_{i \neq k}^M q_i \sum_{i \neq k}^M \log q_i d\pmb{Z}_i \right\} d\pmb{Z}_k \\
                                                         & = - \int q_k C \ d\pmb{Z}_k \\
\end{align*}
$$

Note that we can replace the term in curly brackets by a constant as it does not depend on $q_k$.

Then, we can rewrite:

$$
\begin{align*}
\mathcal{L}(q) & = \int q_k \mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] d\pmb{Z}_k - \int q_k \log q_k d\pmb{Z}_k - \int q_k \mathrm{const} \ d\pmb{Z}_k \\
               & = \int q_k \left( \mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] - C \right) d\pmb{Z}_k - \int q_k \log q_k d\pmb{Z}_k
\end{align*}
$$

Then, the lower bound is optimized for all possible distributions of $q_k$, leaving $\{q_i : i \neq k\}$ fixed. That is:

$$
\begin{align*}
\hat{q}_k & = \underset{q_k}{\mathrm{argmax}} \ \mathcal{L}(q) \\
          & = \underset{q_k}{\mathrm{argmax}} \int q_k \left( \mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] - C \right) d\pmb{Z}_k - \int q_k \log q_k d\pmb{Z}_k
\end{align*}
$$


Note that we can rewrite the first term as a probability distribution by using the exponential (and logarithm):

$$
\begin{align*}
\mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] - C & =
\log \exp \left\{ \mathbb{E}_{q_{i \neq j}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] - C \right\} \\
             & = \log \frac{\exp \left(\mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] \right)}{\exp(\mathrm{C})} \\
             & = \log \hat{P}(\pmb{X}, \pmb{Z})
\end{align*}
$$

And force $\exp(-\mathrm{C})$ to be the normalization factor. Then, we write:

$$
\begin{align*}
\hat{q}_k & = \underset{q_k}{\mathrm{argmax}} \ \int q_k \left( \mathbb{E}_{q_{i \neq k}} \left[ \log P(\pmb{X}, \pmb{Z}) \right] - C \right) d\pmb{Z}_k - \int q_k \log q_k d\pmb{Z}_k \\
          & = \underset{q_k}{\mathrm{argmax}} \ \int q_k \log \hat{P}(\pmb{X}, \pmb{Z}) d\pmb{Z}_k - \int q_k \log q_k d\pmb{Z}_k \\
          & = \underset{q_k}{\mathrm{argmax}} \ \int q_k \log \frac{\hat{P}(\pmb{X}, \pmb{Z})} {q_k} d\pmb{Z}_k \\
          & = \underset{q_k}{\mathrm{argmax}} \ - \mathcal{KL} \left[ q_k || \hat{P}(\pmb{X}, \pmb{Z}) \right] \\
          & = \underset{q_k}{\mathrm{argmin}} \ \mathcal{KL} \left[ q_k || \hat{P}(\pmb{X}, \pmb{Z}) \right]
\end{align*}
$$

We just have seen that maximizing the lower bound for $q_k$ is equivalent to minimizing the given KL, which occurs when $q_k = \hat{P}(\pmb{X}, \pmb{K})$ (i.e. when KL is zero).

In order to obtain $q$, we can iteratively optimize each of the components of $q$ independently:

1. First, we initialize each of the factors $q_i$.
1. Cycle through each of the factors replacing each $q_k$ with the nex estimate and keepint the other factors $q_{i \neq k}$ fixed.
1. Repeat previous step until convergence (it is guaranteed to converge as bound is convex wrt factors $q_i$).

## References

- Explanatory [video](https://www.youtube.com/watch?v=zQEhkNpBzS4) on deriving the mean field approximation
- 10.1 subsection from Bishop's Book.
- Chapter 21 from "Machine Learning: A Probabilistic Perspective" (Murphy, 2012)