In [217]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from scipy.stats import norm, multivariate_normal

#### Run the HMC for the test function

In [222]:
# testing with the Banana likelihood in [-1,2]^2

def loglike(X):
    logpdf = -0.25*(5*(0.2-X[0]))**2 - (20*(X[1]/4 - X[0]**4))**2
    return logpdf

def grad_loglike(X):
    gradx = 12.5*(0.2 - X[0]) + 3200*X[0]**3 *(-X[0]**4 + X[1]/4)
    grady = -200*(-X[0]**4 + X[1]/4)
    return np.array([gradx,grady])

def HMC(x0,num_samples,num_steps=30,step_size=0.01):

    p_dist = multivariate_normal(mean=[0,0],cov=np.eye(2)) # use to draw new momenta and compute kinetic energy
    samples = []
    samples.append(x0)

    for i in range(0,num_samples):
        # draw random momentum

        # evolve from current state to new state using H

        # compute transition probability

        # print(H_state,H_new)

        # do transition
        if (a > np.random.uniform(0,1)):
            x_state = x_new
        samples.append(x_state)

    return np.array(samples)


def evolve(x,p,num_steps,step_size):
    # evolve the Hamiltonian for num_steps with step_size using the leapfrog integrator
    # this is good at conserving energy so numerical error is lower, which means a high acceptance rate :)
    # x, p = x[:], p[:]
    for i in range(num_steps):
        p = p - (step_size / 2) * grad_loglike(x)
        x = x + step_size * p
        p = p - (step_size / 2) * grad_loglike(x)
    return x, p


# x0 choice below is to ensure good starting point
samples = HMC(0.02+0.01*np.random.randn(2),num_samples=10000,num_steps=25,step_size=0.001)


#### Plot function vs your samples

In [223]:
x = np.linspace(-1,1, 500)
y = np.linspace(-1,2, 500)
xx, yy = np.meshgrid(x, y)
grid = np.vstack([xx.ravel(), yy.ravel()]).T


def loglike(X):
    logpdf = -0.25*(5*(0.2-X[:,0]))**2 - (20*(X[:,1]/4 - X[:,0]**4))**2
    # print(logpdf)
    return logpdf

func = np.exp(loglike(grid).reshape(x.shape[0],y.shape[0]))

print(func.shape)

fig,ax = plt.subplots(1,2,figsize= (12,4))

for axes in ax:
    cont = axes.contourf(x,y,func,cmap='Blues_r')
    fig.colorbar(cont,ax=axes)

ax[1].scatter(samples[:,0],samples[:,1],alpha=0.1,s=4,color='C1')
fig.tight_layout()