- Gaussian Mixture Model: SVGD v GSVGD vs HMC vs IAF flow
- Bayesian logistic regression: HMC v IAF Flow v
- Bayesian neural network
- Variational GP regression 

In [1]:
%load_ext autoreload
%autoreload 2

import math
import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.optim as optim
from torch.distributions import MultivariateNormal
from pymanopt.manifolds import Grassmann
import altair as alt
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import sys
sys.path.append("..")

alt.data_transformers.enable('default', max_rows=None)

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [2]:
def get_density_chart(P, d=7.0, step=0.1):
    xv, yv = torch.meshgrid([
      torch.arange(-d, d, step), 
      torch.arange(-d, d, step)
    ])
    
    pos_xy = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), dim=-1)
    p_xy = P.log_prob(pos_xy.to(device)).exp().unsqueeze(-1).cpu()
    print(p_xy.shape, pos_xy.shape)
    df = torch.cat([pos_xy, p_xy], dim=-1).numpy()
    df = pd.DataFrame({
      'x': df[:, :, 0].ravel(),
      'y': df[:, :, 1].ravel(),
      'p': df[:, :, 2].ravel(),
    })

    chart = alt.Chart(df).mark_point().encode(
    x='x:Q',
    y='y:Q',
    color=alt.Color('p:Q', scale=alt.Scale(scheme='viridis')),
    tooltip=['x','y','p']
    )

    return chart


def get_particles_chart(X):
    df = pd.DataFrame({
      'x': X[:, 0],
      'y': X[:, 1],
    })

    chart = alt.Chart(df).mark_circle(color='red').encode(
    x='x:Q',
    y='y:Q'
    )

    return chart

In [3]:
device

device(type='cpu')

In [4]:
from src.svgd import SVGD
from src.kernel import RBF

## Gaussian Mixture Model

$$q(x) = \frac{1}{3}\mathcal{N}(x;(-2, 0)^\intercal, I) + \frac{2}{3}\mathcal{N}(x;(2, 0)^\intercal, I)$$

Calculate $\mathbb{E}_q[X]$ using

HMC: 2 component mixture

In [5]:
from experiments.utils import GaussianMixture

In [6]:
repetitions = 1

In [7]:
mean = torch.Tensor([0,0]).to(device)
covariance = torch.Tensor([[1., 0.999], [0.999, 1.]]).to(device)

In [8]:
mean.shape, covariance.shape

(torch.Size([2]), torch.Size([2, 2]))

$P_0X$

In [9]:
distribution = torch.distributions.MultivariateNormal(mean, covariance)

In [10]:
## SVGD
pred_SVGD = []
for r in range(repetitions):
    
    num_particles = 100
    lr = 1e-2
    epochs = 500
    
    # sample from variational density
    x_init = torch.randn(num_particles, *distribution.event_shape)
    x = x_init.clone().to(device)
    kernel = RBF()
    svgd = SVGD(distribution, kernel, optim.Adam([x], lr=lr))
    for _ in tqdm(range(epochs)):
        svgd.step(x)
        
    pred_SVGD.append(x.mean(axis=0))

  0%|          | 0/500 [00:00<?, ?it/s]

In [11]:
fig = get_density_chart(distribution, d=5.0, step=0.1)

(fig + get_particles_chart(x_init.cpu().numpy())) | (fig + get_particles_chart(x.cpu().numpy()))

torch.Size([100, 100, 1]) torch.Size([100, 100, 2])


In [14]:
import sys
from experiments.utils import GaussianMixture
from src.kernel import RBF
from src.gsvgd import GSVGD

# P = torch.Tensor([[1], [1]]).requires_grad_(True)
# x = torch.Tensor([[1, 0], [0, 1], [1, 2]]).requires_grad_(True)
# gauss = torch.distributions.MultivariateNormal(
#     loc = torch.Tensor([-0.6871,0.8010]),
#     covariance_matrix = 5 * torch.Tensor([[0.2260,0.1652],[0.1652,0.6779]])
# )

## Setup target parameters
num_particles = 100

k = RBF()
manifold = Grassmann(2, 1)

x_init = torch.randn(num_particles, *distribution.event_shape)
x = x_init.clone()
A = torch.Tensor([[1], [0]]).requires_grad_(True)
# A = torch.Tensor([[1, 1], [0, 1]]).requires_grad_(True)

gsvgd = GSVGD(
    target=distribution,
    kernel=k,
    manifold=manifold,
    optimizer=optim.Adam([x], lr=1e-2),
)
A = gsvgd.fit(x, A, epochs=5000)

100%|██████████| 5000/5000 [02:36<00:00, 31.89it/s]


In [13]:
fig = get_density_chart(distribution, d=5.0, step=0.1)

(fig + get_particles_chart(x_init.cpu().numpy())) | (fig + get_particles_chart(x.cpu().numpy()))

torch.Size([100, 100, 1]) torch.Size([100, 100, 2])


In [22]:
x.shape

torch.Size([100, 2])