<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_9_Variational_Inference_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Out of touch, or: Non-Conjugate Variational Inference
Last week, we discussed how to perform Bayesian inference when our exact posterior is computationally intractable. Specifically, Bayesian variational inference seeks to identify _approximating_ or _surrogate_ distributions $Q$ that are "close" in a KL-sense to the true posterior distribution, given by,
$$\newcommand{\data}{\text{Data}}\newcommand{\E}{\mathbb{E}}\newcommand{\ELBO}{\text{ELBO}}
\begin{align*}D_{KL}(Q(\theta | \data) || \Pr(\theta | \data)) &= \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta | \data) }\right]\\
&= -\ELBO(\theta) + \log \Pr(\data)\\
\ELBO(\theta) &:= -\E_Q[ \log Q(\theta | \data)] + \E_Q[\log \Pr(\data | \theta)] + \E_Q[\log \Pr(\theta)] \\
  &= \E_Q[\log \Pr(\data | \theta)] - \E_Q\left[ \log \frac{Q(\theta | \data)}{\Pr(\theta)}\right].
\end{align*}$$

Rather than evaluate $D_{KL}(Q(\theta | \data) || \Pr(\theta | \data))$, variational inference (often) focuses on maximizing (and evaluating) the $\ELBO$ term, which provides a lower bound on the marginal likelihood $\Pr(\data)$.

Before proceeding with optimization, we are required to specify structural independencies across latent variables $\theta_j$, to provide itermediate surrogates $Q_j$. A common factorization is the mean-field, given by,
$$\newcommand{\indep}{\perp \!\!\!\! \perp}Q(\theta) = \prod_{j=1}^p Q_j(\theta_j),$$ or, intuitively that each $\theta_j \indep \theta_{j'}$ for $j \neq j'$ under $Q$. There are certainly other options for how to factor $Q$ over latent variables (e.g., *structured* mean-field, etc), and trade-offs can sometimes be made over model/computational complexity and downstream accuracy, but often the simplest place to begin is the mean field.

Given a factorization for $Q$, CAVI seeks to identify the optimal $Q_j^*$, which tells us that,
$$\begin{align*}
\log Q_j^*(\theta_j) &= \E_Q\left[\log \Pr(\data | \theta) | \theta_j\right] + \E_Q\left[\log \Pr(\theta) | \theta_j \right].
\end{align*}$$
Here, we condition on $\theta_j$, and compute expectations with respect to $Q$ for _other_ variables $\theta_{j'}$.

Our derivation of the variational linear regression model seemed to have identifying $Q_j$ from "thin air", is there a systematic means to identify the functional form of $Q_j$?

## Conditional conjugacy and Exponential Families
Let's suppose that our prior distribution for $\theta_j$ is in the exponential family, $\Pr(\theta_j) \propto \exp(\lambda_j \cdot T_j(\theta_j))$ where $\lambda_j$ are the _natural_ parameters, $T_j(\theta_j)$ are the sufficient statistics, and assuming some constant base measure.

$$\begin{align*}
\log Q_j(\theta_j) &= \E_Q\left[\log \Pr(\data, \theta) | \theta_j\right] + O(1) \\
&= \E_Q\left[\log \Pr(\data | \theta) | \theta_j\right] + \log \Pr(\theta_j) + O(1) \\
&= \E_Q[\eta_j(\theta_{\neg j}, \data)] \cdot T_j(\theta_j)  + \lambda_j \cdot T_j(\theta_j) + O(1) ⇒\\
&= \E_Q[\eta_j(\theta_{\neg j}, \data) \cdot T_j(\theta_j)] + \lambda_j \cdot T_j(\theta_j) + O(1)\\
&= \E_Q[\eta_j(\theta_{\neg j}, \data)] \cdot T_j(\theta_j) + \lambda_j \cdot T_j(\theta_j) + O(1)\\
&= \underbrace{(\E_Q[\eta_j(\theta_{\neg j}, \data)] + \lambda_j)}_{\widetilde{\lambda}_j} \cdot T_j(\theta_j) + O(1)\\
Q_j(\theta_j) &\propto \exp\left(\widetilde{\lambda}_j \cdot T_j(\theta_j)\right),
\end{align*}$$
where $\E_Q[\eta_j(\theta_{\neg j}, \data)]$ is some function of the $\data$ and _other_ parameters $\theta_{\neg j}$.

## Example: Normal Regression Revisited
$\newcommand{\bX}{\mathbf{X}}\newcommand{\by}{\mathbf{y}}\newcommand{\bI}{\mathbf{I}}$
Recall our Bayesian linear regression problem,
$$\begin{align*}
\by | \bX, \beta &\sim N(\bX\beta, \bI_n \sigma^2) \\
\beta &\sim N(0, \bI_p \sigma^2_b).
\end{align*}$$
We sought to identify $Q(\beta) = \prod_j Q_j(\beta_j)$. We can re-write our CAVI update as,
$$\newcommand{\resid}{\mathbf{r}}\begin{align*}
\log Q_j^*(\beta_j) &= \frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j \beta_j -\frac{1}{2\sigma^2}\beta_j^2 X_j^T X_j
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  -\frac{1}{2\sigma^2} X_j^T X_j\right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
  -\frac{\beta_j^2}{2 \sigma^2_b} + O(1) \\
  &= \underbrace{\left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  -\frac{1}{2\sigma^2} X_j^T X_j\right]}_{\E_Q[\eta_j(\beta_{\neg j}, \data)]}\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
  \left[0, -\frac{1}{2 \sigma^2_b}\right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix} + O(1) \\
  &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j + 0,  
  -\frac{1}{2\sigma^2} X_j^T X_j -\frac{1}{2 \sigma^2_b} \right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
   + O(1) \\
   &= \left[\frac{1}{\sigma^2}\E_Q[\resid_j^T]\bX_j,  
  -\frac{1}{2} \left(\frac{X_j^T X_j}{\sigma^2} + \frac{1}{ \sigma^2_b}\right) \right]\begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix}
   + O(1).
