Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
18 changes: 7 additions & 11 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
# Used as the ground truth that others are compared against.
ref_adtype = AutoForwardDiff()

test_adtypes = if MOONCAKE_SUPPORTED
[
AutoReverseDiff(; compile=false),
AutoReverseDiff(; compile=true),
AutoMooncake(; config=nothing),
]
else
[AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
end
test_adtypes = [
AutoReverseDiff(; compile=false),
AutoReverseDiff(; compile=true),
AutoMooncake(; config=nothing),
]

@testset "Unsupported backends" begin
@model demo() = x ~ Normal()
Expand Down Expand Up @@ -43,13 +39,13 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
# Put predicates here to avoid long lines
is_mooncake = adtype isa AutoMooncake
is_1_10 = v"1.10" <= VERSION < v"1.11"
is_1_11 = v"1.11" <= VERSION < v"1.12"
is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13"
is_svi_vnv =
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}

# Mooncake doesn't work with several combinations of SimpleVarInfo.
if is_mooncake && is_1_11 && is_svi_vnv
if is_mooncake && is_1_11_or_1_12 && is_svi_vnv
# https://github.com/compintell/Mooncake.jl/issues/470
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
m, getlogjoint_internal, linked_varinfo; adtype=adtype
Expand Down
12 changes: 2 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using MacroTools
using MCMCChains
using StableRNGs
using ReverseDiff
using Mooncake
using Zygote

using Distributed
Expand All @@ -37,13 +38,6 @@ using DynamicPPL: getargs_dottilde, getargs_tilde
const GROUP = get(ENV, "GROUP", "All")
const AQUA = get(ENV, "AQUA", "true") == "true"

# Skip Mooncake if it doesn't work
const MOONCAKE_SUPPORTED = VERSION < v"1.12.0"
if MOONCAKE_SUPPORTED
Pkg.add("Mooncake")
using Mooncake: Mooncake
end

Random.seed!(100)
include("test_util.jl")

Expand Down Expand Up @@ -85,9 +79,7 @@ include("test_util.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
if MOONCAKE_SUPPORTED
include("ext/DynamicPPLMooncakeExt.jl")
end
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
@testset "prob and logprob macro" begin
Expand Down