Skip to content

Accumulators stage 2 #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: breaking
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
@@ -87,7 +87,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend
)
# The parameters at which we evaluate f.
θ = vi[:]

113 changes: 76 additions & 37 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@
"""
LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
context::AbstractContext=DefaultContext();
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
)
@@ -28,10 +29,10 @@
- and if `adtype` is provided, calculate the gradient of the log density at
that point.

At its most basic level, a LogDensityFunction wraps the model together with its
the type of varinfo to be used, as well as the evaluation context. These must
be known in order to calculate the log density (using
[`DynamicPPL.evaluate!!`](@ref)).
At its most basic level, a LogDensityFunction wraps the model together with
the type of varinfo to be used, as well as the evaluation context and a function
to extract the log density from the VarInfo. These must be known in order to
calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
@@ -73,13 +74,13 @@
1

julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # LogDensityFunction respects the accumulators in VarInfo:
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
julia> # One can also specify evaluating e.g. the log prior only:
f_prior = LogDensityFunction(model, getlogprior);

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
@@ -94,11 +95,17 @@
```
"""
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
M<:Model,
F<:Function,
V<:AbstractVarInfo,
C<:AbstractContext,
AD<:Union{Nothing,ADTypes.AbstractADType},
}
"model used for evaluation"
model::M
"varinfo used for evaluation"
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
getlogdensity::F
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
varinfo::V
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
@@ -109,7 +116,8 @@

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
@@ -125,21 +133,28 @@
x = map(identity, varinfo[:])
if use_closure(adtype)
prep = DI.prepare_gradient(
x -> logdensity_at(x, model, varinfo, context), adtype, x
x -> logdensity_at(x, model, getlogdensity, varinfo, context), adtype, x
)
else
prep = DI.prepare_gradient(
logdensity_at,
adtype,
x,
DI.Constant(model),
DI.Constant(getlogdensity),
DI.Constant(varinfo),
DI.Constant(context),
)
end
end
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
model, varinfo, context, adtype, prep
return new{
typeof(model),
typeof(getlogdensity),
typeof(varinfo),
typeof(context),
typeof(adtype),
}(
model, getlogdensity, varinfo, context, adtype, prep
)
end
end
@@ -164,64 +179,87 @@
end
end

"""
ldf_default_varinfo(model::Model, getlogdensity::Function)

Create the default AbstractVarInfo that should be used for evaluating the log density.

Only the accumulators necesessary for `getlogdensity` will be used.
"""
function ldf_default_varinfo(::Model, getlogdensity::Function)
msg = """

Check warning on line 190 in src/logdensityfunction.jl

Codecov / codecov/patch

src/logdensityfunction.jl#L189-L190

Added lines #L189 - L190 were not covered by tests
LogDensityFunction does not know what sort of VarInfo should be used when \
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
"""
return error(msg)

Check warning on line 194 in src/logdensityfunction.jl

Codecov / codecov/patch

src/logdensityfunction.jl#L194

Added line #L194 was not covered by tests
end

ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)

function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
end

function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
end

"""
logdensity_at(
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
context::AbstractContext
)

Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x` are inserted
into it, and its own parameters are discarded. It does, however, determine whether the log
prior, likelihood, or joint is returned, based on which accumulators are set in it.
into it, and its own parameters are discarded. `getlogdensity` is the function that extracts
the log density from the evaluated varinfo.
"""
function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
context::AbstractContext,
)
varinfo_new = unflatten(varinfo, x)
varinfo_eval = last(evaluate!!(model, varinfo_new, context))
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
if has_prior && has_likelihood
return getlogjoint(varinfo_eval)
elseif has_prior
return getlogprior(varinfo_eval)
elseif has_likelihood
return getloglikelihood(varinfo_eval)
else
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
end
return getlogdensity(varinfo_eval)
end

### LogDensityProblems interface

function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,Nothing}}
) where {M,V,C}
::Type{<:LogDensityFunction{M,F,V,C,Nothing}}
) where {M,F,V,C}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,AD}}
) where {M,V,C,AD<:ADTypes.AbstractADType}
::Type{<:LogDensityFunction{M,F,V,C,AD}}
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
return logdensity_at(x, f.model, f.varinfo, f.context)
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context)
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
) where {M,V,C,AD<:ADTypes.AbstractADType}
f::LogDensityFunction{M,F,V,C,AD}, x::AbstractVector
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
DI.value_and_gradient(
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context),

