# Variational Inference

Variational inference is an algorithm to sample the posterior distribution with an unknown normalization factor $Z(y)$

$$ 
\begin{align*}
\rho_{\rm post}(\theta | y) = \frac{\rho(\theta, y)}{Z(y)} =  \frac{\rho_{\rm prior}(\theta) \rho(y | \theta)}{Z(y)} \propto e^{-\Phi_R(\theta; y)}
\end{align*}
$$

We consider the case that the conditional probability $\rho(y | \theta)$ and the prior $\rho_{\rm prior}(\theta)$ are easy to compute.

## Basic variational inference algorithm

The basic idea of variational inference is to find a simpler distribution $q_{\lambda}(\theta)$, which is parameterized by $\lambda$, to approximate the original
distribution $\rho_{\rm post}(\theta | y)$

$KL$ divergence is widely used to measure the distance between these distributions, 
$$
\begin{align*}
KL\Bigl[q_{\lambda}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr] &= \int q_{\lambda}(\theta)  \log \frac{q_{\lambda}(\theta)}{\rho_{\rm post}(\theta | y)} d\theta \\
&= \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \log \frac{q_{\lambda}(\theta)}{\rho_{\rm post}(\theta | y)} \Bigr]
\end{align*}
$$

The goal is obtain an optimal $\lambda$, which minimizes the $KL$ divergence.  A natural idea is to use gradient descent method

$$
\begin{align*}
\nabla_{\lambda} KL\Bigl[q_{\lambda}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr] &= \nabla_{\lambda} \int q_{\lambda}(\theta)  \log \frac{q_{\lambda}(\theta)}{\rho_{\rm post}(\theta | y)} d\theta \\
&=  \int \nabla_{\lambda} q_{\lambda}(\theta)  \Bigl( \log q_{\lambda}(\theta) - \log \rho_{\rm prior}(\theta) - \log \rho (y|\theta)\Bigr) \\
&=  \mathbb{E}_{\theta \sim q_{\lambda}(\theta)} \Bigl[ \nabla_{\lambda} \log q_{\lambda}(\theta)  \Bigl( \log q_{\lambda}(\theta) - \log \rho_{\rm prior}(\theta) - \log \rho (y|\theta)\Bigr) \Bigr]
\end{align*}
$$

Here we use the fact 
$$
\begin{align*}
\int \nabla_{\lambda} q_{\lambda}(\theta) d\theta  = 0
\end{align*}
$$
It is worth noticing that the gradient does not depend on the unknown normalization factor $Z(y)$. And the expectation can be approximated by Monte Carlo methods.


### Evidence lower bound
$KL$ divergence can be written as 
$$
\begin{align*}
KL\Bigl[q_{\lambda}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr] &= \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \log \frac{q_{\lambda}(\theta)}{\rho_{\rm post}(\theta | y)} \Bigr] \\
&= Z_y - \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \log \frac{\rho(\theta, y)}{q_{\lambda}(\theta)} \Bigr] 
\end{align*}
$$

The evidence lower bound $ELBO(\lambda)$ is defined as 
$$
\begin{align*}
ELBO(\lambda) = \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \log \frac{\rho(\theta, y)}{q_{\lambda}(\theta)} \Bigr] 
\end{align*}
$$

Therefore, minimizing $KL$ divergence is equivalent to maximizing $ELBO(\lambda)$. And they have the same gradient (with different signs) with respect to $\lambda$.

And we have 
$$
\begin{align*}
\log \rho(y) \geq ELBO(\lambda) = \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \log \frac{\rho(\theta, y)}{q_{\lambda}(\theta)} \Bigr] 
\end{align*}
$$

## Mean field approximation

We assume that the parameterized distribution has some nice structures as following

$$
\begin{align*}
q_{\lambda}(\theta) = \Pi_{i=1}^{m}q_{\lambda_i}(\theta_i)
\end{align*}
$$

$KL$ divergence becomes
$$
\begin{align*}
KL\Bigl[q_{\lambda}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr] 
&= \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \sum_{i=1}^{m} \log q_{\lambda_i}(\theta_i) - \log \rho_{\rm post}(\theta | y) \Bigr]
\end{align*}
$$

This can be minimized with the coordinate descent method, namely sequentially minimize each $\lambda_i$. The $KL$ divergence can be rewritten as 

