<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_9_Variational_Inference_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Out of touch, or: Non-Conjugate Variational Inference
Last week, we discussed how to perform Bayesian inference when our exact posterior is computationally intractable. Specifically, Bayesian variational inference seeks to identify _approximating_ or _surrogate_ distributions $Q$ that are "close" in a KL-sense to the true posterior distribution, given by,
$$\newcommand{\data}{\text{Data}}\newcommand{\E}{\mathbb{E}}\newcommand{\ELBO}{\text{ELBO}}
\begin{align*}D_{KL}(Q(\theta | \data) || \Pr(\theta | \data)) &= \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta | \data) }\right]\\
&= -\ELBO(\theta) + \log \Pr(\data)\\
\ELBO(\theta) &:= -\E_Q[ \log Q(\theta | \data)] + \E_Q[\log \Pr(\data | \theta)] + \E_Q[\log \Pr(\theta)] \\
  &= \E_Q[\log \Pr(\data | \theta)] - \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta)}\right].
\end{align*}$$

Rather than evaluate $D_{KL}(Q(\theta | \data) || \Pr(\theta | \data))$, variational inference (often) focuses on maximizing (and evaluating) the $\ELBO$ term, which provides a lower bound on the marginal likelihood $\Pr(\data)$.

Before proceeding with optimization, we are required to specify structural independencies across latent variables $\theta_j$, to provide itermediate surrogates $Q_j$. A common factorization is the mean-field, given by,
$$\newcommand{\indep}{\perp \!\!\!\! \perp}Q(\theta) = \prod_{j=1}^p Q_j(\theta_j),$$ or, intuitively that each $\theta_j \indep \theta_{j'}$ for $j \neq j'$ under $Q$. There are certainly other options for how to factor $Q$ over latent variables (e.g., *structured* mean-field, etc), and trade-offs can sometimes be made over model/computational complexity and downstream accuracy, but often the simplest place to begin is the mean field.

Given a factorization for $Q$, CAVI seeks to identify the optimal $Q_j^*$, which tells us that,
$$\begin{align*}
\log Q_j^*(\theta_j) &= \E_Q\left[\log \Pr(\data | \theta) | \theta_j\right] + \E_Q\left[\log \Pr(\theta) | \theta_j \right].
\end{align*}$$
Here, we condition on $\theta_j$, and compute expectations with respect to $Q$ for _other_ variables $\theta_{j'}$.

Our derivation of the variational linear regression model seemed to have identifying $Q_j$ from "thin air", is there a systematic means to identify the functional form of $Q_j$?

## Conditional conjugacy and Exponential Families
Let's suppose that our prior distribution for $\theta_j$ is in the exponential family, $\Pr(\theta_j) \propto \exp(\lambda_j \cdot T_j(\theta_j))$ where $\lambda_j$ are the _natural_ parameters, $T_j(\theta_j)$ are the sufficient statistics, and assuming some constant base measure.

$$\begin{align*}
\log Q_j(\theta_j) &= \E_Q\left[\log \Pr(\data, \theta) | \theta_j\right] + O(1) \\
&= \E_Q\left[\log \Pr(\data | \theta) | \theta_j\right] + \log \Pr(\theta_j) + O(1) \\
&= \E_Q[\eta_j(\theta_{\neg j}, \data)] \cdot T_j(\theta_j)  + \lambda_j \cdot T_j(\theta_j) + O(1) ⇒\\
&= \E_Q[\eta_j(\theta_{\neg j}, \data) \cdot T_j(\theta_j)] + \lambda_j \cdot T_j(\theta_j) + O(1)\\
&= \E_Q[\eta_j(\theta_{\neg j}, \data)] \cdot T_j(\theta_j) + \lambda_j \cdot T_j(\theta_j) + O(1)\\
&= \underbrace{(\E_Q[\eta_j(\theta_{\neg j}, \data)] + \lambda_j)}_{\widetilde{\lambda}_j} \cdot T_j(\theta_j) + O(1)\\
Q_j(\theta_j) &\propto \exp\left(\widetilde{\lambda}_j \cdot T_j(\theta_j)\right),
\end{align*}$$
where $\E_Q[\eta_j(\theta_{\neg j}, \data)]$ is some function of the $\data$ and _other_ parameters $\theta_{\neg j}$.

## Example: Normal Regression Revisited
$\newcommand{\bX}{\mathbf{X}}\newcommand{\by}{\mathbf{y}}\newcommand{\bI}{\mathbf{I}}$
Recall our Bayesian linear regression problem,
$$\begin{align*}
\by | \bX, \beta &\sim N(\bX\beta, \bI_n \sigma^2) \\
\beta &\sim N(0, \bI_p \sigma^2_b).
\end{align*}$$
We sought to identify $Q(\beta) = \prod_j Q_j(\beta_j)$. We can re-write our CAVI update as,
$$\newcommand{\resid}{\mathbf{r}}\begin{align*}
\log Q_j^*(\beta_j) &= \frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j \beta_j -\frac{1}{2\sigma^2}\beta_j^2 X_j^T X_j
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  -\frac{1}{2\sigma^2} X_j^T X_j\right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \underbrace{\left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  -\frac{1}{2\sigma^2} X_j^T X_j\right]}_{\E_Q[\eta_j(\beta_{\neg j}, \data)]}\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
  \left[0, -\frac{1}{2 \sigma^2_b}\right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix} + O(1) \\
  &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j + 0,  
  -\frac{1}{2\sigma^2} X_j^T X_j -\frac{1}{2 \sigma^2_b} \right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
   + O(1) \\
   &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  
  -\frac{1}{2} \left(\frac{X_j^T X_j}{\sigma^2} + \frac{1}{ \sigma^2_b}\right) \right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
   + O(1).
\end{align*}$$
Recall the natural parameters for a Normal distribution $N(\mu, \sigma^2)$ are given by $\lambda = [\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}]$. We recognize the functional form above as
$$\begin{align*}
\log Q_j^*(\beta_j) &= \left[\widetilde{\mu} \cdot \frac{1}{\widetilde{\sigma}^2}, -\frac{1}{2} \cdot \frac{1}{\widetilde{\sigma}^2}\right]\cdot \begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix} + O(1) ⇒\\
Q_j^*(\beta_j) &:= N(\beta_j | \widetilde{\mu}, \widetilde{\sigma}^2).
\end{align*}$$.