Skip to content

Commit

Permalink
improve internal namings
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Aug 31, 2020
1 parent 237138f commit 200b52f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 30 deletions.
56 changes: 34 additions & 22 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ struct Trajectory{I<:AbstractIntegrator,TC<:AbstractTerminationCriterion}
"Integrator used to simulate trajectory."
integrator::I
"Criterion to terminate the simulation."
criterion::TC
termination_criterion::TC
end

Base.show(io::IO, τ::Trajectory) = print(io, "Trajectory(integrator=$(τ.integrator), criterion=$(τ.criterion))")
Base.show(io::IO, τ::Trajectory) = print(io, "Trajectory(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))")

struct HMCKernel{T<:Trajectory, TS<:AbstractTrajectorySampler} <: AbstractKernel
τ::T
Expand All @@ -216,10 +216,10 @@ end
Base.show(io::IO, κ::HMCKernel) = print(io, "HMCKernel(\n τ=$(κ.τ),\n trajectory_sampler_type=$(κ.trajectory_sampler_type)\n)")

function transition(rng, h::Hamiltonian, κ::HMCKernel, z::PhasePoint)
@unpack τ, TS = κ
@unpack τ, trajectory_sampler_type = κ
τ = reconstruct(τ, integrator=jitter(rng, τ.integrator))
z = refresh(rng, z, h) # refresh momentum variable
return transition(rng, h, τ, TS, z)
return transition(rng, h, τ, trajectory_sampler_type, z)
end

"""
Expand All @@ -238,17 +238,17 @@ end
###

function transition(rng, h, τ::Trajectory{I, <:FixedNSteps}, ::Type{TS}, z) where {I, TS}
@unpack integrator, criterion = τ
@unpack integrator, termination_criterion = τ
H0 = energy(z)
z′, is_accept, α = sample_phasepoint(rng, integrator, criterion, TS, h, z)
z′, is_accept, α = sample_phasepoint(rng, integrator, termination_criterion, TS, h, z)
# Do the actual accept / reject
z = accept_phasepoint!(z, z′, is_accept) # NOTE: this function changes `z′` in place in matrix-parallel mode
# Reverse momentum variable to preserve reversibility
z = PhasePoint(z.θ, -z.r, z.ℓπ, z.ℓκ)
H = energy(z)
tstat = merge(
(
n_steps=criterion.L,
n_steps=termination_criterion.L,
is_accept=is_accept,
acceptance_rate=α,
log_density=z.ℓπ.value,
Expand All @@ -261,9 +261,9 @@ function transition(rng, h, τ::Trajectory{I, <:FixedNSteps}, ::Type{TS}, z) whe
end

function transition(rng, h, τ::Trajectory{I, <:FixedIntegrationTime}, ::Type{TS}, z) where {I, TS}
@unpack integrator, criterion = τ
# Create the corresponding `FixedNSteps` criterion
L = max(1, floor(Int, criterion.λ / nom_step_size(integrator)))
@unpack integrator, termination_criterion = τ
# Create the corresponding `FixedNSteps` termination criterion
L = max(1, floor(Int, termination_criterion.λ / nom_step_size(integrator)))
τ = Trajectory(integrator, FixedNSteps(L))
return transition(rng, h, τ, TS, z)
end
Expand Down Expand Up @@ -554,7 +554,7 @@ using the generalised no-U-turn criterion with additional U-turn checks.
## References
1. https://arxiv.org/abs/1701.02434
1. Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. [arXiv preprint arXiv:1701.02434](https://arxiv.org/abs/1701.02434).
2. https://github.com/stan-dev/stan/pull/2800
3. https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727/33
"""
Expand Down Expand Up @@ -600,10 +600,18 @@ function generalised_uturn_criterion(rho, p_sharp_minus, p_sharp_plus)
return (dot(rho, p_sharp_minus) <= 0) || (dot(rho, p_sharp_plus) <= 0)
end

"""
Recursivly build a tree for a given depth `j`.
"""
function build_tree(rng, integrator, tc, h, z, sampler::TS, v, j, H0) where {TS}
"Recursivly build a tree for a given depth `j`."
function build_tree(
rng::AbstractRNG,
integrator::AbstractIntegrator,
tc::AbstractTerminationCriterion,
h::Hamiltonian,
z::PhasePoint,
sampler::TS,
v::Int,
j::Int,
H0::AbstractFloat,
) where {TS}
if j == 0
# Base case - take one leapfrog step in the direction v.
z′ = step(integrator, h, z, v)
Expand Down Expand Up @@ -635,28 +643,32 @@ function build_tree(rng, integrator, tc, h, z, sampler::TS, v, j, H0) where {TS}
end