$$
\begin{align*}
KL\Bigl[q_{\lambda}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr] 
&= \mathbb{E}_{\theta \sim q_{\lambda}(\theta)}  \Bigl[ \sum_{i=1}^{m} \log q_{\lambda_i}(\theta_i) - \log \rho_{\rm post}(\theta | y) \Bigr]\\
&= \sum_{i=1}^{m} \mathbb{E}_{\theta_i \sim q_{\lambda_i}(\theta_i)}\log q_{\lambda_i}(\theta_i) - \mathbb{E}_{\theta \sim q_{\lambda}(\theta)} \log \Bigl(\rho_{\rm post}(\theta_{-i_0} | y)\rho_{\rm post}(\theta_{i_0} | \theta_{-i_0} , y) \Bigr) \\
&=  \mathbb{E}_{\theta_{i_0} \sim q_{\lambda_{i_0}}(\theta_{i_0})}\Bigl[ \log q_{\lambda_{i_0}}(\theta_{i_0}) - 
\mathbb{E}_{\theta_{-i_0} \sim q_{\lambda_{-i_0}}(\theta_{-i_0})} \log \rho_{\rm post}(\theta_{i_0} | \theta_{-i_0} , y)
\Bigr] + C
\end{align*}
$$
here other terms in $C$ are independent of $\lambda_{i_0}$. Let denote 
$$
h_{i_0}(\cdot) = \exp\Bigl( \mathbb{E}_{\theta_{-i_0} \sim q_{\lambda_{-i_0}}(\theta_{-i_0})} \log \rho_{\rm post}(\cdot | \theta_{-i_0} , y) \Bigr)
$$

The optimal solution satisfies
$$
q_{\lambda_{i_0}}(\cdot) \propto h_{i_0}(\cdot)
$$

## Stein variational gradient descent [2]

### Preliminary
Let denote functional $\mathcal{F}$ on the density space $\mathcal{P}(\Omega) = \{\rho \in \mathcal{C}^{\infty}(\Omega) \cap L_1(\Omega), \rho > 0 \textrm{ a.e.}, \int \rho dx = 1\} $
We can define the tangent space of $\mathcal{P}$ at $\rho$ as 

$$
\begin{align*}
T_{\rho}\mathcal{P} = \bigl\{\frac{d}{dt}\rho(t)\Big|_{t=0} : \rho(t) \textrm{ is a curve in } \mathcal{P}, \rho(0) = \rho \bigr\} = \bigl\{\sigma \in \mathcal{C}^{\infty}(\Omega), \int \sigma dx = 0 \bigr\}
\end{align*}
$$

To further define the functional derivative $\frac{\delta \mathcal{E}}{\delta \rho}$ of $\mathcal{E}(\rho)$ at $\rho$, we need to introduce the cotangent space at $\rho$, $T_{\rho}^{*}\mathcal{P}$, which is linear space on $T_{\rho}\mathcal{P}$, since 

$$
\frac{\delta (\mathcal{E}_1 + \mathcal{E}_2)}{\delta \rho} \delta = 
\frac{\delta \mathcal{E}_1 }{\delta \rho} \delta + 
\frac{\delta \mathcal{E}_2}{\delta \rho}\delta
$$

The metric tensor at the point $\rho$ is defined as 

$$
G(\rho) : T_{\rho}\mathcal{P} \rightarrow T_{\rho}^{*}\mathcal{P}
$$

This induces the inner product on $T_{\rho}\mathcal{P}$

$$
g_{\rho}(\sigma_1, \sigma_2) = \int \sigma_1 G(\rho) \sigma_2  dx  = \int \Phi_1 G(\rho)^{-1} \Phi_2  dx
$$
where $\sigma_i \in T_{\rho}\mathcal{P}$ and $\Phi_i = G(\rho)\sigma_i \in T_{\rho}^{*}\mathcal{P}$. This can be used to define the steepest descent direction in density function space.


The corresponding distance between 2 densities $\rho^1$ and $\rho^2$ is defined as 

$$
\begin{align*}
&{\rm dist}(\rho^1, \rho^2)^2 = \min_{\sigma \in T_{\rho}\mathcal{P}} \int\int_{t=0}^{t=1} g_{\rho}(\sigma, \sigma) dt dx \\
&s.t. \qquad \frac{\partial \rho_t}{\partial t} =  \sigma \qquad \rho_0 = \rho^1 \quad \rho_1 = \rho^2
\end{align*}
$$



### Stein variational gradient descent 
Let define the KL divergence function on the density space $\mathcal{P}(\Omega)$
$$
\begin{align*}
\mathcal{E}(\rho_t) = KL\Bigl[\rho_{t} \Big\Vert  \rho_{\rm post}(\theta | y)\Bigr] = \int\rho_t\log \rho_{\rm post}(\theta | y)d\theta - \int\rho_t\log \rho_t d\theta
\end{align*}
$$

