-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
MWE
using DynamicPPL, Distributions
import DifferentiationInterface as DI
import SparseConnectivityTracer as SCT
import SparseMatrixColorings as SMC
import ForwardDiff
@model function f()
x ~ Normal()
y ~ Normal(x)
end
struct LogDensityAt{M,V}
model::M
varinfo::V
end
function (lda::LogDensityAt)(x)
varinfo_new = DynamicPPL.unflatten(lda.varinfo, x)
# NOTE: Changing this line to `evaluate_threadunsafe!!` makes it work
varinfo_eval = last(DynamicPPL.evaluate_threadsafe!!(lda.model, varinfo_new))
return DynamicPPL.getlogjoint(varinfo_eval)
end
lda = LogDensityAt(f(), DynamicPPL.VarInfo(f()))
# Works
DI.hessian(lda, DI.AutoForwardDiff(), [0.5, 0.5])
# Fails
hess_adtype = DI.AutoSparse(
DI.AutoForwardDiff(),
sparsity_detector=SCT.TracerLocalSparsityDetector(),
coloring_algorithm=SMC.GreedyColoringAlgorithm(),
)
DI.hessian(lda, hess_adtype, [0.5, 0.5])
Traceback
julia> DI.hessian(lda, hess_adtype, [0.5, 0.5])
ERROR: TypeError: in new, expected Float64, got a value of type SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{Int64, BitSet, Dict{Int64, BitSet}, SparseConnectivityTracer.NotShared}}
Stacktrace:
[1] LogPriorAccumulator{Float64}(logp::SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}})
@ DynamicPPL ~/ppl/dppl/src/default_accumulators.jl:86
[2] convert
@ ~/ppl/dppl/src/default_accumulators.jl:66 [inlined]
[3] cvt1
@ ./essentials.jl:612 [inlined]
[4] ntuple
@ ./ntuple.jl:50 [inlined]
[5] convert
@ ./essentials.jl:614 [inlined]
[6] Tuple{…}(nt::@NamedTuple{…})
@ Base ./namedtuple.jl:198
[7] convert
@ ./namedtuple.jl:185 [inlined]
[8] convert
@ ~/ppl/dppl/src/accumulators.jl:176 [inlined]
[9] setindex!
@ ./array.jl:987 [inlined]
[10] map_accumulators!!
@ ~/ppl/dppl/src/threadsafe.jl:66 [inlined]
[11] accumulate_assume!!
@ ~/ppl/dppl/src/abstract_varinfo.jl:311 [inlined]
[12] assume(dist::Normal{…}, vn::VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
@ DynamicPPL ~/ppl/dppl/src/context_implementations.jl:127
[13] tilde_assume
@ ~/ppl/dppl/src/context_implementations.jl:22 [inlined]
[14] tilde_assume!!
@ ~/ppl/dppl/src/context_implementations.jl:69 [inlined]
[15] f
@ ./REPL[37]:2 [inlined]
[16] _evaluate!!(model::Model{…}, varinfo::DynamicPPL.ThreadSafeVarInfo{…})
@ DynamicPPL ~/ppl/dppl/src/model.jl:921
[17] evaluate_threadsafe!!
@ ~/ppl/dppl/src/model.jl:903 [inlined]
[18] LogDensityAt
@ ./REPL[39]:3 [inlined]
[19] trace_function(::Type{SparseConnectivityTracer.Dual{…}}, f::LogDensityAt{Model{…}, VarInfo{…}}, x::Vector{Float64})
@ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/la0m5/src/trace_functions.jl:48
[20] _local_hessian_sparsity
@ ~/.julia/packages/SparseConnectivityTracer/la0m5/src/trace_functions.jl:124 [inlined]
[21] hessian_sparsity
@ ~/.julia/packages/SparseConnectivityTracer/la0m5/src/adtypes_interface.jl:202 [inlined]
[22] hessian_sparsity_with_contexts
@ ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl:62 [inlined]
[23] prepare_hessian_nokwarg(::Val{…}, ::LogDensityAt{…}, ::ADTypes.AutoSparse{…}, ::Vector{…})
@ DifferentiationInterfaceSparseMatrixColoringsExt ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl:29
[24] hessian(::LogDensityAt{…}, ::ADTypes.AutoSparse{…}, ::Vector{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/L0TGS/src/second_order/hessian.jl:34
[25] top-level scope
@ REPL[47]:1
Some type information was truncated. Use `show(err)` to see complete types.
Versioninfo
Fails on both Julia 1.11.7 and 1.12.1.
(ppl) pkg> st
Status `~/ppl/Project.toml`
[a0c0ee7d] DifferentiationInterface v0.7.9
[31c24e10] Distributions v0.25.122
[366bfd00] DynamicPPL v0.37.5 `dppl`
[f6369f11] ForwardDiff v1.2.2
[9f842d2f] SparseConnectivityTracer v1.1.1
[0a514795] SparseMatrixColorings v0.4.22
Metadata
Metadata
Assignees
Labels
No labels