Skip to content

Commit

Permalink
add helper to move momentum variable to CuArrays if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Aug 14, 2019
1 parent 6f98dd0 commit 6c557b5
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/hamiltonian.jl
Expand Up @@ -49,6 +49,17 @@ phasepoint(
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r))
) where {T<:AbstractVector} = PhasePoint(θ, r, ℓπ, ℓκ)

# If position variable and momentum variable are in different container,
# move the momentum variable to that of the position variable.
# This is neeced for AHMC to work with CuArrays (without depending on it).
phasepoint(
h::Hamiltonian,
θ::T1,
_r::T2;
r=T1(_r),
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r))
) where {T1<:AbstractVector,T2<:AbstractVector} = PhasePoint(θ, r, ℓπ, ℓκ)

Base.isfinite(v::DualValue) = all(isfinite, v.value) && all(isfinite, v.gradient)
Base.isfinite(v::AbstractVector) = all(isfinite, v)
Expand Down

0 comments on commit 6c557b5

Please sign in to comment.