Skip to content

Commit

Permalink
Merge a9855c6 into 38f1400
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jun 1, 2020
2 parents 38f1400 + a9855c6 commit f14de20
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 96 deletions.
16 changes: 14 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,25 @@ include("diagnosis.jl")
include("sampler.jl")
export sample

include("contrib/ad.jl")

### Init

using Requires

function __init__()
include(joinpath(@__DIR__, "contrib", "diffeq.jl"))
include(joinpath(@__DIR__, "contrib", "ad.jl"))
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
export DiffEqIntegrator
include("contrib/diffeq.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("contrib/forwarddiff.jl")
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("contrib/zygote.jl")
end
end

end # module
86 changes: 1 addition & 85 deletions src/contrib/ad.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const ADSUPPORT = (:ForwardDiff, :Zygote)
const ADAVAILABLE = Dict{Module, Function}()

Hamiltonian(metric::AbstractMetric, ℓπ, m::Module) = ADAVAILABLE[m](metric, ℓπ)
Hamiltonian(metric::M, ℓπ::T, m::Module) where {M<:AbstractMetric,T} = ADAVAILABLE[m](metric, ℓπ)

function Hamiltonian(metric::AbstractMetric, ℓπ)
available = collect(keys(ADAVAILABLE))
Expand All @@ -16,87 +16,3 @@ function Hamiltonian(metric::AbstractMetric, ℓπ)
return Hamiltonian(metric, ℓπ, first(available))
end
end

### ForwardDiff

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin

import .ForwardDiff, .ForwardDiff.DiffResults

function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractVector)
res = DiffResults.GradientResult(θ)
ForwardDiff.gradient!(res, ℓπ, θ)
return DiffResults.value(res), DiffResults.gradient(res)
end

# Implementation 1
function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
jacob = similar(θ)
res = DiffResults.JacobianResult(similar(θ, size(θ, 2)), jacob)
ForwardDiff.jacobian!(res, ℓπ, θ)
jacob_full = DiffResults.jacobian(res)

d, n = size(jacob)
for i in 1:n
jacob[:,i] = jacob_full[i,1+(i-1)*d:i*d]
end
return DiffResults.value(res), jacob
end

# Implementation 2
# function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
# local densities
# f(x) = (densities = ℓπ(x); sum(densities))
# res = DiffResults.GradientResult(θ)
# ForwardDiff.gradient!(res, f, θ)
# return ForwardDiff.value.(densities), DiffResults.gradient(res)
# end

# Implementation 3
# function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
# v = similar(θ, size(θ, 2))
# g = similar(θ)
# for i in 1:size(θ, 2)
# res = GradientResult(θ[:,i])
# gradient!(res, ℓπ, θ[:,i])
# v[i] = value(res)
# g[:,i] = gradient(res)
# end
# return v, g
# end

function ForwardDiffHamiltonian(metric::AbstractMetric, ℓπ)
∂ℓπ∂θ::AbstractVecOrMat) = ∂ℓπ∂θ_forwarddiff(ℓπ, θ)
return Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
end

ADAVAILABLE[ForwardDiff] = ForwardDiffHamiltonian

end # @require

### Zygote

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin

import .Zygote

function ∂ℓπ∂θ_zygote(ℓπ, θ::AbstractVector)
res, back = Zygote.pullback(ℓπ, θ)
return res, first(back(Zygote.sensitivity(res)))
end

function ∂ℓπ∂θ_zygote(ℓπ, θ::AbstractMatrix)
res, back = Zygote.pullback(ℓπ, θ)
return res, first(back(ones(eltype(res), size(res))))
end

function ZygoteADHamiltonian(metric::AbstractMetric, ℓπ)
∂ℓπ∂θ::AbstractVecOrMat) = ∂ℓπ∂θ_zygote(ℓπ, θ)
return Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
end

ADAVAILABLE[Zygote] = ZygoteADHamiltonian

# Zygote.@adjoint

end # @require
6 changes: 0 additions & 6 deletions src/contrib/diffeq.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin

import .OrdinaryDiffEq

struct DiffEqIntegrator{T<:AbstractScalarOrVec{<:AbstractFloat}, DiffEqSolver} <: AbstractLeapfrog{T}
Expand Down Expand Up @@ -43,7 +41,3 @@ function step(
end
return res
end

export DiffEqIntegrator

end # @require
50 changes: 50 additions & 0 deletions src/contrib/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import .ForwardDiff, .ForwardDiff.DiffResults

function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractVector)
res = DiffResults.GradientResult(θ)
ForwardDiff.gradient!(res, ℓπ, θ)
return DiffResults.value(res), DiffResults.gradient(res)
end

