Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying trajectories #214

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
31c6517
add regression test against master
xukai92 Aug 5, 2020
097dc3f
rename EndPointTS to MetropolisTS
xukai92 Aug 5, 2020
27668bd
rename GeneralisedNoUTurn to NoUTurn
xukai92 Aug 5, 2020
72309f8
update static trajectories
xukai92 Aug 5, 2020
52827b9
update dynamic trajectories
xukai92 Aug 5, 2020
8878944
Support a richer interface for NUTS
xukai92 Aug 5, 2020
04f621e
Apply suggestions from code review
xukai92 Aug 9, 2020
9425d27
add missing imports
xukai92 Aug 9, 2020
c827dca
fix Hong's typo
xukai92 Aug 9, 2020
28887ca
fix geweke
xukai92 Aug 9, 2020
d17d79d
name back no-U-turns
xukai92 Aug 28, 2020
89f9b43
remove unnecessary interface for test
xukai92 Aug 28, 2020
cf0e5fc
AbstractKernel -> AbstractKernel
xukai92 Aug 28, 2020
ddc56f2
FixedLength -> FixedIntegrationTime
xukai92 Aug 28, 2020
90b06cb
Update src/trajectory.jl
xukai92 Aug 28, 2020
086590d
make internal naming more descriptive
xukai92 Aug 28, 2020
7c5d344
remove old comments
xukai92 Aug 31, 2020
cbd1234
Revert "remove old comments"
xukai92 Aug 31, 2020
237138f
rename TS to trajectory_sampler_type
xukai92 Aug 31, 2020
200b52f
improve internal namings
xukai92 Aug 31, 2020
242dc4e
push test toml
xukai92 Aug 31, 2020
c59a739
lower bound 1.3 on Travis
xukai92 Aug 31, 2020
e39aa93
Improve badge
xukai92 Sep 1, 2020
68d1452
Update src/trajectory.jl
xukai92 Oct 27, 2020
0e9200f
Update src/trajectory.jl
xukai92 Oct 28, 2020
bde50d8
resolve conflicts
xukai92 Jan 8, 2021
3c06f2d
Apply suggestions from code review
xukai92 Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode
.history
.DS_Store
Manifest.toml
test/Project.toml
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"

[targets]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "BSON", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ integrator = Leapfrog(initial_ϵ)
# - multinomial sampling scheme,
# - generalised No-U-Turn criteria, and
# - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
proposal = NUTS{MultinomialTS, NoUTurn}(integrator)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
Expand Down Expand Up @@ -115,9 +115,9 @@ where `ϵ` is the step size of leapfrog integration.
- Static HMC with a fixed number of steps (`n_steps`) (Neal, R. M. (2011)): `StaticTrajectory(integrator, n_steps)`
- HMC with a fixed total trajectory length (`trajectory_length`) (Neal, R. M. (2011)): `HMCDA(integrator, trajectory_length)`
- Original NUTS with slice sampling (Hoffman, M. D., & Gelman, A. (2014)): `NUTS{SliceTS,ClassicNoUTurn}(integrator)`
- Generalised NUTS with slice sampling (Betancourt, M. (2017)): `NUTS{SliceTS,GeneralisedNoUTurn}(integrator)`
- Generalised NUTS with slice sampling (Betancourt, M. (2017)): `NUTS{SliceTS,NoUTurn}(integrator)`
- Original NUTS with multinomial sampling (Betancourt, M. (2017)): `NUTS{MultinomialTS,ClassicNoUTurn}(integrator)`
- Generalised NUTS with multinomial sampling (Betancourt, M. (2017)): `NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)`
- Generalised NUTS with multinomial sampling (Betancourt, M. (2017)): `NUTS{MultinomialTS,NoUTurn}(integrator)`

### Adaptor (`adaptor`)

Expand Down
49 changes: 42 additions & 7 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, chol
using StatsFuns: logaddexp, logsumexp
using Random: GLOBAL_RNG, AbstractRNG
using ProgressMeter: ProgressMeter
using Parameters: @unpack, reconstruct
using Parameters: @with_kw, @unpack, reconstruct
using ArgCheck: @argcheck

using DocStringExtensions: TYPEDEF, TYPEDFIELDS
Expand All @@ -18,11 +18,19 @@ import Parameters: reconstruct
include("utilities.jl")

# Notations
# ℓπ: log density of the target distribution
# θ: position variables / model parameters
# ∂ℓπ∂θ: gradient of the log density of the target distribution w.r.t θ
# r: momentum variables
# z: phase point / a pair of θ and r
# θ₀: initial position
# r₀: initial momentum
# z₀: initial phase point
# ℓπ: log density of the target distribution
# ∇ℓπ: gradient of the log density of the target distribution w.r.t θ
# κ: kernel
# τ: trajectory
# ϵ: step size
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# L: step number
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# λ: integration time
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
Expand All @@ -35,12 +43,39 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog

include("trajectory.jl")
@deprecate find_good_eps find_good_stepsize
export EndPointTS, SliceTS, MultinomialTS,
StaticTrajectory, HMCDA, NUTS,
ClassicNoUTurn, GeneralisedNoUTurn,
StrictGeneralisedNoUTurn,
export Trajectory, HMCKernel,
FixedNSteps, FixedLength,
ClassicNoUTurn, NoUTurn, StrictNoUTurn,
MetropolisTS, SliceTS, MultinomialTS,
find_good_stepsize

# Deprecations for trajectory.jl

abstract type AbstractTrajectory end

