diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 89223f597..e490a0a7c 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -28,29 +28,30 @@ jobs: fail-fast: true # TODO: toggle matrix: version: - - '1.10' + # - '1.10' - '1.11' - '1.12' group: - - Core/Internals - - Back/DifferentiateWith - - Core/SimpleFiniteDiff - - Back/SparsityDetector - - Core/ZeroBackends - - Back/ChainRules - # - Back/Diffractor - - Back/Enzyme - - Back/FastDifferentiation - - Back/FiniteDiff - - Back/FiniteDifferences - - Back/ForwardDiff - - Back/GTPSA - - Back/Mooncake - - Back/PolyesterForwardDiff - - Back/ReverseDiff - - Back/Symbolics - - Back/Tracker - - Back/Zygote + # - Core/Internals + # - Back/DifferentiateWith + # - Core/SimpleFiniteDiff + # - Back/SparsityDetector + # - Core/ZeroBackends + # - Back/ChainRules + # # - Back/Diffractor + # - Back/Enzyme + # - Back/FastDifferentiation + # - Back/FiniteDiff + # - Back/FiniteDifferences + # - Back/ForwardDiff + # - Back/GTPSA + # - Back/Mooncake + # - Back/PolyesterForwardDiff + - Back/Reactant + # - Back/ReverseDiff + # - Back/Symbolics + # - Back/Tracker + # - Back/Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -64,6 +65,8 @@ jobs: group: Back/ChainRules - version: '1.12' group: Back/Enzyme + - version: '1.12' + group: Back/Reactant - version: '1.12' group: Back/DifferentiateWith env: @@ -104,61 +107,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - test-DIT: - name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: true - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Formalities - - Zero - - Standard - - Weird - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v5 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - using Pkg; - Pkg.Registry.update(); - Pkg.develop(path="./DifferentiationInterface"); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - else; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DIT - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DIT: + # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 60 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: true + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Formalities + # - Zero + # - Standard + # - Weird + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v5 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + # using Pkg; + # Pkg.Registry.update(); + # Pkg.develop(path="./DifferentiationInterface"); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + # else; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DIT + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ded9bd6c3..1cbea565f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -21,6 +21,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" @@ -46,6 +47,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [ "ForwardDiff", "DiffResults", ] +DifferentiationInterfaceReactantExt = ["Reactant", "Enzyme"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -56,7 +58,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.18.0" +ADTypes = "1.19.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" @@ -71,6 +73,7 @@ GTPSA = "1.4.0" LinearAlgebra = "1" Mooncake = "0.4.175" PolyesterForwardDiff = "0.1.2" +Reactant = "0.2.178" ReverseDiff = "1.15.1" SparseArrays = "1" SparseConnectivityTracer = "0.6.14, 1" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 0da5201a9..58abbfa09 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -19,6 +19,8 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoTracker`](@extref ADTypes.AutoTracker) - [`AutoZygote`](@extref ADTypes.AutoZygote) +In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants. + ## Features Given a backend object, you can use: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl new file mode 100644 index 000000000..7740b53a6 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl @@ -0,0 +1,13 @@ +module DifferentiationInterfaceReactantExt + +using ADTypes: ADTypes, AutoReactant +import DifferentiationInterface as DI +using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray + +DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode) +DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode) + +include("utils.jl") +include("onearg.jl") + +end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl new file mode 100644 index 000000000..af82235a8 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl @@ -0,0 +1,80 @@ +struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + xr::XR + gr::GR + compiled_gradient::CG + compiled_gradient!::CG! + compiled_value_and_gradient::CVG + compiled_value_and_gradient!::CVG! +end + +function DI.prepare_gradient_nokwarg( + strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + _sig = DI.signature(f, rebackend, x; strict) + backend = rebackend.mode + xr = to_reac(x) + gr = to_reac(similar(x)) + contextsr = map(to_reac, contexts) + compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...) + compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...) + compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...) + compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...) + return ReactantGradientPrep( + _sig, + xr, + gr, + compiled_gradient, + compiled_gradient!, + compiled_value_and_gradient, + compiled_value_and_gradient!, + ) +end + +function DI.gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + (; xr, compiled_gradient) = prep + copyto!(xr, x) + contextsr = map(to_reac, contexts) + gr = compiled_gradient(f, backend, xr, contextsr...) + return gr +end + +function DI.value_and_gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + (; xr, compiled_value_and_gradient) = prep + copyto!(xr, x) + contextsr = map(to_reac, contexts) + yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...) + return yr, gr +end + +function DI.gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + (; xr, gr, compiled_gradient!) = prep + copyto!(xr, x) + contextsr = map(to_reac, contexts) + compiled_gradient!(f, gr, backend, xr, contextsr...) + return copyto!(grad, gr) +end + +function DI.value_and_gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + (; xr, gr, compiled_value_and_gradient!) = prep + copyto!(xr, x) + contextsr = map(to_reac, contexts) + yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) + return yr, copyto!(grad, gr) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl new file mode 100644 index 000000000..7c22e9a4c --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl @@ -0,0 +1,7 @@ +to_reac(x::AbstractArray) = to_rarray(x) +to_reac(x::ConcreteRArray) = x +to_reac(x::Number) = ConcreteRNumber(x) +to_reac(x::ConcreteRNumber) = x + +to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c))) +to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c))) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 4a07e7301..393b9051a 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -30,6 +30,7 @@ using ADTypes: AutoMooncake, AutoMooncakeForward, AutoPolyesterForwardDiff, + AutoReactant, AutoReverseDiff, AutoSymbolics, AutoTracker, @@ -118,6 +119,7 @@ export AutoGTPSA export AutoMooncake export AutoMooncakeForward export AutoPolyesterForwardDiff +export AutoReactant export AutoReverseDiff export AutoSymbolics export AutoTracker diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl new file mode 100644 index 000000000..b0640498a --- /dev/null +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -0,0 +1,21 @@ +using Pkg +Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl") +Pkg.add("Reactant") + +using DifferentiationInterface +using DifferentiationInterfaceTest +using Reactant +using Test + +backend = AutoReactant() + +@test check_available(backend) +@test check_inplace(backend) + +test_differentiation( + backend, DifferentiationInterfaceTest.default_scenarios(; + include_constantified = true, include_cachified = false + ); + excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), + logging = false +) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index a417d9eac..cb718f78c 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -10,6 +10,7 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -40,6 +41,7 @@ DataFrames = "1.6.1" DifferentiationInterface = "0.7.7" DocStringExtensions = "0.8,0.9" ForwardDiff = "0.10.36,1" +GPUArraysCore = "0.2.0" JET = "0.9,0.10,0.11" JLArrays = "0.1,0.2,0.3" LinearAlgebra = "1" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 51b216915..84d9d2b38 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -92,6 +92,7 @@ using DifferentiationInterface: using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES +using GPUArraysCore: @allowscalar using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent using PrecompileTools: @compile_workload diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 5761eb346..bef174384 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -224,8 +224,8 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries y = sc.f(x) if y isa Number - y_cache[1] = y - return y_cache[1] + @allowscalar y_cache[1] = y + return @allowscalar y_cache[1] else copyto!(y_cache, y) return copy(y_cache)