Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying trajectories #214

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
31c6517
add regression test against master
xukai92 Aug 5, 2020
097dc3f
rename EndPointTS to MetropolisTS
xukai92 Aug 5, 2020
27668bd
rename GeneralisedNoUTurn to NoUTurn
xukai92 Aug 5, 2020
72309f8
update static trajectories
xukai92 Aug 5, 2020
52827b9
update dynamic trajectories
xukai92 Aug 5, 2020
8878944
Support a richer interface for NUTS
xukai92 Aug 5, 2020
04f621e
Apply suggestions from code review
xukai92 Aug 9, 2020
9425d27
add missing imports
xukai92 Aug 9, 2020
c827dca
fix Hong's typo
xukai92 Aug 9, 2020
28887ca
fix geweke
xukai92 Aug 9, 2020
d17d79d
name back no-U-turns
xukai92 Aug 28, 2020
89f9b43
remove unnecessary interface for test
xukai92 Aug 28, 2020
cf0e5fc
AbstractKernel -> AbstractKernel
xukai92 Aug 28, 2020
ddc56f2
FixedLength -> FixedIntegrationTime
xukai92 Aug 28, 2020
90b06cb
Update src/trajectory.jl
xukai92 Aug 28, 2020
086590d
make internal naming more descriptive
xukai92 Aug 28, 2020
7c5d344
remove old comments
xukai92 Aug 31, 2020
cbd1234
Revert "remove old comments"
xukai92 Aug 31, 2020
237138f
rename TS to trajectory_sampler_type
xukai92 Aug 31, 2020
200b52f
improve internal namings
xukai92 Aug 31, 2020
242dc4e
push test toml
xukai92 Aug 31, 2020
c59a739
lower bound 1.3 on Travis
xukai92 Aug 31, 2020
e39aa93
Improve badge
xukai92 Sep 1, 2020
68d1452
Update src/trajectory.jl
xukai92 Oct 27, 2020
0e9200f
Update src/trajectory.jl
xukai92 Oct 28, 2020
bde50d8
resolve conflicts
xukai92 Jan 8, 2021
3c06f2d
Apply suggestions from code review
xukai92 Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.vscode
.history
.DS_Store
Manifest.toml
test/Project.toml
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ os:
- osx

julia:
- 1.0
- 1
- 1.3
- nightly

matrix:
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ StatsFuns = "0.8, 0.9"
julia = "1"

[extras]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCDebugging = "6d524b87-5f90-4494-b601-374a5b87a94b"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "BSON", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AdvancedHMC.jl

