Skip to content

Commit

Permalink
Merge 1c6b77d into 9923f23
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Jun 2, 2020
2 parents 9923f23 + 1c6b77d commit 9ee12e4
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 26 deletions.
1 change: 1 addition & 0 deletions .github/workflows/AHMC-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
matrix:
version:
Expand Down
8 changes: 7 additions & 1 deletion src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ function step(
else
res = z
end
!isfinite(z) && break
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end
Expand Down
48 changes: 34 additions & 14 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ function transition(
)
H0 = energy(z)
integrator = jitter(rng, τ.integrator)
z′ = samplecand(rng, τ, h, z)
# Are we going to accept the `z′` via MH criteria?
is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′))
z′, is_accept, α = sample_phasepoint(rng, τ, h, z)
# Do the actual accept / reject
z = accept_phasepoint!(z, z′, is_accept) # NOTE: this function changes `z′` in place in matrix-parallel mode
# Reverse momentum variable to preserve reversibility
Expand Down Expand Up @@ -261,18 +259,27 @@ function accept_phasepoint!(z::T, z′::T, is_accept) where {T<:PhasePoint{<:Abs
return z′
end

### Use end-point from trajectory as proposal
### Use end-point from the trajectory as a proposal and apply MH correction

samplecand(rng, τ::StaticTrajectory{EndPointTS}, h, z) = step.integrator, h, z, τ.n_steps)
function sample_phasepoint(rng, τ::StaticTrajectory{EndPointTS}, h, z)
z′ = step.integrator, h, z, τ.n_steps)
is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′))
return z′, is_accept, α
end

### Multinomial sampling from trajectory

randcat(rng::AbstractRNG, zs::AbstractVector{<:PhasePoint}, unnorm_ℓp::AbstractVector) = zs[randcat_logp(rng, unnorm_ℓp)]
function randcat(rng::AbstractRNG, zs::AbstractVector{<:PhasePoint}, unnorm_ℓp::AbstractVector)
p = exp.(unnorm_ℓp .- logsumexp(unnorm_ℓp))
i = randcat(rng, p)
return zs[i]
end

# zs is in the form of Vector{PhasePoint{Matrix}} and has shape [n_steps][dim, n_chains]
function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMatrix)
z = similar(first(zs))
is = randcat_logp(rng, unnorm_ℓP)
P = exp.(unnorm_ℓP .- logsumexp(unnorm_ℓP; dims=2)) # (n_chains, n_steps)
is = randcat(rng, P')
foreach(enumerate(is)) do (i_chain, i_step)
zi = zs[i_step]
z.θ[:,i_chain] = zi.θ[:,i_chain]
Expand All @@ -285,14 +292,27 @@ function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMat
return z
end

function samplecand(rng, τ::StaticTrajectory{MultinomialTS}, h, z)
zs = step.integrator, h, z, τ.n_steps; full_trajectory = Val(true))
ℓws = -energy.(zs)
if eltype(ℓws) <: AbstractVector
ℓws = hcat(ℓws...)
function sample_phasepoint(rng, τ::StaticTrajectory{MultinomialTS}, h, z)
n_steps = abs.n_steps)
# TODO: Deal with vectorized-mode generically.
# Currently the direction of multiple chains are always coupled
n_steps_fwd = rand_coupled(rng, 0:n_steps)
zs_fwd = step.integrator, h, z, n_steps_fwd; fwd=true, full_trajectory=Val(true))
n_steps_bwd = n_steps - n_steps_fwd
zs_bwd = step.integrator, h, z, n_steps_bwd; fwd=false, full_trajectory=Val(true))
zs = vcat(reverse(zs_bwd)..., z, zs_fwd...)
ℓweights = -energy.(zs)
if eltype(ℓweights) <: AbstractVector
ℓweights = cat(ℓweights...; dims=2)
end
unnorm_ℓprob = ℓws
return randcat(rng, zs, unnorm_ℓprob)
unnorm_ℓprob = ℓweights
z′ = randcat(rng, zs, unnorm_ℓprob)
# Computing adaptation statistics for dual averaging as done in NUTS
Hs = -ℓweights
ΔH = Hs .- energy(z)
α = exp.(min.(0, -ΔH)) # this is a matrix for vectorized mode and a vector otherwise
α = typeof(α) <: AbstractVector ? mean(α) : vec(mean(α; dims=2))
return z′, true, α
end

abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} end
Expand Down
62 changes: 51 additions & 11 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@ function Base.randn(rng::AbstractVector{<:AbstractRNG}, T, dim::Int, n_chains::I
return cat(randn.(rng, T, dim)...; dims=2)
end

# Sample from Categorical distributions
"""
`rand_coupled` produces coupled randomness given a vector of RNGs. For example,
when a vector of RNGs is provided, `rand_coupled` peforms a single `rand` call
(rather than a `rand` call for each RNG) while keep all RNGs synchronised.
This is important if we want to couple multiple Markov chains.
"""

randcat_logp(rng::AbstractRNG, unnorm_ℓp::AbstractVector) =
randcat(rng, exp.(unnorm_ℓp .- logsumexp(unnorm_ℓp)))
rand_coupled(rng::AbstractRNG, args...) = rand(rng, args...)

function rand_coupled(rngs::AbstractVector{<:AbstractRNG}, args...)
# Dummpy calles to sync RNGs
foreach(rngs[2:end]) do rng
rand(rng, args...)
end
return res = rand(first(rngs), args...)
end

# Sample from Categorical distributions

function randcat(rng::AbstractRNG, p::AbstractVector{T}) where {T}
u = rand(rng, T)
Expand All @@ -31,17 +45,43 @@ function randcat(rng::AbstractRNG, p::AbstractVector{T}) where {T}
return max(i, 1)
end

randcat_logp(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
unnorm_ℓP::AbstractMatrix
) = randcat(rng, exp.(unnorm_ℓP .- logsumexp(unnorm_ℓP; dims=2)))
"""
randcat(rng, P::AbstractMatrix)
Generating Categorical random variables in a vectorized mode.
`P` is supposed to be a matrix of (D, N) where each column is a probability vector.
Example
```
P = [
0.5 0.3;
0.4 0.6;
0.1 0.1
]
u = [0.3, 0.4]
C = [
0.5 0.3
0.9 0.9
1.0 1.0
]
```
Then `C .< u'` is
```
[
0 1
0 0
0 0
]
```
thus `convert.(Int, vec(sum(C .< u'; dims=1))) .+ 1` equals `[1, 2]`.
"""
function randcat(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
P::AbstractMatrix{T}
) where {T}
u = rand(rng, T, size(P, 1))
C = cumsum(P; dims=2)
is = convert.(Int, vec(sum(C .< u; dims=2)))
return max.(is, 1)
u = rand(rng, T, size(P, 2))
C = cumsum(P; dims=1)
indices = convert.(Int, vec(sum(C .< u'; dims=1))) .+ 1
return max.(indices, 1) # prevent numerical issue for Float32
end

0 comments on commit 9ee12e4

Please sign in to comment.