Skip to content

Commit

Permalink
Merge 242dc4e into 07236bf
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Aug 31, 2020
2 parents 07236bf + 242dc4e commit eef5484
Show file tree
Hide file tree
Showing 20 changed files with 466 additions and 408 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.vscode
.history
.DS_Store
Manifest.toml
test/Project.toml
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"

[targets]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "BSON", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ All the combinations are tested in [this file](https://github.com/TuringLang/Adv
function sample(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
τ::AbstractKernel,
θ::AbstractVector{<:AbstractFloat},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand Down
51 changes: 43 additions & 8 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,30 @@ using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, chol
using StatsFuns: logaddexp, logsumexp
using Random: GLOBAL_RNG, AbstractRNG
using ProgressMeter: ProgressMeter
using Parameters: @unpack, reconstruct
using Parameters: @with_kw, @unpack, reconstruct
using ArgCheck: @argcheck

using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DocStringExtensions: SIGNATURES, TYPEDEF, TYPEDFIELDS

import StatsBase: sample
import Parameters: reconstruct

include("utilities.jl")

# Notations
# ℓπ: log density of the target distribution
# θ: position variables / model parameters
# ∂ℓπ∂θ: gradient of the log density of the target distribution w.r.t θ
# r: momentum variables
# z: phase point / a pair of θ and r
# θ₀: initial position
# r₀: initial momentum
# z₀: initial phase point
# ℓπ: log density of the target distribution
# ∇ℓπ: gradient of the log density of the target distribution w.r.t θ
# κ: kernel
# τ: trajectory
# ϵ: leap-frog integration step size
# L: leap-frog integration step number
# λ: leap-frog integration time

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
Expand All @@ -35,12 +43,39 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog

include("trajectory.jl")
@deprecate find_good_eps find_good_stepsize
export EndPointTS, SliceTS, MultinomialTS,
StaticTrajectory, HMCDA, NUTS,
ClassicNoUTurn, GeneralisedNoUTurn,
StrictGeneralisedNoUTurn,
export Trajectory, HMCKernel,
FixedNSteps, FixedIntegrationTime,
ClassicNoUTurn, GeneralisedNoUTurn, NoUTurn, StrictGeneralisedNoUTurn, StrictNoUTurn,
MetropolisTS, SliceTS, MultinomialTS,
find_good_stepsize

# Deprecations for trajectory.jl

abstract type AbstractTrajectory end

struct HMC{TS} end
HMC{TS}(int::AbstractIntegrator, L) where {TS} = HMCKernel(Trajectory(int, FixedNSteps(L)), TS)
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
HMC::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)

struct StaticTrajectory{TS} end
@deprecate StaticTrajectory{TS}(args...) where {TS} HMC{TS}(args...)
@deprecate StaticTrajectory(args...) HMC(args...)

struct HMCDA{TS} end
HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} = HMCKernel(Trajectory(int, FixedIntegrationTime(λ)), TS)
HMCDA(int::AbstractIntegrator, λ) = HMCDA{MetropolisTS}(int, λ)
HMCDA::AbstractScalarOrVec{<:Real}, λ) = HMCDA{MetropolisTS}(Leapfrog(ϵ), λ)

struct NUTS{TS, TC} end
NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} =
HMCKernel(Trajectory(int, TC(args...; kwargs...)), TS)
NUTS(int::AbstractIntegrator, args...; kwargs...) =
NUTS{MultinomialTS, GeneralisedNoUTurn}(int, args...; kwargs...)
NUTS::AbstractScalarOrVec{<:Real}) = NUTS{MultinomialTS, GeneralisedNoUTurn}(Leapfrog(ϵ))

export AbstractTrajectory, HMC, StaticTrajectory, HMCDA, NUTS

include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation: StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging
Expand Down
45 changes: 24 additions & 21 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ function reconstruct(
return reconstruct(h, metric=metric)
end

reconstruct(τ::AbstractProposal, ::AbstractAdaptor) = τ
reconstruct(κ::AbstractKernel, ::AbstractAdaptor) = κ
function reconstruct(
τ::AbstractProposal, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
κ::AbstractKernel, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
)
# FIXME: this does not support change type of `ϵ` (e.g. Float to Vector)
# FIXME: this is buggy for `JitteredLeapfrog`
integrator = reconstruct.integrator, ϵ=getϵ(adaptor))
return reconstruct(τ, integrator=integrator)
ϵ = getϵ(adaptor)
τ = reconstruct(
κ.τ, integrator=reconstruct.τ.integrator, ϵ=ϵ)
)
return reconstruct(κ, τ=τ)
end

function resize(h::Hamiltonian, θ::AbstractVecOrMat{T}) where {T<:AbstractFloat}
Expand Down Expand Up @@ -47,28 +50,28 @@ end
function step(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
z::PhasePoint
)
# Refresh momentum
z = refresh(rng, z, h)
# Make transition
return transition(rng, τ, h, z)
return transition(rng, κ, h, z)
end

Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
adaptor::Adaptation.NoAdaptation,
i::Int,
n_adapts::Int,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat}
) = h, τ, false
) = h, κ, false

function Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
adaptor::AbstractAdaptor,
i::Int,
n_adapts::Int,
Expand All @@ -81,10 +84,10 @@ function Adaptation.adapt!(
adapt!(adaptor, θ, α)
i == n_adapts && finalize!(adaptor)
h = reconstruct(h, adaptor)
τ = reconstruct(τ, adaptor)
κ = reconstruct(κ, adaptor)
isadapted = true
end
return h, τ, isadapted
return h, κ, isadapted
end

"""
Expand All @@ -105,7 +108,7 @@ simple_pm_next!(pm, stat::NamedTuple) = ProgressMeter.next!(pm)

sample(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::AbstractVecOrMat{<:AbstractFloat},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -117,7 +120,7 @@ sample(
) = sample(
GLOBAL_RNG,
h,
τ,
κ,
θ,
n_samples,
adaptor,
Expand All @@ -132,7 +135,7 @@ sample(
sample(
rng::AbstractRNG,
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::AbstractVecOrMat{T},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -142,7 +145,7 @@ sample(
progress::Bool=false
)
Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
- The randomness is controlled by `rng`.
- If `rng` is not provided, `GLOBAL_RNG` will be used.
- The initial point is given by `θ`.
Expand All @@ -155,7 +158,7 @@ Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
function sample(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::T,
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -175,18 +178,18 @@ function sample(
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing
time = @elapsed for i = 1:n_samples
# Make a step
t = step(rng, h, τ, t.z)
# Adapt h and τ; what mutable is the adaptor
t = transition(rng, h, κ, t.z)
# Adapt h and κ; what mutable is the adaptor
tstat = stat(t)
h, τ, isadapted = adapt!(h, τ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))
# Update progress meter
if progress
# Do include current iteration and mass matrix
pm_next!(pm, (iterations=i, tstat..., mass_matrix=h.metric))
# Report finish of adapation
elseif verbose && isadapted && i == n_adapts
@info "Finished $n_adapts adapation steps" adaptor τ.integrator h.metric
@info "Finished $n_adapts adapation steps" adaptor κ.τ.integrator h.metric
end
# Store sample
if !drop_warmup || i > n_adapts
Expand All @@ -205,7 +208,7 @@ function sample(
EBFMI_est = "[" * join(EBFMI_est, ", ") * "]"
average_acceptance_rate = "[" * join(average_acceptance_rate, ", ") * "]"
end
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h τ EBFMI_est average_acceptance_rate
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate
end
return θs, stats
end
Loading

0 comments on commit eef5484

Please sign in to comment.