Skip to content

Commit

Permalink
Merge 569eeba into d435a0b
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Nov 6, 2020
2 parents d435a0b + 569eeba commit 7d8382b
Show file tree
Hide file tree
Showing 7 changed files with 540 additions and 117 deletions.
22 changes: 22 additions & 0 deletions src/hamiltonian.jl
Expand Up @@ -26,6 +26,12 @@ end
Base.similar(dv::DualValue{<:AbstractVecOrMat{T}}) where {T<:AbstractFloat} =
DualValue(zeros(T, size(dv.value)...), zeros(T, size(dv.gradient)...))

function accept!(d::T, d′::T, is_accept::AbstractVector{Bool}) where {T<:DualValue{<:AbstractVector}}
accept!(d.value, d′.value, is_accept)
accept!(d.gradient, d′.gradient, is_accept)
return d
end

# `∂H∂θ` now returns `(logprob, -∂ℓπ∂θ)`
function ∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat)
res = h.∂ℓπ∂θ(θ)
Expand Down Expand Up @@ -84,6 +90,22 @@ Base.isfinite(v::DualValue) = all(isfinite, v.value) && all(isfinite, v.gradient
Base.isfinite(v::AbstractVecOrMat) = all(isfinite, v)
Base.isfinite(z::PhasePoint) = isfinite(z.ℓπ) && isfinite(z.ℓκ)

# Return the accepted phase point
function accept!(z::T, z′::T, is_accept::AbstractVector{Bool}) where {T<:PhasePoint{<:AbstractMatrix}}
# Revert unaccepted proposals in `z′`
if any(!, is_accept)
accept!(z.θ, z′.θ, is_accept)
accept!(z.r, z′.r, is_accept)
accept!(z.ℓπ, z′.ℓπ, is_accept)
accept!(z.ℓκ, z′.ℓκ, is_accept)
end
# Always return `z′` as any unaccepted proposal is already reverted
# NOTE: This in place treatment of `z′` is for memory efficient consideration.
# We can also copy `z′ and avoid mutating the original `z′`. But this is
# not efficient and immutability of `z′` is not important in this local scope.
return z′
end

###
### Negative energy (or log probability) functions.
### NOTE: the general form (i.e. non-Euclidean) of K depends on both θ and r.
Expand Down
6 changes: 3 additions & 3 deletions src/integrator.jl
Expand Up @@ -60,13 +60,13 @@ function step(
h::Hamiltonian,
z::P,
n_steps::Int=1;
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
fwd::Union{Bool,AbstractVector{Bool}}=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false)
) where {T<:AbstractScalarOrVec{<:AbstractFloat}, P<:PhasePoint, FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'
ϵ = step_size(lf)
ϵ = ifelse.(fwd, ϵ, .-ϵ)'

res = if FullTraj
Vector{P}(undef, n_steps)
Expand Down

0 comments on commit 7d8382b

Please sign in to comment.