Skip to content

LogDensityFunction with dimension unpreserving transformations #1149

@Red-Portal

Description

@Red-Portal

Hi, I noted a problem with LogDensityFunction. In principle, I believe LogDensityFunction is supposed to be used in both constrained and unconstrained space, depending on the configuration conveyed by vi. The problem is that certain transformations are volume-preserving but dimension-non-preserving. For instance, the transformation associated with a $k$-dimensional Dirichlet maps between a $k$-dimensional simplex and a $k-1$ dimensional real space. This fact seems to cause an error when using LogDensityFunction with a vi with transformed turned on. Here is a basic example:

using DynamicPPL
using ReverseDiff
using Distributions
using LogDensityProblems
using ADTypes, AutoReverseDiff

DynamicPPL.@model function dirichlet()
    x ~ Dirichlet([1.0, 1.0])
    return x
end

m = dirichlet()
vi = DynamicPPL.ldf_default_varinfo(m, DynamicPPL.getlogjoint_internal)
vi = DynamicPPL.set_transformed!!(vi, true)
prob = DynamicPPL.LogDensityFunction(
    m, DynamicPPL.getlogjoint_internal, vi; adtype=AutoReverseDiff()
)

The correct dimensionality of prob is now 1, not 2.
However, the following assert fails:

@assert  LogDensityProblems.dimension(prob) == 1

The following density evaluation also fails:

LogDensityProblems.logdensity(prob, [0.0,])

This returns the following:

ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [1:2]
Stacktrace:
 [1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
   @ Base ./essentials.jl:14
 [2] checkbounds
   @ ./abstractarray.jl:699 [inlined]
 [3] getindex
   @ ./array.jl:936 [inlined]
 [4] macro expansion
   @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:0 [inlined]
 [5] unflatten_metadata(metadata::@NamedTuple{x::DynamicPPL.Metadata{…}}, x::Vector{Float64})
   @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:389
 [6] unflatten
   @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:369 [inlined]
 [7] logdensity_at
   @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/logdensityfunction.jl:242 [inlined]
 [8] logdensity(f::LogDensityFunction{…}, x::Vector{…})
   @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/logdensityfunction.jl:279
 [9] top-level scope
   @ REPL[45]:1
Some type information was truncated. Use `show(err)` to see complete types.

More importantly, the following also fails:

LogDensityProblems.logdensity(prob, [0.0, 1.0])

by returning the following:

ERROR: DimensionMismatch: inconsistent array dimensions
Stacktrace:
  [1] logpdf
    @ ~/.julia/packages/Distributions/psM3H/src/common.jl:266 [inlined]
  [2] accumulate_assume!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/default_accumulators.jl:95 [inlined]
  [3] #64
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/abstract_varinfo.jl:311 [inlined]
  [4] map
    @ ./tuple.jl:357 [inlined]
  [5] map
    @ ./namedtuple.jl:266 [inlined]
  [6] map
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/accumulators.jl:204 [inlined]
  [7] map_accumulators!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/abstract_varinfo.jl:330 [inlined]
  [8] accumulate_assume!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/abstract_varinfo.jl:311 [inlined]
  [9] tilde_assume!!(::DefaultContext, right::Dirichlet{…}, vn::VarName{…}, vi::VarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/contexts/default.jl:36
 [10] dirichlet(__model__::Model{…}, __varinfo__::VarInfo{…})
    @ Main ./REPL[40]:2
 [11] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:974 [inlined]
 [12] evaluate_threadunsafe!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:940 [inlined]
 [13] evaluate!!(model::Model{…}, varinfo::VarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:925
 [14] logdensity_at
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/logdensityfunction.jl:243 [inlined]
 [15] logdensity(f::LogDensityFunction{…}, x::Vector{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/logdensityfunction.jl:279
 [16] top-level scope
    @ REPL[49]:1
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions