Skip to content

Generate RangeAndLinked via an accumulator, i.e. threaded assume #1153

@penelopeysm

Description

@penelopeysm

The new implementation of LogDensityFunction still requires a VarInfo to bootstrap it, because the ranges + link status are read from its metadata:

function LogDensityFunction(
model::Model,
getlogdensity::Function=getlogjoint_internal,
varinfo::AbstractVarInfo=VarInfo(model);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,

function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms}
all_iden_ranges = NamedTuple()
all_ranges = Dict{VarName,RangeAndLinked}()
offset = 1

This is fine for now. However, the truth is that this is the only bit of the code (in DynamicPPL) that is holding us back from fully supporting threaded assume (yes!!).

LDF now uses this thing called OnlyAccsVarInfo. Because OAVI only contains accumulators, and because wrapping any VarInfo in a ThreadSafeVarInfo makes accumulators thread-safe*, it follows that ThreadSafeVarInfo{<:OnlyAccsVarInfo} is completely thread-safe.

That means we can support generic tilde-statements inside models, regardless of whether they are assume or observe statements. There's a proof of principle in #1137.

The only thing holding this back is the need for a VarInfo at the very beginning of LDF generation. Creating a full VarInfo is of course not possible with threaded assume.

I would like to get rid of this, and instead use an accumulator to collect ranges and link status. This is quite easy to do:

# we can do the NamedTuple microoptimisation thing here too
struct RangeAndLinkedAcc
    # this could be a varname => bool dict if we want to support mixed linking, but truthfully i don't think there's a use case
    should_be_linked::Bool
    current_index::Int # start at 1
    d::Dict{VarName,RangeAndLinked}
end

accumulate_observe!!(acc::RangeAndLinkedAcc, dist, val, vn) = acc
function accumulate_assume!!(acc::RangeAndLinkedAcc, val, logjac, vn, dist)
    if acc.should_be_linked
        y = to_linked_vec_transform(dist)(val)
        range = acc.current_index : acc.current_index + length(y) - 1
        acc.d[vn] = RangeAndLinked(range, acc.should_be_linked)
        return RangeAndLinkedAcc(acc.should_be_linked, acc.current_index + length(y), acc.d)
    else
        # similar...
    end
end

We would need some clever way to combine multiple of these accumulators for TSVI. But once that's done, it should be completely possible to support threaded assume statements with any sampler that uses OAVI, i.e., Hamiltonian samplers.

* subject to the issue with threadid() indexing

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions