diff --git a/Project.toml b/Project.toml index f12b8a6..cc14c2d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,16 +8,20 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] VectorInterfaceChainRulesCoreExt = "ChainRulesCore" +VectorInterfaceEnzymeExt = "Enzyme" VectorInterfaceMooncakeExt = "Mooncake" [compat] Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" +Enzyme = "0.13.131" +EnzymeTestUtils = "0.2.6" LinearAlgebra = "1" Mooncake = "0.5" Random = "1" @@ -29,10 +33,12 @@ julia = "1" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [targets] -test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Random"] +test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random"] diff --git a/ext/VectorInterfaceEnzymeExt.jl b/ext/VectorInterfaceEnzymeExt.jl new file mode 100644 index 0000000..d08aea9 --- /dev/null +++ b/ext/VectorInterfaceEnzymeExt.jl @@ -0,0 +1,287 @@ +module VectorInterfaceEnzymeExt + +# COV_EXCL_START +# Enzyme rules aren't reachable by coverage +using VectorInterface +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation, + α::Annotation{<:Number}, + ) where {RT} + dret = !isa(C, Const) ? C.dval : nothing + cacheα = EnzymeRules.overwritten(config)[3] ? copy(α.val) : α.val + cache = (cacheα, copy(C.val)) # is this better than just unscaling? + ret = scale!(C.val, α.val) + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation, + α::Annotation{<:Number}, + ) where {RT} + αval, Cval = cache + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Cval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + scale!(C.dval, conj(αval)) + return (nothing, Δα) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation, + α::Annotation{<:Number}, + ) where {RT} + if !isa(α, Const) && !isa(C, Const) + add!(C.dval, C.val, α.dval, α.val) + elseif !isa(C, Const) + scale!(C.dval, α.val) + end + scale!(C.val, α.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + ) where {RT} + cacheA = EnzymeRules.overwritten(config)[3] ? copy(A.val) : A.val + cacheα = EnzymeRules.overwritten(config)[4] ? copy(α.val) : α.val + cache = (cacheA, cacheα) + ret = scale!(C.val, A.val, α.val) + dret = !isa(C, Const) ? C.dval : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + ) where {RT} + Aval, αval = cache + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval)) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + zerovector!(C.dval) + return (nothing, nothing, Δα) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + ) where {RT} + scale!(C.val, A.val, α.val) + !isa(C, Const) && !isa(A, Const) && scale!(C.dval, A.dval, α.val) + !isa(α, Const) && !isa(C, Const) && add!(C.dval, A.val, α.dval, One()) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + dret = !isa(C, Const) ? C.dval : nothing + # only need copy of A if α is not constant + cacheA = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : A.val + cacheα = EnzymeRules.overwritten(config)[4] ? copy(α.val) : α.val + cacheβ = EnzymeRules.overwritten(config)[5] ? copy(β.val) : β.val + # only need copy of C if β is not constant + cacheC = !isa(β, Const) ? copy(C.val) : C.val + cache = (cacheA, cacheα, cacheβ, cacheC) + ret = add!(C.val, A.val, α.val, β.val) + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + cache, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + Aval, αval, βval, Cval = cache + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + Δβ = if !isa(β, Const) && !isa(C, Const) + project_scalar(β.val, inner(Cval, C.dval)) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval)) + !isa(C, Const) && scale!(C.dval, conj(βval)) + return (nothing, nothing, Δα, Δβ) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + C::Annotation, + A::Annotation, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + !isa(C, Const) && !isa(A, Const) && add!(C.dval, A.dval, α.val, β.val) + !isa(C, Const) && !isa(α, Const) && add!(C.dval, A.val, α.dval, One()) + !isa(C, Const) && !isa(β, Const) && add!(C.dval, C.val, β.dval, One()) + add!(C.val, A.val, α.val, β.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{RT}, + A::Annotation, + B::Annotation, + ) where {RT} + cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val + cacheB = EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val + cache = (cacheA, cacheB) + ret = inner(A.val, B.val) + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + dret::Active, + cache, + A::Annotation, + B::Annotation, + ) + ΔS = dret.val + Aval, Bval = cache + !isa(A, Const) && add!(A.dval, Bval, conj(ΔS)) + !isa(B, Const) && add!(B.dval, Aval, ΔS) + return (nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + RT::Type{<:Const}, + cache, + A::Annotation, + B::Annotation, + ) + return (nothing, nothing) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{RT}, + A::Annotation, + B::Annotation, + ) where {RT} + ret = inner(A.val, B.val) + if EnzymeRules.needs_shadow(config) # only compute this if actually needed + dret = zero(ret) + !isa(A, Const) && (dret += inner(A.dval, B.val)) + !isa(B, Const) && (dret += inner(A.val, B.dval)) + else + dret = nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, dret) + elseif EnzymeRules.needs_primal(config) + return ret + elseif EnzymeRules.needs_shadow(config) + return dret + else + return nothing + end +end +# COV_EXCL_STOP + +end diff --git a/test/chainrules.jl b/test/chainrules.jl index 4648f92..42d6d03 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -16,9 +16,8 @@ end function ChainRulesTestUtils.test_approx(x::MinimalVec, ::AbstractZero, msg = ""; kwargs...) return test_approx(x, zerovector(x), msg; kwargs...) end -Base.collect(x::MinimalVec) = x.vec -eltypes = (Float32, Float64, ComplexF64) +eltypes = (Float64, ComplexF64) @testset "scale pullbacks ($T)" for T in eltypes n = 12 diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..3f4d251 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,142 @@ +module EnzymeTests + +using VectorInterface +using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec +using Enzyme, EnzymeTestUtils +using Test, TestExtras +using Random + +rng = Random.default_rng() + +precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32)) +precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = 2 * sqrt(eps(Float64)) + +eltypes = (Float64, ComplexF64) + +@testset "scale ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + α = randn(T) + for Tα in (Const, Active) + test_reverse(scale, Duplicated, (x, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (x, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα); atol, rtol) + end + for Tα in (Const, Duplicated) + test_forward(scale, Duplicated, (x, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (x, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα); atol, rtol) + end + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + for Tα in (Const, Active) + test_reverse(scale, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (my, Duplicated), (mx, Duplicated), (α, Tα); atol, rtol) + end + for Tα in (Const, Duplicated) + test_forward(scale, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (my, Duplicated), (mx, Duplicated), (α, Tα); atol, rtol) + end + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + for Tα in (Const, Active) + test_reverse(scale, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_reverse(scale!!, Duplicated, (my, Duplicated), (mx, Duplicated), (α, Tα); atol, rtol) + end + for Tα in (Const, Duplicated) + test_forward(scale, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (mx, Duplicated), (α, Tα); atol, rtol) + test_forward(scale!!, Duplicated, (my, Duplicated), (mx, Duplicated), (α, Tα); atol, rtol) + end +end + +@testset "add ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + α = randn(T) + β = randn(T) + for Tα in (Const, Active), Tβ in (Const, Active) + test_reverse(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end + for Tα in (Const, Duplicated), Tβ in (Const, Duplicated) + test_forward(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_forward(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + for Tα in (Const, Active), Tβ in (Const, Active) + test_reverse(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end + for Tα in (Const, Duplicated), Tβ in (Const, Duplicated) + test_forward(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_forward(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + for Tα in (Const, Active), Tβ in (Const, Active) + test_reverse(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end + for Tα in (Const, Duplicated), Tβ in (Const, Duplicated) + test_forward(add, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + test_forward(add!!, Duplicated, (y, Duplicated), (x, Duplicated), (α, Tα), (β, Tβ); atol, rtol) + end +end + +@testset "inner ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + for RT in (Const, Active) + test_reverse(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end + for RT in (Const, Duplicated) + test_forward(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + for RT in (Const, Active) + test_reverse(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end + for RT in (Const, Duplicated) + test_forward(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + for RT in (Const, Active) + test_reverse(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end + for RT in (Const, Duplicated) + test_forward(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 2e53df6..010ce8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using VectorInterface +using VectorInterface: MinimalVec using Test println("Testing One and Zero") println("====================") @@ -35,6 +36,7 @@ module AquaVectorInterface Aqua.test_all(VectorInterface) end +Base.collect(x::MinimalVec) = x.vec @static if isdefined(Base, :get_extension) && isempty(VERSION.prerelease) println("Testing AD rules") println("================") @@ -44,4 +46,7 @@ end println("Testing Mooncake") println("==================") include("mooncake.jl") + println("Testing Enzyme") + println("==================") + include("enzyme.jl") end