\end{align*}$$
Recall the natural parameters for a Normal distribution $N(\mu, \sigma^2)$ are given by $\lambda = [\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}]$. We recognize the functional form above as
$$\begin{align*}
\log Q_j^*(\beta_j) &= \left[\widetilde{\mu} \cdot \frac{1}{\widetilde{\sigma}^2}, -\frac{1}{2} \cdot \frac{1}{\widetilde{\sigma}^2}\right]\cdot \begin{bmatrix} \beta_j \\ \beta_j^2 \end{bmatrix} + O(1) ⇒\\
Q_j^*(\beta_j) &:= N(\beta_j | \widetilde{\mu}, \widetilde{\sigma}^2).
\end{align*}$$.

In [None]:
# Let's code up the CAVI algorithm for bayesian linear regression
# But this time using the _natural_ parameter form

import jax
import jax.numpy as jnp
import jax.random as rdm

MAX_ITER = 10

N = 500
P = 250
sigma_sq = 0.8
sigma_sq_b = 0.1

seed = 0
key = rdm.PRNGKey(seed)
key, x_key, b_key, y_key = rdm.split(key, 4)

X = rdm.normal(x_key, shape=(N, P))
beta = jnp.sqrt(sigma_sq_b) * rdm.normal(b_key, shape=(P,))
y = X @ beta + jnp.sqrt(sigma_sq) * rdm.normal(y_key, shape=(N,))

prior_nat_1 = 0.
prior_nat_2 = -0.5 / sigma_sq_b
post_nat_1 = jnp.ones((P,)) * prior_nat_1
post_nat_2 = jnp.ones((P,)) * prior_nat_2

#NB: in jax to update an array position j, we need the `set` function which
# looks like post_means = post_means.at[j].set(new_value)
for _iter in range(MAX_ITER):
  post_means = post_nat_1 / (-2 * post_nat_2)
  resid = y - X @ post_means
  for j in range(P):
    Xj = X[:,j]
    # TODO: update model parameters
    pass


  # typically we would evaluate the ELBO here, but that is left as HW exercise...
  # todo: add either ELBO or MSE to sanity check

## Non-conjugate Variational Inference
The above CAVI derivations assume that our surrogate models are the result of conditional conjugacy between the expected log likelihood and the prior. What happens if we are unable to derive updates under those assumptions? In other words,

