-
Notifications
You must be signed in to change notification settings - Fork 37
Description
The new implementation of LogDensityFunction still requires a VarInfo to bootstrap it, because the ranges + link status are read from its metadata:
DynamicPPL.jl/src/logdensityfunction.jl
Lines 157 to 161 in 766f663
| function LogDensityFunction( | |
| model::Model, | |
| getlogdensity::Function=getlogjoint_internal, | |
| varinfo::AbstractVarInfo=VarInfo(model); | |
| adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, |
DynamicPPL.jl/src/logdensityfunction.jl
Lines 299 to 302 in 766f663
| function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} | |
| all_iden_ranges = NamedTuple() | |
| all_ranges = Dict{VarName,RangeAndLinked}() | |
| offset = 1 |
This is fine for now. However, the truth is that this is the only bit of the code (in DynamicPPL) that is holding us back from fully supporting threaded assume (yes!!).
LDF now uses this thing called OnlyAccsVarInfo. Because OAVI only contains accumulators, and because wrapping any VarInfo in a ThreadSafeVarInfo makes accumulators thread-safe*, it follows that ThreadSafeVarInfo{<:OnlyAccsVarInfo} is completely thread-safe.
That means we can support generic tilde-statements inside models, regardless of whether they are assume or observe statements. There's a proof of principle in #1137.
The only thing holding this back is the need for a VarInfo at the very beginning of LDF generation. Creating a full VarInfo is of course not possible with threaded assume.
I would like to get rid of this, and instead use an accumulator to collect ranges and link status. This is quite easy to do:
# we can do the NamedTuple microoptimisation thing here too
struct RangeAndLinkedAcc
# this could be a varname => bool dict if we want to support mixed linking, but truthfully i don't think there's a use case
should_be_linked::Bool
current_index::Int # start at 1
d::Dict{VarName,RangeAndLinked}
end
accumulate_observe!!(acc::RangeAndLinkedAcc, dist, val, vn) = acc
function accumulate_assume!!(acc::RangeAndLinkedAcc, val, logjac, vn, dist)
if acc.should_be_linked
y = to_linked_vec_transform(dist)(val)
range = acc.current_index : acc.current_index + length(y) - 1
acc.d[vn] = RangeAndLinked(range, acc.should_be_linked)
return RangeAndLinkedAcc(acc.should_be_linked, acc.current_index + length(y), acc.d)
else
# similar...
end
endWe would need some clever way to combine multiple of these accumulators for TSVI. But once that's done, it should be completely possible to support threaded assume statements with any sampler that uses OAVI, i.e., Hamiltonian samplers.
* subject to the issue with threadid() indexing