In [1]:
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from blackjax.mcmc.integrators import IntegratorState
import blackjax.mcmc.metrics as metrics
from blackjax.mcmc.hmc import init

from rhmcjax.rhmc.intersection_with_boundary import reflection_necessary, find_next_intersection
from rhmcjax.targets.gaussians_on_circle import GaussiansOnCircle

#### Illustration of Reflection Algorithm used in RHMC
This notebook illustrates the reflection algorithm introduced in the paper [Reflection, Refraction, and Hamiltonian Monte Carlo](https://papers.nips.cc/paper_files/paper/2015/hash/8303a79b1e19a194f1875981be5bdb6f-Abstract.html).
For this, we use the simple multimodal Gaussian density employed in the notebook `rhmc_example.ipynb`.

In [2]:
target = GaussiansOnCircle(restrict_to_unit_hypercube=True)
log_density_fn = target.log_prob

The goal is to visualize how a very large momentum and a large step size result in multiple reflections at the boundary. For this purpose, we copy some contents of the `rhmcjax.rhmc` functions into this notebook since we need to extract some further information for the visualization. 

In [3]:
inverse_mass_matrix = jnp.array([1, 1])
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
            inverse_mass_matrix
        )

In [4]:
a1 = 0
b1 = 0.5
a2 = 1 - 2 * a1

logdensity_and_grad_fn = jax.value_and_grad(log_density_fn)
kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn)

def one_step_vis(state: IntegratorState, step_size: float) -> IntegratorState:
    position, momentum, logdensity_grad = state.position, state.momentum, state.logdensity_grad

    # Half-step evolution of momentum
    momentum = jax.tree_util.tree_map(
        lambda momentum, logdensity_grad: momentum
        + b1 * step_size * logdensity_grad,
        momentum,
        logdensity_grad,
    )

    def cond_fn(val):
        position, momentum, t = val
        return reflection_necessary(position, momentum, kinetic_energy_grad_fn, step_size, t)

    t = 0
    val = [position, momentum, t]
    trace_position = [position]

    # While loop has to be written without tracing for saving intermediate positions
    while cond_fn(val):
        position, momentum, t = val
        position, momentum, t_x, ind_boundary = find_next_intersection(position, momentum, kinetic_energy_grad_fn, step_size, t)
        t += t_x
        # Reverse momentum component perpendicular to boundary
        momentum = momentum.at[ind_boundary].set(-momentum[ind_boundary])
        trace_position.append(position)
        val = [position, momentum, t]


    # Update final position (after all reflections)
    kinetic_grad = kinetic_energy_grad_fn(momentum)
    position = jax.tree_util.tree_map(
        lambda position, kinetic_grad: position + a2 * (step_size - t) * kinetic_grad,
        position,
        kinetic_grad,
    )
    trace_position.append(position)

    # Half-step evolution of momentum
    logdensity, logdensity_grad = logdensity_and_grad_fn(position)
    momentum = jax.tree_util.tree_map(
        lambda momentum, logdensity_grad: momentum
        + b1 * step_size * logdensity_grad,
        momentum,
        logdensity_grad,
    )

    return jnp.array(trace_position)

In [5]:
position = jnp.array([0.25, 0.25])
momentum = jnp.array([2.5, -3.])
hmc_state= init(position, log_density_fn)
integrator_state = IntegratorState(
        hmc_state.position, momentum, hmc_state.logdensity, hmc_state.logdensity_grad
    )
step_size = 1.
traced_positions = one_step_vis(integrator_state, step_size)
print(jnp.array(traced_positions).shape)

(11, 2)


Now, we will visualize the subsequent reflections in a `.gif` using `imageio`:

In [6]:
import numpy as np
from matplotlib.collections import LineCollection
import imageio

In [7]:
%%capture
def plot_subseq_refl(i):
    # Data for plotting
    fig, axs = plt.subplots(1,1, figsize=[4.5, 4])
    # Create a continuous norm to map from data points to colors
    norm = plt.Normalize(0,1)
    # Create a set of line segments so that we can color them individually
    # This creates the points as an N x 1 x 2 array so that we can stack points
    # together easily to get the segments. The segments array for line collection
    # needs to be (numlines) x (points per line) x 2 (for x and y)
    x = traced_positions[:i,0]
    y = traced_positions[:i,1]
    points = jnp.array([x, y]).T.reshape(-1, 1, 2)
    segments = jnp.concatenate([points[:-1], points[1:]], axis=1)

    lc = LineCollection(segments, cmap='cool', norm=norm)
    # Set the values used for colormapping
    vsl = jnp.linspace(0, 1, len(traced_positions)-1)
    lc.set_array(vsl[:i])
    lc.set_linewidth(2)
    axs.add_collection(lc)
    axs.scatter(traced_positions[i-1,0], traced_positions[i-1,1], c='darkred', zorder=4)

    # IMPORTANT ANIMATION CODE HERE
    # Used to set constant limits
    axs.set(xlabel='$x$', ylabel='$y$', xlim=[0, 1], ylim=[0, 1])
    # Used to return the plot as an image rray
    fig.canvas.draw()       # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return image

kwargs_write = {'fps':1.0, 'quantizer':'nq'}
imageio.mimsave('../images/reflection.gif', [plot_subseq_refl(i) for i in range(1, len(traced_positions)+1)], fps=1, loop=3)