Skip to content

Commit

Permalink
Merge 6ec9e9b into f31dc17
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed May 23, 2020
2 parents f31dc17 + 6ec9e9b commit 15d6721
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/workflows/AHMC-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
matrix:
version:
Expand Down
42 changes: 33 additions & 9 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ function transition(
)
H0 = energy(z)
integrator = jitter(rng, τ.integrator)
z′ = samplecand(rng, τ, h, z)
# Are we going to accept the `z′` via MH criteria?
is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′))
z′, is_accept, α = samplecand(rng, τ, h, z)
# Do the actual accept / reject
z = accept_phasepoint!(z, z′, is_accept) # NOTE: this function changes `z′` in place in matrix-parallel mode
# Reverse momentum variable to preserve reversibility
Expand Down Expand Up @@ -223,16 +221,26 @@ end

### Use end-point from trajecory as proposal

samplecand(rng, τ::StaticTrajectory{EndPointTS}, h, z) = step.integrator, h, z, τ.n_steps)
function samplecand(rng, τ::StaticTrajectory{EndPointTS}, h, z)
z′ = step.integrator, h, z, τ.n_steps)
# Are we going to accept the `z′` via MH criteria?
is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′))
return z′, is_accept, α
end

### Multinomial sampling from trajecory

randcat(rng::AbstractRNG, zs::AbstractVector{<:PhasePoint}, unnorm_ℓp::AbstractVector) = zs[randcat_logp(rng, unnorm_ℓp)]
function randcat(rng::AbstractRNG, zs::AbstractVector{<:PhasePoint}, unnorm_ℓp::AbstractVector)
p = exp.(unnorm_ℓp .- logsumexp(unnorm_ℓp))
i = randcat(rng, p)
return zs[i]
end

# zs is in the form of Vector{PhasePoint{Matrix}} and has shape [n_steps][dim, n_chains]
function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMatrix)
z = similar(first(zs))
is = randcat_logp(rng, unnorm_ℓP)
P = exp.(unnorm_ℓP .- logsumexp(unnorm_ℓP; dims=2)) # (n_chians, n_steps)
is = randcat(rng, P)
foreach(enumerate(is)) do (i_chain, i_step)
zi = zs[i_step]
z.θ[:,i_chain] = zi.θ[:,i_chain]
Expand All @@ -246,13 +254,29 @@ function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMat
end

function samplecand(rng, τ::StaticTrajectory{MultinomialTS}, h, z)
zs = step.integrator, h, z, τ.n_steps; res=[z for _ in 1:abs.n_steps)])
n_steps = abs.n_steps)
n_fwd = rand(0:n_steps) # FIXME: deal with multi-chain generically
zs_fwd = step.integrator, h, z, n_fwd; fwd=true, res=[z for _ in 1:n_fwd])
n_bwd = n_steps - n_fwd
zs_bwd = step.integrator, h, z, n_bwd; fwd=false, res=[z for _ in 1:n_bwd])
zs = vcat(reverse(zs_bwd)..., z, zs_fwd...)
ℓws = -energy.(zs)
if eltype(ℓws) <: AbstractVector
ℓws = hcat(ℓws...)
ℓws = cat(ℓws...; dims=2)
end
unnorm_ℓprob = ℓws
return randcat(rng, zs, unnorm_ℓprob)
z′ = randcat(rng, zs, unnorm_ℓprob)
# Computing adaptation statistics for dual averaging as done in NUTS
Hs = -ℓws
ΔH = Hs .- energy(z)
α = exp.(min.(0, -ΔH)) # this is a matrix for vectorized mode and a vector otherwise
α =
if typeof(α) <: AbstractVector
mean(α)
else
vec(mean(α; dims=2))
end
return z′, true, α
end

abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} end
Expand Down
10 changes: 1 addition & 9 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ end

# Sample from Categorical distributions

randcat_logp(rng::AbstractRNG, unnorm_ℓp::AbstractVector) =
randcat(rng, exp.(unnorm_ℓp .- logsumexp(unnorm_ℓp)))

function randcat(rng::AbstractRNG, p::AbstractVector{T}) where {T}
u = rand(rng, T)
c = zero(eltype(p))
Expand All @@ -31,17 +28,12 @@ function randcat(rng::AbstractRNG, p::AbstractVector{T}) where {T}
return max(i, 1)
end

randcat_logp(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
unnorm_ℓP::AbstractMatrix
) = randcat(rng, exp.(unnorm_ℓP .- logsumexp(unnorm_ℓP; dims=2)))

function randcat(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
P::AbstractMatrix{T}
) where {T}
u = rand(rng, T, size(P, 1))
C = cumsum(P; dims=2)
is = convert.(Int, vec(sum(C .< u; dims=2)))
is = convert.(Int, vec(sum(C .< u; dims=2))) .+ 1
return max.(is, 1)
end

0 comments on commit 15d6721

Please sign in to comment.