# Kernelized Stein Gradient

> Notebook accompanying the original post [here](https://www.sanyamkapoor.com/machine-learning/stein-gradient).

## Install Dependencies

We will use PyTorch for all our differentiation needs and `plotly` for plotting.

In [0]:
! pip install plotly>=3.6.0 numpy>=1.16 torch>=1.0

In [0]:
import plotly.graph_objs as go
import plotly.offline as py
from plotly.colors import DEFAULT_PLOTLY_COLORS

import math
import numpy as np
import torch

py.init_notebook_mode(connected=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Drawing Utilities

In [0]:
def configure_plotly_browser_state():
  """
  @NOTE: Run this in each cell before plotting in Google Colab
  """
  
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-1.5.1.min.js?noext',
            },
          });
        </script>
        '''))
  py.init_notebook_mode(connected=True)

  
def get_pos_xy(d=10.0, step=0.1):
    xv, yv = torch.meshgrid([
        torch.arange(-d, d, step), 
        torch.arange(-d, d, step)
    ])
    pos_xy = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), dim=-1)
    return xv, yv, pos_xy


def get_scatter_trace(X, color, name='', mode='markers'):
    return go.Scatter(
        name=name,
        x=X[:, 0],
        y=X[:, 1],
        mode=mode,
        line=dict(
            width=1,
            color='rgba(255,0,0,0.5)',
        ),
        marker=dict(
            size=5,
            color=color,
            showscale=False,
        )
    )


def plot_particles(P, X, name):
    xv, yv, pos_xy = get_pos_xy()
    orig_trace = go.Heatmap(
        name=name,
        x=xv.numpy().flatten(),
        y=yv.numpy().flatten(),
        z=P.log_prob(pos_xy.to(device)).exp().cpu().numpy().flatten(),
        colorscale='Viridis',
        showscale=False,
    )
    
    layout = go.Layout(
        autosize=False,
        width=1000,
        height=700
    )

    fig = go.Figure(layout=layout,
                    data=[orig_trace, get_scatter_trace(X.cpu().numpy(), 'red', name='Stein Particles')])
    return fig

## RBF Kernel

In these experiments, we will use the *rbf* kernel. The kernel is defined as the squared exponential distance between the two vectors, parametrized by a bandwidth argument $\sigma$.

$$
k_{rbf}(\mathbf{x}, \mathbf{x}^\prime) = \exp{-\frac{1}{2\sigma^2}||\mathbf{x}-\mathbf{x}^\prime||^2}
$$

A vectorized version of the  kernel is given below. A few notes on the implementation follow.

In [0]:
class RBF(torch.nn.Module):
  def __init__(self, sigma=None):
    super(RBF, self).__init__()

    self.sigma = sigma

  def forward(self, X, Y):
    XX = X.matmul(X.t())
    XY = X.matmul(Y.t())
    YY = Y.matmul(Y.t())

    dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0)

    # Apply the median heuristic (PyTorch does not give true median)
    if self.sigma is None:
      np_dnorm2 = dnorm2.detach().cpu().numpy()
      h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1))
      sigma = np.sqrt(h).item()
    else:
      sigma = self.sigma

    gamma = 1.0 / (1e-8 + 2 * sigma ** 2)
    K_XY = (-gamma * dnorm2).exp()

    return K_XY
  
# Let us initialize a reusable instance right away.
K = RBF()

Selecting the bandwidth parameter $\sigma$ may be a painful task in itself. A popular heuristic chosen in literature is the *median* heuristic where we choose the bandwidth to be

$$
\sigma^2 = \frac{median^2}{2 \log{n}}
$$

the median among distance of all pairs. This allows for gradient contribution from all the pairs when computing the gradient of the kernel during simulation of the ODE. Note that we use the `numpy` median function because the PyTorch median function does not behave as expected when the number of elements are even (and does not return the mean of the two central elements).

## Stein Variational Gradient Descent

We now simulate the following ODE for each particle $x_j$ in the system.

$$
\dot{x}_j = \frac{1}{n} \sum_{j = 1}^n \left[ k(x_j, x) \nabla_{x_j} \log{p(x_j)} + \nabla_{x_j} k(x_j, x)  \right]
$$

For stability reasons, we use Adagrad to allow for adaptive step size during the simulation. In fact, Adagrad can be replaced by any of the adaptive step size or gradient techniques from Gradient Descent like Adam. For our puposes, Adagrad works just well enough as we will see in the results below.

This is encapsulated in the `step` function below.

In [0]:
class SVGD:
  def __init__(self, P, K, eta=1e-2, rho=0.9):
    self.P = P
    self.K = K
    self.eta = eta
    self.rho = rho

    self._phi_est = None
    self.reset()

  def reset(self):
    self._phi_est = None

  def step(self, X):
    X = X.detach().requires_grad_(True)

    log_prob = self.P.log_prob(X)
    score_func = torch.autograd.grad(log_prob, X,
                                     grad_outputs=torch.ones_like(log_prob),
                                     only_inputs=True)[0]

    K_XX = self.K(X, X.detach())
    grad_K = -torch.autograd.grad(K_XX, X,
                                  grad_outputs=torch.ones_like(K_XX),
                                  only_inputs=True)[0]

    phi = (K_XX.detach().matmul(score_func) + grad_K) / X.size(0)

    if self._phi_est is None:
      self._phi_est = phi ** 2
    else:
      self._phi_est = self.rho * self._phi_est + (1 - self.rho) * phi ** 2

    grad = phi / (1e-8 + self._phi_est.sqrt())

    return X.detach() + self.eta * grad

# Experiments

## Unimodal Gaussian

We will first run this on a Unimodal Gaussian. We initialize the particles in an overdispersed manner and see how they converge around the typical set of the distribution.

**NOTE**: Try increasing the number of particles $n$ and different initializations to see how the particles distribute themselves.

In [0]:
gauss = torch.distributions.MultivariateNormal(torch.Tensor([-0.6871,0.8010]).to(device),
        covariance_matrix=5 * torch.Tensor([[0.2260,0.1652],[0.1652,0.6779]]).to(device))

svgd = SVGD(gauss, K, rho=0.9, eta=1e-2)

n = 100
X = (5 * torch.randn(n, *gauss.event_shape)).to(device)

Let us see how this overdispersed initialization looks like. Note that initializations much farther away from the typical set of the distributions may take longer to converge.

In [0]:
configure_plotly_browser_state()
fig = plot_particles(gauss, X, 'Normal')
py.iplot(fig, config=dict(showLink=False))

In [0]:
for _ in range(1000):
    X = svgd.step(X)

In [0]:
configure_plotly_browser_state()
fig = plot_particles(gauss, X, 'Normal')
py.iplot(fig, config=dict(showLink=False))

## Mixture of Gaussians

The exact same simulation without any manual fine tuning works even for a multimodal Gaussian. We will first create a generic PyTorch distribution which can help us build multiple kinds of Mixture of Gaussians.

In [0]:
class MoG(torch.distributions.Distribution):
  def __init__(self, loc, covariance_matrix):
    self.num_components = loc.size(0)
    self.loc = loc
    self.covariance_matrix = covariance_matrix

    self.dists = [
      torch.distributions.MultivariateNormal(mu, covariance_matrix=sigma)
      for mu, sigma in zip(loc, covariance_matrix)
    ]
    
    super(MoG, self).__init__(torch.Size([]), torch.Size([loc.size(-1)]))

  @property
  def arg_constraints(self):
    return self.dists[0].arg_constraints

  @property
  def support(self):
    return self.dists[0].support

  @property
  def has_rsample(self):
    return False

  def log_prob(self, value):
    return torch.cat(
      [p.log_prob(value).unsqueeze(-1) for p in self.dists], dim=-1).logsumexp(dim=-1)

  def enumerate_support(self):
    return self.dists[0].enumerate_support()

### Mixture of Two Gaussians

Here we create a mixture of two Gaussians where the means are symmetrically placed at $x=|5|$ and the covariance matrix is given by $\begin{pmatrix}0.5 & 0.5 \\ 0.5 & 0.5\end{pmatrix}$.

In [0]:
class MoG2(MoG):
  def __init__(self, device=None):
    loc = torch.Tensor([[-5.0, 0.0], [5.0, 0.0]]).to(device)
    cov = torch.Tensor([0.5, 0.5]).diag().unsqueeze(0).repeat(2, 1, 1).to(device)

    super(MoG2, self).__init__(loc, cov)

In [0]:
mog2 = MoG2(device=device)

svgd = SVGD(mog2, K, rho=0.9, eta=1e-2)

n = 100
X = (5 * torch.randn(n, *mog2.event_shape)).to(device)

In [0]:
configure_plotly_browser_state()
fig = plot_particles(mog2, X, 'Mixture of Two Gaussians')
py.iplot(fig, config=dict(showLink=False))

In [0]:
for _ in range(1000):
    X = svgd.step(X)

In [0]:
configure_plotly_browser_state()
fig = plot_particles(mog2, X, 'Mixture of Two Gaussians')
py.iplot(fig, config=dict(showLink=False))

### Mixture of Six Gaussians

Here we create a mixture of six Gaussians where the means are spread around a circle of radius $5$and the covariance matrix is given by $\begin{pmatrix}0.5 & 0.5 \\ 0.5 & 0.5\end{pmatrix}$.

In [0]:
class MoG6(MoG):
  def __init__(self, device=None):
    def _compute_mu(i):
      return 5.0 * torch.Tensor([[
        torch.tensor(i * math.pi / 3.0).sin(),
        torch.tensor(i * math.pi / 3.0).cos()]])

    loc = torch.cat([_compute_mu(i) for i in range(1, 7)], dim=0).to(device)
    cov = torch.Tensor([0.5, 0.5]).diag().unsqueeze(0).to(device).repeat(6, 1, 1)

    super(MoG6, self).__init__(loc, cov)

In [0]:
mog6 = MoG6(device=device)

svgd = SVGD(mog6, K, rho=0.9, eta=1e-2)

n = 100
X = (5 * torch.randn(n, *mog6.event_shape)).to(device)

In [0]:
configure_plotly_browser_state()
fig = plot_particles(mog6, X, 'Mixture of Six Gaussians')
py.iplot(fig, config=dict(showLink=False))

In [0]:
for _ in range(1000):
    X = svgd.step(X)

In [0]:
configure_plotly_browser_state()
fig = plot_particles(mog6, X, 'Mixture of Six Gaussians')
py.iplot(fig, config=dict(showLink=False))