<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_8_Variational_Inference_PtI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Running Up That Hill, or: Intro to Variational (Bayesian) Inference
$\newcommand{\data}{\text{Data}}$
$\newcommand{\E}{\mathbb{E}}$
Recall that in Bayesian inference, we seek to model the uncertainty in our estimates through a _posterior_ distribution. The posterior is derived from [Bayes' Theorem](https://en.wikipedia.org/wiki/Bayes%27_theorem) as,
$$\Pr(\theta | \data) = \frac{\Pr(\data | \theta) \Pr(\theta)}{\Pr(\data)},$$
where $\Pr(\theta | \data)$ is the [_posterior_ probability](https://en.wikipedia.org/wiki/Posterior_probability) for $\theta$ and reflects our uncertainty in the values that $\theta$ may take on, $\Pr(\data | \theta)$ is our likelihood, $\Pr(\theta)$ is a [_prior_ probability](https://en.wikipedia.org/wiki/Prior_probability) (or _prior_) over $\theta$ and $\Pr(\data)$ is a [_marginal_ probability/likelihood](https://en.wikipedia.org/wiki/Marginal_likelihood) of the data.

Last week, we explored this concept in the regime of "brute forcing" the posterior distribution for a simple exercise (e.g., calculating posterior probability an individual is sick, given a positive test) as well as a result in Exponential Families that leveraged conjugacy (e.g., posterior probability for a coin to land on "heads").

*What happens if our model does not have a simple or conjugate form?*

## Exact Approximate Inference 🤔

Rather than performing inference under an intractible exact posterior $\Pr(\theta | \data)$, we seek to perform inference using a surrogate distribution $Q(\theta | data)$ that is simpler. But how to identify or even quantify how good a proposed surrogate distribution $Q$ is?

Recall, [KL-Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) captures a notion of "[statistical distance](https://en.wikipedia.org/wiki/Statistical_distance)" between parameterized distribution functions, whose definition for discrete variables is,
$$
\begin{align*}
D_{KL}(q || p) &= \sum_{x \in \mathcal{X}} q(x) \log \frac{q(x)}{p(x)} = - \sum_{x \in \mathcal{X}} q(x) \log \frac{p(x)}{q(x)} \\
  &= -\mathbb{E}_{x \sim q}\left[\log \frac{p(x)}{q(x)} \right].
\end{align*}$$

For continuous $x \in \mathbb{R}$, we have,
$$\begin{align*}
D_{KL}(q || p) &= \int_{-\infty}^\infty q(x) \log \frac{q(x)}{p(x)}dx = -\int_{-\infty}^\infty q(x) \log \frac{p(x)}{q(x)}dx \\
  &= -\mathbb{E}_{x \sim q}\left[\log \frac{p(x)}{q(x)} \right].
\end{align*}$$

We can leverage this concept to measure how good a proposal surrogate $Q$ is compared to the true posterior by,
$$D_{KL}(Q(\theta | \data) || \Pr(\theta | \data))$$
however, we often don't know the functional form of $\Pr(\theta | \data)$ let alone compute it in intracable settings! Where are we left?

$$\newcommand{\ELBO}{\text{ELBO}}\begin{align*}
D_{KL}(Q(\theta | \data) || \Pr(\theta | \data)) &= \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta | \data) }\right] \\
  &= \E_Q\left[ \log \frac{Q(\theta | \data)\Pr(\data)}{\Pr(\data | \theta) \Pr(\theta)}\right] \\
  &= \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\data | \theta) \Pr(\theta)}\right] + \E_Q[\log \Pr(\data) ]\\
  &= \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\data | \theta) \Pr(\theta)}\right] + \log \Pr(\data) \\
  &= \underbrace{\E_Q[ \log Q(\theta | \data)] - \E_Q[\log \Pr(\data | \theta)] - \E_Q[\log \Pr(\theta)]}_{-\ELBO} + \log \Pr(\data) \geq 0 ⇒\\
-\ELBO \geq - \log \Pr(\data) \iff \ELBO \leq \log \Pr(\data).
\end{align*}$$