# Implementation 1
function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
jacob = similar(θ)
res = DiffResults.JacobianResult(similar(θ, size(θ, 2)), jacob)
ForwardDiff.jacobian!(res, ℓπ, θ)
jacob_full = DiffResults.jacobian(res)

d, n = size(jacob)
for i in 1:n
jacob[:,i] = jacob_full[i,1+(i-1)*d:i*d]
end
return DiffResults.value(res), jacob
end

# Implementation 2
# function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
# local densities
# f(x) = (densities = ℓπ(x); sum(densities))
# res = DiffResults.GradientResult(θ)
# ForwardDiff.gradient!(res, f, θ)
# return ForwardDiff.value.(densities), DiffResults.gradient(res)
# end

# Implementation 3
# function ∂ℓπ∂θ_forwarddiff(ℓπ, θ::AbstractMatrix)
# v = similar(θ, size(θ, 2))
# g = similar(θ)
# for i in 1:size(θ, 2)
# res = GradientResult(θ[:,i])
# gradient!(res, ℓπ, θ[:,i])
# v[i] = value(res)
# g[:,i] = gradient(res)
# end
# return v, g
# end

function ForwardDiffHamiltonian(metric::AbstractMetric, ℓπ)
∂ℓπ∂θ::AbstractVecOrMat) = ∂ℓπ∂θ_forwarddiff(ℓπ, θ)
return Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
end

ADAVAILABLE[ForwardDiff] = ForwardDiffHamiltonian
18 changes: 18 additions & 0 deletions src/contrib/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import .Zygote

function ∂ℓπ∂θ_zygote(ℓπ, θ::AbstractVector)
res, back = Zygote.pullback(ℓπ, θ)
return res, first(back(Zygote.sensitivity(res)))
end

function ∂ℓπ∂θ_zygote(ℓπ, θ::AbstractMatrix)
res, back = Zygote.pullback(ℓπ, θ)
return res, first(back(ones(eltype(res), size(res))))
end

function ZygoteADHamiltonian(metric::AbstractMetric, ℓπ)
∂ℓπ∂θ::AbstractVecOrMat) = ∂ℓπ∂θ_zygote(ℓπ, θ)
return Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
end

ADAVAILABLE[Zygote] = ZygoteADHamiltonian
6 changes: 6 additions & 0 deletions test/sampler-vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ include("common.jl")
@test show(metric) == nothing
@test show(h) == nothing
@test show(τ) == nothing

# NoAdaptation
Random.seed!(100)
samples, stats = sample(h, τ, θ_init_list[i_test], n_samples; verbose=false)
@test mean(samples) zeros(D, n_chains) atol=RNDATOL * n_chains

# Adaptation
for adaptor in [
MassMatrixAdaptor(metric),
Expand All @@ -51,9 +54,12 @@ include("common.jl")
]
τ isa HMCDA && continue
@test show(adaptor) == nothing

Random.seed!(100)
samples, stats = sample(h, τ, θ_init_list[i_test], n_samples, adaptor, n_adapts; verbose=false, progress=false)
@test mean(samples) zeros(D, n_chains) atol=RNDATOL * n_chains
end

# Passing a vector of same RNGs
rng = [MersenneTwister(1) for _ in 1:n_chains]
h = Hamiltonian(metricT((D, n_chains)), ℓπ, ∂ℓπ∂θ)
Expand Down
6 changes: 3 additions & 3 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ include("common.jl")
θ_init = rand(MersenneTwister(1), D)
ϵ = 0.2
n_steps = 10
n_samples = 12_000
n_adapts = 2_000
n_samples = 22_000
n_adapts = 4_000

function test_stats(::Union{StaticTrajectory,HMCDA}, stats, n_adapts)
for name in (:step_size, :nom_step_size, :n_steps, :is_accept, :acceptance_rate, :log_density, :hamiltonian_energy, :hamiltonian_energy_error, :is_adapt)
Expand Down Expand Up @@ -85,7 +85,7 @@ end
# For other adapatation methods that are able to adpat the step size, we use `find_good_stepsize`.
τ_used = adaptorsym == :MassMatrixAdaptorOnly ? τ : reconstruct(τ, integrator=reconstruct(lf, ϵ=find_good_stepsize(h, θ_init)))
samples, stats = sample(h, τ_used , θ_init, n_samples, adaptor, n_adapts; verbose=false, progress=PROGRESS)
@test mean(samples) zeros(D) atol=RNDATOL
@test mean(samples[(n_adapts+1):end]) zeros(D) atol=RNDATOL
test_stats(τ_used, stats, n_adapts)
end
end
Expand Down

0 comments on commit f14de20

Please sign in to comment.