Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e52ec9a
Re-add Enzyme test
gdalle May 29, 2024
5a0e664
Add second order test for Enzyme
gdalle May 29, 2024
c21e108
Merge branch 'main' into gd/enzyme_second_order
gdalle May 29, 2024
cdfe12f
Add second-order capabilities with Enzyme
gdalle May 29, 2024
7c351d4
Typos
gdalle May 29, 2024
a3962a3
Fix scenarios
gdalle May 29, 2024
a0ffa2c
No mapreduce in default test scenarios
gdalle May 29, 2024
8b446bd
Fix efficiency tests
gdalle May 29, 2024
4568d9c
Replace randn with rand
gdalle May 29, 2024
92d8e76
Default alpha
gdalle May 29, 2024
ec63ffa
Fix
gdalle May 29, 2024
ad251ef
Fix missing ChainRule in 1.6
gdalle May 29, 2024
f24ccb9
Maybe Zygote on 1.6 will finaly shut up
gdalle May 29, 2024
06913e9
Zygote HVP uses ForwardDiff
gdalle May 29, 2024
e8f6ca5
No CRC second order on 1.6
gdalle May 29, 2024
ac745c3
Merge branch 'gd/better_test' into gd/enzyme_second_order
gdalle May 29, 2024
416a0da
Typo
gdalle May 29, 2024
ccc2957
Typo
gdalle May 29, 2024
69d65c4
Version bump
gdalle May 29, 2024
390a51e
Merge branch 'gd/better_test' into gd/enzyme_second_order
gdalle May 29, 2024
6f06570
Unbump
gdalle May 29, 2024
3170333
Function without mapreduce
gdalle May 29, 2024
76aecb7
Merge branch 'main' into gd/enzyme_second_order
gdalle May 29, 2024
74f0bd5
Check if overloads is causing the docs failure
gdalle May 29, 2024
589ed20
No overloads, no backends
gdalle May 29, 2024
fa4162e
Add overloads
gdalle May 29, 2024
4f8c69f
Backends page makes docs crash
gdalle May 29, 2024
f20a4eb
Better Hessian check
gdalle May 29, 2024
7182be1
Reactivate tests
gdalle May 29, 2024
8d03c96
No backends page
gdalle May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions DifferentiationInterface/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,23 @@ makedocs(;
format=Documenter.HTML(; assets=["assets/favicon.ico"]),
pages=[
"Home" => "index.md",
"Tutorials" => ["tutorial1.md", "tutorial2.md"],
"Reference" => ["operators.md", "backends.md", "api.md"],
"Advanced" => ["preparation.md", "overloads.md"],
"Tutorials" => [
"tutorial1.md", #
"tutorial2.md",
],
"Reference" => [
"operators.md", #
# "backends.md",
"api.md",
],
"Advanced" => [
"preparation.md", #
"overloads.md"
],
],
checkdocs=:exports,
plugins=[links],
pagesonly=true,
)

deploydocs(;
Expand Down
8 changes: 7 additions & 1 deletion DifferentiationInterface/docs/src/tutorial1.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ A common use case of automatic differentiation (AD) is optimizing real-valued fu
Let's define a simple objective and a random input vector

```@example tuto1
f(x) = sum(abs2, x)
function f(x::AbstractVector{T}) where {T}
y = zero(T)
for i in eachindex(x)
y += abs2(x[i])
end
return y
end

x = collect(1.0:5.0)
```
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
module DifferentiationInterfaceEnzymeExt

using ADTypes: ADTypes, AutoEnzyme
using Compat
import DifferentiationInterface as DI
using DifferentiationInterface:
DerivativeExtras,
GradientExtras,
HessianExtras,
JacobianExtras,
PullbackExtras,
PushforwardExtras,
SecondDerivativeExtras,
NoDerivativeExtras,
NoGradientExtras,
NoHessianExtras,
NoHVPExtras,
NoJacobianExtras,
NoPullbackExtras,
NoPushforwardExtras,
NoSecondDerivativeExtras,
pick_chunksize
using DocStringExtensions
using Enzyme:
Active,
Const,
Mode,
Duplicated,
DuplicatedNoNeed,
Forward,
Expand All @@ -27,22 +34,26 @@ using Enzyme:
ReverseSplitWithPrimal,
ReverseMode,
autodiff,
autodiff_deferred,
autodiff_thunk,
chunkedonehot,
gradient,
gradient!,
jacobian,
make_zero

const AutoMixedEnzyme = AutoEnzyme{Nothing}
const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
const AutoForwardOrNothingEnzyme = Union{AutoEnzyme{<:ForwardMode},AutoEnzyme{Nothing}}
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}
const AutoReverseOrNothingEnzyme = Union{AutoEnzyme{<:ReverseMode},AutoEnzyme{Nothing}}

