In [None]:
%reload_ext autoreload
%autoreload 2
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.stats import multivariate_normal
from functools import partial

import matplotlib.pyplot as plt

## Define our sampler

**Target and Hamiltonian.** We wish to sample from a posterior $\pi(q)$ on position $q \in \mathbb{R}^d$. HMC augments state with momentum $p \in \mathbb{R}^d$ and samples from the joint density
$$
\pi(p, q) \propto \exp(-H(p,q)), \qquad H(p,q) = K(p) + U(q).
$$

**Kinetic and potential energy:**
$$
K(p) = \frac{1}{2} p^\top M^{-1} p,
$$
$$
U(q) = -\log \pi(q).
$$

**Notation.** 

We write $\mathcal{N}(x; \mu, \Sigma)$ for the Gaussian with mean
$\mu$ and covariance matrix $\Sigma$. 

Thus $\pi(p \mid q) = \mathcal{N}(p; 0, M)$ is the conditional distribution of momentum given position: mean $0$, covariance $M$. Here $M$ is the mass matrix (positive definite). Then $\pi(q)$ is unchanged by marginalization.

**Transition.** One HMC step: (1) draw $p \sim \mathcal{N}(0, M)$; (2) integrate Hamilton's equations for $(q,p)$ with a reversible, volume-preserving integrator (leapfrog) for $L$ steps of size $\Delta t$; (3) optionally apply a Metropolis correction on $H$ (omitted in this demo). The code below implements $H$, $K$, $U$, and the leapfrog integrator.

In [None]:
# for visual purposes, no metropolis correction! see the metropolis notebook.
class hmc():
    def __init__(self, dim, logpi_q, mass, PRNGkey, dt=0.1, L=10):
        # Basic attr.
        self.dim = dim
        self.dt = dt
        self.L = L
        self.PRNGkey = PRNGkey
        
        # Read the position space posterior
        # (Potential energy)
        self.logpi_q = logpi_q
        self.negdV = grad(self.logpi_q)
        
        # Construct the momentum sampler
        self.mass = mass
        self.massinv = jnp.linalg.inv(mass)
        
    def logpi_p_q(self, p, q):
        # the log pi(p|q) (Kinetic energy)
        # -dim/2*log(2 pi) - (p^T M^-1 p) / 2
        return multivariate_normal.logpdf(p, 
                                            mean=jnp.zeros(self.dim), 
                                            cov=self.massinv)
    
    # Construct the Hamiltonian
    def H(self, p, q):
        return -self.logpi_p_q(p,q)-self.logpi_q(q)
    
    # ----------------
    # Define the leapfrog integrator
    @partial(jit, static_argnums=(0,)) # partial jit for class type argument at 0
    def _leapfrog_incre(self, q0, p0):
        # update momentum by half time step
        p_halfdt = p0 + self.dt/2*self.negdV(q0)
        # update position by full time step
        q_dt = q0 + self.dt * self.massinv@p_halfdt
        # update momentum by full time step
        p_dt = p_halfdt + self.dt/2*self.negdV(q_dt)
        return q_dt,p_dt
    
    def _leapfrog(self, q0, p0):
        # integrate L times
        for _ in range(self.L):
            q0,p0 = self._leapfrog_incre(q0,p0)
        return q0,p0
    
    # Generate 1 sample
    def next_sample(self, q0):
        self.PRNGkey, subkey = random.split(self.PRNGkey)
        p0 = random.multivariate_normal(key=subkey, mean=jnp.zeros(self.dim), cov=self.mass, method='cholesky')
        qf, pf = self._leapfrog(q0,p0)
        return qf

    # Generate N samples
    def next_Nsample(self, q0, N):
        qs = -999*jnp.ones((N,self.dim))
        for i in range(N):
            qs=qs.at[i].set(q0)
            q0 = self.next_sample(q0)
        return qs
    
    # ----------------
    # The verbose integrator
    def _leapfrog_verbose(self, q0, p0):
        # integrate L times
        qs = -999*jnp.ones((self.L,self.dim))
        ps = -999*jnp.ones((self.L,self.dim))
        for i in range(self.L):
            qs=qs.at[i].set(q0)
            ps=ps.at[i].set(p0)
            q0,p0 = self._leapfrog_incre(q0,p0)
        # return not only the final state but also the history
        return q0,p0,qs,ps
    
    # Generate 1 sample
    def next_sample_verbose(self, q0):
        self.PRNGkey, subkey = random.split(self.PRNGkey)
        p0 = random.multivariate_normal(key=subkey, mean=jnp.zeros(self.dim), cov=self.mass, method='cholesky')
        qf, pf, qs, ps = self._leapfrog_verbose(q0,p0)
        return qf, pf, qs, ps

