In [51]:
import jax
import jax.numpy as jnp
from flax import nnx 
from jax import lax
import sys 
sys.path.append('/Users/hajunhyeon/Documents/GitHub/smcs') 
sys.path.append("C:/Users/user/Documents/GitHub/smcs")
import gaussian

n = 1000
d = 2
lamb = 0.1
rho0 = 0.2 
M= jnp.eye(d) 

rngs = nnx.Rngs(23)
x = gaussian.generate_data(rngs, n, d, lamb, rho0)

mu, rho = gaussian.posterior_params(x, lamb, rho0)
log_Z_true = gaussian.log_Z(x, lamb, rho0)
print(mu, rho)
print(log_Z_true)

[-5.0189567  2.5775943] 100.2
-989.6046


In [75]:
from typing import NamedTuple
from type_alias import Batch, KeyArray
from jax.scipy.special import logsumexp 

def likel_energy_fn(phi: jax.Array, batch: Batch) -> float:
    return (n / len(batch)) * gaussian.neg_log_likel(phi, batch, lamb)


def prior_energy_fn(theta: jax.Array) -> float:
    return gaussian.neg_log_prior(theta, rho0)


def energy_fn(theta: jax.Array, batch: Batch) -> float:
    return (n / len(batch)) * gaussian.neg_log_likel(
        theta, batch, lamb
    ) + gaussian.neg_log_prior(theta, rho0)  

def leapfrog_update(x, p, eta, batch, M):
    def U(theta, batch):
        return energy_fn(theta, batch)
    
    energy, grad = jax.value_and_grad(U)(x, batch)
    p_half = p - (eta/2) * grad
    x_new = x + eta * (M @ p_half)
    energy_new, grad_new = jax.value_and_grad(U)(x_new, batch)
    p_new = p_half - (eta/2) * grad_new
    return x_new, p_new #, energy_new

def multiple_leapfrog_updates(x, p, eta, batch, M, L):
    def body_fn(_, state):
        x_curr, p_curr = state
        x_next, p_next = leapfrog_update(x_curr, p_curr, eta, batch, M) # , energy
        return (x_next, p_next) #, energy)
    init_state = (x, p)
    x_final, p_final = lax.fori_loop(0, L, body_fn, init_state) # , energy_final
    return x_final, p_final #, energy_final


class Particle(NamedTuple): 
    theta: jax.Array 
    momentum: jax.Array
    log_gamma_0: float 
    log_trans: float 
    log_gamma_k: float  

def init_particle(key: KeyArray) -> Particle:
    theta = jax.random.normal(key, shape=(d,)) / jnp.sqrt(rho0) 
    momentum= jax.random.normal(key, shape=(d,)) 

    lp_theta = gaussian.log_prob(theta, 0.0, rho0)
    lp_momentum = gaussian.log_prob(momentum, 0.0, 1.0) 
    return Particle(theta, momentum, lp_theta + lp_momentum, 0.0, 0.0)


def init_particles(rngs: nnx.Rngs, num_particles: int) -> Particle:
    keys = jax.random.split(rngs(), num_particles)
    return nnx.vmap(init_particle, in_axes=0)(keys)

@nnx.vmap(in_axes= (0, 0, None, None, None, None)) 
def forward_particle(
    key: KeyArray, 
    particle: Particle, 
    batch: Batch, 
    step_size: float, 
    damper: float,
    leapfrog_updates: int
) -> Particle: 
    std= jnp.sqrt(1-damper**2)
    noise= jax.random.normal(key, shape= particle.theta.shape) 

    momentum_tilda= particle.momentum* damper+ std*noise 

    mu_q= particle.momentum 
    mu_p= momentum_tilda 

    # logF= -0.5* jnp.sum(((particle.momentum-damper*mu_p)/std)**2) 
    # logB= -0.5* jnp.sum(((momentum_tilda- damper*mu_q)/std)**2) 
    logF= gaussian.log_prob(particle.momentum, damper* mu_p, (1/std)**2) 
    logB= gaussian.log_prob(momentum_tilda, damper*mu_q, (1/std)**2)

    # run leapfrog updates N times: 
    # theta_new, p_new= leapfrog_update(theta,p_tilda) 

    theta_new, momentum_new= multiple_leapfrog_updates(
        x= particle.theta, 
        p= momentum_tilda, 
        eta= step_size, 
        batch= batch, 
        M= M, 
        L= leapfrog_updates
    ) 
    grad_fn= nnx.value_and_grad(energy_fn) 
    energy_new, _= grad_fn(theta_new, batch) 

    return Particle( 
        theta= theta_new, 
        momentum= momentum_new, 
        log_gamma_0= particle.log_gamma_0, 
        log_trans= particle.log_trans + logF- logB, 
        log_gamma_k= -energy_new
    ) 

def forward_particles( 
    rngs: nnx.Rngs, 
    particles: Particle, 
    batch: Batch, 
    step_size: float, 
    damper: float, 
    leapfrog_updates: float
) -> Particle: 
    keys= jax.random.split(rngs(), len(particles.theta)) 
    return forward_particle(keys, particles, batch, step_size, damper, leapfrog_updates) 

# def resample_if_needed(rngs: nnx.Rngs, particles: Particle, thres: float) -> Particle: 
#     log_w= gaussian.log_prob(particles.momentum, 0.0, 1.0)+ particles.log_gamma_k+ particles.log_trans- particles.log_gamma_0 
#     num_particles= len(log_w) 
#     ess= jnp.exp(2* logsumexp(log_w)- logsumexp(2*log_w)) 

