In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, random
import numpy as np
import matplotlib.pyplot as plt
import time

In [2]:
class HamiltonianMonteCarlo:
    def __init__(self, U, grad_U, epsilon, L, M):
        self.U = U  # Potential energy function
        self.grad_U = grad_U  # Gradient of the potential energy function
        self.epsilon = epsilon  # Step size
        self.L = L  # Number of leapfrog steps
        self.M = M  # Mass matrix
        self.M_inv = jnp.linalg.inv(M)  # Inverse of the mass matrix
        self.samples = None  # Placeholder for sampled points
        self.time = None  # Placeholder for sampling time
        self.acceptance_rate = None  # Placeholder for acceptance rate
        self.mean = None  # Placeholder for mean of samples
        self.covariance = None  # Placeholder for covariance of samples
    
    def leapfrog(self, q, p):
        p = p - 0.5 * self.epsilon * self.grad_U(q)
        for _ in range(self.L):
            q = q + self.epsilon * jnp.dot(self.M_inv, p)
            if _ < self.L - 1:
                p = p - self.epsilon * self.grad_U(q)
        p = p - 0.5 * self.epsilon * self.grad_U(q)
        return q, p

    def sample(self, key, num_samples):
        start_time = time.time()
        samples = []
        accepted = 0

        q = jnp.zeros(self.M.shape[0])  # Initial position
        for _ in range(num_samples):
            key, subkey = random.split(key)
            p = random.normal(subkey, shape=q.shape) * jnp.sqrt(jnp.diag(self.M))

            current_q = q
            current_p = p

            q, p = self.leapfrog(current_q, current_p)

            current_U = self.U(current_q)
            current_K = 0.5 * jnp.dot(current_p, jnp.dot(self.M_inv, current_p))
            proposed_U = self.U(q)
            proposed_K = 0.5 * jnp.dot(p, jnp.dot(self.M_inv, p))

            acceptance_prob = jnp.exp(current_U - proposed_U + current_K - proposed_K)
            if random.uniform(subkey) < acceptance_prob:
                samples.append(q)
                accepted += 1
            else:
                samples.append(current_q)
        
        end_time = time.time()
        self.samples = jnp.array(samples)
        self.time = end_time - start_time
        self.acceptance_rate = accepted / num_samples
        self.mean = jnp.mean(self.samples, axis=0)
        self.covariance = jnp.cov(self.samples, rowvar=False)
