Skip to content

Commit

Permalink
Improve sparse Hessian calculation
Browse files Browse the repository at this point in the history
- use SparsityDetection and SparseDiffTools to get Hessian sparsity and 
coloring
- implementation of "direct" Hessian decompression method
- HessianConfig type to store sparsity info and buffers
  • Loading branch information
ElOceanografo committed May 27, 2021
1 parent 9a56e96 commit cfc209c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparsityDetection = "684fba80-ace3-11e9-3d08-3bc7ed6f96df"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Expand Down
77 changes: 72 additions & 5 deletions src/MarginalLogDensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ using SuiteSparse
using SparseArrays
using HCubature
using Distributions
using SparsityDetection
# using Symbolics
using SparseDiffTools

export MarginalLogDensity,
HessianConfig,
AbstractMarginalizer,
LaplaceApprox,
Cubature,
Expand Down Expand Up @@ -47,18 +50,84 @@ as `mld(u, θ)`, where `u` is a length-`m` vector of the marginalized variables.
case, the return value is the same as the full conditional `logdensity` with `u` and `θ`
"""
struct MarginalLogDensity{TI<:Integer, TM<:AbstractMarginalizer,
TV<:AbstractVector{TI}, TF}
TV<:AbstractVector{TI}, TF, THP}
logdensity::TF
n::TI
imarginal::TV
ijoint::TV
method::TM
hessconfig::THP
end

# Was using this for testing/troubleshooting, will probably delete later
# """
# `num_hessian_sparsity(f, x, [δ=1.0])`
#
# Calculate the sparsity pattern of the Hessian matrix of function `f`. This is a brute-force
# approach, but more robust than the one in SparsityDetection
# """
# function num_hessian_sparsity(f, x, δ=1.0)
# N = length(x)
# g(x) = ForwardDiff.gradient(f, x)
# y = g(x)
# ii = Int[]
# jj = Int[]
# vv = Float64[]
# for j in 1:N
# x[j] += δ
# yj = g(x)
# di = findall(.! (yj .≈ y))
# for i in di
# push!(jj, j)
# push!(ii, i)
# push!(vv, 1.0)
# end
# x[j] -= δ
# end
# return sparse(ii, jj, vv)
# end

struct HessianConfig{THS, THC, TD, TG}
Hsparsity::THS
Hcolors::THC
D::TD
Hcomp_buffer::TD
G::TG
δG::TG
end

function HessianConfig(logdensity, imarginal, ijoint)
x = ones(length(imarginal) + length(ijoint))
Hsparsity = hessian_sparsity(logdensity, x)[imarginal, imarginal]
Hcolors = matrix_colors(tril(Hsparsity))
D = hcat([float.(i .== Hcolors) for i in 1:maximum(Hcolors)]...)
Hcomp_buffer = similar(D)
G = zeros(length(imarginal))
δG = zeros(length(imarginal))
return HessianConfig(Hsparsity, Hcolors, D, Hcomp_buffer, G, δG)
end

function sparse_hessian!(f, θ, hessconfig::HessianConfig, δ=sqrt(eps(Float64)))
g!(G, θ) = ForwardDiff.gradient!(G, f, θ)
for j in 1:size(hessconfig.Hcolors, 2)
g!(hessconfig.G, θ)
g!(hessconfig.δG, θ + δ * @view hessconfig.D[:, j])
hessconfig.Hcomp_buffer[:, j] .= (hessconfig.δG .- hessconfig.G) ./ δ
end
ii, jj, vv = findnz(hessconfig.Hsparsity)
H = sparse(ii, jj, zeros(length(vv)))
for (i, j) in zip(ii, jj)
H[i, j] = hessconfig.Hcomp_buffer[i, hessconfig.Hcolors[j]]
end
return H
end

function MarginalLogDensity(logdensity::Function, n::TI,
imarginal::AbstractVector{TI}, method=LaplaceApprox()) where {TI<:Integer}
ijoint = setdiff(1:n, imarginal)
return MarginalLogDensity(logdensity, n, imarginal, ijoint, method)
hessconfig = HessianConfig(logdensity, imarginal, ijoint)
mld = MarginalLogDensity(logdensity, n, imarginal, ijoint, method, hessconfig)
return mld
end

dimension(mld::MarginalLogDensity) = mld.n
Expand Down Expand Up @@ -137,9 +206,7 @@ function _marginalize(mld::MarginalLogDensity, θjoint::AbstractVector{T},
f(θmarginal) = -mld(θmarginal, θjoint)
N = nmarginal(mld)
opt = optimize(f, zeros(N), LBFGS(), autodiff=:forward)
# H = ForwardDiff.hessian(f, opt.minimizer)
Hv = HesVec(f, opt.minimizer)
H = reduce(hcat, sparse(Hv * i) for i in eachcol(I(N)))
H = sparse_hessian!(f, opt.minimizer, mld.hessconfig)
integral = -opt.minimum + 0.5 * (log((2π)^N) - logdet(H))
return integral
end
Expand Down
8 changes: 4 additions & 4 deletions test/example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ end
θtrue = [μ0; log(σ0); aa; b; log(σ)]
= lengthtrue)
logliktrue)

θmarg = θtrue[[1, 2, 11, 12]]

mld = MarginalLogDensity(loglik, nθ, collect(3:10))
@btime mld(aa, [μ0, log(σ0), b, log(σ)])
@btime mld([μ0, log(σ0), b, log(σ)])
@profiler for i in 1:1000
@btime mld($aa, $θmarg) # 5.3 μs
@btime mld($θmarg) # 115 μs
@profiler for i in 1:5000
mld([μ0, log(σ0), b, log(σ)])
end

Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ij = [2]
mld = MarginalLogDensity(logdensity, N, im)

@testset "Constructors" begin
mld1 = MarginalLogDensity(logdensity, N, im, ij, LaplaceApprox())
hp = HessianConfig(zeros(N, N), zeros(N), zeros(N, N), zeros(N, N), zeros(N), zeros(N))
mld1 = MarginalLogDensity(logdensity, N, im, ij, LaplaceApprox(), hp)
mld2 = MarginalLogDensity(logdensity, N, im)
mld3 = MarginalLogDensity(logdensity, N, im, LaplaceApprox())
mld4 = MarginalLogDensity(logdensity, N, im,
Expand Down

0 comments on commit cfc209c

Please sign in to comment.