Skip to content

Commit

Permalink
WIP - step size adaptation with dual averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 18, 2021
1 parent 0ecb636 commit e87e2b8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
6 changes: 3 additions & 3 deletions aehmc/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def dual_averaging(
"""

def init(
x: TensorVariable,
x_init: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable, TensorVariable]:
step = at.as_tensor(1, "step", dtype="int32")
gradient_avg = at.as_tensor(0, "gradient_avg", dtype=x.dtype)
x_avg = at.as_tensor(0.0, "x_avg", dtype=x.dtype)
gradient_avg = at.as_tensor(0, "gradient_avg", dtype=x_init.dtype)
x_avg = at.as_tensor(0.0, "x_avg", dtype=x_init.dtype)
return step, x_avg, gradient_avg

def update(
Expand Down
26 changes: 26 additions & 0 deletions aehmc/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from aesara.scan.utils import until
from aesara.tensor.var import TensorVariable

from aehmc import algorithms


def dual_averaging_adaptation(
initial_step_size: TensorVariable,
target_acceptance_rate: TensorVariable = at.as_tensor(0.65),
gamma: float = 0.05,
t0: int = 10,
kappa: float = 0.75,
) -> Tuple[Callable, Callable]:

mu = at.log(10 * initial_step_size)
da_init, da_update = algorithms.dual_averaging(mu, gamma, t0, kappa)

def update(
acceptance_probability: TensorVariable,
step: TensorVariable,
x: TensorVariable,
x_avg: TensorVariable,
gradient_avg: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
gradient = target_acceptance_rate - acceptance_probability
return da_update(gradient, step, x, x_avg, gradient_avg)

return da_init, update


def heuristic_adaptation(
kernel: Callable,
Expand Down
45 changes: 44 additions & 1 deletion tests/test_step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aesara.tensor.random.utils import RandomStream

from aehmc import hmc
from aehmc.step_size import heuristic_adaptation
from aehmc.step_size import dual_averaging_adaptation, heuristic_adaptation


def test_heuristic_adaptation():
Expand Down Expand Up @@ -34,3 +34,46 @@ def logprob_fn(x):
epsilon_2_val = epsilon_2.eval()
assert epsilon_2_val > epsilon_1_val
assert epsilon_2_val != np.inf


def test_dual_averaging_adaptation():
def logprob_fn(x):
return -at.sum(0.5 * x)

srng = RandomStream(seed=0)

initial_position = at.as_tensor(1.0, dtype="floatX")
logprob = logprob_fn(initial_position)
logprob_grad = aesara.grad(logprob, wrt=initial_position)

inverse_mass_matrix = at.as_tensor(1.0)

kernel = hmc.kernel(srng, logprob_fn, inverse_mass_matrix, 10)

log_step_size = at.as_tensor(0.0, dtype='floatX')
step_size = at.exp(log_step_size)
init, update = dual_averaging_adaptation(step_size)
step, x_avg, gradient_avg = init(log_step_size)

def one_step(q, logprob, logprob_grad, step, x_t, x_avg, gradient_avg):
*state, p_accept = kernel(q, logprob, logprob_grad, at.exp(x_t))
da_state = update(p_accept, step, x_t, x_avg, gradient_avg)
return *state, *da_state, p_accept

states, updates = aesara.scan(
fn=one_step,
outputs_info=[
{"initial": initial_position},
{"initial": logprob},
{"initial": logprob_grad},
{"initial": step},
{"initial": log_step_size},
{"initial": x_avg},
{"initial": gradient_avg},
None,
],
n_steps=1000
)

step_size = aesara.function((), at.exp(states[-3]), updates=updates)
print(step_size())#, da_state[0].eval(), step_size.eval())

0 comments on commit e87e2b8

Please sign in to comment.