## Define our posterior function

**Test posterior (Gaussian).** We use a 2D normal so that the sampler and trajectories are easy to see:
$$
\log \pi(q) = -\frac{d}{2}\log(2\pi) - \frac{1}{2}\log|\Sigma| - \frac{1}{2}(q - \mu)^\top \Sigma^{-1}(q - \mu).
$$
Here $\mu = (0,0)^\top$ and $\Sigma = \operatorname{diag}(0.8, 1)$. 

The function `logpi_q(q)` returns this $\log \pi(q)$; 

In [None]:
def logpi_q(q):
    # the log posterior
    return multivariate_normal.logpdf(q, jnp.array([0,0]), jnp.array([[0.8,0],[0,1]]))

In [None]:
# Figure below: Log-posterior over (q1,q2) plane; elongated contours match Sigma = diag(0.8,1).
x = jnp.linspace(-3,3,100)
y = jnp.linspace(-3,3,100)
x_grid, y_grid = jnp.meshgrid(x, y)
z = logpi_q(
    jnp.swapaxes(
        jnp.swapaxes(jnp.array([x_grid,y_grid]),0,-1),
        0,1)
    )
h = plt.contourf(x_grid, y_grid, z)
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

**Leapfrog integrator.** Hamilton's equations:
$$
\dot{q} = M^{-1} p,
$$
$$
\dot{p} = -\nabla U(q).
$$
One step from $(q_0, p_0)$ with step size $\Delta t$:
$$
p_{1/2} = p_0 + \frac{\Delta t}{2}\, (-\nabla U(q_0)),
$$
$$
q_1 = q_0 + \Delta t\, M^{-1} p_{1/2},
$$
$$
p_1 = p_{1/2} + \frac{\Delta t}{2}\, (-\nabla U(q_1)).
$$
We apply this $L$ times per HMC step; total simulated time is $L \Delta t$. 

In
the code, `negdV(q)` is $-\nabla U(q)$; `_leapfrog_incre` performs one such
triple update; `_leapfrog` chains $L$ steps. (Note: 
`negdV` in the sampler is $-\nabla_q \log \pi(q) = \Sigma^{-1}(q - \mu)$.)

## Leapfrog integration

In [None]:
# plotting helper func.
def quiver_2dline(q,color='k'):
    # q = [[x1, y1],[x2, y2],...]
    x,y = q.T
    plt.quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], scale_units='xy', angles='xy', scale=1, color=color)

In [None]:
# Figure below: One HMC trajectory from q0=(0,0); contour with leapfrog path. No Metropolis so endpoint always accepted.
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=jnp.eye(2),
              PRNGkey=random.PRNGKey(1),
              L = 50)
# generate 1 point starting from q0 = [0,0]
qf, pf, qs, ps = sampler.next_sample_verbose(q0=jnp.zeros(2))
print('first sample:', qf)
h = plt.contourf(x_grid, y_grid, z)
quiver_2dline(qs)
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

## Tracing the trajectories (no Metropolis correction)
- Notice the momentum resampling at each step.

**One-step transition (no Metropolis).** Each step: draw $p_0 \sim \mathcal{N}(0, M)$, then $(q', p') = \text{Leapfrog}^L(q, p_0)$. The next state is $q'$; momentum is discarded and resampled at the next step. Without a Metropolis accept/reject, $H$ is not preserved and the marginal distribution of $q$ is only approximate.

**Mass matrix $M$.** The kinetic energy is $K(p) = \frac{1}{2} p^\top M^{-1} p$. In the leapfrog update, $q$ is updated by $\Delta t\, M^{-1} p_{1/2}$, so a larger $M$ (larger inertia) yields smaller position updates per step. Scaling $M$ changes the trajectory shape and effective step size in $q$; the code uses `mass` for $M$ and `massinv` for $M^{-1}$.

