In [None]:
%pip install --upgrade pip 
%pip install --upgrade jax 
%pip install "flax[all]"

In [None]:
import jax
import jax.numpy as jnp 
from jax import random
import equinox as eqx
import optax
from jax.random import PRNGKey, split
from typing import Optional, Callable, List
import numpyro.distributions as dist
import matplotlib.pyplot as plt
from tqdm import trange
import math 

In [None]:
from ULA_util import (
    MultivariateNormalDiag,
    GaussianMixture,
    AnnealingSchedule,
    StepSizeMLP,
    ResidualBlock,
    ScoreNetwork,
    UnadjustedLangevin,
    sigmoid 
)

key = jax.random.PRNGKey(0)
pi_0 = MultivariateNormalDiag(loc=jnp.array([0.0]), scale_diag=jnp.array([1.0]))


In [None]:
class UnadjustedLangevin(eqx.Module):
    pi_0: MultivariateNormalDiag
    gamma: Callable[[jnp.ndarray], float]

    schedule: AnnealingSchedule
    stepsize_model: StepSizeMLP
    n_steps: int

    def gamma_k(self, k: int, x: jnp.ndarray) -> float:
        betas = self.schedule.compute_betas()  
        beta_k = betas[k]
        return beta_k*self.gamma(x) + (1.0 - beta_k)*self.pi_0.log_prob(x)

    def forward_kernel(self, k: int, x_prev: jnp.ndarray):
        delta_k = self.stepsize_model(k)
        grad_log = jax.grad(self.gamma_k, argnums=1)(k, x_prev)
        mean = x_prev + delta_k * grad_log
        stdev = jnp.sqrt(2.0*delta_k)
        return mean, stdev

    def log_prob_F_k(self, k: int, x_curr: jnp.ndarray, x_prev: jnp.ndarray) -> float:
        mean, stdev = self.forward_kernel(k, x_prev)
        d = x_curr.size
        log_det = d*jnp.log(stdev**2)
        log_norm = -0.5*d*jnp.log(2*jnp.pi) - 0.5*log_det
        diff = x_curr - mean
        quad = 0.5*jnp.sum((diff**2)/(stdev**2))
        return log_norm - quad

    def get_log_weight(self, key: jax.random.PRNGKey) -> float:
        x0 = self.pi_0.sample(key)
        log_w = - self.pi_0.log_prob(x0)
        x_prev = x0

        for k in range(1, self.n_steps+1):
            key, subkey = jax.random.split(key)
            mean, stdev = self.forward_kernel(k, x_prev)
            x_k = mean + stdev*jax.random.normal(subkey, shape=x_prev.shape)
            log_B = self.log_prob_F_k(k, x_prev, x_k)  
            log_F = self.log_prob_F_k(k, x_k, x_prev)
            log_w += (log_B - log_F)
            x_prev = x_k

        log_w += self.gamma(x_prev)
        return log_w

    def compute_log_Z(self, key: jax.random.PRNGKey, n_samples=512) -> float:
        keys = jax.random.split(key, n_samples)
        log_ws = jax.vmap(self.get_log_weight)(keys)
        max_lw = jnp.max(log_ws)
        return max_lw + jnp.log(jnp.mean(jnp.exp(log_ws - max_lw)))


In [None]:
def main():
    key = jax.random.PRNGKey(0)
    dim = 20
    n_steps = 64

    pi_0 = MultivariateNormalDiag(loc=jnp.zeros(dim), scale_diag=3*jnp.ones(dim))
    target = GaussianMixture(dim, n_components=8, key=key)

    def gamma_fn(x):
        return target.log_prob(x)

    key_sch, key_step = jax.random.split(key)
    schedule = AnnealingSchedule(n_steps, key_sch)
    stepmlp  = StepSizeMLP(n_steps, key_step)

    ula = UnadjustedLangevin(
        pi_0=pi_0,
        gamma=gamma_fn,
        schedule=schedule,
        stepsize_model=stepmlp,
        n_steps=n_steps
    )

    params = eqx.filter(ula, eqx.is_array)
    opt = optax.adam(1e-3)
    opt_state = opt.init(params)

    @eqx.filter_jit
    def loss_fn(ula: UnadjustedLangevin, key):
        logZ_est = ula.compute_log_Z(key, n_samples=128)
        return -logZ_est

    @eqx.filter_jit
    def train_step(ula: UnadjustedLangevin, opt_state, key):
        l, grads = eqx.filter_value_and_grad(loss_fn)(ula, key)
        updates, opt_state = opt.update(grads, opt_state, params=eqx.filter(ula, eqx.is_array))
        ula = eqx.apply_updates(ula, updates)
        return l, ula, opt_state

    steps = 2001
    for step_i in range(steps):
        key, subkey = jax.random.split(key)
        loss_val, ula, opt_state = train_step(ula, opt_state, subkey)
        if step_i % 500==0:
            print(f"step={step_i}, neg logZ={float(loss_val):.4f}")

    key, subkey = jax.random.split(key)
    final_logZ = ula.compute_log_Z(subkey, n_samples=16384)
    print("Final logZ estimate:", float(final_logZ))

if __name__ == "__main__":
    main()