Skip to content

Commit

Permalink
Merge pull request #83 from TuringLang/kx/bugfix-adapt
Browse files Browse the repository at this point in the history
Bugfix for metric renew
  • Loading branch information
yebai committed Jul 25, 2019
2 parents fe3de8b + e2e3c9a commit e09185d
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.2.0"
version = "0.2.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 1 addition & 1 deletion src/adaptation/Adaptation.jl
Expand Up @@ -57,7 +57,7 @@ finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)
##
include("stan_adaption.jl")

export adapt!, finalize!, getϵ, getM⁻¹, reset!,
export adapt!, finalize!, getϵ, getM⁻¹, reset!, renew,
NesterovDualAveraging,
UnitPreconditioner, DiagPreconditioner, DensePreconditioner,
AbstractMetric, UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric,
Expand Down
7 changes: 7 additions & 0 deletions src/adaptation/precond.jl
Expand Up @@ -260,9 +260,12 @@ struct UnitEuclideanMetric{T} <: AbstractMetric
M⁻¹::UniformScaling{T}
dim::Int
end

UnitEuclideanMetric(::Type{T}, dim::Int) where {T} = UnitEuclideanMetric{T}(UniformScaling{T}(one(T)), dim)
UnitEuclideanMetric(dim::Int) = UnitEuclideanMetric(Float64, dim)

renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.dim)

Base.length(e::UnitEuclideanMetric) = e.dim
Base.show(io::IO, uem::UnitEuclideanMetric) = print(io, "UnitEuclideanMetric($(_string_M⁻¹(ones(uem.dim))))")

Expand All @@ -281,6 +284,8 @@ end
DiagEuclideanMetric(::Type{T}, D::Int) where {T} = DiagEuclideanMetric(ones(T, D))
DiagEuclideanMetric(D::Int) = DiagEuclideanMetric(Float64, D)

renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)

Base.length(e::DiagEuclideanMetric) = size(e.M⁻¹, 1)
Base.show(io::IO, dem::DiagEuclideanMetric) = print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))")

Expand Down Expand Up @@ -310,6 +315,8 @@ end
DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))
DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D)

renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)

Base.length(e::DenseEuclideanMetric) = size(e.M⁻¹, 1)
Base.show(io::IO, dem::DenseEuclideanMetric) = print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))")

Expand Down
7 changes: 2 additions & 5 deletions src/sampler.jl
Expand Up @@ -62,12 +62,9 @@ function sample(
if !(adaptor isa Adaptation.NoAdaptation)
if i <= n_adapts
adapt!(adaptor, θs[i], αs[i])
# Finalize adapation
if i == n_adapts
finalize!(adaptor)
(verbose && !progress) && @info "Finished $n_adapts adapation steps" adaptor τ.integrator h.metric
end
i == n_adapts && finalize!(adaptor)
h, τ = update(h, τ, adaptor)
(i == n_adapts && verbose && !progress) && @info "Finished $n_adapts adapation steps" adaptor τ.integrator h.metric
end
# Progress info for adapation
progress && (showvalues[:step_size] = τ.integrator.ϵ; showvalues[:precondition] = h.metric)
Expand Down
9 changes: 7 additions & 2 deletions src/trajectory.jl
Expand Up @@ -239,6 +239,11 @@ function isturn(h::Hamiltonian, zleft::PhasePoint, zright::PhasePoint)
s = (dot(θdiff, ∂H∂r(h, zleft.r)) >= 0 ? 1 : 0) * (dot(θdiff, ∂H∂r(h, zright.r)) >= 0 ? 1 : 0)
return Termination(s == 0, false)
end
# function isturn(h::Hamiltonian, zleft::PhasePoint, zright::PhasePoint)
# θdiff = zright.θ - zleft.θ
# s = (dot(-θdiff, ∂H∂r(h, zleft.r)) < 0) && (dot(θdiff, ∂H∂r(h, zright.r)) < 0)
# return Termination(s, false)
# end

"""
Check termination of a Hamiltonian trajectory.
Expand Down Expand Up @@ -508,7 +513,7 @@ function update(
τ::AbstractProposal,
pc::Adaptation.AbstractPreconditioner
)
metric = reconstruct(h.metric, M⁻¹=getM⁻¹(pc))
metric = renew(h.metric, getM⁻¹(pc))
h = reconstruct(h, metric=metric)
return h, τ
end
Expand All @@ -528,7 +533,7 @@ function update(
τ::AbstractProposal,
ca::Union{Adaptation.NaiveHMCAdaptor, Adaptation.StanHMCAdaptor}
)
metric = reconstruct(h.metric, M⁻¹=getM⁻¹(ca.pc))
metric = renew(h.metric, getM⁻¹(ca.pc))
h = reconstruct(h, metric=metric)
integrator = reconstruct.integrator, ϵ=getϵ(ca.ssa))
τ = reconstruct(τ, integrator=integrator)
Expand Down
21 changes: 17 additions & 4 deletions test/common.jl
Expand Up @@ -8,15 +8,28 @@ const DETATOL = 1e-3 * D
# Random tolerance
const RNDATOL = 5e-2 * D

using Distributions: logpdf, MvNormal

ℓπ(θ) = logpdf(MvNormal(zeros(D), ones(D)), θ)

using Distributions: logpdf, MvNormal, InverseGamma, Normal
using DiffResults: GradientResult, value, gradient
using ForwardDiff: gradient!

ℓπ(θ) = logpdf(MvNormal(zeros(D), ones(D)), θ)

function ∂ℓπ∂θ(θ)
res = GradientResult(θ)
gradient!(res, ℓπ, θ)
return (value(res), gradient(res))
end

function ℓπ_gdemo(θ)
s = exp(θ[1])
m = θ[2]
logprior = logpdf(InverseGamma(2, 3), s) + log(s) + logpdf(Normal(0, sqrt(s)), m)
loglikelihood = logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
return logprior + loglikelihood
end

function ∂ℓπ∂θ_gdemo(θ)
res = GradientResult(θ)
gradient!(res, ℓπ_gdemo, θ)
return (value(res), gradient(res))
end
24 changes: 24 additions & 0 deletions test/models.jl
@@ -0,0 +1,24 @@
using Test, Random, AdvancedHMC
using Statistics: mean
include("common.jl")

@testset "`gdemo`" begin
rng = MersenneTwister(0)

n_samples = 5_000
n_adapts = 1_000

θ_init = randn(rng, 2)

metric = DiagEuclideanMetric(2)
h = Hamiltonian(metric, ℓπ_gdemo, ∂ℓπ∂θ_gdemo)
init_eps = Leapfrog(0.1)
prop = NUTS(init_eps)
adaptor = StanHMCAdaptor(n_adapts, Preconditioner(metric), NesterovDualAveraging(0.8, prop.integrator.ϵ))

samples, _ = sample(rng, h, prop, θ_init, n_samples, adaptor, n_adapts)

m_est = mean(map(_s -> [exp(_s[1]), _s[2]], samples[1000:end]))

@test m_est [49 / 24, 7 / 6] atol=RNDATOL
end
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -6,6 +6,7 @@ using Distributed, Test
"hamiltonian",
"integrator",
"trajectory",
"models",
"hmc",
]

Expand Down

0 comments on commit e09185d

Please sign in to comment.