NUTS vs MCLMC

No-U-Turn sampler 
Microcanonical Langevin Monte Carlo with Metropolis-Hastings step (MAMS)

In [6]:
# Imports
import jax
import jax.numpy as jnp
import blackjax
import matplotlib.pyplot as plt
from datetime import datetime

# Set seed
key = jax.random.PRNGKey(548)

In [7]:
# Print out versions
print(f"JAX version: {jax.__version__}")
print(f"Blackjax version: {blackjax.__version__}")

JAX version: 0.8.0
Blackjax version: 1.2.5


In [8]:
# Make 2d Gaussian target density

"""
Define a 2D correlated Gaussian as our target distribution
Mean: [0, 0]
Covariance: [[1, 0.8], [0.8, 1]]
"""

def logprob_fn(x):
    """
    Compute log probability of a 2D correlated Gaussian.
    
    Parameters:
    -----------
    x : array-like, shape (2,)
        Input point in 2D space
        
    Returns:
    --------
    float
        Log probability at x
    """
    # Define covariance matrix (positive definite, with correlation 0.8)
    cov = jnp.array([[1.0, 0.8], 
                     [0.8, 1.0]])
    
    # Compute log probability using the Gaussian formula
    # log p(x) = -0.5 * x^T * Σ^{-1} * x - 0.5 * log|Σ| - log(2π)
    precision = jnp.linalg.inv(cov)
    log_det = jnp.linalg.slogdet(cov)[1]
    
    return -0.5 * (x @ precision @ x + log_det + 2 * jnp.log(2 * jnp.pi))

# Test the function
test_point = jnp.array([0.0, 0.0])
print(f"Log probability at origin: {logprob_fn(test_point):.4f}")

Log probability at origin: -1.3271
