diff --git a/Project.toml b/Project.toml index 2945298..c54c8ca 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.3.0" +version = "1.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 5154899..9640a33 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -92,3 +92,9 @@ ADTypes.ForwardOrReverseMode ADTypes.ReverseMode ADTypes.SymbolicMode ``` + +## Miscellaneous + +```@docs +ADTypes.Auto +``` diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 7fefb2b..56ebaae 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -20,6 +20,7 @@ include("mode.jl") include("dense.jl") include("sparse.jl") include("legacy.jl") +include("symbols.jl") if !isdefined(Base, :get_extension) include("../ext/ADTypesChainRulesCoreExt.jl") diff --git a/src/symbols.jl b/src/symbols.jl new file mode 100644 index 0000000..f349e84 --- /dev/null +++ b/src/symbols.jl @@ -0,0 +1,29 @@ +""" + ADTypes.Auto(package::Symbol) + +A shortcut that converts an AD package name into an instance of [`AbstractADType`](@ref), with all parameters set to their default values. + +!!! warning + + This function is type-unstable by design and might lead to suboptimal performance. + In most cases, you should never need it: use the individual backend types directly. + +# Example + +```jldoctest +import ADTypes +backend = ADTypes.Auto(:Zygote) + +# output + +ADTypes.AutoZygote() +``` +""" +Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) + +for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, + :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, + :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) + @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 23d0d96..e7d72f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,9 @@ end @testset "Sparse" begin include("sparse.jl") end + @testset "Symbols" begin + include("symbols.jl") + end @testset "Legacy" begin include("legacy.jl") end diff --git a/test/symbols.jl b/test/symbols.jl new file mode 100644 index 0000000..bc39c7e --- /dev/null +++ b/test/symbols.jl @@ -0,0 +1,20 @@ +using ADTypes +using Test + +@test ADTypes.Auto(:ChainRules, 1) isa AutoChainRules{Int64} +@test ADTypes.Auto(:Diffractor) isa AutoDiffractor +@test ADTypes.Auto(:Enzyme) isa AutoEnzyme +@test ADTypes.Auto(:FastDifferentiation) isa AutoFastDifferentiation +@test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff +@test ADTypes.Auto(:FiniteDifferences, 1.0) isa AutoFiniteDifferences{Float64} +@test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff +@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff +@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff +@test ADTypes.Auto(:Symbolics) isa AutoSymbolics +@test ADTypes.Auto(:Tapir) isa AutoTapir +@test ADTypes.Auto(:Tracker) isa AutoTracker +@test ADTypes.Auto(:Zygote) isa AutoZygote + +@test_throws MethodError ADTypes.Auto(:ThisPackageDoesNotExist) +@test_throws UndefKeywordError ADTypes.Auto(:ChainRules) +@test_throws UndefKeywordError ADTypes.Auto(:FiniteDifferences)