#     idxs= jax.random.categorical(rngs(), log_w, shape= (num_particles, )) 
#     resampled_particles= Particle( 
#         theta= jnp.take(particles.theta, idxs, axis=0), 
#         momentum= jnp.take(particles.momentum, idxs, axis=0), 
#         log_gamma_0= jnp.take(particles.log_gamma_k, idxs), 
#         log_trans= jnp.zeros_like(log_w), 
#         log_gamma_k= jnp.zeros_like(log_w),
#     ) 
#     log_Z_ratio_est= logsumexp(log_w)- jnp.log(num_particles) 

#     return jax.lax.cond(
#         ess<thres*num_particles, 
#         lambda _: (resampled_particles, log_Z_ratio_est, 1),
#         lambda _: (particles, 0.0, 0),
#         operand=None,
#     )

In [86]:
import optax
from typing import Optional, Tuple
from utils import get_sliding_batch_start_idxs

def run_uha( # returns Particle, log_Z_est
    rngs: nnx.Rngs, 
    num_particles: int, 
    batch_size: int, 
    overlap: int, 
    num_cycles: int, 
    damper: float, 
    leapfrog_updates: int, 
    # init_step_size: float, 
    # final_step_size: Optional[float] = 1.0e-6,  
    step_size: float, 
    # resample_thres: Optional[float]= 0.5, 
) -> Tuple[Particle, float]: #, int]: 
    
    start_idxs= jnp.concatenate(
        [
            get_sliding_batch_start_idxs(n, batch_size, overlap) 
            for _ in range(num_cycles)
        ]
    ) 
    num_batches= len(start_idxs) 
    # step_size_fn= optax.cosine_decay_schedule(
    #     init_step_size, num_batches, alpha= final_step_size/ init_step_size
    # ) 
    particles= init_particles(rngs, num_particles)  
    log_Z_est= 0.0 
    # resample_cnt= 0 

    @nnx.scan(in_axes= (nnx.Carry, 0), out_axes= nnx.Carry)
    def step(carry, start_idx): 
        k, rngs, particles, log_Z_est = carry #, resample_cnt = carry 
        start_indices= (start_idx, 0) 
        slice_sizes= (batch_size, d) 
        batch= jax.lax.dynamic_slice(x, start_indices, slice_sizes) 
        # step_size= step_size_fn(k) 

        particles= forward_particles(rngs, particles, batch, step_size, damper, leapfrog_updates) 

        log_w = particles.log_trans - particles.log_gamma_0 # particles.log_gamma_k + 
        num_particles = len(log_w)
        log_Z_est= logsumexp(log_w)- jnp.log(num_particles)  

        # particles, log_Z_ratio_est, resampled = resample_if_needed(
        #     rngs, particles, resample_thres
        # )

        return (
            k+1, 
            rngs, 
            particles, 
            log_Z_est,  #+ log_Z_ratio_est, 
            # resample_cnt+ resampled
        )
    
    _, rngs, particles, log_Z_est= step(
        (0, rngs, particles, 0.0), start_idxs
    ) 

    particles= forward_particles(rngs, particles, x, step_size , damper, leapfrog_updates) # step_size_fn(num_batches-1)
    
    log_prob_many= jax.vmap(gaussian.log_prob, in_axes=(0, None, None), out_axes=0) 

    log_w= log_prob_many(particles.momentum, jnp.zeros(d), 1.0)+ particles.log_gamma_k+ particles.log_trans- particles.log_gamma_0 
    
    log_Z_est+= logsumexp(log_w) - jnp.log(num_particles) 

    return particles, log_Z_est #, resample_cnt
 

In [99]:
num_particles = 1000
batch_size = 100
overlap = 50
num_cycles = 20
# init_step_size = 1.0e-2
# final_step_size = 1.0e-4 
step_size= 1e-3
damper= 0.95
leapfrog_updates= 5

# resample_thres= 0.5


particles, log_Z_est= run_uha(
    rngs,
    num_particles,
    batch_size,
    overlap,
    num_cycles, 
    damper, 
    leapfrog_updates, 
    # init_step_size,
    # final_step_size,  
    step_size, 
    # resample_thres,
)

log_prob_many= jax.vmap(gaussian.log_prob, in_axes=(0, None, None), out_axes=0) 

log_w= log_prob_many(particles.momentum, jnp.zeros(d), 1.0)+ particles.log_gamma_k+ particles.log_trans- particles.log_gamma_0
nw = jnp.exp(log_w - logsumexp(log_w))
est_mean = jnp.sum(nw[..., None] * particles.theta, axis=0)
est_var = jnp.sum(nw[..., None] * (particles.theta - est_mean) ** 2, axis=0)

num_batches = num_cycles * ((n - batch_size) // (batch_size - overlap) + 1) 
print(mu, est_mean)
print(rho, 1.0 / est_var)
print(log_Z_true, log_Z_est) 

[-5.0189567  2.5775943] [-4.8181405  2.4706914]
100.2 [ 559.2851 1540.1223]
-989.6046 -1038.6626


In [91]:
logsumexp(log_w) - jnp.log(num_particles) 

Array(-994.6008, dtype=float32)