In [None]:
%load_ext blackcellmagic

In [None]:
import jax
import jax.numpy as jnp
from flax import nnx
import sys
sys.path.append("../")
import matplotlib.pyplot as plt
import synthetic
import pickle

with open("synthetic.data", "rb") as f:
    train_ds = pickle.load(f)

num_train = len(train_ds[0])
plt.scatter(train_ds[0], train_ds[1], s=15, facecolor="k", edgecolor="gray")

In [None]:
from utils import build_masks
from type_alias import Batch

with open("sgd.params", "rb") as f:
    base_params = pickle.load(f)

wd_coeff = 1.0e-4
model_init_fn = lambda rngs: synthetic.Regressor(32, 2, 2, rngs=rngs)
init_loss_fn = lambda params: synthetic.neg_log_prior(params, wd_coeff) * num_train


def loss_fn(model: nnx.Module, batch: Batch):
    x, y = batch
    mu, sigma = model(x)
    nll = synthetic.neg_log_likel(y, mu, sigma)
    nlp = synthetic.neg_log_prior(nnx.state(model).filter(nnx.Param), wd_coeff)
    loss = (nll + nlp) * num_train
    return loss, (nll, nlp)


abstract_model = nnx.eval_shape(lambda: model_init_fn(nnx.Rngs(0)))
graphdef, _ = nnx.split(abstract_model)
masks = build_masks(abstract_model, to_freeze=["fext"]) 

import jax.tree_util as tu
base_momentum = jax.tree_map(lambda p: jnp.zeros_like(p), base_params)

In [None]:
import uha
import importlib

importlib.reload(uha)

run_uha = uha.build(
    base_params, 
    base_momentum, 
    graphdef,
    masks,
    train_ds,
    init_loss_fn,
    loss_fn, 
    damper= 0.95, 
    leapfrog_updates= 5,
)

In [None]:
num_particles = 1000
batch_size = 10
overlap = 5
step_size = 1.0e-4
resample_thres = 0.5
num_cycles = 10

rngs = nnx.Rngs(42)
particles, log_Z_est, resample_cnt = run_uha(
    rngs, 
    num_particles,
    batch_size,
    overlap,
    num_cycles,
    damper=0.9,
    leapfrog_updates=10,
    step_size=step_size,
    resample_thres=resample_thres,
)

num_batches = num_cycles * ((num_train - batch_size) // (batch_size - overlap) + 1)
print(f"Resampling occured {resample_cnt}/{num_batches-1} times.")
print(log_Z_est)

In [None]:
from jax.scipy.special import logsumexp

x_test = jnp.linspace(-1.2, 1.2, 1000)

log_w = particles.log_gamma_k + particles.log_trans - particles.log_gamma_0
@nnx.vmap
def predict(params: nnx.State):
    model = nnx.merge(graphdef, params)
    return model(x_test)

mus, sigmas = predict(particles.params)
mu = jnp.mean(mus, 0)
sigma = jnp.sqrt(jnp.mean(sigmas**2, 0) + jnp.var(mus, 0))

nw = jnp.exp(log_w - logsumexp(log_w))
mu = jnp.sum(mus * nw[..., None], axis=0)
sigma = jnp.sqrt(
    jnp.sum(sigmas ** 2 * nw[..., None], 0)
    + jnp.sum((mus - mu) ** 2 * nw[..., None], 0)
)
plt.figure(figsize=(5, 5))
plt.fill_between(
    x_test,
    mu - 2 * sigma,
    mu + 2 * sigma,
    alpha=0.2,
    facecolor="skyblue",
    edgecolor=None,
)
plt.fill_between(
    x_test, mu - sigma, mu + sigma, alpha=0.4, facecolor="skyblue", edgecolor=None
) 
plt.ylim(-3, 3)
plt.plot(x_test, mu, label="prediction")
plt.plot(train_ds[0], train_ds[1], "rx", label="training data", alpha=0.3)