In [None]:
%pip install --upgrade pip 
%pip install --upgrade jax 
%pip install "flax[all]"

In [5]:
# Caution(0121): JAX worked in python version 3.9 -> Try new setting in desktop

import jax
import jax.numpy as jnp 

import numpy as np 

x_jnp= jnp.arange(10) 
x_np= np.arange(10) 

print(x_jnp) 
print(x_np)

[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]


In [7]:
def selu(x, alpha=1.67, lambda_=1.05): 
    return lambda_*jnp.where(x>0, x, alpha*jnp.exp(x)-alpha) 

x= jnp.arange(1000000) 
%timeit selu(x).block_until_ready() 

selu_jit= jax.jit(selu) 

selu_jit(x).block_until_ready() 
%timeit selu_jit(x).block_until_ready()

1.32 ms ± 55.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
383 μs ± 449 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
f= lambda x: x**3+2*x**2-2*x+1 

dfdx= jax.grad(f) 
d2fdx= jax.grad(dfdx) 
d3fdx= jax.grad(d2fdx) 

print(dfdx(1.)) 
print(d2fdx(1.)) 
print(d3fdx(1.))

5.0
10.0
6.0


In [31]:
## Condition Settings 
import jax
import jax.numpy as jnp 
import jax.random as random 
from jax.scipy.special import logsumexp 
from jax.scipy.stats import norm , t 

# Unnomralized target gamma(x), intitial state proposal pi_0(x), number steps K, stepsize eta, annealing schedule (beta_K), damping coefficient h, mass matrix M, score model

# Algorithm1 (AIS) 

def unadjusted_langevin_ais(log_target_fn, log_initial_fn, K, step_size, beta_schedule, rng_key): 
    rng_key, subkey= random.split(rng_key) 
    x= random.normal(subkey, shape=(log_initial_fn.ndim,)) 
    log_w= -log_initial_fn(x)  
    x_samples= [x]


    for k in range(1, K+1): 
        rng_key, subkey= random.split(rng_key) 
        beta_k= beta_schedule[k] 
        beta_k_minus_1= beta_schedule[k-1]   

        def log_intermediate_fn(z): 
            return beta_k*log_target_fn(z)+ (1-beta_k)*log_initial_fn(z) 
        def log_previous_fn(z): 
            return (beta_k_minus_1*log_target_fn(z)+(1-beta_k_minus_1)*log_initial_fn(z)) 
        
        grad_x= jax.grad(log_intermediate_fn)(x) 
        noise= random.normal(subkey, shape= x.shape) 

        x_new= x+ step_size*grad_x+jnp.sqrt(2.0*step_size)*noise 

        grad_x_new= jax.grad(log_intermediate_fn)(x_new)

        def log_normal_density(z, mean, var): 
            return -0.5*jnp.log(2.0*jnp.pi*var)-0.5*jnp.sum((z-mean)**2)/var 
        
        log_F=log_normal_density(
            x_new, 
            mean= x+step_size*grad_x, 
            var= 2.0*step_size
        ) 

        log_B= log_normal_density( 
            x, 
            mean= x_new+step_size*grad_x_new, 
            var= 2.0*step_size
        ) 

        log_w= log_w+(log_B-log_F) 

        x=x_new
        x_samples.append(x) 
    
    log_w= log_w+log_target_fn(x) 

    return x_samples, log_w

def make_gaussian_mixture_logpdf(d, num_components=8, var=1.0): 
    rng= jax.random.PRNGKey(12345) 
    means = random.normal(rng, shape=(num_components, d)) + 3.0

    def logpdf(x): 
        x_reshaped= x[None, :] 
        diffs= x_reshaped - means 

        sq_maha= jnp.sum(diffs**2, axis=1) 
        log_probs= -0.5*d*jnp.log(2*jnp.pi*var)-0.5*sq_maha/var 
        return logsumexp(log_probs)- jnp.log(num_components) 
    
    logpdf.ndim=d
    return logpdf 

def make_student_t_logpdf(d, df=3): 
    def logpdf(x): 
        return jnp.sum(t.logpdf(x, df=df)) 
    logpdf.ndim=d 
    return logpdf 

def make_std_normal_logpdf(d, var=9.0): 
    def logpdf(x): 
        return -0.5*d*jnp.log(2*jnp.pi*var)-0.5*jnp.sum(x**2)/var
    logpdf.ndim=d 
    return logpdf