> _How can we perform variational inference when conditional conjugacy doesn't apply?_

Just like in CAVI, we begin by assuming our surrogate posterior factorizes in some form. For now, let's assume a mean-field,
$$Q(\theta) = \prod_{j=1}^p Q_j(\theta_j).$$

However unlike before, where we identified $Q_j^*$ under conditional conjugacy, we are going to _assume_ that $Q_j(\theta_j)$ is in the _same_ exponential family as its corresponding prior $\Pr(\theta_j)$.

## Poisson regression revisited
$$\begin{align*}
y_i | x_i &\sim \text{Poi}(\lambda_i)\\
\lambda_i &:= \exp(x_i^T \beta)\\
\beta &\sim N(0, \sigma^2_b),
\end{align*}$$ where $\text{Poi}(k | \lambda) := \frac{\lambda^k \exp(-\lambda)}{k!}$ is the [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) of the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution).

We'd like to identify an approximation distribution $Q$ that is close in a KL-sense to $\Pr(\beta | \mathbf{y})$. Let's assume that $Q$ factorizes as $Q(\beta) = \prod_{j=1}^p Q_j(\beta_j)$, and furthermore that $Q_j := N(\mu_j, \sigma^2_j)$

Recall, the ELBO is given by $$\text{ELBO} := E_Q[\log \ell(\beta | \mathbf{y})] - D_{KL}(Q(\beta) || \Pr(\beta)),$$
where $\log \ell(\beta | \mathbf{y}) = \sum_{i=1}^n \log \text{Poi}(\exp(x_i^T \beta)).$

In order to evaluate, let alone optimize the ELBO we need to figure out what $E_Q[\log \ell(\beta | \mathbf{y})]$ is.
$$\begin{align*}
E_Q[\log \ell(\beta | \mathbf{y})] &= E_Q\left[\sum_{i=1}^n \log \text{Poi}(\exp(x_i^T \beta))\right] \\
&= \sum_{i=1}^n E_Q\left[\log \text{Poi}(\exp(x_i^T \beta))\right]\\
&= \sum_{i=1}^n E_Q\left[y_i\cdot (x_i^T\beta) - \exp(x_i^T\beta)\right]\\
&= \sum_{i=1}^n y_i \cdot x_i^T E_Q[\beta] - \sum_{i=1}^n \E_Q[\exp(x_i^T\beta)]\\
&= \sum_{i=1}^n y_i \cdot x_i^T E_Q[\beta] - \sum_{i=1}^n \exp(x_i^T E_Q[\beta] + \frac{1}{2}x_i^T V_Q[\beta] x_i )\\
&= \mathbf{y}^T X E_Q[\beta] - \sum_{i=1}^n \exp(x_i^T E_Q[\beta] + \frac{1}{2}x_i^T V_Q[\beta] x_i ) \\
&= \mathbf{y}^T X \mathbf{\mu} - \sum_{i=1}^n \exp(x_i^T \mathbf{\mu} + \frac{1}{2}x_i^T \text{diag}(\mathbf{\sigma}^2) x_i ).
\end{align*}$$

Recall that
$$\begin{align*}
D_{KL}(Q(\beta) || \Pr(\beta)) &= \sum_{j=1}^p D_{KL}(Q(\beta_j) || \Pr(\beta_j)) \\
&= \sum_{j=1}^p D_{KL}(N(\mu_j, \sigma^2_j) || N(0, \sigma^2_b)) \\
&= \sum_{j=1}^p \left[\frac{(\mu_j^2 + \sigma^2_j)}{2 \sigma^2_b} + \frac{1}{2}\left[\log \sigma^2_j + \log \sigma^2_b - 1\right]\right].
\end{align*}$$

Putting these piece together, we have,
$$\begin{align*}
\text{ELBO} &:= \mathbf{y}^T X \mathbf{\mu}
  - \sum_{i=1}^n \exp(x_i^T \mathbf{\mu} + \frac{1}{2}x_i^T \text{diag}(\mathbf{\sigma}^2) x_i )
  - \sum_{j=1}^p \left[\frac{(\mu_j^2 + \sigma^2_j)}{2 \sigma^2_b} + \frac{1}{2}\left[\log \sigma^2_j + \log \sigma^2_b - 1\right]\right] \\
