diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 57c26b8..124583f 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -24,7 +24,7 @@ _needs_tangent(::Type{T}) where {T <: Number} = # scale # ----- @is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number} -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) α = primal(α_Δα) @@ -43,7 +43,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra return C_ΔC, scale_pullback end -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) α, Δα = extract(α_Δα) @@ -60,7 +60,7 @@ end @is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number} -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -81,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra return C_ΔC, scale_pullback end -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -98,7 +98,7 @@ end @is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number} -function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -123,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray} return C_ΔC, add_pullback end -function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -142,7 +142,7 @@ end @is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} -function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray}) +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) # prepare arguments A, ΔA = arrayify(A_ΔA) B, ΔB = arrayify(B_ΔB) @@ -159,7 +159,7 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray return CoDual(s, NoFData()), inner_pullback end -function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray}) +function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual, B_ΔB::Dual) # prepare arguments A, ΔA = arrayify(A_ΔA) B, ΔB = arrayify(B_ΔB) diff --git a/test/mooncake.jl b/test/mooncake.jl index d8466d1..3f44ca3 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -4,6 +4,7 @@ using VectorInterface using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec using Test, TestExtras using Mooncake +import Mooncake: arrayify using Random rng = Random.default_rng() @@ -11,6 +12,13 @@ rng = Random.default_rng() precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32)) precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64)) +function Mooncake.arrayify(A_dA::Mooncake.CoDual{<:MinimalVec}) + return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).data.vec) +end +function Mooncake.arrayify(A_dA::Mooncake.Dual{<:MinimalVec}) + return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).fields.vec) +end + eltypes = (Float32, Float64, ComplexF64) @testset "scale ($T)" for T in eltypes