In [26]:
def run_experiment_on_target(make_log_target_fn, 
                             make_log_initial_fn, 
                             K_values=[64, 256], 
                             dims=[20, 200, 500], 
                             step_size=0.01, 
                             n_seeds=3): 
    results={} 
    for d in dims:  
        log_target_fn= make_log_target_fn(d)  
        log_initial_fn= make_log_initial_fn(d) 

        for K in K_values: 
            logZ_seeds= [] 
            for seed in range(n_seeds): 
                rng= random.PRNGKey(seed) 
                betas= jnp.linspace(0., 1., K+1) 

                _, log_w= unadjusted_langevin_ais(
                    log_target_fn, 
                    log_initial_fn, 
                    K=K, 
                    step_size= step_size, 
                    beta_schedule= betas, 
                    rng_key= rng
                )
                logZ_seeds.append(np.array(log_w)) 
                
            logZ_seeds= np.array(logZ_seeds) 
            mean_logZ= logZ_seeds.mean() 
            sem_logZ= logZ_seeds.std()/np.sqrt(n_seeds) 
            results[(d, K)]= (mean_logZ, sem_logZ) 

    return results

In [32]:
if __name__ == "__main__": 
    gm_results= run_experiment_on_target(
        make_log_target_fn  = lambda d: make_gaussian_mixture_logpdf(d, num_components=8, var=1.0),
        make_log_initial_fn = lambda d: make_std_normal_logpdf(d, var=9.0),
        K_values=[64, 256], 
        dims= [20, 200, 500], 
        step_size=0.01, 
        n_seeds=3
    ) 
    for (d, K), (mean_logZ, sem_logZ) in gm_results.items(): 
        print(f"[Gaussian mixture, d={d}, K={K}]  logZ = {mean_logZ:.3f} ± {sem_logZ:.3f}")

    t_results= run_experiment_on_target(
        make_log_target_fn= lambda dd: make_student_t_logpdf(dd, df=3),  
        make_log_initial_fn = lambda d: make_std_normal_logpdf(d, var=1.0),
        K_values=[64, 256], 
        dims= [20, 200, 500], 
        step_size=0.01, 
        n_seeds=3
    ) 
    for (d, K), (mean_logZ, sem_logZ) in t_results.items(): 
        print(f"[Student-t df=3, d={d}, K={K}]  logZ = {mean_logZ:.3f} ± {sem_logZ:.3f}")
    

[Gaussian mixture, d=20, K=64]  logZ = -68.893 ± 11.313
[Gaussian mixture, d=20, K=256]  logZ = -49.879 ± 9.933
[Gaussian mixture, d=200, K=64]  logZ = -700.062 ± 11.594
[Gaussian mixture, d=200, K=256]  logZ = -455.788 ± 6.110
[Gaussian mixture, d=500, K=64]  logZ = -1808.832 ± 35.289
[Gaussian mixture, d=500, K=256]  logZ = -1176.230 ± 15.941
[Student-t df=3, d=20, K=64]  logZ = -0.819 ± 0.593
[Student-t df=3, d=20, K=256]  logZ = -0.779 ± 0.422
[Student-t df=3, d=200, K=64]  logZ = -9.295 ± 2.183
[Student-t df=3, d=200, K=256]  logZ = -8.750 ± 0.819
[Student-t df=3, d=500, K=64]  logZ = -27.825 ± 0.926
[Student-t df=3, d=500, K=256]  logZ = -23.324 ± 0.796


이제 step-size optimization 포함해서 다시 시도 (결과가 쓰레기..) : GPU 로 다시 추후 시도

In [34]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.stats import t
import optax
import numpy as np

##############################################################################
# Target & Initial distributions
##############################################################################

def make_student_t_logpdf(d, df=3):
    """
    i.i.d. Student-t(df) in R^d
    """
    def logpdf(x):
        return jnp.sum(t.logpdf(x, df=df))
    return logpdf

def make_std_normal_logpdf(d):
    """
    Standard normal N(0,I_d) in dimension d
    """
    def logpdf(x):
        sq = jnp.sum(x**2)
        return -0.5*d*jnp.log(2.*jnp.pi) - 0.5*sq
    return logpdf

##############################################################################
# Unadjusted Langevin AIS w/ per-step-size parameters
##############################################################################

def unadjusted_langevin_ais_per_step(
    log_target_fn,
    log_initial_fn,
    d,
    betas,            # shape (K+1,)
    step_size_params, # shape (K,)
    rng_key
):
    K = len(betas) - 1
    rng_key, subkey = random.split(rng_key)
    # x0 ~ pi0
    x = random.normal(subkey, shape=(d,))
    # log_w = - log pi0(x0)
    log_w = -log_initial_fn(x)

    def log_normal_density(z, mean, var):
        # isotropic N(mean, var I)
        diff = z - mean
        sq   = jnp.sum(diff**2)
        d_   = diff.shape[0]
        return -0.5 * d_ * jnp.log(2.*jnp.pi*var) - 0.5 * (sq/var)

    for k in range(1, K+1):
        rng_key, subkey = random.split(rng_key)
        beta_k = betas[k]

        # define gamma_k
        def log_intermediate(z):
            return beta_k*log_target_fn(z) + (1.-beta_k)*log_initial_fn(z)

        eta     = jnp.exp(step_size_params[k-1])
        grad_x  = jax.grad(log_intermediate)(x)
        noise   = random.normal(subkey, shape=x.shape)
        x_new   = x + eta*grad_x + jnp.sqrt(2.*eta)*noise

        # weight update
        log_F = log_normal_density(
            x_new,
            mean=x + eta*grad_x,
            var=2.*eta
        )
        grad_x_new = jax.grad(log_intermediate)(x_new)
        log_B = log_normal_density(
            x,
            mean=x_new + eta*grad_x_new,
            var=2.*eta
        )
        log_w += (log_B - log_F)
        x = x_new

    log_w += log_target_fn(x)
    return x, log_w

