Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 81 additions & 78 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,30 @@ jobs:
fail-fast: true # TODO: toggle
matrix:
version:
- '1.10'
# - '1.10'
- '1.11'
- '1.12'
group:
- Core/Internals
- Back/DifferentiateWith
- Core/SimpleFiniteDiff
- Back/SparsityDetector
- Core/ZeroBackends
- Back/ChainRules
# - Back/Diffractor
- Back/Enzyme
- Back/FastDifferentiation
- Back/FiniteDiff
- Back/FiniteDifferences
- Back/ForwardDiff
- Back/GTPSA
- Back/Mooncake
- Back/PolyesterForwardDiff
- Back/ReverseDiff
- Back/Symbolics
- Back/Tracker
- Back/Zygote
# - Core/Internals
# - Back/DifferentiateWith
# - Core/SimpleFiniteDiff
# - Back/SparsityDetector
# - Core/ZeroBackends
# - Back/ChainRules
# # - Back/Diffractor
# - Back/Enzyme
# - Back/FastDifferentiation
# - Back/FiniteDiff
# - Back/FiniteDifferences
# - Back/ForwardDiff
# - Back/GTPSA
# - Back/Mooncake
# - Back/PolyesterForwardDiff
- Back/Reactant
# - Back/ReverseDiff
# - Back/Symbolics
# - Back/Tracker
# - Back/Zygote
skip_lts:
- ${{ github.event.pull_request.draft }}
skip_pre:
Expand All @@ -64,6 +65,8 @@ jobs:
group: Back/ChainRules
- version: '1.12'
group: Back/Enzyme
- version: '1.12'
group: Back/Reactant
- version: '1.12'
group: Back/DifferentiateWith
env:
Expand Down Expand Up @@ -104,61 +107,61 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false

test-DIT:
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
timeout-minutes: 60
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
actions: write
contents: read
strategy:
fail-fast: true
matrix:
version:
- '1.10'
- '1.11'
- '1.12'
group:
- Formalities
- Zero
- Standard
- Weird
skip_lts:
- ${{ github.event.pull_request.draft }}
skip_pre:
- ${{ github.event.pull_request.draft }}
exclude:
- skip_lts: true
version: '1.10'
- skip_pre: true
version: '1.12'
env:
JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v5
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: x64
- uses: julia-actions/cache@v2
- name: Install dependencies & run tests
run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
using Pkg;
Pkg.Registry.update();
Pkg.develop(path="./DifferentiationInterface");
if ENV["JULIA_DI_PR_DRAFT"] == "true";
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
else;
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
end;'
- uses: julia-actions/julia-processcoverage@v1
with:
directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
- uses: codecov/codecov-action@v5
with:
files: lcov.info
flags: DIT
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
# test-DIT:
# name: ${{ matrix.version }} - DIT (${{ matrix.group }})
# runs-on: ubuntu-latest
# if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
# timeout-minutes: 60
# permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
# actions: write
# contents: read
# strategy:
# fail-fast: true
# matrix:
# version:
# - '1.10'
# - '1.11'
# - '1.12'
# group:
# - Formalities
# - Zero
# - Standard
# - Weird
# skip_lts:
# - ${{ github.event.pull_request.draft }}
# skip_pre:
# - ${{ github.event.pull_request.draft }}
# exclude:
# - skip_lts: true
# version: '1.10'
# - skip_pre: true
# version: '1.12'
# env:
# JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
# JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
# steps:
# - uses: actions/checkout@v5
# - uses: julia-actions/setup-julia@v2
# with:
# version: ${{ matrix.version }}
# arch: x64
# - uses: julia-actions/cache@v2
# - name: Install dependencies & run tests
# run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
# using Pkg;
# Pkg.Registry.update();
# Pkg.develop(path="./DifferentiationInterface");
# if ENV["JULIA_DI_PR_DRAFT"] == "true";
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
# else;
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
# end;'
# - uses: julia-actions/julia-processcoverage@v1
# with:
# directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
# - uses: codecov/codecov-action@v5
# with:
# files: lcov.info
# flags: DIT
# token: ${{ secrets.CODECOV_TOKEN }}
# fail_ci_if_error: false
5 changes: 4 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
Expand All @@ -46,6 +47,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [
"ForwardDiff",
"DiffResults",
]
DifferentiationInterfaceReactantExt = ["Reactant", "Enzyme"]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
Expand All @@ -56,7 +58,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.18.0"
ADTypes = "1.19.0"
ChainRulesCore = "1.23.0"
DiffResults = "1.1.0"
Diffractor = "=0.2.6"
Expand All @@ -71,6 +73,7 @@ GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.4.175"
PolyesterForwardDiff = "0.1.2"
Reactant = "0.2.178"
ReverseDiff = "1.15.1"
SparseArrays = "1"
SparseConnectivityTracer = "0.6.14, 1"
Expand Down
2 changes: 2 additions & 0 deletions DifferentiationInterface/docs/src/explanation/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
- [`AutoTracker`](@extref ADTypes.AutoTracker)
- [`AutoZygote`](@extref ADTypes.AutoZygote)

