In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import blackjax

sns.set(font_scale=2, style='whitegrid')

In [None]:
from sklearn.datasets import make_moons

x_all, y_all = make_moons(n_samples=100, noise=.1)

In [None]:
fig, ax = plt.subplots(figsize=(7,7))

ax.scatter(x_all[y_all == 0, 0], x_all[y_all == 0, 1], label='Class 0')
ax.scatter(x_all[y_all == 1, 0], x_all[y_all == 1, 1], label='Class 1')

ax.set(xlabel='x', ylabel='y')
ax.legend()

fig.tight_layout()
fig.show()

In [None]:
from flax import linen as nn

class MLP(nn.Module):
    n_classes: int
    H: int = 100

    @nn.compact
    def __call__(self, x):
        out = nn.Sequential([
            nn.Dense(features=self.H),
            nn.relu,
            nn.Dense(features=self.H),
            nn.relu,
            nn.Dense(features=self.n_classes)
        ])(x)
        return out

model = MLP(n_classes=2)

In [None]:
rng_key = jax.random.PRNGKey(137)
rng_key, init_params_key = jax.random.split(rng_key)
init_params = jax.jit(model.init)(init_params_key, jnp.ones(2))
n_params = sum([len(np.ravel(p)) for p in jax.tree_util.tree_flatten(init_params)[0]])

In [None]:
from functools import partial
import distrax


def logprior_fn(params):
    leaves, _ = jax.tree_util.tree_flatten(init_params)
    flat_params = jnp.concatenate([jnp.ravel(p) for p in leaves])
    lik = distrax.Normal(0., 1.).log_prob(flat_params)
    return jnp.sum(lik)

def loglikelihood_fn(params):
    logits = model.apply(params, x_all)
    lik = distrax.Categorical(logits=logits).log_prob(y_all)
    return jnp.sum(lik)

def logprob_fn(params):
    return loglikelihood_fn(params) + logprior_fn(params)

In [None]:
hmc = blackjax.hmc(logprob_fn, 1e-4, jnp.eye(n_params), 100)
hmc_state = hmc.init(init_params)
hmc_kernel = jax.jit(hmc.step)

In [None]:
_x = np.linspace(x_all[:, 0].min() - 1., x_all[:, 0].max() + 1., 100)
_y = np.linspace(x_all[:, 1].min() - 1., x_all[:, 1].max() + 1., 100)
grid = np.stack(np.meshgrid(_x, _y), axis=-1)
grid.shape

In [None]:
burn_in = 1000
n_samples = 500
rng_key = jax.random.PRNGKey(137)

sample_logits = []
for e in tqdm(range(burn_in + n_samples)):
    rng_key, sample_rng_key = jax.random.split(rng_key)
    hmc_state, info = hmc_kernel(sample_rng_key, hmc_state)
    if e >= burn_in:
        sample_logits.append(model.apply(hmc_state.position, grid))

sample_logits = np.stack(sample_logits, axis=-1)

In [None]:
fig, ax = plt.subplots(figsize=(7,7))

ax.contourf(grid[..., 0], grid[..., 1], np.mean(sample_logits, axis=-1)[..., 1], vmin=0., vmax=1.)

ax.scatter(x_all[y_all == 0, 0], x_all[y_all == 0, 1], label='Class 0')
ax.scatter(x_all[y_all == 1, 0], x_all[y_all == 1, 1], label='Class 1')

ax.set(xlabel='x', ylabel='y')

fig.tight_layout()
fig.show()