struct HMC{TS} end
HMC{TS}(int::AbstractIntegrator, L) where {TS} = HMCKernel(Trajectory(int, FixedNSteps(L)), TS)
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider the following for clarity and performance?

Suggested change
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
HMC(int::AbstractIntegrator, L) = HMCKernel(Trajectory(int, FixedNSteps(L)), MetropolisTS)

HMC(ϵ::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HMC::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
HMC::AbstractScalarOrVec{<:Real}, L) = HMCKernel(Trajectory(Leapfrog(ϵ), FixedNSteps(L)), MetropolisTS)


struct StaticTrajectory{TS} end
@deprecate StaticTrajectory{TS}(args...) where {TS} HMC{TS}(args...)
@deprecate StaticTrajectory(args...) HMC(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, consider calling HMCKernel for clarity.


struct HMCDA{TS} end
HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} = HMCKernel(Trajectory(int, FixedLength(λ)), TS)
HMCDA(int::AbstractIntegrator, λ) = HMCDA{MetropolisTS}(int, λ)
HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) = HMCDA{MetropolisTS}(Leapfrog(ϵ), λ)

struct NUTS{TS, TC} end
NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} =
HMCKernel(Trajectory(int, TC(args...; kwargs...)), TS)
NUTS(int::AbstractIntegrator, args...; kwargs...) =
NUTS{MultinomialTS, NoUTurn}(int, args...; kwargs...)
NUTS(ϵ::AbstractScalarOrVec{<:Real}) = NUTS{MultinomialTS, NoUTurn}(Leapfrog(ϵ))

export AbstractTrajectory, HMC, StaticTrajectory, HMCDA, NUTS
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation: StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging
Expand Down
47 changes: 26 additions & 21 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ function reconstruct(
return reconstruct(h, metric=metric)
end

reconstruct(τ::AbstractProposal, ::AbstractAdaptor) = τ
reconstruct(κ::AbstractProposal, ::AbstractAdaptor) = κ
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
function reconstruct(
τ::AbstractProposal, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
κ::AbstractProposal, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
)
return reconstruct(κ, getϵ(adaptor))
end
function reconstruct(κ::AbstractProposal, ϵ::AbstractScalarOrVec{<:Real})
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# FIXME: this does not support change type of `ϵ` (e.g. Float to Vector)
# FIXME: this is buggy for `JitteredLeapfrog`
integrator = reconstruct(τ.integrator, ϵ=getϵ(adaptor))
return reconstruct(τ, integrator=integrator)
τ = reconstruct(
κ.τ, integrator=reconstruct(κ.τ.integrator, ϵ=ϵ)
)
return reconstruct(κ, τ=τ)
end

function resize(h::Hamiltonian, θ::AbstractVecOrMat{T}) where {T<:AbstractFloat}
Expand Down Expand Up @@ -47,28 +52,28 @@ end
function step(
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
z::PhasePoint
)
# Refresh momentum
z = refresh(rng, z, h)
# Make transition
return transition(rng, τ, h, z)
return transition(rng, κ, h, z)
end

Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
adaptor::Adaptation.NoAdaptation,
i::Int,
n_adapts::Int,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat}
) = h, τ, false
) = h, κ, false

function Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
adaptor::AbstractAdaptor,
i::Int,
n_adapts::Int,
Expand All @@ -81,10 +86,10 @@ function Adaptation.adapt!(
adapt!(adaptor, θ, α)
i == n_adapts && finalize!(adaptor)
h = reconstruct(h, adaptor)
τ = reconstruct(τ, adaptor)
κ = reconstruct(κ, adaptor)
isadapted = true
end
return h, τ, isadapted
return h, κ, isadapted
end

"""
Expand All @@ -105,7 +110,7 @@ simple_pm_next!(pm, stat::NamedTuple) = ProgressMeter.next!(pm)

sample(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
θ::AbstractVecOrMat{<:AbstractFloat},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -117,7 +122,7 @@ sample(
) = sample(
GLOBAL_RNG,
h,
τ,
κ,
θ,
n_samples,
adaptor,
Expand All @@ -132,7 +137,7 @@ sample(
sample(
rng::AbstractRNG,
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
θ::AbstractVecOrMat{T},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -142,7 +147,7 @@ sample(
progress::Bool=false
)

Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
- The randomness is controlled by `rng`.
- If `rng` is not provided, `GLOBAL_RNG` will be used.
- The initial point is given by `θ`.
Expand All @@ -155,7 +160,7 @@ Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
function sample(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractProposal,
θ::T,
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -175,18 +180,18 @@ function sample(
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing
time = @elapsed for i = 1:n_samples
# Make a step
t = step(rng, h, τ, t.z)
# Adapt h and τ; what mutable is the adaptor
t = transition(rng, h, κ, t.z)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# Adapt h and κ; what mutable is the adaptor
tstat = stat(t)
h, τ, isadapted = adapt!(h, τ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))
# Update progress meter
if progress
# Do include current iteration and mass matrix
pm_next!(pm, (iterations=i, tstat..., mass_matrix=h.metric))
# Report finish of adapation
elseif verbose && isadapted && i == n_adapts
@info "Finished $n_adapts adapation steps" adaptor τ.integrator h.metric
@info "Finished $n_adapts adapation steps" adaptor κ.τ.integrator h.metric
end
# Store sample
if !drop_warmup || i > n_adapts
Expand All @@ -205,7 +210,7 @@ function sample(
EBFMI_est = "[" * join(EBFMI_est, ", ") * "]"
average_acceptance_rate = "[" * join(average_acceptance_rate, ", ") * "]"
end
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h τ EBFMI_est average_acceptance_rate
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate
end
return θs, stats
end