In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants.

## Features

Given a backend object, you can use:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DifferentiationInterfaceReactantExt

using ADTypes: ADTypes, AutoReactant
import DifferentiationInterface as DI
using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray

DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode)
DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode)

include("utils.jl")
include("onearg.jl")

end # module
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG}
_sig::Val{SIG}
xr::XR
gr::GR
compiled_gradient::CG
compiled_gradient!::CG!
compiled_value_and_gradient::CVG
compiled_value_and_gradient!::CVG!
end

function DI.prepare_gradient_nokwarg(
strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f, rebackend, x; strict)
backend = rebackend.mode
xr = to_reac(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't save anything as a prep argument if a reactant array, I would keep this as if reactant array then xr is nothing otherwise to_rarray(x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable

gr = to_reac(similar(x))
contextsr = map(to_reac, contexts)
compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...)
compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...)
compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...)
compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...)
return ReactantGradientPrep(
_sig,
xr,
gr,
compiled_gradient,
compiled_gradient!,
compiled_value_and_gradient,
compiled_value_and_gradient!,
)
end

function DI.gradient(
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_gradient) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if x is not a reactantarray

contextsr = map(to_reac, contexts)
gr = compiled_gradient(f, backend, xr, contextsr...)
return gr
end

function DI.value_and_gradient(
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_value_and_gradient) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

contextsr = map(to_reac, contexts)
yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...)
return yr, gr
end

function DI.gradient!(
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, gr, compiled_gradient!) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Etc

contextsr = map(to_reac, contexts)
compiled_gradient!(f, gr, backend, xr, contextsr...)
return copyto!(grad, gr)
end

function DI.value_and_gradient!(
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, gr, compiled_value_and_gradient!) = prep
copyto!(xr, x)
contextsr = map(to_reac, contexts)
yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...)
return yr, copyto!(grad, gr)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
to_reac(x::AbstractArray) = to_rarray(x)
to_reac(x::ConcreteRArray) = x
to_reac(x::Number) = ConcreteRNumber(x)
to_reac(x::ConcreteRNumber) = x

to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c)))
to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c)))
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using ADTypes:
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReactant,
AutoReverseDiff,
AutoSymbolics,
AutoTracker,
Expand Down Expand Up @@ -118,6 +119,7 @@ export AutoGTPSA
export AutoMooncake
export AutoMooncakeForward
export AutoPolyesterForwardDiff
export AutoReactant
export AutoReverseDiff
export AutoSymbolics
export AutoTracker
Expand Down
21 changes: 21 additions & 0 deletions DifferentiationInterface/test/Back/Reactant/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Pkg
Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl")
Pkg.add("Reactant")

using DifferentiationInterface
using DifferentiationInterfaceTest
using Reactant
using Test

backend = AutoReactant()

@test check_available(backend)
@test check_inplace(backend)

test_differentiation(
backend, DifferentiationInterfaceTest.default_scenarios(;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array

include_constantified = true, include_cachified = false
);
excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback),
logging = false
)
2 changes: 2 additions & 0 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down Expand Up @@ -40,6 +41,7 @@ DataFrames = "1.6.1"
DifferentiationInterface = "0.7.7"
DocStringExtensions = "0.8,0.9"
ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2.0"
JET = "0.9,0.10,0.11"
JLArrays = "0.1,0.2,0.3"
LinearAlgebra = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ using DifferentiationInterface:
using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap
using DifferentiationInterface: PreparationMismatchError
using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES
using GPUArraysCore: @allowscalar
using JET: @test_opt
using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent
using PrecompileTools: @compile_workload
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")
function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries
y = sc.f(x)
if y isa Number
y_cache[1] = y
return y_cache[1]
@allowscalar y_cache[1] = y
return @allowscalar y_cache[1]
else
copyto!(y_cache, y)
return copy(y_cache)
Expand Down
Loading