|
| 1 | +module DynamicPPLDifferentiationInterfaceTestExt |
| 2 | + |
| 3 | +using DynamicPPL: |
| 4 | + DynamicPPL, |
| 5 | + ADTypes, |
| 6 | + LogDensityProblems, |
| 7 | + Model, |
| 8 | + DI, # DifferentiationInterface |
| 9 | + AbstractVarInfo, |
| 10 | + VarInfo, |
| 11 | + LogDensityFunction |
| 12 | +import DifferentiationInterfaceTest as DIT |
| 13 | + |
| 14 | +""" |
| 15 | + REFERENCE_ADTYPE |
| 16 | +
|
| 17 | +Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since |
| 18 | +it's the default AD backend used in Turing.jl. |
| 19 | +""" |
| 20 | +const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff() |
| 21 | + |
| 22 | +""" |
| 23 | + make_scenario( |
| 24 | + model::Model, |
| 25 | + adtype::ADTypes.AbstractADType, |
| 26 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 27 | + params::Vector{<:Real}=varinfo[:], |
| 28 | + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
| 29 | + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
| 30 | + ) |
| 31 | +
|
| 32 | +Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`. |
| 33 | +
|
| 34 | +More docs to follow. |
| 35 | +""" |
| 36 | +function make_scenario( |
| 37 | + model::Model, |
| 38 | + adtype::ADTypes.AbstractADType; |
| 39 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 40 | + params::Vector{<:Real}=varinfo[:], |
| 41 | + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
| 42 | + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, |
| 43 | +) |
| 44 | + params = map(identity, params) |
| 45 | + context = DynamicPPL.DefaultContext() |
| 46 | + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context) |
| 47 | + if DynamicPPL.use_closure(adtype) |
| 48 | + f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context) |
| 49 | + di_contexts = () |
| 50 | + else |
| 51 | + f = DynamicPPL.logdensity_at |
| 52 | + di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context)) |
| 53 | + end |
| 54 | + |
| 55 | + # Calculate ground truth to compare against |
| 56 | + grad_true = if expected_grad === nothing |
| 57 | + ldf_reference = LogDensityFunction(model; adtype=reference_adtype) |
| 58 | + LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2] |
| 59 | + else |
| 60 | + expected_grad |
| 61 | + end |
| 62 | + |
| 63 | + return DIT.Scenario{:gradient,:out}( |
| 64 | + f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)" |
| 65 | + ) |
| 66 | +end |
| 67 | + |
| 68 | +function DynamicPPL.TestUtils.AD.run_ad( |
| 69 | + model::Model, |
| 70 | + adtype::ADTypes.AbstractADType; |
| 71 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 72 | + params::Vector{<:Real}=varinfo[:], |
| 73 | + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
| 74 | + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, |
| 75 | + kwargs..., |
| 76 | +) |
| 77 | + scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad) |
| 78 | + tweaked_adtype = DynamicPPL.tweak_adtype( |
| 79 | + adtype, model, varinfo, DynamicPPL.DefaultContext() |
| 80 | + ) |
| 81 | + return DIT.test_differentiation( |
| 82 | + tweaked_adtype, [scen]; scenario_intact=false, kwargs... |
| 83 | + ) |
| 84 | +end |
| 85 | + |
| 86 | +end |
0 commit comments