forward_mode(backend::AutoEnzyme{<:ForwardMode}) = backend.mode
# forward mode if possible
forward_mode(backend::AutoEnzyme{<:Mode}) = backend.mode
forward_mode(::AutoEnzyme{Nothing}) = Forward

reverse_mode(backend::AutoEnzyme{<:ReverseMode}) = backend.mode
# reverse mode if possible
reverse_mode(backend::AutoEnzyme{<:Mode}) = backend.mode
reverse_mode(::AutoEnzyme{Nothing}) = Reverse

DI.check_available(::AutoEnzyme) = true
Expand All @@ -60,10 +71,14 @@ function zero_sametype!(x_target, x)
return x_sametype
end

include("utils.jl")

include("forward_onearg.jl")
include("forward_twoarg.jl")

include("reverse_onearg.jl")
include("reverse_twoarg.jl")

include("common_onearg.jl")

end # module
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
## Second derivative

DI.prepare_second_derivative(f, ::AutoEnzyme, x, v) = NoSecondDerivativeExtras()

function DI.second_derivative(
f, backend::AutoForwardOrNothingEnzyme, x, ::NoSecondDerivativeExtras
)
df = DeferredDerivative(f, forward_mode(backend))
return DI.derivative(df, AutoEnzyme(forward_mode(backend)), x)
end

function DI.second_derivative!(
f, der2, backend::AutoForwardOrNothingEnzyme, x, ::NoSecondDerivativeExtras
)
df = DeferredDerivative(f, forward_mode(backend))
return DI.derivative!(df, der2, AutoEnzyme(forward_mode(backend)), x)
end

## Hessian

struct EnzymeHessianExtras{G,JE} <: HessianExtras
∇f::G
jac_extras::JE
end

function DI.prepare_hessian(f, backend::AutoEnzyme, x)
∇f = DeferredGradient(f, reverse_mode(backend))
jac_extras = DI.prepare_jacobian(∇f, AutoEnzyme(forward_mode(backend)), x)
return EnzymeHessianExtras(∇f, jac_extras)
end

function DI.hessian(f, backend::AutoEnzyme, x, extras::EnzymeHessianExtras)
@compat (; ∇f, jac_extras) = extras
return DI.jacobian(∇f, AutoEnzyme(forward_mode(backend)), x, jac_extras)
end

function DI.hessian!(f, hess, backend::AutoEnzyme, x, extras::EnzymeHessianExtras)
@compat (; ∇f, jac_extras) = extras
return DI.jacobian!(∇f, hess, AutoEnzyme(forward_mode(backend)), x, jac_extras)
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ function DI.value_and_pushforward(
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
)
dx_sametype = convert(typeof(x), dx)
y, new_dy = autodiff(forward_mode(backend), f, Duplicated, Duplicated(x, dx_sametype))
y, new_dy = autodiff(
forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype)
)
return y, new_dy
end