**Integration length $L$.** The trajectory length in simulated time is $L \Delta t$. Larger $L$ gives longer paths and potentially larger moves in $q$ per HMC step; too large $L$ can loop back and reduce acceptance (when Metropolis is used).

In [None]:
# Figure below: N consecutive HMC steps; each path from previous endpoint with new random momentum. Chain drifts over posterior.
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=jnp.eye(2),
              PRNGkey=random.PRNGKey(1),
              dt = 0.2,
              L = 8)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
for i in range(4):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

In [None]:
# Figure below: H along leapfrog path for five trajectories; wiggles show non-conservation (no Metropolis).
qf = jnp.zeros(2)
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    h = sampler.H(qs,ps)
    plt.plot(h,color=f'C{i}')
plt.xlabel('LF integration step')
plt.ylabel('H')
plt.show()


In [None]:
# Figure below (top): Mass M=100I, five trajectories; heavier mass gives shorter, stiffer paths. (bottom): H along path.
mass = jnp.eye(2)*100
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=mass,
              PRNGkey=random.PRNGKey(1),
              dt = 0.2,
              L = 30)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

qf = jnp.zeros(2)
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    h = sampler.H(qs,ps)
    plt.plot(h,color=f'C{i}')
plt.xlabel('LF integration step')
plt.ylabel('H')
plt.show()

#### Integration error and Metropolis correction
- Two cases, same $L \Delta t$ = constant.

In [None]:
# Figure below (top): Large dt, L*dt=60 fixed; fewer, coarser steps, jagged paths. (bottom): H along path; strong energy drift.
mass = jnp.eye(2)
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=mass,
              PRNGkey=random.PRNGKey(1),
              dt = 1,
              L = 60)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
qf = jnp.array([1.0,0.0])
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

qf = jnp.zeros(2)
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    h = sampler.H(qs,ps)
    plt.plot(h,color=f'C{i}')
plt.xlabel('LF integration step')
plt.ylabel('H')
plt.show()

In [None]:
# Figure below (top): Small dt, same L*dt=60; many small steps, smoother paths. (bottom): H and -U; H flatter.
mass = jnp.eye(2)
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=mass,
              PRNGkey=random.PRNGKey(1),
              dt = 0.1,
              L = 600)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
qf = jnp.array([1.0,0.0])
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

qf = jnp.zeros(2)
for i in range(5):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    h = sampler.H(qs,ps)
    v = -sampler.logpi_q(qs)
    # plt.plot(v,color=f'C{i}',ls=":") # potential energy
    plt.plot(h,color=f'C{i}')
plt.xlabel('LF integration step')
plt.ylabel('H')
plt.show()

## NumPyro example

**Same target.** 

We sample the same target

- $\pi(q) = \mathcal{N}(q; \mu, \Sigma)$ with 

- $\mu = (0,0)^\top$ and 

- $\Sigma = \operatorname{diag}(0.8, 1)$. 

The model is defined by a single `numpyro.sample("q", ...)`; with no observed
data this is the posterior. 

NumPyro builds $\log \pi(q)$ and $\nabla_q \log \pi(q)$ from the distribution, then runs HMC (or NUTS) with leapfrog and Metropolis.

**Parameterization: Cholesky vs covariance.** 

We pass the Cholesky factor $L$ (lower triangular) with $\Sigma = L L^\top$
rather than $\Sigma$ itself. 

A mean-zero vector is drawn by 
- $z \sim \mathcal{N}(0, I)$
-  $q = L z$
-  The
log-density uses $\log \det \Sigma = 2 \sum \log(\operatorname{diag}(L))$
- the quadratic form $q^\top \Sigma^{-1} q = \| L^{-1} q \|^2$, which is computed
by solving $L u = q$ (forward solve) for $u$ and then $\|u\|^2$.

In high dimensions this is preferable: 

- (1) $L$ has only $d(d+1)/2$ free entries and guarantees $\Sigma$ is positive
definite; 

- (2) no explicit $d \times d$ matrix inversion; 

- (3) forward/backward solves with triangular $L$ are numerically stable and
$O(d^2)$ instead of $O(d^3)$ for matrix inversion; 