Check warning on line 259 in src/logdensityfunction.jl

Codecov / codecov/patch

src/logdensityfunction.jl#L259

Added line #L259 was not covered by tests
f.prep,
f.adtype,
x,
)
else
DI.value_and_gradient(
@@ -230,6 +268,7 @@
f.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.getlogdensity),
DI.Constant(f.varinfo),
DI.Constant(f.context),
)
@@ -304,7 +343,7 @@
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
return LogDensityFunction(model, f.getlogdensity, f.varinfo, f.context; adtype=f.adtype)
end

"""
8 changes: 6 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
@@ -606,7 +606,9 @@
x = vi.values
y, logjac = with_logabsdet_jacobian(b, x)
vi_new = Accessors.@set(vi.values = y)
vi_new = acclogprior!!(vi_new, -logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, -logjac)

Check warning on line 610 in src/simple_varinfo.jl

Codecov / codecov/patch

src/simple_varinfo.jl#L609-L610

Added lines #L609 - L610 were not covered by tests
end
return settrans!!(vi_new, t)
end

@@ -619,7 +621,9 @@
y = vi.values
x, logjac = with_logabsdet_jacobian(b, y)
vi_new = Accessors.@set(vi.values = x)
vi_new = acclogprior!!(vi_new, logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, logjac)

Check warning on line 625 in src/simple_varinfo.jl

Codecov / codecov/patch

src/simple_varinfo.jl#L624-L625

Added lines #L624 - L625 were not covered by tests
end
return settrans!!(vi_new, NoTransformation())
end

8 changes: 5 additions & 3 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
@@ -184,7 +184,7 @@

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
ldf = LogDensityFunction(model, getlogjoint, varinfo; adtype=adtype)

Check warning on line 187 in src/test_utils/ad.jl

Codecov / codecov/patch

src/test_utils/ad.jl#L187

Added line #L187 was not covered by tests

value, grad = logdensity_and_gradient(ldf, params)
grad = collect(grad)
@@ -193,7 +193,9 @@
if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
ldf_reference = LogDensityFunction(

Check warning on line 196 in src/test_utils/ad.jl

Codecov / codecov/patch

src/test_utils/ad.jl#L196

Added line #L196 was not covered by tests
model, getlogjoint, varinfo; adtype=reference_adtype
)
logdensity_and_gradient(ldf_reference, params)
else
expected_value_and_grad
20 changes: 15 additions & 5 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
@@ -1241,7 +1241,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
setrange!(md, vn, start:(start + length(yvec) - 1))
# Set the new value.
setval!(md, yvec, vn)
vi = acclogprior!!(vi, -logjac)
if hasacc(vi, Val(:LogPrior))
vi = acclogprior!!(vi, -logjac)
end
return vi
end

@@ -1278,7 +1280,9 @@ function _link(model::Model, varinfo::VarInfo, vns)
varinfo = deepcopy(varinfo)
md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

@@ -1292,7 +1296,9 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
varinfo = deepcopy(varinfo)
md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

@@ -1441,7 +1447,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns)
varinfo = deepcopy(varinfo)
md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

@@ -1455,7 +1463,9 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
varinfo = deepcopy(varinfo)
md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

2 changes: 1 addition & 1 deletion src/varname.jl
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.

## Note
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
and similarly to `v`. But this is slow.
"""
12 changes: 9 additions & 3 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -24,10 +24,12 @@ using DynamicPPL: LogDensityFunction

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, linked_varinfo)
f = LogDensityFunction(m, getlogjoint, linked_varinfo)
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
ref_ldf = LogDensityFunction(
m, getlogjoint, linked_varinfo; adtype=ref_adtype
)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in test_adtypes
@@ -106,7 +108,11 @@ using DynamicPPL: LogDensityFunction
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = LogDensityFunction(
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
model,
getlogjoint,
vi,
SamplingContext(spl);
adtype=AutoReverseDiff(; compile=true),
)
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
end
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.