diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index f23b4b3a3f..cb43c436e6 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -7,15 +7,9 @@ using Reactant: ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false -function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} - x_c = ConcreteRArray(zeros(T, size(x))) - x_c .= x - return x_c -end - -ArrayInterface.aos_to_soa(x::TracedRArray) = x -function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return Ops.reshape(vcat(x...), size(x)...) +for aType in + (AbstractArray{<:ConcreteRNumber}, AbstractArray{<:TracedRNumber}, TracedRArray) + @eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x) end end diff --git a/src/Overlay.jl b/src/Overlay.jl index bc8c62bbef..d2e1ce3f58 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -127,6 +127,7 @@ for (cT, aT, bT) in ( @reactant_overlay @noinline function LinearAlgebra.mul!( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) + A, B = aos_to_soa(A), aos_to_soa(B) if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) else @@ -137,11 +138,7 @@ for (cT, aT, bT) in ( # Needed mostly for 1.10 where 3-arg mul is often specialized @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, true, false) - else - LinearAlgebra.mul!(C, A, B) - end + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) return C end end diff --git a/src/Reactant.jl b/src/Reactant.jl index d06784c136..addf7089b8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -172,6 +172,22 @@ unwrapped_eltype(::RArray{T,N}) where {T,N} = T unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T +aos_to_soa(x::AbstractArray) = x +aos_to_soa(x::TracedRArray) = x +function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} + x_c = ConcreteRArray(zeros(T, size(x))) + x_c .= x + return x_c +end +function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} + for i in eachindex(x) + if !isassigned(x, i) + x[i] = TracedUtils.promote_to(TracedRNumber{T}, 0) + end + end + return Ops.reshape(vcat(x...), size(x)...) +end + include("Ops.jl") include("TracedUtils.jl") diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c03df03c03..aa56c7b92c 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -163,26 +163,6 @@ function TracedUtils.set_mlir_data!( return x end -function overloaded_mul!( - @nospecialize(C::TracedRArray), - @nospecialize(A::AbstractArray{<:TracedRNumber}) - @nospecialize(B::AbstractArray), - α::Number=true, - β::Number=false, -) where {T} - overloaded_mul!(C, Ops.reshape(vcat(A...), size(A)...), B, α, β) -end - -function overloaded_mul!( - @nospecialize(C::TracedRArray), - @nospecialize(A::AbstractArray), - @nospecialize(B::AbstractArray{<:TracedRNumber}) - α::Number=true, - β::Number=false, -) where {T} - overloaded_mul!(C, A, Ops.reshape(vcat(B...), size(B)...), α, β) -end - # Core functions function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}),