Skip to content

Unable to compute sparse Hessians with ThreadSafeVarInfo #1081

@penelopeysm

Description

@penelopeysm

MWE

using DynamicPPL, Distributions
import DifferentiationInterface as DI
import SparseConnectivityTracer as SCT
import SparseMatrixColorings as SMC
import ForwardDiff

@model function f()
    x ~ Normal()
    y ~ Normal(x)
end
struct LogDensityAt{M,V}
    model::M
    varinfo::V
end
function (lda::LogDensityAt)(x)
    varinfo_new = DynamicPPL.unflatten(lda.varinfo, x)
    # NOTE: Changing this line to `evaluate_threadunsafe!!` makes it work
    varinfo_eval = last(DynamicPPL.evaluate_threadsafe!!(lda.model, varinfo_new))
    return DynamicPPL.getlogjoint(varinfo_eval)
end
lda = LogDensityAt(f(), DynamicPPL.VarInfo(f()))

# Works
DI.hessian(lda, DI.AutoForwardDiff(), [0.5, 0.5])

# Fails
hess_adtype = DI.AutoSparse(
    DI.AutoForwardDiff(),
    sparsity_detector=SCT.TracerLocalSparsityDetector(),
    coloring_algorithm=SMC.GreedyColoringAlgorithm(),
)
DI.hessian(lda, hess_adtype, [0.5, 0.5])

Traceback

julia> DI.hessian(lda, hess_adtype, [0.5, 0.5])
ERROR: TypeError: in new, expected Float64, got a value of type SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{Int64, BitSet, Dict{Int64, BitSet}, SparseConnectivityTracer.NotShared}}
Stacktrace:
  [1] LogPriorAccumulator{Float64}(logp::SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}})
    @ DynamicPPL ~/ppl/dppl/src/default_accumulators.jl:86
  [2] convert
    @ ~/ppl/dppl/src/default_accumulators.jl:66 [inlined]
  [3] cvt1
    @ ./essentials.jl:612 [inlined]
  [4] ntuple
    @ ./ntuple.jl:50 [inlined]
  [5] convert
    @ ./essentials.jl:614 [inlined]
  [6] Tuple{…}(nt::@NamedTuple{})
    @ Base ./namedtuple.jl:198
  [7] convert
    @ ./namedtuple.jl:185 [inlined]
  [8] convert
    @ ~/ppl/dppl/src/accumulators.jl:176 [inlined]
  [9] setindex!
    @ ./array.jl:987 [inlined]
 [10] map_accumulators!!
    @ ~/ppl/dppl/src/threadsafe.jl:66 [inlined]
 [11] accumulate_assume!!
    @ ~/ppl/dppl/src/abstract_varinfo.jl:311 [inlined]
 [12] assume(dist::Normal{…}, vn::VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
    @ DynamicPPL ~/ppl/dppl/src/context_implementations.jl:127
 [13] tilde_assume
    @ ~/ppl/dppl/src/context_implementations.jl:22 [inlined]
 [14] tilde_assume!!
    @ ~/ppl/dppl/src/context_implementations.jl:69 [inlined]
 [15] f
    @ ./REPL[37]:2 [inlined]
 [16] _evaluate!!(model::Model{…}, varinfo::DynamicPPL.ThreadSafeVarInfo{…})
    @ DynamicPPL ~/ppl/dppl/src/model.jl:921
 [17] evaluate_threadsafe!!
    @ ~/ppl/dppl/src/model.jl:903 [inlined]
 [18] LogDensityAt
    @ ./REPL[39]:3 [inlined]
 [19] trace_function(::Type{SparseConnectivityTracer.Dual{…}}, f::LogDensityAt{Model{…}, VarInfo{…}}, x::Vector{Float64})
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/la0m5/src/trace_functions.jl:48
 [20] _local_hessian_sparsity
    @ ~/.julia/packages/SparseConnectivityTracer/la0m5/src/trace_functions.jl:124 [inlined]
 [21] hessian_sparsity
    @ ~/.julia/packages/SparseConnectivityTracer/la0m5/src/adtypes_interface.jl:202 [inlined]
 [22] hessian_sparsity_with_contexts
    @ ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl:62 [inlined]
 [23] prepare_hessian_nokwarg(::Val{…}, ::LogDensityAt{…}, ::ADTypes.AutoSparse{…}, ::Vector{…})
    @ DifferentiationInterfaceSparseMatrixColoringsExt ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl:29
 [24] hessian(::LogDensityAt{…}, ::ADTypes.AutoSparse{…}, ::Vector{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/L0TGS/src/second_order/hessian.jl:34
 [25] top-level scope
    @ REPL[47]:1
Some type information was truncated. Use `show(err)` to see complete types.

Versioninfo

Fails on both Julia 1.11.7 and 1.12.1.

(ppl) pkg> st
Status `~/ppl/Project.toml`
  [a0c0ee7d] DifferentiationInterface v0.7.9
  [31c24e10] Distributions v0.25.122
  [366bfd00] DynamicPPL v0.37.5 `dppl`
  [f6369f11] ForwardDiff v1.2.2
  [9f842d2f] SparseConnectivityTracer v1.1.1
  [0a514795] SparseMatrixColorings v0.4.22

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions