## Test HMC sampler

In [None]:
import torch
import hamiltorch
import matplotlib.pyplot as plt

In [None]:
def log_prob_func(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., 0]
    x2 = x[..., 1]
    log_prob = torch.sin(torch.pi * x1) - 2.0 * (x1**2 + x2**2 - 2.0)**2
    log_prob = torch.clamp(log_prob, -1.00e09, None)
    return log_prob

In [None]:
hamiltorch.set_random_seed(123)

ndim = 2
nsamp = 10_000
nsteps_per_samp = 5
stepsize = 0.25

x0 = torch.zeros(ndim)

x = hamiltorch.sample(
    log_prob_func=log_prob_func, 
    params_init=x0, 
    num_samples=nsamp,
    step_size=stepsize, 
    num_steps_per_sample=nsteps_per_samp,
)
x = torch.vstack(x)

In [None]:
bins = 64
xmax = 3.0

# Evaluate PDF on grid
grid_edges = 2 * [torch.linspace(-xmax, xmax, bins + 1)]
grid_coords = [0.5 * (e[:-1] + e[1:]) for e in grid_edges]
grid_points = torch.stack([c.ravel() for c in torch.meshgrid(*grid_coords, indexing="ij")], axis=-1)
grid_values = torch.hstack([log_prob_func(point) for point in grid_points])
grid_values = torch.exp(grid_values)
grid_values = grid_values.reshape((bins, bins))

fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(4.5, 2))
axs[0].hist2d(x[:, 0], x[:, 1], bins=grid_edges)
axs[1].pcolormesh(grid_coords[0], grid_coords[0], grid_values.T)
plt.show()