##############################################################################
# 2) negative_logZ for training
##############################################################################

def negative_logZ(params, d, betas, log_target_fn, log_initial_fn, rng_key):
    """
    Runs AIS with step_size_params=params, returns -logZ
    """
    _, log_w = unadjusted_langevin_ais_per_step(
        log_target_fn  = log_target_fn,
        log_initial_fn = log_initial_fn,
        d             = d,
        betas         = betas,
        step_size_params=params,
        rng_key       = rng_key
    )
    return -log_w

##############################################################################
# 3) Training routine
##############################################################################

def train_step_sizes(
    d,
    K,
    log_target_fn,
    log_initial_fn,
    n_iters=300,
    lr=1e-2,
    seed=42
):
    """
    Trains per-step-size for dimension d, K steps.  Returns the array of step sizes.
    """
    # Initialize log step sizes
    theta_init = jnp.log(0.01)*jnp.ones(K)
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(theta_init)

    betas = jnp.linspace(0., 1., K+1)

    def loss_fn(params, rng_key):
        return negative_logZ(params, d, betas, log_target_fn, log_initial_fn, rng_key)

    @jax.jit
    def step(params, opt_state, rng_key):
        val, grads = jax.value_and_grad(loss_fn)(params, rng_key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, val

    rng = random.PRNGKey(seed)
    params = theta_init
    for i in range(n_iters):
        rng, subkey = random.split(rng)
        params, opt_state, current_loss = step(params, opt_state, subkey)
        if i % 50 == 0:
            s = jnp.exp(params)
            print(f"[TRAIN] iter={i}, neg logZ={current_loss:.3f}, step sizes[0..3]={s[:3]}")

    return jnp.exp(params)  # the actual step sizes

##############################################################################
# 4) Final experiment procedure to replicate "paper" style results
##############################################################################

def run_experiment_student_t(
    dims = [20, 200, 500],
    Ks   = [64, 256],
    df   = 3,
    n_train_iters=300,
    n_eval_seeds=5,  # how many times we evaluate for final?
    lr=1e-2
):
    """
    For each d in dims and K in Ks:
      1) Train the step sizes by maximizing logZ
      2) Evaluate final logZ over multiple seeds
      3) Print mean +- std / sqrt(n_eval_seeds)
    """
    results = {}
    for d in dims:
        print(f"\n=== Dimension d={d} ===")
        # Build the target & initial for dimension d
        log_target_fn  = make_student_t_logpdf(d, df=df)
        log_initial_fn = make_std_normal_logpdf(d)

        for K in Ks:
            print(f">>> K={K}, training per-step-size ...")
            # 1) Train
            best_step_sizes = train_step_sizes(
                d=d, K=K,
                log_target_fn=log_target_fn,
                log_initial_fn=log_initial_fn,
                n_iters=n_train_iters,
                lr=lr,
                seed=0
            )
            print(f"Trained step sizes (first few): {best_step_sizes[:3]} ...")

            # 2) Evaluate over multiple seeds
            betas = jnp.linspace(0., 1., K+1)
            logZ_values = []
            for seed_eval in range(n_eval_seeds):
                rng = random.PRNGKey(seed_eval+100)  # offset seed
                _, logw = unadjusted_langevin_ais_per_step(
                    log_target_fn  = log_target_fn,
                    log_initial_fn = log_initial_fn,
                    d             = d,
                    betas         = betas,
                    step_size_params=jnp.log(best_step_sizes),  # we store them as log in the function
                    rng_key       = rng
                )
                logZ_values.append(np.array(logw))
            logZ_values = np.array(logZ_values)
            mean_logZ = logZ_values.mean()
            sem_logZ  = logZ_values.std() / np.sqrt(n_eval_seeds)

            print(f"[RESULT] d={d}, K={K}, mean logZ={mean_logZ:.3f} ± {sem_logZ:.3f}")
            results[(d,K)] = (mean_logZ, sem_logZ)

    return results

##############################################################################
# "MAIN"
##############################################################################

if __name__ == "__main__":
    final_results = run_experiment_student_t(
        dims=[20, 200],  # you can do [20, 200, 500]
        Ks=[64, 256],
        df=3,
        n_train_iters=300,
        n_eval_seeds=3,
        lr=1e-2
    )
    print("\nFinal results dictionary:")
    for (d,K), (m, s) in final_results.items():
        print(f"d={d}, K={K} => logZ = {m:.3f} ± {s:.3f}")



=== Dimension d=20 ===
>>> K=64, training per-step-size ...
[TRAIN] iter=0, neg logZ=1.931, step sizes[0..3]=[0.0101005 0.0101005 0.0099005]
[TRAIN] iter=50, neg logZ=1.516, step sizes[0..3]=[0.00968404 0.01000272 0.00891441]
[TRAIN] iter=100, neg logZ=1.681, step sizes[0..3]=[0.00907268 0.01075002 0.00889927]
[TRAIN] iter=150, neg logZ=1.142, step sizes[0..3]=[0.00939275 0.01049997 0.00921382]
[TRAIN] iter=200, neg logZ=0.143, step sizes[0..3]=[0.00918134 0.00974978 0.00926989]
[TRAIN] iter=250, neg logZ=2.283, step sizes[0..3]=[0.00935059 0.00929267 0.00961471]
Trained step sizes (first few): [0.00987686 0.00864683 0.0100778 ] ...
[RESULT] d=20, K=64, mean logZ=-1.374 ± 0.334
>>> K=256, training per-step-size ...


In [2]:
import jax
print(jax.devices())

[CpuDevice(id=0)]


Now let's try UHA!

In [3]:
def leapfrog_step(x, p, step_size, log_target_fn, mass_inv=None):  
    if mass_inv is None: 
        mass_inv= jnp.eye(x.shape[0]) 
    
    def potential_energy(z): 
        return -log_target_fn(z) 
    
    gradient_x= jax.grad(potential_energy)(x) 
    p_half= p- 0.5*step_size*gradient_x 

    x_new= x+step_size*(mass_inv@ p_half) 

    gradient_x_new= jax.grad(potential_energy)(x_new) 
    p_new= p_half-0.5*step_size*gradient_x_new 

    return x_new, p_new 


def leapfrog_iterations(x, p, step_size, iterations, log_target_fn, mass_inv=None): 
    def body_fn(_, carry): 
        x_curr, p_curr= carry 
        x_next, p_next= leapfrog_step(
            x_curr, p_curr, step_size, log_target_fn, mass_inv
        ) 
        return None, (x_next, p_next) 
    _, (x_final, p_final) = jax.lax.scan(body_fn, None, (x,p), length= iterations) 
    return x_final, p_final 

In [7]:
def unadjusted_hamilton_ais_per_step(
    log_target_fn,
    log_initial_fn,
    d,
    betas,            # shape (K+1,)
    step_size_params, # shape (K,)
    h, 
    mass_inv=None, 
    rng_key=None
): 
    if rng_key is None: 
        rng_key= jax.random.PRNGKey(0) 
    if mass_inv is None: 
        mass_inv= jnp.eye(d) 
    
    K = len(betas) - 1
    rng_key, subkey = random.split(rng_key)
    # x0 ~ pi0
    x= random.normal(subkey, shape=(d,)) 
    rng_key, subkey= random.split(rng_key) 
    p= random.normal(subkey, shape=(d,))
    # log_w = - log pi0(x0)

    def log_gaussian_p(p_val, mean, cov):
        diff = p_val - mean
        sq   = jnp.sum(diff**2)
        d_   = p_val.shape[0]
        return -0.5 * d_ * jnp.log(2.*jnp.pi) - 0.5*sq # Needs adjustment if not M=I

    log_w= -log_initial_fn(x)-log_gaussian_p(p, jnp.zeors(d), jnp.eye(d)) 

    for k in range(1, K+1):
        beta_k = betas[k]

        # define gamma_k
        def log_intermediate(z):
            return beta_k*log_target_fn(z) + (1.-beta_k)*log_initial_fn(z)

        rng_key, subkey= random.split(rng_key) 
        noise= random.normal(subkey, shape=(d, )) 

        p_mean= h*p 
        p_std= jnp.sqrt(1.0-h**2) 
        p_new= p_mean+p_std*noise # Again, we assume M=I

        log_num= log_gaussian_p(p, h*p_new, jnp.eye(d)) 
        log_den= log_gaussian_p(p_new, h*p, jnp.eye(d)) # Again, we assume M=I
        log_w= log_w+ (log_num-log_den) 

        p= p_new

        eps= jnp.exp(step_size_params[k-1]) 

        n_leapfrog= 5 
        x_new, p_new= leapfrog_iterations(
            x, p, 
            stepsize= eps, 
            iterations= n_leapfrog, 
            log_target_fn= log_intermediate, 
            mass_inv= mass_inv
        )

        x, p= x_new, p_new

    log_w = log_w+ log_target_fn(x) + log_gaussian_p(p, jnp.zeros(d), jnp.eye(d)) 
    return x, p, log_w

In [19]:
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm, t

##############################################################################
# 1) LOGPDFs: Gaussian mixture, Student-t, and Gaussian initial
##############################################################################

def make_gaussian_mixture_logpdf(d, num_components=8, var=1.0):
    """
    A d-dimensional mixture of Gaussians: means ~ N(0, 3^2 I).
    Each component has covariance var*I.
    """
    rng = jax.random.PRNGKey(1234)
    means = random.normal(rng, shape=(num_components, d))+3

    def logpdf(x):
        # shape(x) = (d,)
        diffs = x[None, :] - means  # (num_components, d)
        sq_maha = jnp.sum(diffs**2, axis=1)
        d_ = x.shape[0]
        log_comps = -0.5 * d_ * jnp.log(2.*jnp.pi*var) - 0.5*(sq_maha / var)
        return logsumexp(log_comps) - jnp.log(num_components)

    return logpdf

def make_student_t_logpdf(d, df=3):
    """
    i.i.d. Student-t(df) in R^d.
    """
    def logpdf(x):
        return jnp.sum(t.logpdf(x, df=df))
    return logpdf

def make_std_normal_logpdf(d, var=1.0):
    """
    logpdf of N(0, var I).
    """
    def logpdf(x):
        sq = jnp.sum(x**2)
        d_ = x.shape[0]
        return -0.5*d_*jnp.log(2.*jnp.pi*var) - 0.5*(sq/var)
    return logpdf

##############################################################################
# 2) GAUSSIAN MOMENTUM HELPERS: log N(p; mean, I)
##############################################################################

def log_gaussian_p(p, mean):
    diff = p - mean
    sq   = jnp.sum(diff**2)
    d_   = p.shape[0]
    return -0.5*d_*jnp.log(2.*jnp.pi) - 0.5*sq

##############################################################################
# 3) LEAPFROG integrator (M=I) for potential U(x) = -log_intermediate(x)
##############################################################################

def leapfrog_step(x, p, step_size, log_intermediate):
    """
    One leapfrog step with M=I.
    """
    def potential_energy(z):
        return - log_intermediate(z)

    gradient_x = jax.grad(potential_energy)(x)
    p_half = p - 0.5 * step_size * gradient_x
    x_new  = x + step_size * p_half

    gradient_x_new = jax.grad(potential_energy)(x_new)
    p_new      = p_half - 0.5 * step_size * gradient_x_new
    return x_new, p_new

def leapfrog_integration(x, p, step_size, n_leapfrog, log_intermediate):
    """
    Repeats 'n_leapfrog' times the leapfrog step.
    """
    for _ in range(n_leapfrog):
        x, p = leapfrog_step(x, p, step_size, log_intermediate)
    return x, p

##############################################################################
# 4) Unadjusted Hamiltonian AIS (UHA) with FIXED step size(s)
##############################################################################

def unadjusted_hamiltonian_ais_fixed(
    log_target_fn,
    log_initial_fn,
    d,
    K,
    step_size,     # scalar or array shape(K,)
    damping_coeff, # h in [0,1)
    n_leapfrog=5,
    rng_key=jax.random.PRNGKey(0)
):
    """
    UHA with partial momentum refresh p_k ~ N(h p_{k-1}, (1-h^2)I).
    We only do bridging corrections for the momentum update, skipping the
    reverse kernel for the leapfrog step => 'truly unadjusted'.

    Returns (x_final, p_final, log_w).
    """
    # unify step_size
    if isinstance(step_size, float):
        step_size = jnp.array([step_size]*K)

    # schedule
    betas = jnp.linspace(0.0, 1.0, K+1)

    # 1) sample x0, p0
    rng_key, subkey = random.split(rng_key)
    x = random.normal(subkey, shape=(d,))
    rng_key, subkey = random.split(rng_key)
    p = random.normal(subkey, shape=(d,))

    # 2) init AIS weight
    log_w = -log_initial_fn(x) - log_gaussian_p(p, jnp.zeros(d))

    for k in range(1, K+1):
        beta_k = betas[k]

        # define gamma_k
        def log_intermediate(z):
            return beta_k * log_target_fn(z) + (1 - beta_k)*log_initial_fn(z)

        # (a) partial momentum refresh
        rng_key, subkey = random.split(rng_key)
        noise = random.normal(subkey, shape=(d,))
        p_mean = damping_coeff * p
        p_new  = p_mean + jnp.sqrt(1.-damping_coeff**2)*noise

        # bridging correction for momentum update
        log_num = log_gaussian_p(p, damping_coeff*p_new)
        log_den = log_gaussian_p(p_new, damping_coeff*p)
        log_w  += (log_num - log_den)
        p = p_new

        # (b) leapfrog
        eps_k = step_size[k-1]
        x_new, p_new = leapfrog_integration(
            x, p,
            step_size   = eps_k,
            n_leapfrog  = n_leapfrog,
            log_intermediate=log_intermediate
        )
        # skip bridging correction => unadjusted

        x, p = x_new, p_new

    # 3) final correction
    log_w += log_target_fn(x) + log_gaussian_p(p, jnp.zeros(d))
    return x, p, log_w

##############################################################################
# 5) A "run_experiment_uha_on_target" function, same style as ULA
##############################################################################

def run_experiment_uha_on_target(
    make_log_target_fn,
    make_log_initial_fn,
    K_values=[64,256],
    dims=[20,200,500],
    step_size=0.01,
    damping=0.9,
    n_leapfrog=5,
    n_seeds=3
):
    """
    For each dimension, each K, run unadjusted_hamiltonian_ais_fixed multiple
    times with different seeds and average log_w => (mean_logZ, sem).
    Returns a dict.
    """
    results = {}
    for d in dims:
        # build target & initial logpdf
        log_target_fn  = make_log_target_fn(d)
        log_initial_fn = make_log_initial_fn(d)

        for K in K_values:
            logZ_list = []
            for seed in range(n_seeds):
                rng = random.PRNGKey(seed)
                _, _, lw = unadjusted_hamiltonian_ais_fixed(
                    log_target_fn  = log_target_fn,
                    log_initial_fn = log_initial_fn,
                    d             = d,
                    K             = K,
                    step_size     = step_size,
                    damping_coeff = damping,
                    n_leapfrog   = n_leapfrog,
                    rng_key      = rng
                )
                logZ_list.append(np.array(lw))

            logZ_list = np.array(logZ_list)
            mean_logZ = logZ_list.mean()
            sem_logZ  = logZ_list.std() / np.sqrt(n_seeds)
            results[(d,K)] = (mean_logZ, sem_logZ)
    return results

##############################################################################
# 6) MAIN: replicate the "same conditions" as ULA (dims=[20,200,500], Ks=[64,256])
##############################################################################

if __name__ == "__main__":
    # Example: test on Gaussian mixture => pi0 = N(0, 9I)
    # EXACTLY like the user might have done for ULA
    print("======== UHA on Gaussian Mixture, same conditions ========")
    gm_results = run_experiment_uha_on_target(
        make_log_target_fn  = lambda d: make_gaussian_mixture_logpdf(d, num_components=8, var=1.0),
        make_log_initial_fn = lambda d: make_std_normal_logpdf(d, var=9.0),
        K_values    = [64, 256],
        dims        = [20, 200, 500],
        step_size   = 0.01,   # fix step size
        damping     = 0.9,
        n_leapfrog  = 5,
        n_seeds     = 3
    )
    for (d,K), (mean_logZ, sem_logZ) in gm_results.items():
        print(f"[GMM UHA, d={d}, K={K}] logZ = {mean_logZ:.3f} ± {sem_logZ:.3f}")

    # Now test on Student-t => pi0 = N(0,I)
    print("\n======== UHA on Student-t(df=3), same conditions ========")
    t_results = run_experiment_uha_on_target(
        make_log_target_fn  = lambda dd: make_student_t_logpdf(dd, df=3),
        make_log_initial_fn = lambda d: make_std_normal_logpdf(d, var=1.0),
        K_values    = [64, 256],
        dims        = [20, 200, 500],
        step_size   = 0.01,
        damping     = 0.9,
        n_leapfrog  = 5,
        n_seeds     = 3
    )
    for (d,K), (mean_logZ, sem_logZ) in t_results.items():
        print(f"[T UHA, d={d}, K={K}]  logZ = {mean_logZ:.3f} ± {sem_logZ:.3f}")


[GMM UHA, d=20, K=64] logZ = -34.254 ± 8.038
[GMM UHA, d=20, K=256] logZ = 1.241 ± 2.851
[GMM UHA, d=200, K=64] logZ = -364.356 ± 19.768
[GMM UHA, d=200, K=256] logZ = 21.797 ± 4.549
[GMM UHA, d=500, K=64] logZ = -900.890 ± 22.825
[GMM UHA, d=500, K=256] logZ = 79.052 ± 8.423

[T UHA, d=20, K=64]  logZ = 2.043 ± 2.491
[T UHA, d=20, K=256]  logZ = 0.180 ± 1.663
[T UHA, d=200, K=64]  logZ = -26.983 ± 10.888
[T UHA, d=200, K=256]  logZ = -35.848 ± 5.678
[T UHA, d=500, K=64]  logZ = -58.248 ± 10.049
[T UHA, d=500, K=256]  logZ = -76.936 ± 11.134


In [25]:
import jax
import jax.numpy as jnp
import jax.random as random
from functools import partial

def leapfrog_step_mass(x, p, step_size, log_intermediate, mass_matrix_inv):
    """
    One leapfrog step with custom mass matrix.
    """
    def potential_energy(z):
        return -log_intermediate(z)
    
    gradient_x = jax.grad(potential_energy)(x)
    p_half = p - 0.5 * step_size * gradient_x
    x_new = x + step_size * mass_matrix_inv @ p_half
    gradient_x_new = jax.grad(potential_energy)(x_new)
    p_new = p_half - 0.5 * step_size * gradient_x_new
    
    return x_new, p_new

def leapfrog_integration_mass(x, p, step_size, n_leapfrog, log_intermediate, mass_matrix_inv):
    """
    n_leapfrog steps of leapfrog integration with mass matrix.
    """
    def body_fn(_, state):
        x, p = state
        return leapfrog_step_mass(x, p, step_size, log_intermediate, mass_matrix_inv)
    
    return jax.lax.fori_loop(0, n_leapfrog, body_fn, (x, p)) 

def compute_mass_matrix_from_samples(samples):
    """
    Compute empirical mass matrix from samples.
    Uses Welford's online algorithm for numerical stability.
    """
    mean = jnp.mean(samples, axis=0)
    centered = samples - mean[None, :]
    cov = jnp.mean(centered[:, :, None] * centered[:, None, :], axis=0)
    # Add small diagonal term for stability
    eps = 1e-5
    return cov + eps * jnp.eye(cov.shape[0])

def adapt_step_size(accept_stats, target_accept=0.65, adaptation_rate=0.1):
    """
    Adapt step size based on acceptance statistics.
    Uses Robbins-Monro algorithm.
    """
    log_step_size = jnp.log(step_size)
    new_log_step_size = log_step_size + adaptation_rate * (accept_stats - target_accept)
    return jnp.exp(new_log_step_size) 




@partial(jax.jit, static_argnums=(0, 1, 3, 5, 6))
def uha_single_chain(
    log_target_fn,
    log_initial_fn,
    rng_key,
    d,
    step_size,
    K,
    n_leapfrog,
    damping_coeff,
    mass_matrix=None,
):
    """
    Single chain of UHA with improved numerical stability.
    """
    if mass_matrix is None:
        mass_matrix = jnp.eye(d)
    mass_matrix_inv = jnp.linalg.inv(mass_matrix)
    
    # Schedule
    betas = jnp.linspace(0.0, 1.0, K+1)
    
    # Initial samples with scaled mass matrix
    key1, key2 = random.split(rng_key)
    x = random.multivariate_normal(key1, jnp.zeros(d), mass_matrix)
    p = random.multivariate_normal(key2, jnp.zeros(d), mass_matrix)
    
    # Initialize log weight with stable computation
    log_det_mass = jnp.linalg.slogdet(mass_matrix)[1]  # More stable than log(det())
    kinetic_energy = 0.5 * p @ mass_matrix_inv @ p
    log_w = -log_initial_fn(x) - kinetic_energy - 0.5 * d * jnp.log(2 * jnp.pi) - 0.5 * log_det_mass

    def scan_body(carry, beta_k):
        x, p, log_w, key = carry
        
        def log_intermediate(z):
            return jnp.where(
                beta_k > 0.999,
                log_target_fn(z),  # Avoid numerical issues near beta=1
                beta_k * log_target_fn(z) + (1 - beta_k) * log_initial_fn(z)
            )
        
        # Momentum refresh with careful scaling
        key, subkey = random.split(key)
        noise = random.multivariate_normal(subkey, jnp.zeros(d), mass_matrix)
        p_mean = damping_coeff * p
        p_new = p_mean + jnp.sqrt(1. - damping_coeff**2) * noise
        
        # Stable momentum bridging correction
        log_num = -0.5 * jnp.sum(jnp.square(mass_matrix_inv @ (p - damping_coeff * p_new)))
        log_den = -0.5 * jnp.sum(jnp.square(mass_matrix_inv @ (p_new - damping_coeff * p)))
        log_w += (log_num - log_den)
        
        # Scaled leapfrog step size
        effective_step_size = step_size / jnp.sqrt(d)
        x_new, p_new = leapfrog_integration_mass(
            x, p_new, effective_step_size, n_leapfrog, log_intermediate, mass_matrix_inv
        )
        
        return (x_new, p_new, log_w, key), None

    # Run chain
    (x, p, log_w, _), _ = jax.lax.scan(
        scan_body,
        (x, p, log_w, rng_key),
        betas[1:]
    )
    
    # Final correction with stable computation
    kinetic_energy = 0.5 * p @ mass_matrix_inv @ p
    log_w += log_target_fn(x) - kinetic_energy - 0.5 * d * jnp.log(2 * jnp.pi) - 0.5 * log_det_mass
    
    return x, p, log_w

def run_parallel_uha_experiment(
    make_log_target_fn,
    make_log_initial_fn,
    K_values=[64, 256],
    dims=[20, 200, 500],
    damping=0.9,
    n_leapfrog=5,
    n_chains=10,
    make_mass_matrix_fn=None  # Added parameter
):
    """
    Parallel UHA experiments with dimension-aware parameters.
    
    Args:
        make_mass_matrix_fn: Optional function to create custom mass matrix.
                           If None, uses default stable mass matrix.
    """
    if make_mass_matrix_fn is None:
        make_mass_matrix_fn = lambda d: jnp.diag(jnp.linspace(0.1, 2.0, d))
    
    results = {}
    
    for d in dims:
        # Dimension-dependent step size
        step_size = 0.1 / jnp.sqrt(d)
        
        log_target_fn = make_log_target_fn(d)
        log_initial_fn = make_log_initial_fn(d)
        mass_matrix = make_mass_matrix_fn(d)
        
        for K in K_values:
            parallel_uha = jax.vmap(
                lambda key: uha_single_chain(
                    log_target_fn, log_initial_fn, key, d,
                    step_size, K, n_leapfrog, damping, mass_matrix
                )
            )
            
            keys = random.split(random.PRNGKey(0), n_chains)
            _, _, log_weights = parallel_uha(keys)
            
            # Stable mean computation
            max_log_w = jnp.max(log_weights)
            shifted_weights = jnp.exp(log_weights - max_log_w)
            mean_logZ = jnp.log(jnp.mean(shifted_weights)) + max_log_w
            sem_logZ = jnp.std(log_weights) / jnp.sqrt(n_chains)
            
            results[(d, K)] = (float(mean_logZ), float(sem_logZ))
    
    return results

# Example usage with custom mass matrix
if __name__ == "__main__":
    def make_diagonal_mass_matrix(d):
        # Example: diagonal mass matrix with increasing values
        return jnp.diag(jnp.exp(jnp.linspace(0, 1, d)))
    
    print("======== Parallel UHA on Gaussian Mixture ========")
    gm_results = run_parallel_uha_experiment(
        make_log_target_fn=lambda d: make_gaussian_mixture_logpdf(d, num_components=8, var=1.0),
        make_log_initial_fn=lambda d: make_std_normal_logpdf(d, var=9.0),
        make_mass_matrix_fn=make_diagonal_mass_matrix,  # Custom mass matrix
        K_values=[64, 256],
        dims=[20, 200, 500],
        damping=0.9,
        n_leapfrog=5,
        n_chains=10  # Number of parallel chains
    )
    
    for (d, K), (mean_logZ, sem_logZ) in gm_results.items():
        print(f"[GMM UHA, d={d}, K={K}] logZ = {mean_logZ:.3f} ± {sem_logZ:.3f}")

[GMM UHA, d=20, K=64] logZ = -104.276 ± 6.549
[GMM UHA, d=20, K=256] logZ = -74.712 ± 4.601
[GMM UHA, d=200, K=64] logZ = -1452.594 ± 17.588
[GMM UHA, d=200, K=256] logZ = -1443.069 ± 17.152
[GMM UHA, d=500, K=64] logZ = -3762.187 ± 25.924
[GMM UHA, d=500, K=256] logZ = -3746.924 ± 27.925


In [None]:
@partial(jax.jit, static_argnums=(0, 1, 3, 5, 6))
def uha_adaptation_phase(
    log_target_fn,
    log_initial_fn,
    rng_key,
    d,
    init_step_size,
    K,
    n_leapfrog,
    damping_coeff,
    n_adaptation_steps=1000,
):
    """
    Run adaptation phase to tune step size and mass matrix.
    """
    # Initialize
    mass_matrix = jnp.eye(d)
    step_size = init_step_size
    samples = []
    accept_stats = []
    
    def adaptation_step(state, _):
        x, p, key, step_size, mass_matrix = state
        
        # Generate proposal
        key, subkey = random.split(key)
        x_prop, p_prop, log_w = uha_single_chain(
            log_target_fn, log_initial_fn, subkey, d,
            step_size, K, n_leapfrog, damping_coeff, mass_matrix
        )
        
        # Accept/reject step (simplified for adaptation)
        key, subkey = random.split(key)
        accept_prob = jnp.minimum(1.0, jnp.exp(log_w))
        accepted = random.bernoulli(subkey, accept_prob)
        
        # Update statistics
        x_new = jnp.where(accepted, x_prop, x)
        samples.append(x_new)
        accept_stats.append(accepted)
        
        # Adapt step size and mass matrix
        if len(samples) >= 100:  # Wait for burn-in
            step_size = adapt_step_size(jnp.mean(accept_stats[-100:]))
            mass_matrix = compute_mass_matrix_from_samples(jnp.stack(samples[-100:]))
        
        return (x_new, p_prop, key, step_size, mass_matrix), None
    
    # Run adaptation
    init_state = (jnp.zeros(d), jnp.zeros(d), rng_key, init_step_size, mass_matrix)
    final_state, _ = jax.lax.scan(adaptation_step, init_state, jnp.arange(n_adaptation_steps))
    
    return final_state[3], final_state[4]  # Return tuned step_size and mass_matrix
