-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
hmc.py
256 lines (218 loc) · 7.57 KB
/
hmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from typing import Callable, Dict, Tuple
import aesara
import aesara.tensor as at
import numpy as np
from aesara.ifelse import ifelse
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable
import aehmc.integrators as integrators
import aehmc.metrics as metrics
import aehmc.trajectory as trajectory
def new_state(
q: TensorVariable, logprob_fn: Callable
) -> Tuple[TensorVariable, TensorVariable, TensorVariable]:
"""Create a new HMC state from a position.
Parameters
----------
q
The chain's position.
logprob_fn
The function that computes the value of the log-probability density
function at any position.
Returns
-------
A new HMC state, i.e. a tuple with the position, current value of the
potential energy and gradient of the potential energy.
"""
potential_energy = -logprob_fn(q)
potential_energy_grad = aesara.grad(potential_energy, wrt=q)
return q, potential_energy, potential_energy_grad
def new_kernel(
srng: RandomStream,
logprob_fn: Callable,
divergence_threshold: int = 1000,
) -> Callable:
"""Build a HMC kernel.
Parameters
----------
srng
A RandomStream object that tracks the changes in a shared random state.
logprob_fn
A function that returns the value of the log-probability density
function of a chain at a given position.
divergence_threshold
The difference in energy above which we say the transition is
divergent.
Returns
-------
A kernel that takes the current state of the chain and that returns a new
state.
References
----------
.. [0]: Neal, Radford M. "MCMC using Hamiltonian dynamics." Handbook of markov
chain monte carlo 2.11 (2011): 2.
"""
def potential_fn(x):
return -logprob_fn(x)
def step(
q: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
num_integration_steps: int,
) -> Tuple[
Tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, bool],
Dict,
]:
"""Perform a single step of the HMC algorithm.
Parameters
----------
q
The initial position.
potential_energy
The initial value of the potential energy.
potential_energy_grad
The initial value of the gradient of the potential energy wrt the position.
step_size
The step size used in the symplectic integrator
inverse_mass_matrix
One or two-dimensional array used as the inverse mass matrix that
defines the euclidean metric.
num_integration_steps
The number of times we run the integrator at each step.
Returns
-------
A tuple that contains on the one hand: the new position, value of the
potential energy, gradient of the potential energy and acceptance
propbability. On the other hand a dictionaruy that contains the update
rules for the shared variables updated in the scan operator.
"""
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_metric(
inverse_mass_matrix
)
symplectic_integrator = integrators.velocity_verlet(
potential_fn, kinetic_energy_fn
)
proposal_generator = hmc_proposal(
symplectic_integrator,
kinetic_energy_fn,
num_integration_steps,
divergence_threshold,
)
p = momentum_generator(srng)
(
q_new,
_,
potential_energy_new,
potential_energy_grad_new,
acceptance_probability,
is_divergent,
), updates = proposal_generator(
srng, q, p, potential_energy, potential_energy_grad, step_size
)
return (
q_new,
potential_energy_new,
potential_energy_grad_new,
acceptance_probability,
is_divergent,
), updates
return step
def hmc_proposal(
integrator: Callable,
kinetic_energy: Callable[[TensorVariable], TensorVariable],
num_integration_steps: TensorVariable,
divergence_threshold: int,
) -> Callable:
"""Builds a function that returns a HMC proposal.
Parameters
--------
integrator
The symplectic integrator used to integrate the hamiltonian dynamics.
kinetic_energy
The function used to compute the kinetic energy.
num_integration_steps
The number of times we need to run the integrator every time the
returned function is called.
divergence_threshold
The difference in energy above which we say the transition is
divergent.
Returns
-------
A function that generates a new state for the chain.
"""
integrate = trajectory.static_integration(integrator, num_integration_steps)
def propose(
srng: RandomStream,
q: TensorVariable,
p: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
) -> Tuple[
Tuple[
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
],
Dict,
]:
"""Use the HMC algorithm to propose a new state.
Parameters
----------
srng
A RandomStream object that tracks the changes in a shared random state.
q
The initial position.
potential_energy
The initial value of the potential energy.
potential_energy_grad
The initial value of the gradient of the potential energy wrt the position.
step_size
The step size used in the symplectic integrator
Returns
-------
A tuple that contains on the one hand: the new position, value of the
potential energy, gradient of the potential energy and acceptance
probability. On the other hand a dictionary that contains the update
rules for the shared variables updated in the scan operator.
"""
(
new_q,
new_p,
new_potential_energy,
new_potential_energy_grad,
), updates = integrate(q, p, potential_energy, potential_energy_grad, step_size)
# flip the momentum to keep detailed balance
flipped_p = -1.0 * new_p
# compute transition-related quantities
energy = potential_energy + kinetic_energy(p)
new_energy = new_potential_energy + kinetic_energy(flipped_p)
delta_energy = energy - new_energy
delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy)
is_transition_divergent = at.abs(delta_energy) > divergence_threshold
p_accept = at.clip(at.exp(delta_energy), 0, 1.0)
do_accept = srng.bernoulli(p_accept)
(
final_q,
final_p,
final_potential_energy,
final_potential_energy_grad,
) = ifelse(
do_accept,
(new_q, flipped_p, new_potential_energy, new_potential_energy_grad),
(q, p, potential_energy, potential_energy_grad),
)
return (
final_q,
final_p,
final_potential_energy,
final_potential_energy_grad,
p_accept,
is_transition_divergent,
), updates
return propose