We can define the functional derivative 

$$
\begin{align*}
\frac{\delta \mathcal{E}}{\delta \rho_t} \sigma 
&= \int \sigma \log \rho_{\rm post}(\theta | y)d\theta - \int \sigma \log \rho_td\theta - \int\sigma \rho_t\delta \log\rho_t  d\theta \\
&= \int \sigma \Bigl(\log \rho_{\rm post}(\theta | y) - \log \rho_t - 1 \Bigr) d\theta
\end{align*}
$$

Following Riesz representationa theorem, we have

$$
\begin{align*}
\frac{\delta \mathcal{E}}{\delta \rho_t} 
= \log \rho_{\rm post}(\theta | y) - \log \rho_t - 1
\end{align*}
$$


Let define the metric tensor $G(\rho)$, then the steepest gradient flow to minimize $\mathcal{E}$ becomes

$$\frac{\partial \rho_t}{\partial t} = - G(\rho_t)^{-1} \frac{\delta \mathcal{E}}{\delta \rho_t}$$

since

$$\frac{\partial \mathcal{E}}{\partial t} = \frac{\delta \mathcal{E}}{\delta \rho_t} \frac{\partial \rho_t}{\partial t} = - \frac{\delta \mathcal{E}}{\delta \rho_t} G(\rho_t)^{-1} \frac{\delta \mathcal{E}}{\delta \rho_t} = g_{\rho}\bigl(G(\rho_t)^{-1} \frac{\delta \mathcal{E}}{\delta \rho_t}, G(\rho_t)^{-1} \frac{\delta \mathcal{E}}{\delta \rho_t}\bigr) $$



#### 2-Wasserstein metric 

$$
G(\rho)^{-1} \Phi = - \nabla\cdot(\rho\nabla\Phi), \qquad  \Phi \in T_{\rho}^{*}\mathcal{P}
$$

The distance between $\rho^1$ and $\rho^2$ becomes

$$
\begin{align*}
{\rm dist}(\rho^1, \rho^2)^2 &= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int\int_{t=0}^{t=1} g_{\rho}(G(\rho)^{-1} \Phi, G(\rho)^{-1} \Phi) dt d\theta \\
&= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int_{t=0}^{t=1} \int \rho_t \nabla \Phi \cdot \nabla \Phi d\theta dt\\
&s.t. \qquad \frac{\partial \rho_t}{\partial t} + \nabla\cdot(\rho_t\nabla\Phi) = 0  \qquad \rho_0 = \rho^1 \quad \rho_1 = \rho^2
\end{align*}
$$

which corresponds to the Wasserstein distance.

The density evolution equation is

$$
\begin{align*}
\frac{\partial \rho_t(\theta)}{\partial t} 
&= -\nabla_{\theta} \cdot \Bigl[\rho_t \cdot \nabla_{\theta}\bigl(\log \rho_{\rm post}(\theta | y) - \log \rho_t - 1\bigr) \Bigr]\\
&= -\nabla_{\theta} \cdot \Bigl[\rho_t \cdot \nabla_{\theta}\log \rho(\theta,y)\Bigr] + \nabla_{\theta}\Bigl[ \nabla_{\theta} \rho_t \Bigr]
\end{align*}
$$

#### Stein metric

