In [None]:
%pip install --upgrade pip 
%pip install --upgrade jax 
%pip install "flax[all]"

In [2]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.stats import t
import optax
import numpy as np

##############################################################################
# Target & Initial distributions
##############################################################################

def make_student_t_logpdf(d, df=3):
    """
    i.i.d. Student-t(df) in R^d
    """
    def logpdf(x):
        return jnp.sum(t.logpdf(x, df=df))
    return logpdf

def make_std_normal_logpdf(d):
    """
    Standard normal N(0,I_d) in dimension d
    """
    def logpdf(x):
        sq = jnp.sum(x**2)
        return -0.5*d*jnp.log(2.*jnp.pi) - 0.5*sq
    return logpdf

##############################################################################
# Unadjusted Langevin AIS w/ per-step-size parameters
##############################################################################

def unadjusted_langevin_ais_per_step(
    log_target_fn,
    log_initial_fn,
    d,
    betas,            # shape (K+1,)
    step_size_params, # shape (K,)
    rng_key
):
    K = len(betas) - 1
    rng_key, subkey = random.split(rng_key)
    # x0 ~ pi0
    x = random.normal(subkey, shape=(d,))
    # log_w = - log pi0(x0)
    log_w = -log_initial_fn(x)

    def log_normal_density(z, mean, var):
        # isotropic N(mean, var I)
        diff = z - mean
        sq   = jnp.sum(diff**2)
        d_   = diff.shape[0]
        return -0.5 * d_ * jnp.log(2.*jnp.pi*var) - 0.5 * (sq/var)

    for k in range(1, K+1):
        rng_key, subkey = random.split(rng_key)
        beta_k = betas[k]

        # define gamma_k
        def log_intermediate(z):
            return beta_k*log_target_fn(z) + (1.-beta_k)*log_initial_fn(z)

        eta     = jnp.exp(step_size_params[k-1])
        grad_x  = jax.grad(log_intermediate)(x)
        noise   = random.normal(subkey, shape=x.shape)
        x_new   = x + eta*grad_x + jnp.sqrt(2.*eta)*noise

        # weight update
        log_F = log_normal_density(
            x_new,
            mean=x + eta*grad_x,
            var=2.*eta
        )
        grad_x_new = jax.grad(log_intermediate)(x_new)
        log_B = log_normal_density(
            x,
            mean=x_new + eta*grad_x_new,
            var=2.*eta
        )
        log_w += (log_B - log_F)
        x = x_new

    log_w += log_target_fn(x)
    return x, log_w

##############################################################################
# 2) negative_logZ for training
##############################################################################

def negative_logZ(params, d, betas, log_target_fn, log_initial_fn, rng_key):
    """
    Runs AIS with step_size_params=params, returns -logZ
    """
    _, log_w = unadjusted_langevin_ais_per_step(
        log_target_fn  = log_target_fn,
        log_initial_fn = log_initial_fn,
        d             = d,
        betas         = betas,
        step_size_params=params,
        rng_key       = rng_key
    )
    return -log_w

##############################################################################
# 3) Training routine
##############################################################################

def train_step_sizes(
    d,
    K,
    log_target_fn,
    log_initial_fn,
    n_iters=300,
    lr=1e-2,
    seed=42
):
    """
    Trains per-step-size for dimension d, K steps.  Returns the array of step sizes.
    """
    # Initialize log step sizes
    theta_init = jnp.log(0.01)*jnp.ones(K)
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(theta_init)

    betas = jnp.linspace(0., 1., K+1)

    def loss_fn(params, rng_key):
        return negative_logZ(params, d, betas, log_target_fn, log_initial_fn, rng_key)

    @jax.jit
    def step(params, opt_state, rng_key):
        val, grads = jax.value_and_grad(loss_fn)(params, rng_key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, val

    rng = random.PRNGKey(seed)
    params = theta_init
    for i in range(n_iters):
        rng, subkey = random.split(rng)
        params, opt_state, current_loss = step(params, opt_state, subkey)
        if i % 50 == 0:
            s = jnp.exp(params)
            print(f"[TRAIN] iter={i}, neg logZ={current_loss:.3f}, step sizes[0..3]={s[:3]}")

    return jnp.exp(params)  # the actual step sizes

##############################################################################
# 4) Final experiment procedure to replicate "paper" style results
##############################################################################

def run_experiment_student_t(
    dims = [20, 200, 500],
    Ks   = [64, 256],
    df   = 3,
    n_train_iters=300,
    n_eval_seeds=5,  # how many times we evaluate for final?
    lr=1e-2
):
    """
    For each d in dims and K in Ks:
      1) Train the step sizes by maximizing logZ
      2) Evaluate final logZ over multiple seeds
      3) Print mean +- std / sqrt(n_eval_seeds)
    """
    results = {}
    for d in dims:
        print(f"\n=== Dimension d={d} ===")
        # Build the target & initial for dimension d
        log_target_fn  = make_student_t_logpdf(d, df=df)
        log_initial_fn = make_std_normal_logpdf(d)

        for K in Ks:
            print(f">>> K={K}, training per-step-size ...")
            # 1) Train
            best_step_sizes = train_step_sizes(
                d=d, K=K,
                log_target_fn=log_target_fn,
                log_initial_fn=log_initial_fn,
                n_iters=n_train_iters,
                lr=lr,
                seed=0
            )
            print(f"Trained step sizes (first few): {best_step_sizes[:3]} ...")

            # 2) Evaluate over multiple seeds
            betas = jnp.linspace(0., 1., K+1)
            logZ_values = []
            for seed_eval in range(n_eval_seeds):
                rng = random.PRNGKey(seed_eval+100)  # offset seed
                _, logw = unadjusted_langevin_ais_per_step(
                    log_target_fn  = log_target_fn,
                    log_initial_fn = log_initial_fn,
                    d             = d,
                    betas         = betas,
                    step_size_params=jnp.log(best_step_sizes),  # we store them as log in the function
                    rng_key       = rng
                )
                logZ_values.append(np.array(logw))
            logZ_values = np.array(logZ_values)
            mean_logZ = logZ_values.mean()
            sem_logZ  = logZ_values.std() / np.sqrt(n_eval_seeds)

            print(f"[RESULT] d={d}, K={K}, mean logZ={mean_logZ:.3f} ± {sem_logZ:.3f}")
            results[(d,K)] = (mean_logZ, sem_logZ)

    return results

##############################################################################
# "MAIN"
##############################################################################

if __name__ == "__main__":
    final_results = run_experiment_student_t(
        dims=[20, 200],  # you can do [20, 200, 500]
        Ks=[64, 256],
        df=3,
        n_train_iters=300,
        n_eval_seeds=3,
        lr=1e-2
    )
    print("\nFinal results dictionary:")
    for (d,K), (m, s) in final_results.items():
        print(f"d={d}, K={K} => logZ = {m:.3f} ± {s:.3f}")


KeyboardInterrupt: 

ImportError: _multiarray_umath failed to import

ImportError: numpy._core.umath failed to import