diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 26ec35b65..54a302a6f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -86,7 +86,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e7565d137..3c092c06b 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model); + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with the -type of varinfo to be used. These must be known in order to calculate the log -density (using [`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with a +function that specifies how to extract the log density, and the type of +VarInfo to be used. These must be known in order to calculate the log density +(using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # LogDensityFunction respects the accumulators in VarInfo: - f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); +julia> # One can also specify evaluating e.g. the log prior only: + f_prior = LogDensityFunction(model, getlogprior); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M - "varinfo used for evaluation" + "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + getlogdensity::F + "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD @@ -106,7 +110,8 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model); + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing @@ -120,15 +125,22 @@ struct LogDensityFunction{ # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x) + prep = DI.prepare_gradient( + LogDensityAt(model, getlogdensity, varinfo), adtype, x + ) else prep = DI.prepare_gradient( - logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo) + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(getlogdensity), + DI.Constant(varinfo), ) end end - return new{typeof(model),typeof(varinfo),typeof(adtype)}( - model, varinfo, adtype, prep + return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( + model, getlogdensity, varinfo, adtype, prep ) end end @@ -149,83 +161,112 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo; adtype=adtype) + LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) end end +""" + ldf_default_varinfo(model::Model, getlogdensity::Function) + +Create the default AbstractVarInfo that should be used for evaluating the log density. + +Only the accumulators necesessary for `getlogdensity` will be used. +""" +function ldf_default_varinfo(::Model, getlogdensity::Function) + msg = """ + LogDensityFunction does not know what sort of VarInfo should be used when \ + `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. + """ + return error(msg) +end + +ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) +end + +function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) + return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +end + """ logdensity_at( x::AbstractVector, model::Model, + getlogdensity::Function, varinfo::AbstractVarInfo, ) -Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo`. Note that the `varinfo` argument is provided only -for its structure, in the sense that the parameters from the vector `x` are -inserted into it, and its own parameters are discarded. It does, however, -determine whether the log prior, likelihood, or joint is returned, based on -which accumulators are set in it. +Evaluate the log density of the given `model` at the given parameter values +`x`, using the given `varinfo`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` +are inserted into it, and its own parameters are discarded. `getlogdensity` is +the function that extracts the log density from the evaluated varinfo. """ -function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo) +function logdensity_at( + x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo +) varinfo_new = unflatten(varinfo, x) varinfo_eval = last(evaluate!!(model, varinfo_new)) - has_prior = hasacc(varinfo_eval, Val(:LogPrior)) - has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) - if has_prior && has_likelihood - return getlogjoint(varinfo_eval) - elseif has_prior - return getlogprior(varinfo_eval) - elseif has_likelihood - return getloglikelihood(varinfo_eval) - else - error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") - end + return getlogdensity(varinfo_eval) end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo}( + LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( model::M + getlogdensity::F, varinfo::V ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo)`. +getlogdensity, varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo} +struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} model::M + getlogdensity::F varinfo::V end -(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo) +function (ld::LogDensityAt)(x::AbstractVector) + return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) +end ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,Nothing}} -) where {M,V} + ::Type{<:LogDensityFunction{M,F,V,Nothing}} +) where {M,F,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,AD}} -) where {M,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,F,V,AD}} +) where {M,F,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo) + return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,AD}, x::AbstractVector -) where {M,V,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,F,V,AD}, x::AbstractVector +) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x) + DI.value_and_gradient( + LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x + ) else DI.value_and_gradient( - logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo) + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.getlogdensity), + DI.Constant(f.varinfo), ) end end @@ -264,9 +305,9 @@ There are two ways of dealing with this: 1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) -2. Use a constant context. This lets us pass a two-argument function to - DifferentiationInterface, as long as we also give it the 'inactive argument' - (i.e. the model) wrapped in `DI.Constant`. +2. Use a constant DI.Context. This lets us pass a two-argument function to DI, + as long as we also give it the 'inactive argument' (i.e. the model) wrapped + in `DI.Constant`. The relative performance of the two approaches, however, depends on the AD backend used. Some benchmarks are provided here: @@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo; adtype=f.adtype) + return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ea371c7da..5dabcc361 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -613,7 +613,9 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - vi_new = acclogprior!!(vi_new, -logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, -logjac) + end return settrans!!(vi_new, t) end @@ -626,7 +628,9 @@ function invlink!!( y = vi.values x, logjac = with_logabsdet_jacobian(b, y) vi_new = Accessors.@set(vi.values = x) - vi_new = acclogprior!!(vi_new, logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, logjac) + end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 5285391b1..cbcd44d49 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,14 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, - LogDensityFunction, - VarInfo, - AbstractVarInfo, - link, - DefaultContext, - AbstractContext +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median @@ -58,6 +51,8 @@ $(TYPEDFIELDS) struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} "The DynamicPPL model that was tested" model::Model + "The function used to extract the log density from the model" + getlogdensity::Function "The VarInfo that was used" varinfo::AbstractVarInfo "The values at which the model was evaluated" @@ -184,6 +179,7 @@ function run_ad( benchmark::Bool=false, value_atol::AbstractFloat=1e-6, grad_atol::AbstractFloat=1e-6, + getlogdensity::Function=getlogjoint, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::AbstractADType=REFERENCE_ADTYPE, @@ -197,7 +193,7 @@ function run_ad( verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) grad = collect(grad) @@ -206,7 +202,9 @@ function run_ad( if test # Calculate ground truth to compare against value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) + ldf_reference = LogDensityFunction( + model, getlogdensity, varinfo; adtype=reference_adtype + ) logdensity_and_gradient(ldf_reference, params) else expected_value_and_grad @@ -234,6 +232,7 @@ function run_ad( return ADResult( model, + getlogdensity, varinfo, params, adtype, diff --git a/src/varinfo.jl b/src/varinfo.jl index b3380e7f9..71364a854 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1154,7 +1154,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, -logjac) + end return vi end @@ -1191,7 +1193,9 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1205,7 +1209,9 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1354,7 +1360,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1368,7 +1376,9 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end diff --git a/src/varname.jl b/src/varname.jl index c16587065..3eb1f2460 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really - Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. ## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` +- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, and similarly to `v`. But this is slow. """ diff --git a/test/ad.jl b/test/ad.jl index 0947c017a..510da82c2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -29,10 +29,12 @@ using DynamicPPL: LogDensityFunction @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, linked_varinfo) + f = LogDensityFunction(m, getlogjoint, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction( + m, getlogjoint, linked_varinfo; adtype=ref_adtype + ) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes @@ -109,9 +111,10 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) + ldf = LogDensityFunction( + sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true) + ) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index d6e66ec59..c4d0d6beb 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -15,8 +15,18 @@ end vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + vi = first(varinfos) + theta = vi[:] + ldf_joint = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) + ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) + @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) + ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) + @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ + loglikelihood(model, vi) + @testset "$(varinfo)" for varinfo in varinfos - logdensity = DynamicPPL.LogDensityFunction(model, varinfo) + logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ)