Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards v0.2.0 #81

Merged
merged 25 commits into from Jul 23, 2019
Merged
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d079ef1
introduce Divergence type
xukai92 Jul 11, 2019
00290da
change the filed name of divergence from s to div in FullBinaryTree
xukai92 Jul 11, 2019
5ea10cc
draft the interface to return stats
xukai92 Jul 11, 2019
d5d761d
improve the formatting of verbose
xukai92 Jul 11, 2019
d111369
implement Base.show for user side types (#36)
xukai92 Jul 11, 2019
83c1345
remove return_stats - need to check @cpfiffer's interface change
xukai92 Jul 11, 2019
687c877
RFC Adapter types #28
xukai92 Jul 11, 2019
e35cef4
change version number to v0.2.0
xukai92 Jul 11, 2019
0d24742
add step size jittering
xukai92 Jul 11, 2019
d94894a
quote Stan
xukai92 Jul 11, 2019
fe4e0d6
implement renew to all internal types except AbstractMetric
xukai92 Jul 11, 2019
f401e3d
remove unused caller renew
xukai92 Jul 11, 2019
3120395
implement renew for AbstractMetric; all caller renwers are removed (#45)
xukai92 Jul 11, 2019
bfb96c2
improve _string_diag
xukai92 Jul 11, 2019
c5e4a4c
return stats via NamedTuple
xukai92 Jul 11, 2019
a061e5f
fix test and update README.md
xukai92 Jul 12, 2019
e5da205
ignore logdensity and graident function in Base.show for Hamiltonian …
xukai92 Jul 12, 2019
c722ac5
remove unused stats type
xukai92 Jul 18, 2019
f4e5aa0
update named tuple signature
xukai92 Jul 19, 2019
fd6ebcd
Divergence -> Termination
xukai92 Jul 19, 2019
68caa8d
div -> termination
xukai92 Jul 19, 2019
5c0d1d8
remove renew except those for metric
xukai92 Jul 19, 2019
7ea83e9
remove renew except those for metric
xukai92 Jul 19, 2019
3f4e02d
remove renew except those for metric
xukai92 Jul 19, 2019
356bb8f
remove renew for metric
xukai92 Jul 23, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/trajectory.jl
Expand Up @@ -202,6 +202,19 @@ end
### The doubling tree algorithm for expanding trajectory.
###

"""
Divergence reasons
- `dynamic`: due to stoping criteria
- `numeric`: due to large energy deviation from starting (possibly numeric errors)
"""
struct Divergence
yebai marked this conversation as resolved.
Show resolved Hide resolved
dynamic::Bool
numeric::Bool
end

Base.:*(d1::Divergence, d2::Divergence) = Divergence(d1.dynamic || d2.dynamic, d1.numeric || d2.numeric)
isdivergent(d::Divergence) = d.dynamic || d.numeric

"""
A full binary tree trajectory with only necessary leaves and information stored.
"""
Expand All @@ -210,7 +223,7 @@ struct FullBinaryTree{S<:AbstractTreeSampler}
zright # right most leaf node
zcand # candidate leaf node
sampler::S # condidate sampler
s # termination stats, i.e. 0 means termination and 1 means continuation
div # divergence reasons
α # MH stats, i.e. sum of MH accept prob for all leapfrog steps
nα # total # of leap frog steps, i.e. phase points in a trajectory
end
Expand All @@ -220,7 +233,8 @@ Detect U turn for two phase points (`zleft` and `zright`) under given Hamiltonia
"""
function isturn(h::Hamiltonian, zleft::PhasePoint, zright::PhasePoint)
θdiff = zright.θ - zleft.θ
return (dot(θdiff, ∂H∂r(h, zleft.r)) >= 0 ? 1 : 0) * (dot(θdiff, ∂H∂r(h, zright.r)) >= 0 ? 1 : 0)
s = (dot(θdiff, ∂H∂r(h, zleft.r)) >= 0 ? 1 : 0) * (dot(θdiff, ∂H∂r(h, zright.r)) >= 0 ? 1 : 0)
return Divergence(s == 0, false)
end

"""
Expand All @@ -231,13 +245,13 @@ isdivergent(
nt::NUTS,
H0::F,
H′::F
) where {F<:AbstractFloat} = (s.logu < nt.Δ_max + -H′) ? 0 : 1
) where {F<:AbstractFloat} = Divergence(false, !(s.logu < nt.Δ_max + -H′))
yebai marked this conversation as resolved.
Show resolved Hide resolved
isdivergent(
s::MultinomialTreeSampler,
nt::NUTS,
H0::F,
H′::F
) where {F<:AbstractFloat} = (-H0 < nt.Δ_max + -H′) ? 0 : 1
) where {F<:AbstractFloat} = Divergence(false, !(-H0 < nt.Δ_max + -H′))

"""
combine(h::Hamiltonian, tleft::FullBinaryTree, tright::FullBinaryTree)
Expand All @@ -255,8 +269,8 @@ function combine(
zright = tright.zright
zcand = combine(rng, tleft, tright)
sampler = combine(tleft.sampler, tright.sampler)
s = tleft.s * tright.s * isturn(h, zleft, zright)
return FullBinaryTree(zleft, zright, zcand, sampler, s, tright.α + tright.α, tright.nα + tright.nα)
div = tleft.div * tright.div * isturn(h, zleft, zright)
return FullBinaryTree(zleft, zright, zcand, sampler, div, tright.α + tright.α, tright.nα + tright.nα)
end

"""
Expand Down Expand Up @@ -299,14 +313,14 @@ function build_tree(
z′ = step(nt.integrator, h, z, v)
H′ = -neg_energy(z′)
basesampler = makebase(sampler, H′)
s′ = 1 - isdivergent(basesampler, nt, H0, H′)
div = isdivergent(basesampler, nt, H0, H′)
α′ = exp(min(0, H0 - H′))
return FullBinaryTree(z′, z′, z′, basesampler, s′, α′, 1)
return FullBinaryTree(z′, z′, z′, basesampler, div, α′, 1)
else
# Recursion - build the left and right subtrees.
t′ = build_tree(rng, nt, h, z, sampler, v, j - 1, H0)
# Expand tree if not terminated
if t′.s == 1
if !isdivergent(t′.div)
# Expand left
if v == -1
t′′ = build_tree(rng, nt, h, t′.zleft, sampler, v, j - 1, H0) # left tree
Expand Down Expand Up @@ -341,10 +355,11 @@ function transition(
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractTreeSampler}
H0 = -neg_energy(z0)

zleft = z0; zright = z0; zcand = z0; j = 0; s = 1; sampler = S(rng, H0)
zleft = z0; zright = z0; zcand = z0;
j = 0; div = Divergence(false, false); sampler = S(rng, H0)

local t
while s == 1 && j <= nt.max_depth
while !isdivergent(div) && j <= nt.max_depth
# Sample a direction; `-1` means left and `1` means right
v = rand(rng, [-1, 1])
if v == -1
Expand All @@ -357,13 +372,13 @@ function transition(
zright = t.zright
end
# Perform a MH step if not terminated
if t.s == 1 && mh_accept(rng, sampler, t.sampler)
if !isdivergent(t.div) && mh_accept(rng, sampler, t.sampler)
zcand = t.zcand
end
# Combine the sampler from the proposed tree and the current tree
sampler = combine(sampler, t.sampler)
# Detect termination
s = s * t.s * isturn(h, zleft, zright)
div = div * t.div * isturn(h, zleft, zright)
# Increment tree depth
j = j + 1
end
Expand Down