Skip to content

Commit

Permalink
Refactor test for step size adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 20, 2021
1 parent 5f4d652 commit 58ab8ba
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions tests/test_step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,40 @@
from aehmc.step_size import dual_averaging_adaptation, heuristic_adaptation


def test_heuristic_adaptation():
@pytest.fixture()
def init():
def logprob_fn(x):
return -at.sum(0.5 * x)
return -2 * (x - 1.0) ** 2

srng = RandomStream(seed=0)
inverse_mass_matrix = at.as_tensor(1.0)
kernel = hmc.kernel(srng, logprob_fn, inverse_mass_matrix, 10)

initial_position = at.as_tensor(1.0, dtype="floatX")
logprob = logprob_fn(initial_position)
logprob_grad = aesara.grad(logprob, wrt=initial_position)
reference_state = (initial_position, logprob, logprob_grad)
initial_state = hmc.new_state(initial_position, logprob_fn)

inverse_mass_matrix = at.as_tensor(1.0)
return initial_state, kernel

kernel = hmc.kernel(srng, logprob_fn, inverse_mass_matrix, 10)

def test_heuristic_adaptation(init):
reference_state, kernel = init

epsilon_1 = heuristic_adaptation(
kernel, reference_state, at.as_tensor(1, dtype="floatX"), 0.95
kernel, reference_state, at.as_tensor(0.5, dtype="floatX"), 0.95
)
epsilon_1_val = epsilon_1.eval()
assert epsilon_1_val != np.inf

epsilon_2 = heuristic_adaptation(
kernel, reference_state, at.as_tensor(1, dtype="floatX"), 0.05
kernel, reference_state, at.as_tensor(0.5, dtype="floatX"), 0.05
)
epsilon_2_val = epsilon_2.eval()
assert epsilon_2_val > epsilon_1_val
assert epsilon_2_val != np.inf


def test_dual_averaging_adaptation():
def logprob_fn(x):
return -at.sum(0.5 * x ** 2)

srng = RandomStream(seed=0)
inverse_mass_matrix = at.as_tensor(1.0)
kernel = hmc.kernel(srng, logprob_fn, inverse_mass_matrix, 10)

initial_position = at.as_tensor(1.0, dtype="floatX")
logprob = logprob_fn(initial_position)
logprob_grad = aesara.grad(logprob, wrt=initial_position)
def test_dual_averaging_adaptation(init):
initial_state, kernel = init

step_size = at.as_tensor(1.0, dtype="floatX")
logpstepsize = at.log(step_size)
Expand All @@ -62,9 +56,9 @@ def one_step(q, logprob, logprob_grad, step, x_t, x_avg, gradient_avg):
states, updates = aesara.scan(
fn=one_step,
outputs_info=[
{"initial": initial_position},
{"initial": logprob},
{"initial": logprob_grad},
{"initial": initial_state[0]},
{"initial": initial_state[1]},
{"initial": initial_state[2]},
{"initial": step},
{"initial": logpstepsize},
{"initial": logstepsize_avg},
Expand All @@ -78,4 +72,4 @@ def one_step(q, logprob, logprob_grad, step, x_t, x_avg, gradient_avg):
step_size = aesara.function((), at.exp(states[-3][-1]), updates=updates)
assert np.mean(p_accept()) == pytest.approx(0.65, rel=10e-3)
assert step_size() < 10
assert step_size() > 10e-1
assert step_size() > 1e-1

0 comments on commit 58ab8ba

Please sign in to comment.