&= \mathbf{y}^T X \mathbf{\mu}
  - \sum_{i=1}^n \exp\left(x_i^T \mathbf{\mu} + \frac{1}{2}\sum_{j=1}^p x_{ij}^2 \sigma^2_j\right)
  - \sum_{j=1}^p \left[\frac{(\mu_j^2 + \sigma^2_j)}{2 \sigma^2_b} + \frac{1}{2}\left[\log \sigma^2_j + \log \sigma^2_b - 1\right]\right].
\end{align*}$$

Now, let's figure out our gradients over variational parameters $\newcommand{\bmu}{\mathbf{\mu}}\newcommand{\bsigma}{\mathbf{\sigma}}\bmu, \bsigma^2$.
$$\begin{align*}
\nabla_\bmu \text{ELBO} &= \nabla_\bmu \mathbf{y}^T X\bmu
  - \sum_{i=1}^n \nabla_\bmu \exp\left(x_i^T \mathbf{\mu} + \frac{1}{2}\sum_{j=1}^p x_{ij}^2 \sigma^2_j\right)
  - \sum_{j=1}^p \nabla_\bmu \left[\frac{(\mu_j^2 + \sigma^2_j)}{2 \sigma^2_b} + \frac{1}{2}\left[\log \sigma^2_j + \log \sigma^2_b - 1\right]\right] \\
&= X^T \mathbf{y} - \sum_{i=1}^n \exp\left(x_i^T \mathbf{\mu} + \frac{1}{2}\sum_{j=1}^p x_{ij}^2 \sigma^2_j\right) x_i
- \bmu / \sigma^2_b\\
\nabla_{\bsigma^2} \text{ELBO} &= \nabla_{\bsigma^2}  \mathbf{y}^T X\bmu
  - \sum_{i=1}^n \nabla_{\bsigma^2}  \exp\left(x_i^T \mathbf{\mu} + \frac{1}{2}\sum_{j=1}^p x_{ij}^2 \sigma^2_j\right)
  - \sum_{j=1}^p \nabla_{\bsigma^2}  \left[\frac{(\mu_j^2 + \sigma^2_j)}{2 \sigma^2_b} + \frac{1}{2}\left[\log \sigma^2_j + \log \sigma^2_b - 1\right]\right] \\
&= - \frac{1}{2}\sum_{i=1}^n \exp\left(x_i^T \mathbf{\mu} + \frac{1}{2}\sum_{j=1}^p x_{ij}^2 \sigma^2_j\right) (x_i \circ x_i)
  - \frac{1}{2 \sigma^2_b} + \frac{1}{2 \bsigma^2}.
\end{align*}$$

Despite this model, not exhibiting conjugacy, we were able to a derive closed form for the ELBO as well as its gradients wrt variational parameters, which is uncommon!

In [None]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.scipy.stats as stats



def elbo(mu, sigma_sq, prior_var, y, X):
  """
  Our ELBO
  """
  ...

def compute_gradients(mu, sigma_sq, prior_var, y, X):
  """
  Update variational params
  """
  ...


def poiss_reg(y, X, prior_var=1e-3, step_size = 1.0, max_iter=100, tol=1e-3):
  """

  """
  n, p = X.shape

  # fake bookkeeping
  elbo = -100000
  delta = 10000

  mu = jnp.zeros(p)
  sigma_sq = jnp.ones(p) * prior_var

  for epoch in range(max_iter):

    # fit using our function
    grad_mu, grad_sigma_sq = compute_gradients(mu, sigma_sq, prior_var, y, X)

    # use gradient descent
    mu = mu + step_size * grad_mu
    sigma_sq = sigma_sq + step_size * sigma_sq

    # evaluate ELBO
    newelbo = elbo(mu, sigma_sq, prior_var, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newelbo - elbo)
    print(f"ELBO[{epoch}] = {newelbo}")
    if delta < tol:
      break

    # replace old value
    elbo = newelbo

  return beta