In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import blackjax

sns.set(font_scale=2, style='whitegrid')

In [None]:
_raw_data = pd.read_csv('snelson.csv')
x_all, y_all = np.atleast_2d(_raw_data.x.values).T, np.atleast_2d(_raw_data.y.values).T
x_all = (x_all - np.mean(x_all, axis=0, keepdims=True)) / np.std(x_all, axis=0, keepdims=True)
y_all = (y_all - np.mean(y_all, axis=0, keepdims=True)) / np.std(y_all, axis=0, keepdims=True)
x_all.shape, y_all.shape

In [None]:
fig, ax = plt.subplots(figsize=(9,6))

ax.scatter(x_all, y_all)

ax.set(xlabel='x', ylabel='y')

fig.tight_layout()
fig.show()

In [None]:
from flax import linen as nn

class MLP(nn.Module):
    out_size: int
    H: int = 100

    @nn.compact
    def __call__(self, x):
        out = nn.Sequential([
            nn.Dense(features=self.H),
            nn.relu,
            nn.Dense(features=self.H),
            nn.relu,
            nn.Dense(features=self.out_size)
        ])(x)
        return out

model = MLP(out_size=1)

In [None]:
rng_key = jax.random.PRNGKey(137)
rng_key, init_params_key = jax.random.split(rng_key)
init_params = jax.jit(model.init)(init_params_key, jnp.ones((1, 1)))
n_params = sum([len(np.ravel(p)) for p in jax.tree_util.tree_flatten(init_params)[0]])

In [None]:
from functools import partial
import distrax


def logprior_fn(params):
    leaves, _ = jax.tree_util.tree_flatten(init_params)
    flat_params = jnp.concatenate([jnp.ravel(p) for p in leaves])
    lik = distrax.Normal(0., 1.).log_prob(flat_params)
    return jnp.sum(lik)

def loglikelihood_fn(params):
    f = model.apply(params, x_all)
    lik = distrax.Normal(f, .5).log_prob(y_all)
    return jnp.sum(lik)

def logprob_fn(params):
    return loglikelihood_fn(params) + logprior_fn(params)

In [None]:
x_test = np.linspace(x_all.min() - 1., x_all.max() + 1., 200)
x_test = np.atleast_2d(x_test).T
x_test.shape

In [None]:
hmc = blackjax.hmc(logprob_fn, 1e-4, jnp.eye(n_params), 200)
hmc_state = hmc.init(init_params)
hmc_kernel = jax.jit(hmc.step)

In [None]:
burn_in = 200
n_samples = 100
rng_key = jax.random.PRNGKey(137)

for _ in tqdm(range(burn_in)):
    rng_key, sample_rng_key = jax.random.split(rng_key)
    hmc_state, info = hmc_kernel(sample_rng_key, hmc_state)

sample_f = []
for _ in tqdm(range(n_samples)):
    rng_key, sample_rng_key = jax.random.split(rng_key)
    hmc_state, info = hmc_kernel(sample_rng_key, hmc_state)
    sample_f.append(model.apply(hmc_state.position, x_test))

sample_f = np.stack(sample_f, axis=-1)
sample_f.shape

In [None]:
fig, ax = plt.subplots(figsize=(9,6))

y_test = np.mean(sample_f, axis=-1)[..., 0]
y_test_std = np.std(sample_f, axis=-1)[..., 0]
ax.plot(x_test[..., 0], y_test, c='red')
ax.fill_between(x_test[..., 0], y_test - 2 * y_test_std, y_test + 2 * y_test_std, color='red', alpha=.2)

ax.scatter(x_all, y_all)

ax.set(xlabel='x', ylabel='y')

fig.tight_layout()
fig.show()