function transition(
rng, h, τ::Trajectory{I, C}, ::Type{TS}, z0
rng::AbstractRNG,
h::Hamiltonian,
τ::Trajectory{I, C},
::Type{TS},
z0::PhasePoint,
) where {I, C<:DynamicTerminationCriterion, TS}
@unpack integrator, criterion = τ
@unpack integrator, termination_criterion = τ
H0 = energy(z0)
tree = BinaryTree(z0, z0, criterion, zero(H0), zero(Int), zero(H0))
tree = BinaryTree(z0, z0, termination_criterion, zero(H0), zero(Int), zero(H0))
sampler = TS(rng, z0)
termination = Termination(false, false)
zcand = z0

τ = reconstruct(τ, integrator=integrator)

j = 0
while !isterminated(termination) && j < criterion.max_depth
while !isterminated(termination) && j < termination_criterion.max_depth
# Sample a direction; `-1` means left and `1` means right
v = rand(rng, [-1, 1])
if v == -1
# Create a tree with depth `j` on the left
tree′, sampler′, termination′ = build_tree(rng, integrator, criterion, h, tree.zleft, sampler, v, j, H0)
tree′, sampler′, termination′ = build_tree(rng, integrator, termination_criterion, h, tree.zleft, sampler, v, j, H0)
treeleft, treeright = tree′, tree
else
# Create a tree with depth `j` on the right
tree′, sampler′, termination′ = build_tree(rng, integrator, criterion, h, tree.zright, sampler, v, j, H0)
tree′, sampler′, termination′ = build_tree(rng, integrator, termination_criterion, h, tree.zright, sampler, v, j, H0)
treeleft, treeright = tree, tree′
end
# Perform a MH step and increse depth if not terminated
Expand All @@ -671,7 +683,7 @@ function transition(
# Update sampler
sampler = combine(zcand, sampler, sampler′)
# update termination
termination = termination * termination′ * isterminated(criterion, h, tree, treeleft, treeright)
termination = termination * termination′ * isterminated(termination_criterion, h, tree, treeleft, treeright)
end

H = energy(zcand)
Expand Down
2 changes: 1 addition & 1 deletion test/sampler-vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ include("common.jl")
StepSizeAdaptor(0.8, lfi),
),
]
κ.τ.criterion isa FixedIntegrationTime && continue
κ.τ.termination_criterion isa FixedIntegrationTime && continue

@test show(adaptor) == nothing; println()

Expand Down
2 changes: 1 addition & 1 deletion test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end
end
samples, stats = sample(h, κ_used , θ_init, n_samples, adaptor, n_adapts; verbose=false, progress=PROGRESS)
@test mean(samples[(n_adapts+1):end]) zeros(D) atol=RNDATOL
test_stats(κ_used.τ.criterion, stats, n_adapts)
test_stats(κ_used.τ.termination_criterion, stats, n_adapts)
end
end
end
Expand Down
12 changes: 6 additions & 6 deletions test/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ function ahmc_isturn_classic(z0, z1, rho, v=1)
return AdvancedHMC.isterminated(ClassicNoUTurn(), h, tree).dynamic
end

function hand_isturn(z0, z1, rho, v=1)
function hand_isturn_generalised(z0, z1, rho, v=1)
s = (dot(rho, -z0.r) >= 0) || (dot(-rho, z1.r) >= 0)
return s
end

function ahmc_isturn(z0, z1, rho, v=1)
function ahmc_isturn_generalised(z0, z1, rho, v=1)
tree = AdvancedHMC.BinaryTree(z0, z1, rho, 0, 0, 0.0)
return AdvancedHMC.isterminated(GeneralisedNoUTurn(), h, tree).dynamic
end

function ahmc_isturn_strict(z0, z1, rho, v=1)
function ahmc_isturn_strict_generalised(z0, z1, rho, v=1)
t = AdvancedHMC.isterminated(
StrictGeneralisedNoUTurn(), h,
AdvancedHMC.BinaryTree(z0, z1, rho, 0, 0, 0.0),
Expand Down Expand Up @@ -255,10 +255,10 @@ end
ts_hand_isturn_fwd = hand_isturn_classic.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))
ts_ahmc_isturn_fwd = ahmc_isturn_classic.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))

ts_hand_isturn_generalised_fwd = hand_isturn.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))
ts_ahmc_isturn_generalised_fwd = ahmc_isturn.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))
ts_hand_isturn_generalised_fwd = hand_isturn_generalised.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))
ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))

ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strict.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))
ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strict_generalised.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)], Ref(1))

check_subtree_u_turns.(Ref(traj_z[1]), traj_z, [rho[:,i] for i = 1:length(traj_z)])

Expand Down

0 comments on commit 200b52f

Please sign in to comment.