diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 7ea51918f..4b6d3fb41 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -8,12 +8,8 @@ use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false function DynamicPPL.tweak_adtype( - ad::ADTypes.AutoForwardDiff{chunk_size}, - ::DynamicPPL.Model, - vi::DynamicPPL.AbstractVarInfo, + ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, params::AbstractVector ) where {chunk_size} - params = vi[:] - # Use DynamicPPL tag to improve stack traces # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ # NOTE: DifferentiationInterface disables tag checking if the diff --git a/src/chains.jl b/src/chains.jl index f176b8e68..4660a1a31 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -156,13 +156,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.LogDensityFunction, + ldf::DynamicPPL.LogDensityFunction{Tlink}, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, -) +) where {Tlink} strategy = InitFromParams( - VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), nothing, ) accs = if include_log_probs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..b70bf2bf1 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,11 +258,15 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, -) + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} vr = p.params range_and_linked = _get_range_and_linked(vr, vn) - transform = if range_and_linked.is_linked + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked from_linked_vec_transform(dist) else from_vec_transform(dist) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 65eab448e..abfb61c94 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -154,30 +157,52 @@ struct LogDensityFunction{ _adprep::ADP _dim::Int + """ + function LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + link::Union{Bool,Set{VarName}}=false; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + + Generate a `LogDensityFunction` for the given model. + + The `link` argument specifies which VarNames in the model should be linked. This can + either be a Bool (if `link=true` all variables are linked; if `link=false` all variables + are unlinked); or a `Set{VarName}` specifying exactly which variables should be linked. + Any sub-variables of the set's elements will be linked. + """ function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); + link::Union{Bool,Set{VarName}}=false; adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - # Figure out which variable corresponds to which index, and - # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) - x = [val for val in varinfo[:]] + # Run the model once to determine variable ranges and linking. Because the + # parameters stored in the LogDensityFunction are never used, we can just use + # InitFromPrior to create new values. The actual values don't matter, only the + # length, since that's used for gradient prep. + vi = OnlyAccsVarInfo(AccumulatorTuple((RangeLinkedValueAcc(link),))) + _, vi = DynamicPPL.init!!(model, vi, InitFromPrior()) + rlvacc = first(vi.accs) + Tlink, all_iden_ranges, all_ranges, x = get_data(rlvacc) + @info Tlink, all_iden_ranges, all_ranges, x + # That gives us all the information we need to create the LogDensityFunction. dim = length(x) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype - adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + adtype = DynamicPPL.tweak_adtype(adtype, model, x) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -209,15 +234,24 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = fast_ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) @@ -225,9 +259,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -235,10 +269,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -247,12 +281,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) @@ -263,7 +299,7 @@ end tweak_adtype( adtype::ADTypes.AbstractADType, model::Model, - varinfo::AbstractVarInfo, + params::AbstractVector ) Return an 'optimised' form of the adtype. This is useful for doing @@ -274,79 +310,108 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype -###################################################### -# Helper functions to extract ranges and link status # -###################################################### - -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. -""" - get_ranges_and_linked(varinfo::VarInfo) +############################## +# RangeLinkedVal accumulator # +############################## -Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter -representation, along with whether each variable is linked or unlinked. - -This function should return a tuple containing: +struct RangeLinkedValueAcc{L<:Union{Bool,Set{VarName}},N<:NamedTuple} <: AbstractAccumulator + should_link::L + current_index::Int + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} + values::Vector{Any} +end +function RangeLinkedValueAcc(should_link::Union{Bool,Set{VarName}}) + return RangeLinkedValueAcc(should_link, 1, (;), Dict{VarName,RangeAndLinked}(), Any[]) +end -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. -""" -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) +function get_data(rlvacc::RangeLinkedValueAcc) + link_statuses = Bool[] + for ral in rlvacc.iden_varname_ranges + push!(link_statuses, ral.is_linked) end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + for (_, ral) in rlvacc.varname_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end + return ( + Tlink, rlvacc.iden_varname_ranges, rlvacc.varname_ranges, [v for v in rlvacc.values] + ) end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) + +accumulator_name(::Type{<:RangeLinkedValueAcc}) = :RangeLinkedValueAcc +accumulate_observe!!(acc::RangeLinkedValueAcc, dist, val, vn) = acc +function accumulate_assume!!( + acc::RangeLinkedValueAcc, val, logjac, vn::VarName{sym}, dist::Distribution +) where {sym} + link_this_vn = if acc.should_link isa Bool + acc.should_link + else + # Set{VarName} + any(should_link_vn -> subsumes(should_link_vn, vn), acc.should_link) + end + val = if link_this_vn + to_linked_vec_transform(dist)(val) + else + to_vec_transform(dist)(val) end - return all_iden_ranges, all_ranges, offset + new_values = vcat(acc.values, val) + len = length(val) + range = (acc.current_index):(acc.current_index + len - 1) + ral = RangeAndLinked(range, link_this_vn) + iden_varnames, other_varnames = if getoptic(vn) === identity + merge(acc.iden_varname_ranges, (sym => ral,)), acc.varname_ranges + else + acc.varname_ranges[vn] = ral + acc.iden_varname_ranges, acc.varname_ranges + end + return RangeLinkedValueAcc( + acc.should_link, acc.current_index + len, iden_varnames, other_varnames, new_values + ) end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) +function Base.copy(acc::RangeLinkedValueAcc) + return RangeLinkedValueAcc( + acc.should_link, + acc.current_index, + acc.iden_varname_ranges, + copy(acc.varname_ranges), + copy(acc.values), + ) +end +_zero(acc::RangeLinkedValueAcc) = RangeLinkedValueAcc(acc.should_link) +reset(acc::RangeLinkedValueAcc) = _zero(acc) +split(acc::RangeLinkedValueAcc) = _zero(acc) +function combine(acc1::RangeLinkedValueAcc, acc2::RangeLinkedValueAcc) + new_values = vcat(acc1.values, acc2.values) + new_current_index = acc1.current_index + acc2.current_index - 1 + acc2_iden_varnames_shifted = NamedTuple( + k => RangeAndLinked((ral.range .+ (acc1.current_index - 1)), ral.is_linked) for + (k, ral) in pairs(acc2.iden_varname_ranges) + ) + new_iden_varname_ranges = merge(acc1.iden_varname_ranges, acc2_iden_varnames_shifted) + acc2_varname_ranges_shifted = Dict{VarName,RangeAndLinked}() + for (k, ral) in acc2.varname_ranges + acc2_varname_ranges_shifted[k] = RangeAndLinked( + (ral.range .+ (acc1.current_index - 1)), ral.is_linked + ) end - return all_iden_ranges, all_ranges, offset + new_varname_ranges = merge(acc1.varname_ranges, acc2_varname_ranges_shifted) + return RangeLinkedValueAcc( + # TODO: using acc1.should_link is not really 'correct', but `should_link` only + # affects model evaluation and `combine` only runs at the end of model evaluation, + # so it shouldn't matter + acc1.should_link, + new_current_index, + new_iden_varname_ranges, + new_varname_ranges, + new_values, + ) end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 44db66296..b58c3e7bc 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,16 +14,17 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) + x = varinfo[:] @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,11 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => +ADTYPES = ( + ( + "EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), - "EnzymeReverse" => + ), + ( + "EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 06492d6e1..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -108,6 +108,22 @@ end end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when