diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..d5988119a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,6 +19,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/ad.jl b/test/ad.jl index d7505aab2..0236c232f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -5,15 +5,11 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() - test_adtypes = if MOONCAKE_SUPPORTED - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] @testset "Unsupported backends" begin @model demo() = x ~ Normal() @@ -43,13 +39,13 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest # Put predicates here to avoid long lines is_mooncake = adtype isa AutoMooncake is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" + is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" is_svi_vnv = linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv + if is_mooncake && is_1_11_or_1_12 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( m, getlogjoint_internal, linked_varinfo; adtype=adtype diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..5e40635e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ using MacroTools using MCMCChains using StableRNGs using ReverseDiff +using Mooncake using Zygote using Distributed @@ -37,13 +38,6 @@ using DynamicPPL: getargs_dottilde, getargs_tilde const GROUP = get(ENV, "GROUP", "All") const AQUA = get(ENV, "AQUA", "true") == "true" -# Skip Mooncake if it doesn't work -const MOONCAKE_SUPPORTED = VERSION < v"1.12.0" -if MOONCAKE_SUPPORTED - Pkg.add("Mooncake") - using Mooncake: Mooncake -end - Random.seed!(100) include("test_util.jl") @@ -85,9 +79,7 @@ include("test_util.jl") end @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") - if MOONCAKE_SUPPORTED - include("ext/DynamicPPLMooncakeExt.jl") - end + include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end @testset "prob and logprob macro" begin