Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement step-size adaptation #41

Merged
merged 7 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions aehmc/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def dual_averaging(
"""

def init(
x: TensorVariable,
x_init: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable, TensorVariable]:
step = at.as_tensor(1, "step", dtype="int32")
gradient_avg = at.as_tensor(0, "gradient_avg", dtype=x.dtype)
x_avg = at.as_tensor(0.0, "x_avg", dtype=x.dtype)
gradient_avg = at.as_tensor(0, "gradient_avg", dtype=x_init.dtype)
x_avg = at.as_tensor(0.0, "x_avg", dtype=x_init.dtype)
return step, x_avg, gradient_avg

def update(
Expand Down
43 changes: 26 additions & 17 deletions aehmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def new_state(q: TensorVariable, logprob_fn: Callable):
def kernel(
srng: RandomStream,
logprob_fn: TensorVariable,
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
num_integration_steps: TensorVariable,
num_integration_steps: int,
divergence_threshold: int = 1000,
):
"""Build a HMC kernel.
Expand All @@ -35,13 +34,9 @@ def kernel(
logprob_fn
A function that returns the value of the log-probability density
function of a chain at a given position.
step_size
The step size used in the symplectic integrator
inverse_mass_matrix
One or two-dimensional array used as the inverse mass matrix that
defines the euclidean metric.
num_integration_steps
The number of times we apply the symplectic integrator to integrate the trajectory.
divergence_threshold
The difference in energy above which we say the transition is
divergent.
Expand All @@ -64,7 +59,6 @@ def potential_fn(x):
proposal_generator = hmc_proposal(
symplectic_integrator,
kinetic_energy_fn,
step_size,
num_integration_steps,
divergence_threshold,
)
Expand All @@ -73,44 +67,53 @@ def step(
q: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
):
"""Perform a single step of the HMC algorithm."""
"""Perform a single step of the HMC algorithm.

Parameters
----------
step_size
The step size used in the symplectic integrator

"""
p = momentum_generator(srng)
(
q_new,
p_new,
_,
potential_energy_new,
potential_energy_grad_new,
) = proposal_generator(srng, q, p, potential_energy, potential_energy_grad)
return q_new, potential_energy_new, potential_energy_grad_new
p_accept,
) = proposal_generator(
srng, q, p, potential_energy, potential_energy_grad, step_size
)
return q_new, potential_energy_new, potential_energy_grad_new, p_accept

return step


def hmc_proposal(
integrator: Callable,
kinetic_energy: Callable[[TensorVariable], TensorVariable],
step_size: TensorVariable,
num_integration_steps: TensorVariable,
divergence_threshold: int,
):
"""Builds a function that returns a HMC proposal."""

integrate = trajectory.static_integration(
integrator, step_size, num_integration_steps
)
integrate = trajectory.static_integration(integrator, num_integration_steps)

def propose(
srng: RandomStream,
q: TensorVariable,
p: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
):
"""Use the HMC algorithm to propose a new state."""

new_q, new_p, new_potential_energy, new_potential_energy_grad = integrate(
q, p, potential_energy, potential_energy_grad
q, p, potential_energy, potential_energy_grad, step_size
)

# flip the momentum to keep detailed balance
Expand All @@ -136,6 +139,12 @@ def propose(
(q, p, potential_energy, potential_energy_grad),
)

return final_q, final_p, final_potential_energy, final_potential_energy_grad
return (
final_q,
final_p,
final_potential_energy,
final_potential_energy_grad,
p_accept,
)

return propose
20 changes: 16 additions & 4 deletions aehmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
def kernel(
srng: RandomStream,
logprob_fn: Callable[[TensorVariable], TensorVariable],
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
max_num_expansions: int = aet.as_tensor(10),
divergence_threshold: int = 1000,
Expand All @@ -26,7 +25,7 @@ def kernel(
Parameters
----------
srng
RandomStream object.
Randomstream object.
rlouf marked this conversation as resolved.
Show resolved Hide resolved
logprob_fn
A function that returns the value of the log-probability density
function of a chain at a given position.
Expand Down Expand Up @@ -73,16 +72,28 @@ def potential_fn(x):
srng,
trajectory_integrator,
uturn_check_fn,
step_size,
max_num_expansions,
)

def step(
q: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
):
"""Move the chain by one step."""
"""Move the chain by one step.

Parameters
----------
srng
Randomstream object.
rlouf marked this conversation as resolved.
Show resolved Hide resolved
logprob_fn
A function that returns the value of the log-probability density
function of a chain at a given position.
step_size
The step size used in the symplectic integrator

"""
p = momentum_generator(srng)
initial_state = (q, p, potential_energy, potential_energy_grad)
initial_termination_state = new_termination_state(q, max_num_expansions)
Expand All @@ -100,6 +111,7 @@ def step(
p,
initial_termination_state,
initial_energy,
step_size,
)
for key, value in updates.items():
key.default_update = value
Expand Down
157 changes: 157 additions & 0 deletions aehmc/step_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Callable, Tuple

import aesara
import aesara.tensor as at
from aesara.scan.utils import until
from aesara.tensor.var import TensorVariable

from aehmc import algorithms


def dual_averaging_adaptation(
initial_log_step_size: TensorVariable,
target_acceptance_rate: TensorVariable = at.as_tensor(0.65),
gamma: float = 0.05,
t0: int = 10,
kappa: float = 0.75,
) -> Tuple[Callable, Callable]:
"""Tune the step size to achieve a desired target acceptance rate.

Let us note :math:`\\epsilon` the current step size, :math:`\\alpha_t` the
metropolis acceptance rate at time :math:`t` and :math:`\\delta` the desired
aceptance rate. We define:

.. math:
H_t = \\delta - \\alpha_t

the error at time t. We would like to find a procedure that adapts the
value of :math:`\\epsilon` such that :math:`h(x) =\\mathbb{E}\\left[H_t|\\epsilon\\right] = 0`
Following [1]_, the authors of [2]_ proposed the following update scheme. If
we note :math:``x = \\log \\epsilon` we follow:

.. math:
x_{t+1} \\LongLeftArrow \\mu - \\frac{\\sqrt{t}}{\\gamma} \\frac{1}{t+t_0} \\sum_{i=1}^t H_i
\\overline{x}_{t+1} \\LongLeftArrow x_{t+1}\\, t^{-\\kappa} + \\left(1-t^\\kappa\\right)\\overline{x}_t

:math:`\\overline{x}_{t}` is guaranteed to converge to a value such that
:math:`h(\\overline{x}_t)` converges to 0, i.e. the Metropolis acceptance
rate converges to the desired rate.

See reference [2]_ (section 3.2.1) for a detailed discussion.

Parameters
----------
initial_log_step_size:
Initial value of the logarithm of the step size, used as an iterate in
the dual averaging algorithm.
target_acceptance_rate:
Target acceptance rate.
gamma
Controls the speed of convergence of the scheme. The authors of [2]_ recommend
a value of 0.05.
t0: float >= 0
Free parameter that stabilizes the initial iterations of the algorithm.
Large values may slow down convergence. Introduced in [2]_ with a default
value of 10.
kappa: float in ]0.5, 1]
Controls the weights of past steps in the current update. The scheme will
quickly forget earlier step for a small value of `kappa`. Introduced
in [2]_, with a recommended value of .75

Returns
-------
init
A function that initializes the state of the dual averaging scheme.
update
A function that updates the state of the dual averaging scheme.

References
----------
.. [1]: Nesterov, Yurii. "Primal-dual subgradient methods for convex
problems." Mathematical programming 120.1 (2009): 221-259.
.. [2]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
"""

mu = at.log(10) + initial_log_step_size
da_init, da_update = algorithms.dual_averaging(mu, gamma, t0, kappa)

def update(
acceptance_probability: TensorVariable,
step: TensorVariable,
log_step_size: TensorVariable,
log_step_size_avg: TensorVariable,
gradient_avg: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
gradient = target_acceptance_rate - acceptance_probability
return da_update(gradient, step, log_step_size, log_step_size_avg, gradient_avg)

return da_init, update


def heuristic_adaptation(
kernel: Callable,
reference_state: Tuple,
initial_step_size: TensorVariable,
target_acceptance_rate=0.65,
max_num_iterations=100,
):
"""Find a reasonable initial step size during warmup.

While the dual averaging scheme is guaranteed to converge to a reasonable
value for the step size starting from any value, choosing a good first
value can speed up the convergence. This heuristics doubles and halves the
step size until the acceptance probability of the HMC proposal crosses the
target value.

Parameters
----------
kernel
A function that takes a state, a step size and returns a new state.
reference_hmc_state
The location (HMC state) where this first step size must be found. This function
never advances the chain.
inverse_mass_matrix
The inverse mass matrix relative to which the step size must be found.
initial_step_size
The first step size used to start the search.
target_acceptance_rate
Once that value of the metropolis acceptance probability is reached we
estimate that we have found a "reasonable" first step size.
max_num_iterations
The maximum number of times we iterate on the algorithm.

Returns
-------
float
A reasonable first value for the step size.

Reference
---------
.. [1]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
"""

def update(step_size, direction, previous_direction):
step_size = (2.0 ** direction) * step_size
*_, p_accept = kernel(*reference_state, step_size)
new_direction = at.where(
at.lt(target_acceptance_rate, p_accept), at.constant(1), at.constant(-1)
)
return (step_size.astype("floatX"), new_direction, direction), until(
at.neq(direction, previous_direction)
)

(step_sizes, _, _), _ = aesara.scan(
fn=update,
outputs_info=[
{"initial": initial_step_size},
{"initial": at.constant(0)},
{"initial": at.constant(0)},
],
n_steps=max_num_iterations,
)

return step_sizes[-1]
9 changes: 4 additions & 5 deletions aehmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@

def static_integration(
integrator: Callable,
step_size: float,
num_integration_steps: int,
) -> Callable:
"""Generate a trajectory by integrating several times in one direction."""

def integrate(q_init, p_init, energy_init, energy_grad_init) -> IntegratorStateType:
def integrate(
q_init, p_init, energy_init, energy_grad_init, step_size
) -> IntegratorStateType:
def one_step(q, p, potential_energy, potential_energy_grad):
new_state = integrator(
q, p, potential_energy, potential_energy_grad, step_size
Expand Down Expand Up @@ -270,7 +271,6 @@ def multiplicative_expansion(
srng: RandomStream,
trajectory_integrator: Callable,
uturn_check_fn: Callable,
step_size: TensorVariable,
max_num_expansions: TensorVariable,
):
"""Sample a trajectory and update the proposal sequentially
Expand All @@ -290,8 +290,6 @@ def multiplicative_expansion(
and the integrated trajectory.
uturn_check_fn
Function used to check the U-Turn criterion.
step_size
The step size used by the symplectic integrator.
max_num_expansions
The maximum number of trajectory expansions until the proposal is
returned.
Expand All @@ -306,6 +304,7 @@ def expand(
momentum_sum,
termination_state,
initial_energy,
step_size,
):
def expand_once(
step,
Expand Down