Expand All @@ -15,7 +17,9 @@ function DI.pushforward(
)
dx_sametype = convert(typeof(x), dx)
new_dy = only(
autodiff(forward_mode(backend), f, DuplicatedNoNeed, Duplicated(x, dx_sametype))
autodiff(
forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype)
),
)
return new_dy
end
Expand Down Expand Up @@ -121,3 +125,17 @@ function DI.value_and_jacobian!(
y, new_jac = DI.value_and_jacobian(f, backend, x, extras)
return y, copyto!(jac, new_jac)
end

## HVP

DI.prepare_hvp(f, ::AutoForwardOrNothingEnzyme, x, v) = NoHVPExtras()

function DI.hvp(f, backend::AutoForwardOrNothingEnzyme, x, v, ::NoHVPExtras)
∇f = DeferredGradient(f, reverse_mode(backend))
return DI.pushforward(∇f, AutoEnzyme(forward_mode(backend)), x, v)
end

function DI.hvp!(f, p, backend::AutoForwardOrNothingEnzyme, x, v, ::NoHVPExtras)
∇f = DeferredGradient(f, reverse_mode(backend))
return DI.pushforward!(∇f, p, AutoEnzyme(forward_mode(backend)), x, v)
end
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function DI.value_and_pushforward(
dy_sametype = zero(y)
autodiff(
forward_mode(backend),
f!,
Const(f!),
Const,
Duplicated(y, dy_sametype),
Duplicated(x, dx_sametype),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras()
function DI.value_and_pullback(
f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
der, y = autodiff(ReverseWithPrimal, Const(f), Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end
Expand Down Expand Up @@ -43,7 +43,7 @@ function DI.value_and_pullback!(
f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
)
dx_sametype = zero_sametype!(dx, x)
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
_, y = autodiff(ReverseWithPrimal, Const(f), Active, Duplicated(x, dx_sametype))
dx_sametype .*= dy
return y, copyto!(dx, dx_sametype)
end
Expand Down Expand Up @@ -155,3 +155,17 @@ function DI.value_and_jacobian!(
end

=#

## HVP

DI.prepare_hvp(f, ::AutoReverseEnzyme, x, v) = NoHVPExtras()

function DI.hvp(f, backend::AutoReverseEnzyme, x, v, ::NoHVPExtras)
∇f = DeferredGradient(f, reverse_mode(backend))
return DI.pullback(∇f, backend, x, v)
end

function DI.hvp!(f, p, backend::AutoReverseEnzyme, x, v, ::NoHVPExtras)
∇f = DeferredGradient(f, reverse_mode(backend))
return DI.pullback!(∇f, p, backend, x, v)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct DeferredGradient{F,M<:Mode}
f::F
mode::M
end

function (def_grad::DeferredGradient{F,<:ForwardMode})(z) where {F}
return error("Not implemented yet")
end

function (def_grad::DeferredGradient{F,<:ReverseMode})(z) where {F}
@compat (; f, mode) = def_grad
grad = make_zero(z)
autodiff_deferred(mode, Const(f), Active, Duplicated(z, grad))
return grad
end

struct DeferredDerivative{F,M<:Mode}
f::F
mode::M
end

function (def_der::DeferredDerivative{F,<:ForwardMode})(z) where {F}
@compat (; f, mode) = def_der
return only(autodiff_deferred(mode, Const(f), DuplicatedNoNeed, Duplicated(z, one(z))))
end

function (def_der::DeferredDerivative{F,<:ReverseMode})(z) where {F}
return error("Not implemented yet")
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/utils/check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Check whether `backend` supports differentiation of two-argument functions.
"""
check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend))

sqnorm(x::AbstractArray) = sum(abs2, x)
sqnorm(x::AbstractArray) = sum(abs2.(x))

"""
check_hessian(backend)
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/utils/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function backend_str(backend::AbstractADType)
elseif mode(backend) isa SymbolicMode
return "$bs (symbolic)"
elseif mode(backend) isa ForwardOrReverseMode
return "$bs (forward/reverse)"
return "$bs (forward|reverse)"
else
error("Unknown mode")
end
Expand Down
24 changes: 13 additions & 11 deletions DifferentiationInterface/test/Single/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,31 @@ dense_backends = [
AutoEnzyme(; mode=Enzyme.Reverse),
]

sparse_backends = [
AutoSparse(
AutoEnzyme(; mode=Enzyme.Forward);
sparse_backends =
AutoSparse.(
dense_backends,
sparsity_detector=TracerSparsityDetector(),
coloring_algorithm=GreedyColoringAlgorithm(),
),
AutoSparse(
AutoEnzyme(; mode=Enzyme.Reverse);
sparsity_detector=TracerSparsityDetector(),
coloring_algorithm=GreedyColoringAlgorithm(),
),
]
)

for backend in vcat(dense_backends, sparse_backends)
@test check_available(backend)
@test check_twoarg(backend)
@test !check_hessian(backend; verbose=false)
# @test !check_hessian(backend; verbose=false)
end

## Dense backends

test_differentiation(dense_backends; second_order=false, logging=LOGGING);

test_differentiation(AutoEnzyme(); first_order=false, logging=LOGGING);
test_differentiation(
AutoEnzyme(; mode=Enzyme.Forward);
first_order=false,
excluded=[HVPScenario, HessianScenario],
logging=LOGGING,
);

test_differentiation(
AutoEnzyme(; mode=Enzyme.Forward); # TODO: add more
correctness=false,
Expand Down
1 change: 0 additions & 1 deletion DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#=
Constraints on the scenarios:
- non-allocating whenever possible
- type-stable
- GPU-compatible (no scalar indexing)
- vary shapes to be tricky
Expand Down