**Vanilla Stein**

Testing vanilla stein in high-dimensions.

Low-dim testing: From https://colab.research.google.com/github/activatedgeek/stein-gradient/blob/master/Stein.ipynb#scrollTo=VE8ANKLgy1PH

In [1]:
import math
import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.optim as optim
import altair as alt

alt.data_transformers.enable('default', max_rows=None)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
def get_density_chart(P, d=7.0, step=0.1):
  xv, yv = torch.meshgrid([
      torch.arange(-d, d, step), 
      torch.arange(-d, d, step)
  ])
  pos_xy = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), dim=-1)
  p_xy = P.log_prob(pos_xy.to(device)).exp().unsqueeze(-1).cpu()
  
  df = torch.cat([pos_xy, p_xy], dim=-1).numpy()
  df = pd.DataFrame({
      'x': df[:, :, 0].ravel(),
      'y': df[:, :, 1].ravel(),
      'p': df[:, :, 2].ravel(),
  })
  
  chart = alt.Chart(df).mark_point().encode(
    x='x:Q',
    y='y:Q',
    color=alt.Color('p:Q', scale=alt.Scale(scheme='viridis')),
    tooltip=['x','y','p']
  )
  
  return chart


def get_particles_chart(X):
  df = pd.DataFrame({
      'x': X[:, 0],
      'y': X[:, 1],
  })

  chart = alt.Chart(df).mark_circle(color='red').encode(
    x='x:Q',
    y='y:Q'
  )
  
  return chart

In [5]:
class RBF(torch.nn.Module):
    def __init__(self, sigma=None):
        super(RBF, self).__init__()

        self.sigma = sigma

    def forward(self, X, Y):
        XX = X.matmul(X.t())
        XY = X.matmul(Y.t())
        YY = Y.matmul(Y.t())

        dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0)

        # Apply the median heuristic (PyTorch does not give true median)
        if self.sigma is None:
          np_dnorm2 = dnorm2.detach().cpu().numpy()
          h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1))
          sigma = np.sqrt(h).item()
        else:
          sigma = self.sigma

        gamma = 1.0 / (1e-8 + 2 * sigma ** 2)
        K_XY = (-gamma * dnorm2).exp()

        return K_XY
  
KKK = RBF()

In [6]:
KKK

RBF()

In [7]:
class SVGD:
    def __init__(self, P, K, optimizer):
        self.P = P
        self.K = K
        self.optim = optimizer

    def phi(self, X):
        X = X.detach().requires_grad_(True)

        log_prob = self.P.log_prob(X)
        score_func = autograd.grad(log_prob.sum(), X)[0]

        K_XX = self.K(X, X.detach())
        grad_K = -autograd.grad(K_XX.sum(), X)[0]

        phi = (K_XX.detach().matmul(score_func) + grad_K) / X.size(0)

        return phi

    def step(self, X):
        self.optim.zero_grad()
        X.grad = -self.phi(X)
        self.optim.step()

In [8]:
gauss = torch.distributions.MultivariateNormal(torch.Tensor([-0.6871,0.8010]).to(device),
        covariance_matrix=5 * torch.Tensor([[0.2260,0.1652],[0.1652,0.6779]]).to(device))

n = 10
X_init = (3 * torch.randn(n, *gauss.event_shape)).to(device)

In [9]:
gauss_chart = get_density_chart(gauss, d=7.0, step=0.1)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [10]:
gauss_chart + get_particles_chart(X_init.cpu().numpy())

  for col_name, dtype in df.dtypes.iteritems():


In [11]:
X = X_init.clone()
svgd = SVGD(gauss, K, optim.Adam([X], lr=1e-1))
for _ in range(1000):
    svgd.step(X)

In [12]:
gauss_chart + get_particles_chart(X.cpu().numpy())

  for col_name, dtype in df.dtypes.iteritems():


In [13]:
class MoG(torch.distributions.Distribution):
    def __init__(self, loc, covariance_matrix):
        self.num_components = loc.size(0)
        self.loc = loc
        self.covariance_matrix = covariance_matrix

        self.dists = [
          torch.distributions.MultivariateNormal(mu, covariance_matrix=sigma)
          for mu, sigma in zip(loc, covariance_matrix)
        ]

        super(MoG, self).__init__(torch.Size([]), torch.Size([loc.size(-1)]))

    @property
    def arg_constraints(self):
        return self.dists[0].arg_constraints

    @property
    def support(self):
        return self.dists[0].support

    @property
    def has_rsample(self):
        return False

    def log_prob(self, value):
        return torch.cat(
          [p.log_prob(value).unsqueeze(-1) for p in self.dists], dim=-1).logsumexp(dim=-1)

    def enumerate_support(self):
        return self.dists[0].enumerate_support()