The implications of the above derivation suggest that we can maximize the [*evidence lower bound*](https://en.wikipedia.org/wiki/Evidence_lower_bound) or $\ELBO$ to minimize the KL-divergence between $Q$ and $\Pr(\theta | \data)$ up to a constant.

Given this, a helpful representation of the $\ELBO$ is to re-write it as,
$$\begin{align*}
\ELBO &:= -\E_Q[ \log Q(\theta | \data)] + \E_Q[\log \Pr(\data | \theta)] + \E_Q[\log \Pr(\theta)] \\
  &= \E_Q[\log \Pr(\data | \theta)] - \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta)}\right]\\
  &= \E_Q[\log \Pr(\data | \theta)] + D_{KL}(Q(\theta | \data) || \Pr(\theta)).
\end{align*}$$

Fantastic! We've specified how to evaluate our objective (i.e. $\ELBO$), but we haven't yet specified how to maximize it. Traditional calculus fails us, because $Q$ is a *function*, and not a fixed variable/parameter.

*How do we proceed?*

## Variational Inference
$\newcommand{\indep}{\perp \!\!\!\! \perp}$
NB: Calculus of Variations,

Before we leverage the algorithmic toolkit of calculus of variations, we need to specify the *structural* form of $Q$. Namely, how does $Q$ factorize with respect to our model variables? Typically a *mean-field approximation* is used, which implies that
$$Q(\theta) = \prod_{j=1}^p Q_j(\theta_j),$$ are intuitively that each $\theta_j \indep \theta_{j'}$ for $j \neq j'$ under $Q$.

To identify the optimal $Q_j^*$ we leverage concepts from calculus of variations, which tells us that,
$$\begin{align*}
\log Q_j^*(\theta_j) &= \E_Q\left[\log \Pr(\data | \theta) | \theta_j\right] + \E_Q\left[\log \Pr(\theta) | \theta_j \right].
\end{align*}$$

But to identify $Q_j^*$ we need to be able to compute expectations of _other_ approximate posteriors $Q_{j'}^*$, which suggests a cyclic algorithm, where we perform updates for one variable at a time (i.e. coordinate ascent variational inference), or CAVI for short.

Together the CAVI algorithm looks like, TBD: algo

## Example: Bayesian Linear Regression
$\newcommand{\bX}{\mathbf{X}}\newcommand{\by}{\mathbf{y}}\newcommand{\bI}{\mathbf{I}}$
Let's say we observe $y_i, x_i$ pairs where
$$\begin{align*}
y_i | x_i, \beta &\sim N(x_i^T \beta, \sigma^2) \\
\beta_j &\sim N(0, \sigma^2_b).
\end{align*}$$

If we observe $n$ pairs of $\{y_i, x_i\}$, we can re-write the above as
$$\begin{align*}
\by | \bX, \beta &\sim N(\bX\beta, \bI_n \sigma^2) \\
\beta &\sim N(0, \bI_p \sigma^2_b).
\end{align*}$$

Here, we'd like to identify an approximate posterior $Q(\beta)$, rather than its exact posterior $\Pr(\beta | \by)$. Let's take a mean-field approach and assume that $Q$ factorizes as $$Q(\beta) = \prod_{j=1}^P Q_j(\beta_j).$$

Next, we need to identify the optimal $Q_j^*$ for each $j$. Namely,
$$\begin{align*}
\log Q_j^*(\beta_j) &= \E_Q\left[\log \Pr(\by | \beta) | \beta_j\right] + \E_Q\left[\log \Pr(\beta) | \beta_j \right] \\
  &= \E_Q\left[\log N(\by | \bX \beta, \bI_n \sigma^2) | \beta_j\right]
  + \E_Q\left[\log N(\beta | 0, \bI_p \sigma^2_b) | \beta_j \right] \\
  &= \E_Q\left[\log N(\by | \bX \beta, \bI_n \sigma^2) | \beta_j\right]
  + \E_Q\left[-\frac{1}{2 \sigma^2_b}\sum_{j'=1}^P \beta_{j'}^2 \ | \ \beta_j \right] \\
  &= \E_Q\left[\log N(\by | \bX \beta, \bI_n \sigma^2) | \beta_j\right]
  + \E_Q\left[-\frac{1}{2 \sigma^2_b} \beta_j^2 \ | \ \beta_j \right] \\
  &= \E_Q\left[\log N(\by | \bX \beta, \bI_n \sigma^2) | \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \E_Q\left[-\frac{1}{2 \sigma^2}(\by - \bX \beta)^T(\by - \bX \beta) \ | \ \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
\end{align*}$$
$\newcommand{\resid}{\mathbf{r}}$
To proceed, let's first define a residual $\resid_j := \by - \sum_{j' \neq j} \bX_{j'}\beta_{j'}$ that reflects residualizing $\by$ by everything _except_ the $j$th term. Continuing we have,
$$\begin{align*}
\log Q_j^*(\beta_j) &=
  \E_Q\left[-\frac{1}{2 \sigma^2}(\by - \bX \beta)^T(\by - \bX \beta) \ | \ \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \E_Q\left[-\frac{1}{2 \sigma^2}(\resid_j - \bX_j \beta_j)^T(\resid_j - \bX_j \beta_j) \ | \ \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= -\frac{1}{2 \sigma^2}\E_Q\left[\resid_j^T\resid_j - 2\resid_j^T\bX_j \beta_j + \beta_j^2 X_j^T X_j \ | \ \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= -\frac{1}{2 \sigma^2}\E_Q\left[ - 2\resid_j^T\bX_j \beta_j + \beta_j^2 X_j^T X_j \ | \ \beta_j\right]
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j \beta_j -\frac{1}{2\sigma^2}\beta_j^2 X_j^T X_j
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j\beta_j -\frac{1}{2}\beta_j^2 \underbrace{\left(\frac{X_j^T X_j}{\sigma^2}
  + \frac{1}{\sigma^2_b}\right)}_{1 / \widetilde{\sigma}^2_j} + O(1) \\
  &= \frac{1}{\widetilde{\sigma}^2_j}\underbrace{\frac{\widetilde{\sigma}^2_j}{\sigma^2}\E_Q[\resid_j^T]\bX_j}_{\tilde{\mu}_j}\beta_j -\frac{\beta_j^2}{2\widetilde{\sigma}^2_j} + O(1) \\
  &= \frac{\tilde{\mu}_j\beta_j}{\widetilde{\sigma}^2_j} -\frac{\beta_j^2}{2\widetilde{\sigma}^2_j}
  - \frac{\tilde{\mu}_j^2}{2\widetilde{\sigma}^2_j} + O(1) \\
  &= \frac{(\beta_j - \tilde{\mu}_j)^2}{2 \widetilde{\sigma}^2_j} + O(1) \Rightarrow\\
Q_j^*(\beta_j) = N(\beta_j | \tilde{\mu}_j, \widetilde{\sigma}^2_j).
\end{align*}$$
This implies that the _optimal_ $Q_j^*(\beta_j)$ is a normal distribution with variational parameters $\tilde{\mu}_j, \widetilde{\sigma}^2_j$. Together we have,
$$Q(\beta) = \prod_j Q_j(\beta_j) = \prod_j N(\beta_j | \tilde{\mu}_j, \widetilde{\sigma}^2_j).$$

Comparing this to the _true_ posterior (which is tractable in this case) we have,
$$\begin{align*}
\Pr(\beta | \by) &= N(\beta | \widetilde{\beta}, \widetilde{\Sigma}) \\
\widetilde{\beta} &= \frac{1}{\sigma^2}\by^T\bX\widetilde{\Sigma} \\
\widetilde{\Sigma} &= (\bX^T\bX\frac{1}{\sigma^2} + \bI_p \frac{1}{\sigma^2_b})^{-1}\\
\end{align*}$$


In [None]:
# Let's code up the CAVI algorithm for bayesian linear regression

import jax
import jax.numpy as jnp
import jax.random as rdm