-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import List, Tuple | ||
|
||
|
||
def build_schedule( | ||
num_steps: int, | ||
initial_buffer_size: int = 75, | ||
final_buffer_size: int = 50, | ||
first_window_size: int = 25, | ||
) -> List[Tuple[int, bool]]: | ||
"""Return the schedule for Stan's warmup. | ||
The schedule below is intended to be as close as possible to Stan's _[1]. | ||
The warmup period is split into three stages: | ||
1. An initial fast interval to reach the typical set. Only the step size is | ||
adapted in this window. | ||
2. "Slow" parameters that require global information (typically covariance) | ||
are estimated in a series of expanding intervals with no memory; the step | ||
size is re-initialized at the end of each window. Each window is twice the | ||
size of the preceding window. | ||
3. A final fast interval during which the step size is adapted using the | ||
computed mass matrix. | ||
Schematically: | ||
``` | ||
+---------+---+------+------------+------------------------+------+ | ||
| fast | s | slow | slow | slow | fast | | ||
+---------+---+------+------------+------------------------+------+ | ||
``` | ||
The distinction slow/fast comes from the speed at which the algorithms | ||
converge to a stable value; in the common case, estimation of covariance | ||
requires more steps than dual averaging to give an accurate value. See _[1] | ||
for a more detailed explanation. | ||
Fast intervals are given the label 0 and slow intervals the label 1. | ||
Note | ||
---- | ||
It feels awkward to return a boolean that indicates whether the current | ||
step is the last step of a middle window, but not for other windows. This | ||
should probably be changed to "is_window_end" and we should manage the | ||
distinction upstream. | ||
Parameters | ||
---------- | ||
num_steps: int | ||
The number of warmup steps to perform. | ||
initial_buffer: int | ||
The width of the initial fast adaptation interval. | ||
first_window_size: int | ||
The width of the first slow adaptation interval. | ||
final_buffer_size: int | ||
The width of the final fast adaptation interval. | ||
Returns | ||
------- | ||
A list of tuples (window_label, is_middle_window_end). | ||
References | ||
---------- | ||
.. [1]: Stan Reference Manual v2.22 Section 15.2 "HMC Algorithm" | ||
""" | ||
schedule = [] | ||
|
||
# Give up on mass matrix adaptation when the number of warmup steps is too small. | ||
if num_steps < 20: | ||
schedule += [(0, False)] * (num_steps - 1) | ||
else: | ||
# When the number of warmup steps is smaller that the sum of the provided (or default) | ||
# window sizes we need to resize the different windows. | ||
if initial_buffer_size + first_window_size + final_buffer_size > num_steps: | ||
initial_buffer_size = int(0.15 * num_steps) | ||
final_buffer_size = int(0.1 * num_steps) | ||
first_window_size = num_steps - initial_buffer_size - final_buffer_size | ||
|
||
# First stage: adaptation of fast parameters | ||
schedule += [(0, False)] * (initial_buffer_size - 1) | ||
schedule.append((0, False)) | ||
|
||
# Second stage: adaptation of slow parameters in successive windows | ||
# doubling in size. | ||
final_buffer_start = num_steps - final_buffer_size | ||
|
||
next_window_size = first_window_size | ||
next_window_start = initial_buffer_size | ||
while next_window_start < final_buffer_start: | ||
current_start, current_size = next_window_start, next_window_size | ||
if 3 * current_size <= final_buffer_start - current_start: | ||
next_window_size = 2 * current_size | ||
else: | ||
current_size = final_buffer_start - current_start | ||
next_window_start = current_start + current_size | ||
schedule += [(1, False)] * (next_window_start - 1 - current_start) | ||
schedule.append((1, True)) | ||
|
||
# Last stage: adaptation of fast parameters | ||
schedule += [(0, False)] * (num_steps - 1 - final_buffer_start) | ||
schedule.append((0, False)) | ||
|
||
return schedule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
|
||
from aehmc import window_adaptation | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"num_steps, expected_schedule", | ||
[ | ||
(19, [(0, False)] * 18), # no mass matrix adaptation | ||
( | ||
100, | ||
[(0, False)] * 15 + [(1, False)] * 74 + [(1, True)] + [(0, False)] * 10, | ||
), # windows are resized | ||
( | ||
200, | ||
[(0, False)] * 75 | ||
+ [(1, False)] * 24 | ||
+ [(1, True)] | ||
+ [(1, False)] * 49 | ||
+ [(1, True)] | ||
+ [(0, False)] * 50, | ||
), | ||
], | ||
) | ||
def test_adaptation_schedule(num_steps, expected_schedule): | ||
adaptation_schedule = window_adaptation.build_schedule(num_steps) | ||
assert adaptation_schedule == expected_schedule |