- (4) gradients for HMC stay in the $L$-space and avoid building full $\Sigma$. Using `scale_tril=L` in NumPyro follows this parameterization.

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMC, MCMC

# Same target: q ~ N(mu, Sigma), mu=(0,0), Sigma = diag(0.8, 1); scale_tril = Cholesky L with Sigma = L L^T
def model():
    q = numpyro.sample("q", dist.MultivariateNormal(
        jnp.zeros(2),
        scale_tril=jnp.linalg.cholesky(jnp.array([[0.8, 0.0], [0.0, 1.0]]))
    ))
    return q

# HMC kernel: step size and number of steps (NumPyro uses NUTS by default in practice)
kernel = HMC(model, step_size=0.2)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
rng = random.PRNGKey(0)
mcmc.run(rng)

samples = mcmc.get_samples()
q_numpyro = samples["q"]  # shape: (1000, 2)
mcmc.print_summary()

In [None]:
# Overlay NumPyro samples on the same Gaussian posterior contour
_mean = jnp.zeros(2)
_cov = jnp.array([[0.8, 0.0], [0.0, 1.0]])
_x = jnp.linspace(-3, 3, 100)
_y = jnp.linspace(-3, 3, 100)
_xg, _yg = jnp.meshgrid(_x, _y)
_points = jnp.stack([jnp.ravel(_xg), jnp.ravel(_yg)], axis=1)
_z = vmap(lambda p: multivariate_normal.logpdf(p, _mean, _cov))(_points).reshape(_xg.shape)
plt.contourf(_xg, _yg, _z)
plt.scatter(q_numpyro[:, 0], q_numpyro[:, 1], alpha=0.3, s=5, c="C1")
plt.axis("scaled")
plt.xlabel("q1")
plt.ylabel("q2")
plt.colorbar()
plt.show()

## Mixed Gaussian

**Mixture posterior.** We switch to a three-component Gaussian mixture:
$$
\pi(q) = \sum_{k=1}^{3} w_k\, \mathcal{N}(q; \mu_k, \Sigma_k), \qquad w_k = 1/3,\ \Sigma_k = 0.2\, I.
$$
Means $\mu_k$ are $(0,1)$, $(2,0)$, $(-1,-2)$. The redefined `logpi_q(q)`
returns $\log \pi(q) = \log \sum_k w_k \mathcal{N}(q;\mu_k,\Sigma_k)$; gradients
are used by the leapfrog integrator as before.


In [None]:
# Figure below: Posterior density for three-component mixture; three modes near (0,1), (2,0), (-1,-2).
def logpi_q(q):
    # the log posterior
    res = multivariate_normal.pdf(q, jnp.array([0,1]), jnp.array([[0.2,0],[0,0.2]]))
    res += multivariate_normal.pdf(q, jnp.array([2,0]), jnp.array([[0.2,0],[0,0.2]]))
    res += multivariate_normal.pdf(q, jnp.array([-1,-2]), jnp.array([[0.2,0],[0,0.2]]))
    return jnp.log(res)

x = jnp.linspace(-5,5,100)
y = jnp.linspace(-5,5,100)
x_grid, y_grid = jnp.meshgrid(x, y)
z = jnp.exp(logpi_q(
    jnp.swapaxes(
        jnp.swapaxes(jnp.array([x_grid,y_grid]),0,-1),
        0,1)
    ))
h = plt.contourf(x_grid, y_grid, z)
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

In [None]:
# Figure below: Fifteen HMC steps on mixture; whether chain reaches all three modes.
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=jnp.eye(2),
              PRNGkey=random.PRNGKey(1),
              dt = 0.2,
              L = 15)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
for i in range(15):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

In [None]:
# Figure below: Twenty steps with default mass; compare coverage and mode-hopping.
mass = jnp.array([[1,0],[0,1]])
sampler = hmc(dim=2,
              logpi_q=logpi_q,
              mass=mass,
              PRNGkey=random.PRNGKey(1),
              dt = 0.2,
              L = 30)

h = plt.contourf(x_grid, y_grid, z)

qf = jnp.zeros(2)
for i in range(20):
    # generate 1 point starting from q0 = [0,0]
    qf, pf, qs, ps = sampler.next_sample_verbose(q0=qf)
    print(f'q_{i}:', qf)
    quiver_2dline(qs,color=f'C{i}')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()