[![Build Status](https://travis-ci.com/TuringLang/AdvancedHMC.jl.svg?branch=master)](https://travis-ci.com/TuringLang/AdvancedHMC.jl)
[![AdvancedHMC-CI](https://github.com/TuringLang/AdvancedHMC.jl/workflows/AdvancedHMC-CI/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedHMC.jl/actions?query=workflow%3AAdvancedHMC-CI)
[![Travis CI](https://travis-ci.com/TuringLang/AdvancedHMC.jl.svg?branch=master)](https://travis-ci.com/TuringLang/AdvancedHMC.jl)
[![GitHub Actions CI](https://github.com/TuringLang/AdvancedHMC.jl/workflows/CI/badge.svg)](https://github.com/TuringLang/AdvancedHMC.jl/actions?query=workflow%3ACI)
[![DOI](https://zenodo.org/badge/72657907.svg)](https://zenodo.org/badge/latestdoi/72657907)
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/AdvancedHMC.jl/badge.svg?branch=kx%2Fbug-fix)](https://coveralls.io/github/TuringLang/AdvancedHMC.jl?branch=kx%2Fbug-fix)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turing.ml/stable/docs/library/advancedhmc/)
Expand Down Expand Up @@ -139,7 +139,7 @@ All the combinations are tested in [this file](https://github.com/TuringLang/Adv
function sample(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
τ::AbstractKernel,
θ::AbstractVector{<:AbstractFloat},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand Down
51 changes: 43 additions & 8 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,30 @@ using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, chol
using StatsFuns: logaddexp, logsumexp
using Random: GLOBAL_RNG, AbstractRNG
using ProgressMeter: ProgressMeter
using Parameters: @unpack, reconstruct
using Parameters: @with_kw, @unpack, reconstruct
using ArgCheck: @argcheck

using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DocStringExtensions: SIGNATURES, TYPEDEF, TYPEDFIELDS

import StatsBase: sample
import Parameters: reconstruct

include("utilities.jl")

# Notations
# ℓπ: log density of the target distribution
# θ: position variables / model parameters
# ∂ℓπ∂θ: gradient of the log density of the target distribution w.r.t θ
# r: momentum variables
# z: phase point / a pair of θ and r
# θ₀: initial position
# r₀: initial momentum
# z₀: initial phase point
# ℓπ: log density of the target distribution
# ∇ℓπ: gradient of the log density of the target distribution w.r.t θ
# κ: kernel
# τ: trajectory
# ϵ: leap-frog integration step size
# L: leap-frog integration step number
# λ: leap-frog integration time

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
Expand All @@ -35,12 +43,39 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog

include("trajectory.jl")
@deprecate find_good_eps find_good_stepsize
export EndPointTS, SliceTS, MultinomialTS,
StaticTrajectory, HMCDA, NUTS,
ClassicNoUTurn, GeneralisedNoUTurn,
StrictGeneralisedNoUTurn,
export Trajectory, HMCKernel,
FixedNSteps, FixedIntegrationTime,
ClassicNoUTurn, GeneralisedNoUTurn, NoUTurn, StrictGeneralisedNoUTurn, StrictNoUTurn,
MetropolisTS, SliceTS, MultinomialTS,
find_good_stepsize

# Deprecations for trajectory.jl

abstract type AbstractTrajectory end

struct HMC{TS} end
HMC{TS}(int::AbstractIntegrator, L) where {TS} = HMCKernel(Trajectory(int, FixedNSteps(L)), TS)
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider the following for clarity and performance?

Suggested change
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
HMC(int::AbstractIntegrator, L) = HMCKernel(Trajectory(int, FixedNSteps(L)), MetropolisTS)

HMC(ϵ::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HMC::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
HMC::AbstractScalarOrVec{<:Real}, L) = HMCKernel(Trajectory(Leapfrog(ϵ), FixedNSteps(L)), MetropolisTS)


struct StaticTrajectory{TS} end
@deprecate StaticTrajectory{TS}(args...) where {TS} HMC{TS}(args...)
@deprecate StaticTrajectory(args...) HMC(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, consider calling HMCKernel for clarity.


struct HMCDA{TS} end
HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} = HMCKernel(Trajectory(int, FixedIntegrationTime(λ)), TS)
HMCDA(int::AbstractIntegrator, λ) = HMCDA{MetropolisTS}(int, λ)
HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) = HMCDA{MetropolisTS}(Leapfrog(ϵ), λ)

struct NUTS{TS, TC} end
NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} =
HMCKernel(Trajectory(int, TC(args...; kwargs...)), TS)
NUTS(int::AbstractIntegrator, args...; kwargs...) =
NUTS{MultinomialTS, GeneralisedNoUTurn}(int, args...; kwargs...)
NUTS(ϵ::AbstractScalarOrVec{<:Real}) = NUTS{MultinomialTS, GeneralisedNoUTurn}(Leapfrog(ϵ))

export HMC, StaticTrajectory, HMCDA, NUTS

include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation: StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging
Expand Down
44 changes: 23 additions & 21 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ function reconstruct(
return reconstruct(h, metric=metric)
end

reconstruct(τ::AbstractProposal, ::AbstractAdaptor) = τ
reconstruct(κ::AbstractKernel, ::AbstractAdaptor) = κ
function reconstruct(
τ::AbstractProposal, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
κ::AbstractKernel, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor}
)
# FIXME: this does not support change type of `ϵ` (e.g. Float to Vector)
integrator = update_nom_step_size(τ.integrator, getϵ(adaptor))
return reconstruct(τ, integrator=integrator)
τ = reconstruct(
κ.τ, integrator=update_nom_step_size(κ.τ.integrator, getϵ(adaptor))
)
return reconstruct(κ, τ=τ)
end

function resize(h::Hamiltonian, θ::AbstractVecOrMat{T}) where {T<:AbstractFloat}
Expand Down Expand Up @@ -46,28 +48,28 @@ end
function step(
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
z::PhasePoint
)
# Refresh momentum
z = refresh(rng, z, h)
# Make transition
return transition(rng, τ, h, z)
return transition(rng, κ, h, z)
end

Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
adaptor::Adaptation.NoAdaptation,
i::Int,
n_adapts::Int,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat}
) = h, τ, false
) = h, κ, false

function Adaptation.adapt!(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
adaptor::AbstractAdaptor,
i::Int,
n_adapts::Int,
Expand All @@ -80,10 +82,10 @@ function Adaptation.adapt!(
adapt!(adaptor, θ, α)
i == n_adapts && finalize!(adaptor)
h = reconstruct(h, adaptor)
τ = reconstruct(τ, adaptor)
κ = reconstruct(κ, adaptor)
isadapted = true
end
return h, τ, isadapted
return h, κ, isadapted
end

"""
Expand All @@ -104,7 +106,7 @@ simple_pm_next!(pm, stat::NamedTuple) = ProgressMeter.next!(pm)

sample(
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::AbstractVecOrMat{<:AbstractFloat},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -116,7 +118,7 @@ sample(
) = sample(
GLOBAL_RNG,
h,
τ,
κ,
θ,
n_samples,
adaptor,
Expand All @@ -131,7 +133,7 @@ sample(
sample(
rng::AbstractRNG,
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::AbstractVecOrMat{T},
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -141,7 +143,7 @@ sample(
progress::Bool=false
)

Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
- The randomness is controlled by `rng`.
- If `rng` is not provided, `GLOBAL_RNG` will be used.
- The initial point is given by `θ`.
Expand All @@ -154,7 +156,7 @@ Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
function sample(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
τ::AbstractProposal,
κ::AbstractKernel,
θ::T,
n_samples::Int,
adaptor::AbstractAdaptor=NoAdaptation(),
Expand All @@ -174,18 +176,18 @@ function sample(
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing
time = @elapsed for i = 1:n_samples
# Make a step
t = step(rng, h, τ, t.z)
# Adapt h and τ; what mutable is the adaptor
t = transition(rng, h, κ, t.z)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
# Adapt h and κ; what mutable is the adaptor
tstat = stat(t)
h, τ, isadapted = adapt!(h, τ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))
# Update progress meter
if progress
# Do include current iteration and mass matrix
pm_next!(pm, (iterations=i, tstat..., mass_matrix=h.metric))
# Report finish of adapation
elseif verbose && isadapted && i == n_adapts
@info "Finished $n_adapts adapation steps" adaptor τ.integrator h.metric
@info "Finished $n_adapts adapation steps" adaptor κ.τ.integrator h.metric
end
# Store sample
if !drop_warmup || i > n_adapts
Expand All @@ -204,7 +206,7 @@ function sample(
EBFMI_est = "[" * join(EBFMI_est, ", ") * "]"
average_acceptance_rate = "[" * join(average_acceptance_rate, ", ") * "]"
end
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h τ EBFMI_est average_acceptance_rate
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate
end
return θs, stats
end