Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand Down
19 changes: 3 additions & 16 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@ using Random: Random
using StableRNGs: StableRNG
using Test
using ..Models: gdemo_default
import ForwardDiff, ReverseDiff

# Skip Mooncake on 1.12 as it is not compatible yet
const INCLUDE_MOONCAKE = VERSION < v"1.12"
if INCLUDE_MOONCAKE
import Pkg
Pkg.add("Mooncake")
using Mooncake: Mooncake
end
import ForwardDiff, ReverseDiff, Mooncake

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
Expand All @@ -33,10 +25,8 @@ eltypes_by_adtype = Dict(
ReverseDiff.TrackedVecOrMat,
ReverseDiff.TrackedVector,
),
AutoMooncake => (Mooncake.CoDual,),
)
if INCLUDE_MOONCAKE
eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,)
end

"""
AbstractWrongADBackendError
Expand Down Expand Up @@ -177,10 +167,7 @@ end
"""
All the ADTypes on which we want to run the tests.
"""
ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
if INCLUDE_MOONCAKE
push!(ADTYPES, AutoMooncake(; config=nothing))
end
ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoMooncake()]

# Check that ADTypeCheckContext itself works as expected.
@testset "ADTypeCheckContext" begin
Expand Down