$$
G(\rho)^{-1} \Phi = - \nabla\cdot\Bigl(\rho\int \kappa(\theta,\theta')\rho(\theta')\nabla_{\theta'}\Phi(\theta')d\theta' \Bigr), \qquad  \Phi \in T_{\rho}^{*}\mathcal{P}
$$

The distance between $\rho^1$ and $\rho^2$ becomes

$$
\begin{align*}
{\rm dist}(\rho^1, \rho^2)^2 &= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int\int_{t=0}^{t=1} g_{\rho}(G(\rho)^{-1} \Phi, G(\rho)^{-1} \Phi) dt dx \\
&= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int_{t=0}^{t=1} \int \int \kappa(\theta, \theta') \rho(\theta)\rho(\theta') \nabla_{\theta'}\Phi(\theta') \nabla_{\theta} \Phi(\theta) d\theta d\theta' dt\\
&s.t. \qquad \frac{\partial \rho_t}{\partial t} + \nabla\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\Phi(\theta')d\theta' \Bigr) = 0 \qquad \rho_0 = \rho^1 \quad \rho_1 = \rho^2
\end{align*}
$$

The density evolution equation becomes

$$
\begin{align*}
\frac{\partial \rho_t(\theta)}{\partial t} 
&= \nabla_{\theta}\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\bigl(\log \rho_{\rm post}(\theta' | y) - \log \rho_t(\theta') - 1\bigr)d\theta' \Bigr) \\
&= -\nabla_{\theta}\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\bigl(\log \rho(\theta' , y) - \log \rho_t(\theta')\bigr)d\theta' \Bigr) \\
&= \nabla_{\theta}\cdot\Bigl(\rho_t F(\theta) \Bigr)
\end{align*}
$$

where the drift term (no diffusion term) is

$$
\begin{align*}
F(\theta) &= \int \bigl(-\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y) - \nabla_{\theta} \rho_t(\theta')\bigr)  \kappa(\theta, \theta')d\theta'\\
          &= \int -\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y)\kappa(\theta, \theta') + \nabla_{\theta}\kappa(\theta, \theta')  \rho_t(\theta')  d\theta'\\
          &\approx  \frac{1}{J}\sum_{j=1}^{J}-\nabla_{\theta} \Phi_R(\theta^j;y)\kappa(\theta, \theta^j) + \nabla_{\theta}\kappa(\theta, \theta^j)
\end{align*}
$$


#### Fisher-Rao metric

$$
G(\rho)^{-1} \Phi = (\Phi - \mathbb{E}_\rho[\Phi])\rho, \qquad  \Phi \in T_{\rho}^{*}\mathcal{P}
$$

The distance between $\rho^1$ and $\rho^2$ becomes

$$
\begin{align*}
{\rm dist}(\rho^1, \rho^2)^2 &= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int\int_{t=0}^{t=1} g_{\rho}(G(\rho)^{-1} \Phi, G(\rho)^{-1} \Phi) dt dx \\
&= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int_{t=0}^{t=1} \int \rho \Phi (\Phi - \mathbb{E}_\rho[\Phi]) d\theta dt\\
&s.t. \qquad \frac{\partial \rho_t}{\partial t} - \rho_t (\Phi - \mathbb{E}_{\rho_t}[\Phi]) = 0 \qquad \rho_0 = \rho^1 \quad \rho_1 = \rho^2
\end{align*}
$$

The density evolution equation becomes

$$
\begin{align*}
\frac{\partial \rho_t(\theta)}{\partial t} 
&= \rho_t (\frac{\delta \mathcal{E}}{\delta \rho_t} - \mathbb{E}_{\rho_t}[\frac{\delta \mathcal{E}}{\delta \rho_t}]) \\
&= \rho_t (\log \rho(\theta, y) - \log \rho_t(\theta) - \mathbb{E}_{\rho_t}[\log \rho(\theta, y) - \log \rho_t(\theta)])
\end{align*}
$$

We have 

$$
\begin{align*}
\frac{\partial \log \rho_t(\theta)}{\partial t} 
&= \log \rho(\theta, y) - \log \rho_t(\theta) - \mathbb{E}_{\rho_t}[\log \rho(\theta, y) - \log \rho_t(\theta)]\\
\frac{\partial e^t \log \rho_t(\theta)}{\partial t} 
&= e^t\log \rho(\theta, y) - e^t\mathbb{E}_{\rho_t}[\log \rho(\theta, y) - \log \rho_t(\theta)]\\
\log \rho_t(\theta)
&= (1-e^{-t})\log \rho(\theta, y) - C_t \qquad C_t = \int_0^{t} e^{\tau-t}\mathbb{E}_{\rho_\tau}[\log \rho(\theta, y) - \log \rho_\tau(\theta)] d\tau - e^{-t}\log\rho_0
\end{align*}
$$

Therefore, we have the "fixed point iteration"

$$
\begin{align*}
\rho_t \propto \rho(\theta, y)^{1 - e^{-t}}
\end{align*}
$$


### Connection to Langevin dynamics

Consider the following [initial value Ito process](Langevin.ipynb)

$$
\begin{align*}
d\theta_t = F(t, \theta_t) dt \qquad (\sigma = 0)
\end{align*}
$$
where 
$$
\begin{align*}
F &= -A_t\nabla_{\theta} \Phi_R(\theta;y) - A_t\nabla_{\theta}\log \rho_t(\theta) \qquad A_t = \rho_t(\theta) I \\
  &= -\rho_t(\theta)\nabla_{\theta} \Phi_R(\theta;y) - \nabla_{\theta} \rho_t(\theta) 
\end{align*}
$$

The $KL$-divergence becomes 

$$
\begin{align*}
\frac{\partial}{\partial t}KL\Bigl[\rho_{t}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr]
&= -\int \rho_t(\theta)^2 \bigl(\nabla_{\theta} \Phi_R(\theta; y) + \nabla_{\theta} \log \rho_t(\theta)\bigr)^T \bigl(\nabla_{\theta} \Phi_R(\theta; y) + \nabla_{\theta} \log \rho_t(\theta)\bigr) d\theta  = -\int F^T F d\theta\\
\end{align*}
$$

As for implementation, an ensemble of particles $\{\theta^j\}_{j}^{J}$ and the associated Dirac Delta density are used, and the drift direction is projected in the reproducing kernel Hilbert space (RKHS) with kernel $\kappa$.

$$
\begin{align*}
F(\theta) &= \int F(\theta')\kappa(\theta, \theta') d\theta' \\
          &= \int \bigl(-\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y) - \nabla_{\theta} \rho_t(\theta')\bigr)  \kappa(\theta, \theta')d\theta'\\
          &= \int -\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y)\kappa(\theta, \theta') + \nabla_{\theta}\kappa(\theta, \theta')  \rho_t(\theta')  d\theta'\\
          &\approx  \frac{1}{J}\sum_{j=1}^{J}-\nabla_{\theta} \Phi_R(\theta^j;y)\kappa(\theta, \theta^j) + \nabla_{\theta}\kappa(\theta, \theta^j)
\end{align*}
$$

here we use that $\rho_t = \frac{1}{J}\sum_{j=1}^{J} \delta(\theta - \theta^j)$.

## Gaussian variational inference [3]

Let us consider the Gaussian approximation $q(\theta) = \mathcal{N}(m, C)$ of a density function $\rho$. 


The right KL Gaussian approximation is
    
$$
\begin{align*}
KL(\rho(\theta)\Vert q(\theta)) = \int \rho(\theta)\log \rho(\theta) d\theta - \mathbb{E}_{\rho}[-\frac{1}{2}(\theta - m)^TC^{-1}(\theta - m)] + \frac{1}{2}\log \det C
\end{align*}
$$

The minimization of the right KL Gaussian approximation leads to 

$$
\begin{align*}
m = \mathbb{E}_{\rho}[\theta] \qquad C =\mathbb{E}_{\rho}[(\theta - m)(\theta - m)^T]
\end{align*}
$$  


The left KL Gaussian approximation is
    
$$
\begin{align*}
KL(q(\theta) \Vert \rho(\theta|y)) = \int q(\theta)\log q(\theta) - q(\theta)\rho(\theta|y) d\theta
\end{align*}
$$
More specifically, we consider that $\rho(\theta|y) \propto \rho(y|\theta)\rho_{\rm prior}(\theta) $ and $\rho_{\rm prior}(\theta) \sim \mathcal{N}(m_{\rm prior}, C_{\rm prior})$, we seek to minimize 

$$ \min_{m,C} \mathbb{E}_q[\log q - \log \rho_{\rm prior} - \log(y|\theta)] \qquad q\sim \mathcal{N}(m, C)$$

The fixed point equations are:

$$
\begin{align*}
&m = m_{\rm prior} + C_{\rm prior} \nabla_{m} \mathbb{E}_q[\log \rho(y|\theta)] =  m_{\rm prior} + C_{\rm prior} \mathbb{E}_q[\nabla_{\theta}\log \rho(y|\theta)]\\
&C^{-1} = C_{\rm prior}^{-1}  - 2\nabla_{C}\mathbb{E}_q[\log \rho(y|\theta)] = C_{\rm prior}^{-1}  - \mathbb{E}_q[\nabla_{\theta}^2\log \rho(y|\theta)]
\end{align*}
$$




# Reference
1. [Lecture 5: Variational Inference (Stanford Canvas)](https://canvas.stanford.edu/files/1780120/download?download_frd=1&verifier=MWyibVq7L4EmRgunWLV7pS7CekAI9MLuTJIHxuCV;Lecture+5.pdf;application/pdf)
2. [Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm](https://proceedings.neurips.cc/paper/2016/file/b3ba8f1bee1238a2f37603d90b58898d-Paper.pdf)
3. [The recursive variational Gaussian approximation (R-VGA)](https://link.springer.com/article/10.1007/s11222-021-10068-w)