diff --git a/HISTORY.md b/HISTORY.md index 038968ef1..a9daf473d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,12 @@ # AdvancedHMC Changelog +## 0.8.4 + + - Introduces an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)), + currently to be initialized for a `metric` of type `DiagEuclideanMetric` + via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))` + until a new interface is introduced in an upcoming breaking release to specify the method of adaptation. + ## 0.8.0 - To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`). diff --git a/Project.toml b/Project.toml index 8c32e813a..f38dda02a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.8.3" +version = "0.8.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/api.md b/docs/src/api.md index a1c488fb8..e7caf2d0c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -32,11 +32,15 @@ where `ϵ` is the step size of leapfrog integration. ### Adaptor (`adaptor`) - Adapt the mass matrix `metric` of the Hamiltonian dynamics: `mma = MassMatrixAdaptor(metric)` - + + This is lowered to `UnitMassMatrix`, `WelfordVar` or `WelfordCov` based on the type of the mass matrix `metric` + + There is an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)), + currently to be initialized for a `metric` of type `DiagEuclideanMetric` + via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))` + until a new interface is introduced in an upcoming breaking release to specify the method of adaptation. - Adapt the step size of the leapfrog integrator `integrator`: `ssa = StepSizeAdaptor(δ, integrator)` - + + It uses Nesterov's dual averaging with `δ` as the target acceptance rate. - Combine the two above *naively*: `NaiveHMCAdaptor(mma, ssa)` - Combine the first two using Stan's windowed adaptation: `StanHMCAdaptor(mma, ssa)` @@ -61,12 +65,12 @@ sample( Draw `n_samples` samples using the kernel `κ` under the Hamiltonian system `h` - The randomness is controlled by `rng`. - + + If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used. - The initial point is given by `θ`. - The adaptor is set by `adaptor`, for which the default is no adaptation. - + + It will perform `n_adapts` steps of adaptation, for which the default is `1_000` or 10% of `n_samples`, whichever is lower. - `drop_warmup` specifies whether to drop samples. - `verbose` controls the verbosity. diff --git a/research/src/riemannian_hmc_utility.jl b/research/src/riemannian_hmc_utility.jl index 8ceab303c..5efbdc12a 100644 --- a/research/src/riemannian_hmc_utility.jl +++ b/research/src/riemannian_hmc_utility.jl @@ -2,47 +2,74 @@ using Random, LinearAlgebra, ReverseDiff, ForwardDiff, MCMCLogDensityProblems # Fisher information metric function gen_∂G∂θ_rev(Vfunc, x; f=identity) - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, ReverseDiff.track.(x)) - Hfunc = x -> _Hfunc(x)[3] + Hfunc = gen_hess_fwd(Vfunc, ReverseDiff.track.(x)) + # QUES What's the best output format of this function? return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...] end # TODO Refactor this using https://juliadiff.org/ForwardDiff.jl/stable/user/api/#Preallocating/Configuring-Work-Buffers +function gen_hess_fwd_precompute_cfg(func, x::AbstractVector) + cfg = ForwardDiff.HessianConfig(func, x) + H = Matrix{eltype(x)}(undef, length(x), length(x)) + + function hess(x::AbstractVector) + ForwardDiff.hessian!(H, func, x, cfg) + return H + end + return hess +end + function gen_hess_fwd(func, x::AbstractVector) + cfg = nothing + H = nothing + function hess(x::AbstractVector) - return nothing, nothing, ForwardDiff.hessian(func, x) + if cfg === nothing + cfg = ForwardDiff.HessianConfig(func, x) + H = Matrix{eltype(x)}(undef, length(x), length(x)) + end + ForwardDiff.hessian!(H, func, x, cfg) + return H end return hess end function gen_∂G∂θ_fwd(Vfunc, x; f=identity) - _Hfunc = gen_hess_fwd(Vfunc, x) - Hfunc = x -> _Hfunc(x)[3] - # QUES What's the best output format of this function? - cfg = ForwardDiff.JacobianConfig(Hfunc, x) + chunk = ForwardDiff.Chunk(x) + tag = ForwardDiff.Tag(Vfunc, eltype(x)) + jac_cfg = ForwardDiff.JacobianConfig(Vfunc, x, chunk, tag) + hess_cfg = ForwardDiff.HessianConfig(Vfunc, jac_cfg.duals, chunk, tag) + d = length(x) out = zeros(eltype(x), d^2, d) - return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) - return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] + + function ∂G∂θ_fwd(y) + hess = z -> Symmetric(ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}())) + ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}()) + return out + end + + return ∂G∂θ_fwd end -# 1.764 ms -# fwd -> 5.338 μs -# cfg -> 3.651 μs function reshape_∂G∂θ(H) d = size(H, 2) - return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) + return reshape(H, d, d, :) end function prepare_sample_target(hps, θ₀, ℓπ) Vfunc = x -> -ℓπ(x) # potential energy is the negative log-probability - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian) - Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug + Hfunc = gen_hess_fwd_precompute_cfg(Vfunc, θ₀) # x -> (value, gradient, hessian) - fstabilize = H -> H + hps.λ * I + fstabilize = H -> begin + @inbounds for i in 1:size(H, 1) + H[i, i] += hps.λ + end + H + end Gfunc = x -> begin - H = fstabilize(Hfunc(x)[3]) + H = fstabilize(Hfunc(x)) all(isfinite, H) ? H : diagm(ones(length(x))) end _∂G∂θfunc = gen_∂G∂θ_fwd(Vfunc, θ₀; f=fstabilize) # size==(4, 2) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 41d934e60..699196885 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -59,12 +59,11 @@ export Hamiltonian include("integrator.jl") export Leapfrog, JitteredLeapfrog, TemperedLeapfrog -include("riemannian/integrator.jl") -export GeneralizedLeapfrog include("riemannian/metric.jl") -export IdentityMap, SoftAbsMap, DenseRiemannianMetric - +export AbstractRiemannianMetric, DenseRiemannianMetric, IdentityMap, SoftAbsMap +include("riemannian/integrator.jl") +export GeneralizedLeapfrog, ImplicitMidpoint include("riemannian/hamiltonian.jl") include("trajectory.jl") @@ -89,7 +88,7 @@ export find_good_eps include("adaptation/Adaptation.jl") using .Adaptation import .Adaptation: - StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation + StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint # Helpers for initializing adaptors via AHMC structs @@ -131,6 +130,7 @@ export StepSizeAdaptor, MassMatrixAdaptor, UnitMassMatrix, WelfordVar, + NutpieVar, WelfordCov, NaiveHMCAdaptor, StanHMCAdaptor, diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 413e9de6f..f6ce40009 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -196,7 +196,7 @@ function AbstractMCMC.step( # Adapt h and spl. tstat = stat(t) - h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) tstat = merge(tstat, (is_adapt=isadapted,)) # Compute next transition and state. diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 4f2fde83c..10a9a9805 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -4,13 +4,13 @@ export Adaptation using LinearAlgebra: LinearAlgebra using Statistics: Statistics -using ..AdvancedHMC: AbstractScalarOrVec +using ..AdvancedHMC: AbstractScalarOrVec, PhasePoint using DocStringExtensions """ $(TYPEDEF) -Abstract type for HMC adaptors. +Abstract type for HMC adaptors. """ abstract type AbstractAdaptor end function getM⁻¹ end @@ -21,12 +21,17 @@ function initialize! end function finalize! end export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ +get_position(x::PhasePoint) = x.θ +get_position(x::AbstractVecOrMat{<:AbstractFloat}) = x +const PositionOrPhasePoint = Union{AbstractVecOrMat{<:AbstractFloat}, PhasePoint} + struct NoAdaptation <: AbstractAdaptor end export NoAdaptation include("stepsize.jl") export StepSizeAdaptor, NesterovDualAveraging + include("massmatrix.jl") -export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov +export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, NutpieVar, WelfordCov ## ## Composite adaptors @@ -47,18 +52,14 @@ getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa) # TODO: implement consensus adaptor function adapt!( nca::NaiveHMCAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) - adapt!(nca.ssa, θ, α) - adapt!(nca.pc, θ, α) - return nothing -end -function reset!(aca::NaiveHMCAdaptor) - reset!(aca.ssa) - reset!(aca.pc) + adapt!(nca.ssa, z_or_theta, α) + adapt!(nca.pc, z_or_theta, α) return nothing end + initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa) diff --git a/src/adaptation/massmatrix.jl b/src/adaptation/massmatrix.jl index 105d3baeb..13f360e32 100644 --- a/src/adaptation/massmatrix.jl +++ b/src/adaptation/massmatrix.jl @@ -9,16 +9,18 @@ finalize!(::MassMatrixAdaptor) = nothing function adapt!( adaptor::MassMatrixAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, - α::AbstractScalarOrVec{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, + ::AbstractScalarOrVec{<:AbstractFloat}, is_update::Bool=true, ) - resize_adaptor!(adaptor, size(θ)) - push!(adaptor, θ) + resize_adaptor!(adaptor, size(get_position(z_or_theta))) + push!(adaptor, z_or_theta) is_update && update!(adaptor) return nothing end +Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta)) + ## Unit mass matrix adaptor struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end @@ -39,7 +41,7 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T function adapt!( ::UnitMassMatrix, - ::AbstractVecOrMat{<:AbstractFloat}, + ::PositionOrPhasePoint, ::AbstractScalarOrVec{<:AbstractFloat}, is_update::Bool=true, ) @@ -47,7 +49,6 @@ function adapt!( end ## Diagonal mass matrix adaptor - abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end getM⁻¹(ve::DiagMatrixEstimator) = ve.var @@ -70,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz) -Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s) +Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s) reset!(nv::NaiveVar) = resize!(nv.S, 0) @@ -135,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat} return nothing end -function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T} +function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat} wv.n += 1 (; δ, μ, M, n) = wv n = T(n) @@ -153,6 +154,90 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat} return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5)) end +""" + NutpieVar + +Nutpie-style diagonal mass matrix estimator (using positions and gradients). + +Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement. + +Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`. + +# Fields + +$(FIELDS) +""" +mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} + "Online variance estimator of the posterior positions." + position_estimator::WelfordVar{T,E,V} + "Online variance estimator of the posterior gradients." + gradient_estimator::WelfordVar{T,E,V} + "The number of observations collected so far." + n::Int + "The minimal number of observations after which the estimate of the variances can be updated." + n_min::Int + "The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`." + var::V + function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V} + return new{eltype(E),E,V}( + WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), + WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), + n, n_min, var + ) + end +end + +function Base.show(io::IO, ::NutpieVar{T}) where {T} + return print(io, "NutpieVar{", T, "} adaptor") +end + +function NutpieVar{T}( + sz::Union{Tuple{Int},Tuple{Int,Int}}=(2,); n_min::Int=10, var=ones(T, sz) +) where {T<:AbstractFloat} + return NutpieVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) +end + +function NutpieVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) + return NutpieVar{Float64}(sz; kwargs...) +end + +function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat} + if size_θ != size(nv.var) + @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." + resize_adaptor!(nv.position_estimator, size_θ) + resize_adaptor!(nv.gradient_estimator, size_θ) + nv.var = ones(T, size_θ) + end +end + +function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat} + length_θ = first(size_θ) + if length_θ != size(nv.var, 1) + @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." + resize_adaptor!(nv.position_estimator, size_θ) + resize_adaptor!(nv.gradient_estimator, size_θ) + fill!(resize!(nv.var, length_θ), T(1)) + end +end + +function reset!(nv::NutpieVar) + nv.n = 0 + reset!(nv.position_estimator) + reset!(nv.gradient_estimator) +end + +Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!") + +function Base.push!(nv::NutpieVar, z::PhasePoint) + nv.n += 1 + push!(nv.position_estimator, z.θ) + push!(nv.gradient_estimator, z.ℓπ.gradient) + return nothing +end + +# Ref: https://github.com/pymc-devs/nutpie +get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator)) + ## Dense mass matrix adaptor abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end @@ -175,7 +260,7 @@ end NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}()) -Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s) +Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s) reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0) @@ -225,7 +310,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat} return nothing end -function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T} +function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat} wc.n += 1 (; δ, μ, n, M) = wc n = T(n) diff --git a/src/adaptation/stan_adaptor.jl b/src/adaptation/stan_adaptor.jl index b36a22597..931e741a0 100644 --- a/src/adaptation/stan_adaptor.jl +++ b/src/adaptation/stan_adaptor.jl @@ -136,20 +136,20 @@ is_window_end(a::StanHMCAdaptor) = a.state.i in a.state.window_splits function adapt!( tp::StanHMCAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) tp.state.i += 1 - adapt!(tp.ssa, θ, α) + adapt!(tp.ssa, z_or_theta, α) - resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary. + resize_adaptor!(tp.pc, size(get_position(z_or_theta))) # Resize pre-conditioner if necessary. # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp if is_in_window(tp) # We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window. is_update_M⁻¹ = is_window_end(tp) - adapt!(tp.pc, θ, α, is_update_M⁻¹) + adapt!(tp.pc, z_or_theta, α, is_update_M⁻¹) end if is_window_end(tp) diff --git a/src/adaptation/stepsize.jl b/src/adaptation/stepsize.jl index 2afbb651e..cacb463db 100644 --- a/src/adaptation/stepsize.jl +++ b/src/adaptation/stepsize.jl @@ -174,7 +174,7 @@ end # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp # Note: This function is not merged with `adapt!` to empahsize the fact that # step size adaptation is not dependent on `θ`. -# Note 2: `da.state` and `α` support vectorised HMC but should do so together. +# Note 2: `da.state` and `α` support vectorised HMC but should do so together. function adapt_stepsize!( da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T} ) where {T<:AbstractFloat} @@ -211,7 +211,7 @@ end function adapt!( da::NesterovDualAveraging, - θ::AbstractVecOrMat{<:AbstractFloat}, + ::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) adapt_stepsize!(da, α) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index c782e1a24..ece931c44 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -101,7 +101,11 @@ function Base.similar(z::PhasePoint{<:AbstractVecOrMat{T}}) where {T<:AbstractFl end function phasepoint( - h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)) + h::Hamiltonian, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), ) where {T<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end @@ -115,7 +119,7 @@ function phasepoint( _r::T2; r=safe_rsimilar(θ, _r), ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), ) where {T1<:AbstractVecOrMat,T2<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end diff --git a/src/integrator.jl b/src/integrator.jl index 004028e1e..bacc1dcdc 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -89,14 +89,14 @@ Leapfrog integrator with randomly "jittered" step size `ϵ` for every trajectory $(TYPEDFIELDS) # Description -This is the same as `LeapFrog`(@ref) but with a "jittered" step size. This means -that at the beginning of each trajectory we sample a step size `ϵ` by adding or -subtracting from the nominal/base step size `ϵ0` some random proportion of `ϵ0`, +This is the same as `LeapFrog`(@ref) but with a "jittered" step size. This means +that at the beginning of each trajectory we sample a step size `ϵ` by adding or +subtracting from the nominal/base step size `ϵ0` some random proportion of `ϵ0`, with the proportion specified by `jitter`, i.e. `ϵ = ϵ0 - jitter * ϵ0 * rand()`. p Jittering might help alleviate issues related to poor interactions with a fixed step size: -- In regions with high "curvature" the current choice of step size might mean over-shoot - leading to almost all steps being rejected. Randomly sampling the step size at the +- In regions with high "curvature" the current choice of step size might mean over-shoot + leading to almost all steps being rejected. Randomly sampling the step size at the beginning of the trajectories can therefore increase the probability of escaping such high-curvature regions. - Exact periodicity of the simulated trajectories might occur, i.e. you might be so @@ -168,7 +168,7 @@ $(TYPEDFIELDS) # Description -Tempering can potentially allow greater exploration of the posterior, e.g. +Tempering can potentially allow greater exploration of the posterior, e.g. in a multi-modal posterior jumps between the modes can be more likely to occur. """ struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: AbstractLeapfrog{T} @@ -226,9 +226,7 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - if FullTraj - res = Vector{P}(undef, n_steps) - end + res = FullTraj ? Vector{P}(undef, n_steps) : nothing (; θ, r) = z (; value, gradient) = z.ℓπ @@ -248,20 +246,16 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj + if !isnothing(res) res[i] = z end if !isfinite(z) # Remove undef - if FullTraj + if !isnothing(res) resize!(res, i) end break end end - return if FullTraj - res - else - z - end + return FullTraj === true ? res : z end diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index f8acc7971..aaa9174a3 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,14 +1,70 @@ -#! Eq (14) of Girolami & Calderhead (2011) -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, - θ::AbstractVecOrMat, - r::AbstractVecOrMat, +import AdvancedHMC: refresh, phasepoint, neg_energy, ∂H∂θ, ∂H∂r +using AdvancedHMC: + FullMomentumRefreshment, PartialMomentumRefreshment, DualValue, PhasePoint +using LinearAlgebra: logabsdet, tr, diagm, logdet, Diagonal + +# Specialized phasepoint for Riemannian metrics that need θ for momentum gradient +function phasepoint( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + θ::AbstractVecOrMat{T}, + h::Hamiltonian{<:DenseRiemannianMetric}, +) where {T<:Real} + return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) +end + +# To change L191 of hamiltonian.jl +function refresh( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + ::FullMomentumRefreshment, + h::Hamiltonian{<:DenseRiemannianMetric}, + z::PhasePoint, ) - H = h.metric.G(θ) - G = h.metric.map(H) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't + return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) +end + +# To change L215 of hamiltonian.jl +function refresh( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + ref::PartialMomentumRefreshment, + h::Hamiltonian{<:DenseRiemannianMetric}, + z::PhasePoint, +) + return phasepoint( + h, + z.θ, + ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), + ) +end + +### +### DenseRiemannianMetric-specific Hamiltonian methods +### + +# Specialized phasepoint for DenseRiemannianMetric that passes θ to ∂H∂r +function phasepoint( + h::Hamiltonian{<:DenseRiemannianMetric}, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), +) where {T<:AbstractVecOrMat} + return PhasePoint(θ, r, ℓπ, ℓκ) +end + +# Negative kinetic energy +#! Eq (13) of Girolami & Calderhead (2011) +function neg_energy( + h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T +) where {T<:AbstractVecOrMat} + G = h.metric.map(h.metric.G(θ)) + D = size(G, 1) + # Need to consider the normalizing term as it is no longer same for different θs + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined + mul!(h.metric._temp, inv(G), r) + return -logZ - dot(r, h.metric._temp) / 2 end +# Position gradient with Riemannian correction terms function ∂H∂θ( h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, @@ -58,6 +114,7 @@ function ∂H∂θ( ) where {T} return ∂H∂θ_cache(h, θ, r) end + function ∂H∂θ_cache( h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, @@ -73,7 +130,7 @@ function ∂H∂θ_cache( G, Q, λ, softabsλ = softabs(H, h.metric.map.α) - R = diagm(1 ./ softabsλ) + R = Diagonal(1 ./ softabsλ) # softabsΛ = diagm(softabsλ) # M = inv(softabsΛ) * Q' * r @@ -81,46 +138,78 @@ function ∂H∂θ_cache( J = make_J(λ, h.metric.map.α) + tmp1 = similar(H) + tmp2 = similar(H) + tmp3 = similar(H) + tmp4 = similar(softabsλ) + #! Based on the two equations from the right column of Page 3 of Betancourt (2012) - term_1_cached = Q * (R .* J) * Q' + tmp1 = R .* J + # tmp2 = Q * tmp1 + mul!(tmp2, Q, tmp1) + + # tmp1 = tmp2 * Q' + mul!(tmp1, tmp2, Q') + + term_1_cached = tmp1 + + # Cache first part of the equation + term_1_prod = similar(∂ℓπ∂θ) + @inbounds for i in 1:length(∂ℓπ∂θ) + ∂H∂θᵢ = ∂H∂θ[:, :, i] + term_1_prod[i] = ∂ℓπ∂θ[i] - 1/2 * tr(term_1_cached * ∂H∂θᵢ) + end + else - ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache + ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4 = cache end d = length(∂ℓπ∂θ) - D = diagm((Q' * r) ./ softabsλ) - term_2_cached = Q * D * J * D * Q' - g = - -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly - end + mul!(tmp4, Q', r) + D = Diagonal(tmp4 ./ softabsλ) - dv = DualValue(ℓπ, g) - return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv -end + # tmp1 = D * J + mul!(tmp1, D, J) + # tmp2 = tmp1 * D + mul!(tmp2, tmp1, D) + # tmp1 = Q * tmp2 + mul!(tmp1, Q, tmp2) + # tmp2 = tmp1 * Q' + mul!(tmp2, tmp1, Q') + # term_2_cached = tmp2 -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric -function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, - θ::T, - r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), -) where {T<:AbstractVecOrMat} - return PhasePoint(θ, r, ℓπ, ℓκ) + # g = + # -mapreduce(vcat, 1:d) do i + # ∂H∂θᵢ = ∂H∂θ[:, :, i] + # # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) + # # NOTE Some further optimization can be done here: cache the 1st product all together + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + # end + g = similar(∂ℓπ∂θ) + @inbounds for i in 1:d + ∂H∂θᵢ = ∂H∂θ[:, :, i] + g[i] = term_1_prod[i] + 1/2 * tr(tmp2 * ∂H∂θᵢ) + end + g .*= -1 + + dv = DualValue(ℓπ, g) + return if return_cache + (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4)) + else + dv + end end -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 +#! Eq (14) of Girolami & Calderhead (2011) +function ∂H∂r( + h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T} +) where {T} + H = h.metric.G(θ) + # if !all(isfinite, H) + # println("θ: ", θ) + # println("H: ", H) + # end + G = h.metric.map(H) + # return inv(G) * r + # println("G \ r: ", G \ r) + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 6ce594768..94269cfc1 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -1,3 +1,6 @@ +import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step +using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size + """ $(TYPEDEF) @@ -8,7 +11,7 @@ Generalized leapfrog integrator with fixed step size `ϵ`. $(TYPEDFIELDS) -## References +## References 1. Girolami, Mark, and Ben Calderhead. "Riemann manifold Langevin and Hamiltonian Monte Carlo methods." Journal of the Royal Statistical Society Series B: Statistical Methodology 73, no. 2 (2011): 123-214. """ @@ -21,18 +24,63 @@ function Base.show(io::IO, l::GeneralizedLeapfrog) return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") end -# fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv +abstract type AbstractImplicitMidpoint{T} <: AbstractIntegrator end + +step_size(lf::AbstractImplicitMidpoint) = lf.ϵ +function jitter( + ::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, lf::AbstractImplicitMidpoint +) + lf +end +function temper( + lf::AbstractImplicitMidpoint, + r, + ::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}}, + ::Int, +) + return r end +function stat(lf::AbstractImplicitMidpoint) + (step_size=step_size(lf), nom_step_size=nom_step_size(lf)) +end +update_nom_step_size(lf::AbstractImplicitMidpoint, ϵ) = @set lf.ϵ = ϵ + +""" +$(TYPEDEF) + +Implicit midpoint integrator with fixed step size `ϵ`. + +# Fields + +$(TYPEDFIELDS) + + +## References + +1. James A. Brofos, Roy R. Lederman. "Evaluating the Implicit Midpoint +Integrator for Riemannian Manifold Hamiltonian Monte Carlo" +""" +struct ImplicitMidpoint{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} + "Step size." + ϵ::T + n::Int +end +function Base.show(io::IO, l::ImplicitMidpoint) + return print(io, "ImplicitMidpoint(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") +end + +# fallback to ignore return_cache & cache kwargs for other ∂H∂θ +# function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) +# dv = ∂H∂θ(h, θ, r) +# return return_cache ? (dv, nothing) : dv +# end # TODO(Kai) make sure vectorization works # TODO(Kai) check if tempering is valid -# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` +# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` function step( lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, z::P, n_steps::Int=1; fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 @@ -59,7 +107,7 @@ function step( #r = temper(lf, r, (i=i, is_half=true), n_steps) # eq (16) of Girolami & Calderhead (2011) r_half = r_init - local cache + local cache = nothing for j in 1:(lf.n) # Reuse cache for the first iteration if j == 1 @@ -101,3 +149,62 @@ function step( end return res end + +function step( + lf::ImplicitMidpoint{T}, + h::Hamiltonian{<:DenseRiemannianMetric}, + z::P, + n_steps::Int=1; + fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 + full_trajectory::Val{FullTraj}=Val(false), +) where {T<:AbstractScalarOrVec{<:AbstractFloat},TP,P<:PhasePoint{TP},FullTraj} + n_steps = abs(n_steps) # to support `n_steps < 0` cases + + ϵ = fwd ? step_size(lf) : -step_size(lf) + ϵ = ϵ' + + if !(T <: AbstractFloat) || !(TP <: AbstractVector) + @warn "Vectorization is not tested for ImplicitMidpoint." + end + + res = if FullTraj + Vector{P}(undef, n_steps) + else + z + end + + for i in 1:n_steps + θ_init, r_init = z.θ, z.r + + θ_full = θ_init + r_full = r_init + for j in 1:(lf.n) + θ_bar = (θ_full + θ_init) / 2 + r_bar = (r_full + r_init) / 2 + + dHdr = ∂H∂r(h, θ_bar, r_bar) + (; value, gradient) = ∂H∂θ(h, θ_bar, r_bar) + + θ_full = θ_init + ϵ * dHdr + r_full = r_init - ϵ * gradient + end + + (; value, gradient) = ∂H∂θ(h, θ_full, r_full) + z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) + + if FullTraj + res[i] = z + else + res = z + end + if !isfinite(z) + # Remove undef + if FullTraj + res = res[isassigned.(Ref(res), 1:n_steps)] + end + break + end + end + + return res +end diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 41d11127c..e3beb0441 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -1,3 +1,9 @@ +using AdvancedHMC: AbstractMetric +using LinearAlgebra: eigen, cholesky, Symmetric +import Base: eltype + +# _randn is defined in utilities.jl which is included before this file + abstract type AbstractRiemannianMetric <: AbstractMetric end abstract type AbstractHessianMap end @@ -10,18 +16,28 @@ struct SoftAbsMap{T} <: AbstractHessianMap α::T end -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` +# TODO Register softabs with ReverseDiff +#! The definition of SoftAbs from Page 3 of Betancourt (2012) +function softabs(X::AbstractMatrix{T}, α=20.0) where {T<:Real} + # Enforce symmetry for type stability + F = eigen(Symmetric(X)) # ReverseDiff cannot diff through `eigen` Q = hcat(F.vectors) λ = F.values softabsλ = λ .* coth.(α * λ) return Q * diagm(softabsλ) * Q', Q, λ, softabsλ end +function softabs_decomp(X::AbstractMatrix{T}, α=20.0) where {T<:Real} + # Enforce symmetry for type stability + F = eigen(Symmetric(X)) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q, softabsλ +end + (map::SoftAbsMap)(x) = softabs(x, map.α)[1] -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) struct DenseRiemannianMetric{ T, TM<:AbstractHessianMap, @@ -39,15 +55,19 @@ end # TODO Make dense mass matrix support matrix-mode parallel function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) - _temp = Vector{Float64}(undef, first(size)) + _temp = Vector{Float64}(undef, size[1]) return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) end Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -function Base.show(io::IO, drm::DenseRiemannianMetric) - return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") -end +Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") + +#function eltype(m::DenseRiemannianMetric) +# return eltype(m._temp) +#end + +eltype(::DenseRiemannianMetric{T}) where {T} = T function rand_momentum( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, @@ -56,8 +76,19 @@ function rand_momentum( θ::AbstractVecOrMat, ) where {T} r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) + chol = cholesky(Symmetric(metric.map(metric.G(θ)))) + r = chol.L * r + return r +end + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T,<:SoftAbsMap}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + Q, softabsλ = softabs_decomp(metric.G(θ), metric.map.α) + r = Q * Diagonal(sqrt.(softabsλ)) * r return r end diff --git a/src/sampler.jl b/src/sampler.jl index 3e477ba3a..1b282383b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -60,11 +60,11 @@ end function Adaptation.adapt!( h::Hamiltonian, κ::AbstractMCMCKernel, - adaptor::Adaptation.NoAdaptation, - i::Int, - n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, - α::AbstractScalarOrVec{<:AbstractFloat}, + ::Adaptation.NoAdaptation, + ::Int, + ::Int, + ::PositionOrPhasePoint, + ::AbstractScalarOrVec{<:AbstractFloat}, ) return h, κ, false end @@ -75,19 +75,18 @@ function Adaptation.adapt!( adaptor::AbstractAdaptor, i::Int, n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) - isadapted = false - if i <= n_adapts + adapt = i <= n_adapts + if adapt i == 1 && Adaptation.initialize!(adaptor, n_adapts) - adapt!(adaptor, θ, α) + adapt!(adaptor, z_or_theta, α) i == n_adapts && finalize!(adaptor) h = update(h, adaptor) κ = update(κ, adaptor) - isadapted = true end - return h, κ, isadapted + return h, κ, adapt end """ @@ -148,7 +147,7 @@ end progress::Bool=false ) Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`. -- The randomness is controlled by `rng`. +- The randomness is controlled by `rng`. - If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used. - The initial point is given by `θ`. - The adaptor is set by `adaptor`, for which the default is no adaptation. @@ -185,7 +184,7 @@ function sample( t = transition(rng, h, κ, t.z) # Adapt h and κ; what mutable is the adaptor tstat = stat(t) - h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) if isadapted num_divergent_transitions_during_adaption += tstat.numerical_error else diff --git a/src/trajectory.jl b/src/trajectory.jl index 2e3c1d550..8ef4700b7 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -141,8 +141,9 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -SliceTS(rng::AbstractRNG, z0::PhasePoint) = +function SliceTS(rng::AbstractRNG, z0::PhasePoint) SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) +end """ $(TYPEDEF) @@ -552,7 +553,7 @@ function isterminated(::ClassicNoUTurn, h::Hamiltonian, t::BinaryTree) # z0 is starting point and z1 is ending point z0, z1 = t.zleft, t.zright Δθ = z1.θ - z0.θ - s = (dot(Δθ, ∂H∂r(h, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.r)) >= 0) + s = (dot(Δθ, ∂H∂r(h, z0.θ, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.θ, z1.r)) >= 0) return Termination(s, false) end @@ -565,7 +566,9 @@ Ref: https://arxiv.org/abs/1701.02434 """ function isterminated(::GeneralisedNoUTurn, h::Hamiltonian, t::BinaryTree) rho = t.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, t.zright.θ, t.zright.r) + ) return Termination(s, false) end @@ -595,7 +598,9 @@ phase point of `tright`, the right subtree. """ function check_left_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree} rho = tleft.ts.rho + tright.zleft.r - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, tright.zleft.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, tright.zleft.θ, tright.zleft.r) + ) return Termination(s, false) end @@ -608,7 +613,9 @@ function check_right_subtree( h::Hamiltonian, t::T, tleft::T, tright::T ) where {T<:BinaryTree} rho = tleft.zright.r + tright.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, tleft.zright.θ, tleft.zright.r), ∂H∂r(h, t.zright.θ, t.zright.r) + ) return Termination(s, false) end diff --git a/test/adaptation.jl b/test/adaptation.jl index 346423eaa..df72c159e 100644 --- a/test/adaptation.jl +++ b/test/adaptation.jl @@ -1,6 +1,8 @@ using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff +using AdvancedHMC: + PhasePoint, DualValue using AdvancedHMC.Adaptation: - WelfordVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset! + DiagMatrixEstimator, WelfordVar, NutpieVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset! function runnuts(ℓπ, metric; n_samples=10_000) D = size(metric, 1) @@ -18,7 +20,37 @@ function runnuts(ℓπ, metric; n_samples=10_000) return (samples=samples, stats=stats, adaptor=adaptor) end +# Temporary function until we've settled on a different interface +function runnuts_nutpie(ℓπ, metric::DiagEuclideanMetric; n_samples=10_000) + D = size(metric, 1) + n_adapts = 5_000 + θ_init = rand(D) + rng = MersenneTwister(0) + + nuts = NUTS(0.8) + h = Hamiltonian(metric, ℓπ, ForwardDiff) + step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init) + integrator = AdvancedHMC.make_integrator(nuts, step_size) + κ = AdvancedHMC.make_kernel(nuts, integrator) + # Constructing like this until we've settled on a different interface + adaptor = AdvancedHMC.StanHMCAdaptor( + AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)), + AdvancedHMC.StepSizeAdaptor(nuts.δ, integrator) + ) + samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose=false) + return (samples=samples, stats=stats, adaptor=adaptor) +end +""" +Computes the condition number of a covariance matrix `cov::AbstractMatrix` after preconditioning with the (diagonal) mass matrix estimated in `a::DiagMatrixEstimator`. + +This is a simple but serviceable proxy for eventual sampling efficiency, but see also https://arxiv.org/abs/1905.09813 for a more involved estimate. + +(A lower number generally means that the estimated mass matrix is better). +""" +preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Diagonal(a.var)) \ cov / sqrt(Diagonal(a.var))) + @testset "Adaptation" begin + Random.seed!(1) # Check that the estimated variance is approximately correct. @testset "Online v.s. naive v.s. true var/cov estimation" begin D = 10 @@ -60,15 +92,32 @@ end @testset "MassMatrixAdaptor constructors" begin θ = [0.0, 0.0, 0.0, 0.0] + z = PhasePoint( + θ, θ, DualValue(0., θ), DualValue(0., θ) + ) pc1 = MassMatrixAdaptor(UnitEuclideanMetric) # default dim = 2 pc2 = MassMatrixAdaptor(DiagEuclideanMetric) + # Constructing like this until we've settled on a different interface + pc2_nutpie = NutpieVar{Float64}((2, )) pc3 = MassMatrixAdaptor(DenseEuclideanMetric) - # Var adaptor dimention should be increased to length(θ) from 2 + # Var adaptor dimension should be increased to length(θ) from 2 AdvancedHMC.adapt!(pc1, θ, 1.0) AdvancedHMC.adapt!(pc2, θ, 1.0) + AdvancedHMC.adapt!(pc2_nutpie, z, 1.0) AdvancedHMC.adapt!(pc3, θ, 1.0) @test AdvancedHMC.Adaptation.getM⁻¹(pc2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2_nutpie) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc3) == + LinearAlgebra.diagm(0 => ones(length(θ))) + + # Making sure "all" MassMatrixAdaptors support getting a PhasePoint instead of a Vector + AdvancedHMC.adapt!(pc1, z, 1.0) + AdvancedHMC.adapt!(pc2, z, 1.0) + AdvancedHMC.adapt!(pc2_nutpie, z, 1.0) + AdvancedHMC.adapt!(pc3, z, 1.0) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2_nutpie) == ones(length(θ)) @test AdvancedHMC.Adaptation.getM⁻¹(pc3) == LinearAlgebra.diagm(0 => ones(length(θ))) end @@ -82,10 +131,14 @@ end adaptor2 = StanHMCAdaptor( MassMatrixAdaptor(DiagEuclideanMetric), NesterovDualAveraging(0.8, 0.5) ) + # Constructing like this until we've settled on a different interface + adaptor2_nutpie = StanHMCAdaptor( + NutpieVar{Float64}((2, )), NesterovDualAveraging(0.8, 0.5) + ) adaptor3 = StanHMCAdaptor( MassMatrixAdaptor(DenseEuclideanMetric), NesterovDualAveraging(0.8, 0.5) ) - for a in [adaptor1, adaptor2, adaptor3] + for a in [adaptor1, adaptor2, adaptor2_nutpie, adaptor3] AdvancedHMC.initialize!(a, 1_000) @test a.state.window_start == 76 @test a.state.window_end == 950 @@ -93,6 +146,7 @@ end AdvancedHMC.adapt!(a, θ, 1.0) end @test AdvancedHMC.Adaptation.getM⁻¹(adaptor2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(adaptor2_nutpie) == ones(length(θ)) @test AdvancedHMC.Adaptation.getM⁻¹(adaptor3) == LinearAlgebra.diagm(0 => ones(length(θ))) @@ -112,26 +166,32 @@ end @testset "Adapted mass v.s. true variance" begin D = 10 - n_tests = 5 - @testset "DiagEuclideanMetric" begin + n_tests = 10 + @testset "'Diagonal' MvNormal target" begin for _ in 1:n_tests - Random.seed!(1) # Random variance σ² = 1 .+ abs.(randn(D)) + Σ = Diagonal(σ²) # Diagonal Gaussian - ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²))) + ℓπ = LogDensityDistribution(MvNormal(Σ)) res = runnuts(ℓπ, DiagEuclideanMetric(D)) @test res.adaptor.pc.var ≈ σ² rtol = 0.2 + # For this target, Nutpie (without regularization) will arrive at the true variances after two draws. + res_nutpie = runnuts_nutpie(ℓπ, DiagEuclideanMetric(D)) + @test res.adaptor.pc.var ≈ σ² rtol = 0.2 + @test preconditioned_cond(res_nutpie.adaptor.pc, Σ) < preconditioned_cond(res.adaptor.pc, Σ) + res = runnuts(ℓπ, DenseEuclideanMetric(D)) @test res.adaptor.pc.cov ≈ Diagonal(σ²) rtol = 0.25 end end - @testset "DenseEuclideanMetric" begin + @testset "'Dense' MvNormal target" begin + n_nutpie_superior = 0 for _ in 1:n_tests # Random covariance m = randn(D, D) @@ -143,9 +203,17 @@ end res = runnuts(ℓπ, DiagEuclideanMetric(D)) @test res.adaptor.pc.var ≈ diag(Σ) rtol = 0.2 + # For this target, Nutpie will NOT converge towards the true variances, even after infinite draws. + # HOWEVER, it will asymptotically (but also generally more quickly than Stan) + # find the best preconditioner for the target. + # As these are statistical algorithms, superiority is not always guaranteed, hence this way of testing. + res_nutpie = runnuts_nutpie(ℓπ, DiagEuclideanMetric(D)) + n_nutpie_superior += preconditioned_cond(res_nutpie.adaptor.pc, Σ) < preconditioned_cond(res.adaptor.pc, Σ) + res = runnuts(ℓπ, DenseEuclideanMetric(D)) @test res.adaptor.pc.cov ≈ Σ rtol = 0.25 end + @test n_nutpie_superior > 1 + n_tests / 2 end end @@ -156,6 +224,10 @@ end res = runnuts(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1) @test res.adaptor.pc.var == mass_init + mass_init = fill(0.5, D) + res = runnuts_nutpie(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1) + @test res.adaptor.pc.var == mass_init + mass_init = diagm(0 => fill(0.5, D)) res = runnuts(ℓπ, DenseEuclideanMetric(mass_init); n_samples=1) @test res.adaptor.pc.cov == mass_init