In [None]:
# Modify train data generation in ffjord (https://github.com/rtqichen/ffjord/blob/master/lib/toy_data.py) 
# to accomodate high dimensional multivariate normal multimodal data.

import numpy as np
import sklearn
import sklearn.datasets
from sklearn.utils import shuffle as util_shuffle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Dataset iterator
def inf_train_gen(data, rng=None, batch_size=200, n_dim=20):
    if rng is None:
        rng = np.random.RandomState()

    if data == "8gaussians":
        scale = 4.
        centers = [
            (1, 0), (-1, 0), (0, 1), (0, -1),
            (1. / np.sqrt(2), 1. / np.sqrt(2)),
            (1. / np.sqrt(2), -1. / np.sqrt(2)),
            (-1. / np.sqrt(2), 1. / np.sqrt(2)),
            (-1. / np.sqrt(2), -1. / np.sqrt(2))
        ]
        centers = [(scale * x, scale * y) + (0,) * (n_dim - 2) for x, y in centers]

        dataset = []
        for i in range(batch_size):
            point = rng.randn(n_dim) * 0.5
            idx = rng.randint(8)
            center = centers[idx]
            point += np.array(center)
            dataset.append(point)
        dataset = np.array(dataset, dtype="float32")
        dataset /= 1.414
        return dataset

    # ... (other data generation cases remain unchanged)

def plot_8gaussians_projections(data, num_projections=3):
    """
    Plot projections of the multidimensional 8 Gaussians data.
    
    :param data: numpy array of shape (n_samples, n_dimensions)
    :param num_projections: number of 2D projections to plot
    """
    n_samples, n_dim = data.shape
    fig = plt.figure(figsize=(15, 5 * ((num_projections + 2) // 3)))
    
    # 2D projections
    for i in range(num_projections):
        dim1, dim2 = np.random.choice(n_dim, 2, replace=False)
        ax = fig.add_subplot(((num_projections + 2) // 3), 3, i+1)
        ax.scatter(data[:, dim1], data[:, dim2], alpha=0.5)
        ax.set_title(f'2D Projection: Dim {dim1} vs Dim {dim2}')
        ax.set_xlabel(f'Dimension {dim1}')
        ax.set_ylabel(f'Dimension {dim2}')
    
    # 3D projection
    ax = fig.add_subplot(((num_projections + 2) // 3), 3, num_projections+1, projection='3d')
    dim1, dim2, dim3 = np.random.choice(n_dim, 3, replace=False)
    ax.scatter(data[:, dim1], data[:, dim2], data[:, dim3], alpha=0.5)
    ax.set_title(f'3D Projection: Dim {dim1} vs Dim {dim2} vs Dim {dim3}')
    ax.set_xlabel(f'Dimension {dim1}')
    ax.set_ylabel(f'Dimension {dim2}')
    ax.set_zlabel(f'Dimension {dim3}')
    
    plt.tight_layout()
    plt.show()

# Example usage
if __name__ == "__main__":
    rng = np.random.RandomState(42)
    data = inf_train_gen("8gaussians", rng, batch_size=1000, n_dim=20)
    plot_8gaussians_projections(data, num_projections=3)

**T and A**:

Proposing a scalable SVGD algorithm using multiple model fidelities:

Variants:

- randomly sample points to do low-fidelity likelihood updates and monitor convergence vs single fidelity

- In amortized SVGD, replace neural network training by models of multiple fidelities instead of approximating the likelihood function directly (constraints?)

- use some metric to intelligently decide which points will be lofi likelihood updates.

- monitor divergence of purely low-fi vs purely hi-fi SVGD

- subset simulation? Reach an EASIER posterior with low-fi simulations and then high-fi (probably Peherstorfer et al. already do this so this is not something novel) - this is _multi-level_ SVGD. - https://proceedings.mlr.press/v145/alsup22a/alsup22a.pdf
