Skip to content

Commit

Permalink
Add Stan window adaptation schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 25, 2021
1 parent aee5be3 commit acdd2dc
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
101 changes: 101 additions & 0 deletions aehmc/window_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import List, Tuple


def build_schedule(
num_steps: int,
initial_buffer_size: int = 75,
final_buffer_size: int = 50,
first_window_size: int = 25,
) -> List[Tuple[int, bool]]:
"""Return the schedule for Stan's warmup.
The schedule below is intended to be as close as possible to Stan's _[1].
The warmup period is split into three stages:
1. An initial fast interval to reach the typical set. Only the step size is
adapted in this window.
2. "Slow" parameters that require global information (typically covariance)
are estimated in a series of expanding intervals with no memory; the step
size is re-initialized at the end of each window. Each window is twice the
size of the preceding window.
3. A final fast interval during which the step size is adapted using the
computed mass matrix.
Schematically:
```
+---------+---+------+------------+------------------------+------+
| fast | s | slow | slow | slow | fast |
+---------+---+------+------------+------------------------+------+
```
The distinction slow/fast comes from the speed at which the algorithms
converge to a stable value; in the common case, estimation of covariance
requires more steps than dual averaging to give an accurate value. See _[1]
for a more detailed explanation.
Fast intervals are given the label 0 and slow intervals the label 1.
Note
----
It feels awkward to return a boolean that indicates whether the current
step is the last step of a middle window, but not for other windows. This
should probably be changed to "is_window_end" and we should manage the
distinction upstream.
Parameters
----------
num_steps: int
The number of warmup steps to perform.
initial_buffer: int
The width of the initial fast adaptation interval.
first_window_size: int
The width of the first slow adaptation interval.
final_buffer_size: int
The width of the final fast adaptation interval.
Returns
-------
A list of tuples (window_label, is_middle_window_end).
References
----------
.. [1]: Stan Reference Manual v2.22 Section 15.2 "HMC Algorithm"
"""
schedule = []

# Give up on mass matrix adaptation when the number of warmup steps is too small.
if num_steps < 20:
schedule += [(0, False)] * (num_steps - 1)
else:
# When the number of warmup steps is smaller that the sum of the provided (or default)
# window sizes we need to resize the different windows.
if initial_buffer_size + first_window_size + final_buffer_size > num_steps:
initial_buffer_size = int(0.15 * num_steps)
final_buffer_size = int(0.1 * num_steps)
first_window_size = num_steps - initial_buffer_size - final_buffer_size

# First stage: adaptation of fast parameters
schedule += [(0, False)] * (initial_buffer_size - 1)
schedule.append((0, False))

# Second stage: adaptation of slow parameters in successive windows
# doubling in size.
final_buffer_start = num_steps - final_buffer_size

next_window_size = first_window_size
next_window_start = initial_buffer_size
while next_window_start < final_buffer_start:
current_start, current_size = next_window_start, next_window_size
if 3 * current_size <= final_buffer_start - current_start:
next_window_size = 2 * current_size
else:
current_size = final_buffer_start - current_start
next_window_start = current_start + current_size
schedule += [(1, False)] * (next_window_start - 1 - current_start)
schedule.append((1, True))

# Last stage: adaptation of fast parameters
schedule += [(0, False)] * (num_steps - 1 - final_buffer_start)
schedule.append((0, False))

return schedule
27 changes: 27 additions & 0 deletions tests/test_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from aehmc import window_adaptation


@pytest.mark.parametrize(
"num_steps, expected_schedule",
[
(19, [(0, False)] * 18), # no mass matrix adaptation
(
100,
[(0, False)] * 15 + [(1, False)] * 74 + [(1, True)] + [(0, False)] * 10,
), # windows are resized
(
200,
[(0, False)] * 75
+ [(1, False)] * 24
+ [(1, True)]
+ [(1, False)] * 49
+ [(1, True)]
+ [(0, False)] * 50,
),
],
)
def test_adaptation_schedule(num_steps, expected_schedule):
adaptation_schedule = window_adaptation.build_schedule(num_steps)
assert adaptation_schedule == expected_schedule

0 comments on commit acdd2dc

Please sign in to comment.