# Stein Variational Inference

Stein variational gradient descent and its variants are nonparametric variational inference approaches, where the variational distribution is represented by an ensemble of particles.

## Stein variational gradient descent [1]

To start this section, readers need to be familiar with [probability density function space and its associated metric tensor](PDFSpace.ipynb). Let define the KL divergence function on the density space $\mathcal{P}(\Omega)$
$$
\begin{align*}
\mathcal{E}(\rho_t) = KL\Bigl[\rho_{t} \Big\Vert  \rho_{\rm post}(\theta | y)\Bigr] = \int\rho_t\log \rho_t d\theta - \int\rho_t\log \rho_{\rm post}(\theta | y)d\theta
\end{align*}
$$

Its functional derivative is

$$
\begin{align*}
\frac{\delta \mathcal{E}}{\delta \rho_t} 
=  \log \rho_t + 1  - \log \rho_{\rm post}(\theta | y)
\end{align*}
$$



Let define **Stein metric**

$$
G(\rho)^{-1} \Phi = - \nabla\cdot\Bigl(\rho\int \kappa(\theta,\theta')\rho(\theta')\nabla_{\theta'}\Phi(\theta')d\theta' \Bigr), \qquad  \Phi \in T_{\rho}^{*}\mathcal{P}
$$

The distance between $\rho^A$ and $\rho^B$ becomes

$$
\begin{align*}
{\rm dist}(\rho^A, \rho^B) &= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int_{t=0}^{t=1} g_{\rho_t}(G(\rho_t)^{-1} \Phi, G(\rho_t)^{-1} \Phi) dt \\
&= \min_{\Phi \in T_{\rho}^{*}\mathcal{P}} \int_{t=0}^{t=1} \int \int \kappa(\theta, \theta') \rho_t(\theta)\rho_t(\theta') \nabla_{\theta'}\Phi(\theta') \nabla_{\theta} \Phi(\theta) d\theta d\theta' dt\\
&s.t. \qquad \frac{\partial \rho_t}{\partial t} + \nabla\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\Phi(\theta')d\theta' \Bigr) = 0 \qquad \rho_0 = \rho^A \quad \rho_1 = \rho^B
\end{align*}
$$

For the Gradient descent of KL divergence, the density evolution equation becomes

$$
\begin{align*}
\frac{\partial \rho_t(\theta)}{\partial t} 
&= \nabla_{\theta}\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\bigl(\log \rho_{\rm post}(\theta' | y) - \log \rho_t(\theta') - 1\bigr)d\theta' \Bigr) \\
&= -\nabla_{\theta}\cdot\Bigl(\rho_t\int \kappa(\theta,\theta')\rho_t(\theta')\nabla_{\theta'}\bigl(\log \rho(\theta' , y) - \log \rho_t(\theta')\bigr)d\theta' \Bigr) \\
&= \nabla_{\theta}\cdot\Bigl(\rho_t F(\theta) \Bigr)
\end{align*}
$$

where the drift term (no diffusion term) is

$$
\begin{align*}
F(\theta) &= \int \bigl(-\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y) - \nabla_{\theta} \rho_t(\theta')\bigr)  \kappa(\theta, \theta')d\theta'\\
          &= \int -\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y)\kappa(\theta, \theta') + \nabla_{\theta}\kappa(\theta, \theta')  \rho_t(\theta')  d\theta'\\
          &\approx  \frac{1}{J}\sum_{j=1}^{J}-\nabla_{\theta} \Phi_R(\theta^j;y)\kappa(\theta, \theta^j) + \nabla_{\theta}\kappa(\theta, \theta^j)
\end{align*}
$$



### Connection to Langevin dynamics

Consider the following [initial value Ito process](Langevin.ipynb)

$$
\begin{align*}
d\theta_t = F(t, \theta_t) dt \qquad (\sigma = 0)
\end{align*}
$$
where 
$$
\begin{align*}
F &= -A_t\nabla_{\theta} \Phi_R(\theta;y) - A_t\nabla_{\theta}\log \rho_t(\theta) \qquad A_t = \rho_t(\theta) I \\
  &= -\rho_t(\theta)\nabla_{\theta} \Phi_R(\theta;y) - \nabla_{\theta} \rho_t(\theta) 
\end{align*}
$$

The $KL$-divergence becomes 

$$
\begin{align*}
\frac{\partial}{\partial t}KL\Bigl[\rho_{t}(\theta) \Vert  \rho_{\rm post}(\theta | y)\Bigr]
&= -\int \rho_t(\theta)^2 \bigl(\nabla_{\theta} \Phi_R(\theta; y) + \nabla_{\theta} \log \rho_t(\theta)\bigr)^T \bigl(\nabla_{\theta} \Phi_R(\theta; y) + \nabla_{\theta} \log \rho_t(\theta)\bigr) d\theta  = -\int F^T F d\theta\\
\end{align*}
$$

As for implementation, an ensemble of particles $\{\theta^j\}_{j}^{J}$ and the associated Dirac Delta density are used, and the drift direction is projected in the reproducing kernel Hilbert space (RKHS) with kernel $\kappa$.

$$
\begin{align*}
F(\theta) &= \int F(\theta')\kappa(\theta, \theta') d\theta' \\
          &= \int \bigl(-\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y) - \nabla_{\theta} \rho_t(\theta')\bigr)  \kappa(\theta, \theta')d\theta'\\
          &= \int -\rho_t(\theta')\nabla_{\theta} \Phi_R(\theta';y)\kappa(\theta, \theta') + \nabla_{\theta}\kappa(\theta, \theta')  \rho_t(\theta')  d\theta'\\
          &\approx  \frac{1}{J}\sum_{j=1}^{J}-\nabla_{\theta} \Phi_R(\theta^j;y)\kappa(\theta, \theta^j) + \nabla_{\theta}\kappa(\theta, \theta^j)
\end{align*}
$$

here we use that $\rho_t = \frac{1}{J}\sum_{j=1}^{J} \delta(\theta - \theta^j)$.

## Stein variational Newton [2]

TODO

# Reference
1. [Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm](https://proceedings.neurips.cc/paper/2016/file/b3ba8f1bee1238a2f37603d90b58898d-Paper.pdf)
2. [A Stein Variational Newton method](https://proceedings.neurips.cc/paper/2018/file/fdaa09fc5ed18d3226b3a1a00f1bc48c-Paper.pdf)