Skip to content

Commit a197dd8

Browse files
authored
Merge 70e1aa9 into eed80e5
2 parents eed80e5 + 70e1aa9 commit a197dd8

File tree

5 files changed

+99
-5
lines changed

5 files changed

+99
-5
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2626

2727
[weakdeps]
2828
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
29+
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
2930
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -35,6 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3536

3637
[extensions]
3738
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
39+
DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterfaceTest"]
3840
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3941
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4042
DynamicPPLJETExt = ["JET"]
@@ -52,6 +54,7 @@ ChainRulesCore = "1"
5254
Compat = "4"
5355
ConstructionBase = "1.5.4"
5456
DifferentiationInterface = "0.6.41"
57+
DifferentiationInterfaceTest = "0.9.6"
5558
Distributions = "0.25"
5659
DocStringExtensions = "0.9"
5760
EnzymeCore = "0.6 - 0.8"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module DynamicPPLDifferentiationInterfaceTestExt
2+
3+
using DynamicPPL:
4+
DynamicPPL,
5+
ADTypes,
6+
LogDensityProblems,
7+
Model,
8+
DI, # DifferentiationInterface
9+
AbstractVarInfo,
10+
VarInfo,
11+
LogDensityFunction
12+
import DifferentiationInterfaceTest as DIT
13+
14+
"""
15+
REFERENCE_ADTYPE
16+
17+
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
18+
it's the default AD backend used in Turing.jl.
19+
"""
20+
const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff()
21+
22+
"""
23+
make_scenario(
24+
model::Model,
25+
adtype::ADTypes.AbstractADType,
26+
varinfo::AbstractVarInfo=VarInfo(model),
27+
params::Vector{<:Real}=varinfo[:],
28+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
29+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
30+
)
31+
32+
Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`.
33+
34+
More docs to follow.
35+
"""
36+
function make_scenario(
37+
model::Model,
38+
adtype::ADTypes.AbstractADType;
39+
varinfo::AbstractVarInfo=VarInfo(model),
40+
params::Vector{<:Real}=varinfo[:],
41+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
42+
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
43+
)
44+
params = map(identity, params)
45+
context = DynamicPPL.DefaultContext()
46+
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context)
47+
if DynamicPPL.use_closure(adtype)
48+
f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context)
49+
di_contexts = ()
50+
else
51+
f = DynamicPPL.logdensity_at
52+
di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context))
53+
end
54+
55+
# Calculate ground truth to compare against
56+
grad_true = if expected_grad === nothing
57+
ldf_reference = LogDensityFunction(model; adtype=reference_adtype)
58+
LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2]
59+
else
60+
expected_grad
61+
end
62+
63+
return DIT.Scenario{:gradient,:out}(
64+
f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)"
65+
)
66+
end
67+
68+
function DynamicPPL.TestUtils.AD.run_ad(
69+
model::Model,
70+
adtype::ADTypes.AbstractADType;
71+
varinfo::AbstractVarInfo=VarInfo(model),
72+
params::Vector{<:Real}=varinfo[:],
73+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
74+
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
75+
kwargs...,
76+
)
77+
scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad)
78+
tweaked_adtype = DynamicPPL.tweak_adtype(
79+
adtype, model, varinfo, DynamicPPL.DefaultContext()
80+
)
81+
return DIT.test_differentiation(
82+
tweaked_adtype, [scen]; scenario_intact=false, kwargs...
83+
)
84+
end
85+
86+
end

src/test_utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,8 @@ include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
2121

22+
module AD
23+
function run_ad end
24+
end
25+
2226
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
88
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
99
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1010
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
11+
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
1112
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

test/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DynamicPPL: LogDensityFunction
2+
import DifferentiationInterfaceTest as DIT
23

34
@testset "Automatic differentiation" begin
45
# Used as the ground truth that others are compared against.
@@ -27,7 +28,7 @@ using DynamicPPL: LogDensityFunction
2728
x = DynamicPPL.getparams(f)
2829
# Calculate reference logp + gradient of logp using ForwardDiff
2930
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
30-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
31+
ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)[2]
3132

3233
@testset "$adtype" for adtype in test_adtypes
3334
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
@@ -56,10 +57,9 @@ using DynamicPPL: LogDensityFunction
5657
ref_ldf, adtype
5758
)
5859
else
59-
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
60-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
61-
@test grad ref_grad
62-
@test logp ref_logp
60+
DynamicPPL.TestUtils.AD.run_ad(
61+
m, adtype; varinfo=varinfo, expected_grad=ref_grad
62+
)
6363
end
6464
end
6565
end

0 commit comments

Comments
 (0)