-
Notifications
You must be signed in to change notification settings - Fork 39
/
integrator.jl
250 lines (205 loc) · 7.93 KB
/
integrator.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
####
#### Numerical methods for simulating Hamiltonian trajectory.
####
# TODO: The type `<:Tuple{Integer,Bool}` is introduced to address
# https://github.com/TuringLang/Turing.jl/pull/941#issuecomment-549191813
# We might want to simplify it to `Tuple{Int,Bool}` when we figured out
# why the it behaves unexpected on Windows 32.
"""
$(TYPEDEF)
Represents an integrator used to simulate the Hamiltonian system.
# Implementation
A `AbstractIntegrator` is expected to have the following implementations:
- `stat`(@ref)
- `nom_step_size`(@ref)
- `step_size`(@ref)
"""
abstract type AbstractIntegrator end
stat(::AbstractIntegrator) = NamedTuple()
"""
nom_step_size(::AbstractIntegrator)
Get the nominal integration step size. The current integration step size may
differ from this, for example if the step size is jittered. Nominal step size is
usually used in adaptation.
"""
nom_step_size(i::AbstractIntegrator) = step_size(i)
"""
step_size(::AbstractIntegrator)
Get the current integration step size.
"""
function step_size end
"""
update_nom_step_size(i::AbstractIntegrator, ϵ) -> AbstractIntegrator
Return a copy of the integrator `i` with the new nominal step size ([`nom_step_size`](@ref))
`ϵ`.
"""
function update_nom_step_size end
abstract type AbstractLeapfrog{T} <: AbstractIntegrator end
step_size(lf::AbstractLeapfrog) = lf.ϵ
jitter(::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, lf::AbstractLeapfrog) = lf
temper(lf::AbstractLeapfrog, r, ::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}}, ::Int) =
r
stat(lf::AbstractLeapfrog) = (step_size = step_size(lf), nom_step_size = nom_step_size(lf))
update_nom_step_size(lf::AbstractLeapfrog, ϵ) = @set lf.ϵ = ϵ
"""
$(TYPEDEF)
Leapfrog integrator with fixed step size `ϵ`.
# Fields
$(TYPEDFIELDS)
"""
struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
end
Base.show(io::IO, l::Leapfrog) = print(io, "Leapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)))")
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T
### Jittering
"""
$(TYPEDEF)
Leapfrog integrator with randomly "jittered" step size `ϵ` for every trajectory.
# Fields
$(TYPEDFIELDS)
# Description
This is the same as `LeapFrog`(@ref) but with a "jittered" step size. This means
that at the beginning of each trajectory we sample a step size `ϵ` by adding or
subtracting from the nominal/base step size `ϵ0` some random proportion of `ϵ0`,
with the proportion specified by `jitter`, i.e. `ϵ = ϵ0 - jitter * ϵ0 * rand()`.
p
Jittering might help alleviate issues related to poor interactions with a fixed step size:
- In regions with high "curvature" the current choice of step size might mean over-shoot
leading to almost all steps being rejected. Randomly sampling the step size at the
beginning of the trajectories can therefore increase the probability of escaping such
high-curvature regions.
- Exact periodicity of the simulated trajectories might occur, i.e. you might be so
unlucky as to simulate the trajectory forwards in time `L ϵ` and ending up at the
same point (which results in non-ergodicity; see Section 3.2 in [1]). If momentum
is refreshed before each trajectory, then this should not happen *exactly* but it
can still be an issue in practice. Randomly choosing the step-size `ϵ` might help
alleviate such problems.
# References
1. Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo, 2(11), 2. ([arXiv](https://arxiv.org/pdf/1206.1901))
"""
struct JitteredLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: AbstractLeapfrog{T}
"Nominal (non-jittered) step size."
ϵ0::T
"The proportion of the nominal step size `ϵ0` that may be added or subtracted."
jitter::FT
"Current (jittered) step size."
ϵ::T
end
JitteredLeapfrog(ϵ0, jitter) = JitteredLeapfrog(ϵ0, jitter, ϵ0)
function Base.show(io::IO, l::JitteredLeapfrog)
print(
io,
"JitteredLeapfrog(ϵ0=$(round.(l.ϵ0; sigdigits=3)), jitter=$(round.(l.jitter; sigdigits=3)), ϵ=$(round.(l.ϵ; sigdigits=3)))",
)
end
nom_step_size(lf::JitteredLeapfrog) = lf.ϵ0
update_nom_step_size(lf::JitteredLeapfrog, ϵ0) = @set lf.ϵ0 = ϵ0
# Jitter step size; ref: https://github.com/stan-dev/stan/blob/1bb054027b01326e66ec610e95ef9b2a60aa6bec/src/stan/mcmc/hmc/base_hmc.hpp#L177-L178
function _jitter(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}}
ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand(rng) .- 1))
return @set lf.ϵ = FT.(ϵ)
end
jitter(rng::AbstractRNG, lf::JitteredLeapfrog) = _jitter(rng, lf)
jitter(
rng::AbstractVector{<:AbstractRNG},
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} = _jitter(rng, lf)
### Tempering
# TODO: add ref or at least explain what exactly we're doing
"""
$(TYPEDEF)
Tempered leapfrog integrator with fixed step size `ϵ` and "temperature" `α`.
# Fields
$(TYPEDFIELDS)
# Description
Tempering can potentially allow greater exploration of the posterior, e.g.
in a multi-modal posterior jumps between the modes can be more likely to occur.
"""
struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
"Temperature parameter."
α::FT
end
function Base.show(io::IO, l::TemperedLeapfrog)
print(
io,
"TemperedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), α=$(round.(l.α; sigdigits=3)))",
)
end
"""
temper(lf::TemperedLeapfrog, r, step::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}}, n_steps::Int)
Tempering step. `step` is a named tuple with
- `i` being the current leapfrog iteration and
- `is_half` indicating whether or not it's (the first) half momentum/tempering step
"""
function temper(
lf::TemperedLeapfrog,
r,
step::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}},
n_steps::Int,
)
if step.i > n_steps
throw(BoundsError("Current leapfrog iteration exceeds the total number of steps."))
end
i_temper = 2(step.i - 1) + 1 + !step.is_half # counter for half temper steps
return i_temper <= n_steps ? r * sqrt(lf.α) : r / sqrt(lf.α)
end
# `step` method for integrators above
# method for `DiffEqIntegrator` is defined in the OrdinaryDiffEq extension
const DefaultLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} =
Union{Leapfrog{T},JitteredLeapfrog{FT,T},TemperedLeapfrog{FT,T}}
function step(
lf::DefaultLeapfrog{FT,T},
h::Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false),
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases
ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'
res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end
@unpack θ, r = z
@unpack value, gradient = z.ℓπ
for i = 1:n_steps
# Tempering
r = temper(lf, r, (i = i, is_half = true), n_steps)
# Take a half leapfrog step for momentum variable
r = r - ϵ / 2 .* gradient
# Take a full leapfrog step for position variable
∇r = ∂H∂r(h, r)
θ = θ + ϵ .* ∇r
# Take a half leapfrog step for momentum variable
@unpack value, gradient = ∂H∂θ(h, θ)
r = r - ϵ / 2 .* gradient
# Tempering
r = temper(lf, r, (i = i, is_half = false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ, r; ℓπ = DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end