diff --git a/Project.toml b/Project.toml index d92734c75f..df4addc473 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -[sources.ReactantCore] -path = "lib/ReactantCore" - [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" @@ -74,3 +71,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + +[sources.ReactantCore] +path = "lib/ReactantCore" diff --git a/src/Overlay.jl b/src/Overlay.jl index b00c45f63c..fc035f94f7 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -125,8 +125,12 @@ for (cT, aT, bT) in ( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) A, B = aos_to_soa(A), aos_to_soa(B) - if use_overlayed_version((C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + C2 = aos_to_soa(C) + if use_overlayed_version((C2, A, B)) + TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β) + if C2 !== C + C .= C2 + end else LinearAlgebra.mul!(C, A, B, α, β) end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index c5abb13612..8fb56623c3 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -13,6 +13,8 @@ using ..Reactant: Ops, MLIR, ancestor, + allowscalar, + aos_to_soa, unwrapped_eltype using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array @@ -29,6 +31,9 @@ function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} end x isa WrappedTracedRArray && return convert(TracedRArray{T,N}, materialize_traced_array(x)) + if eltype(x) <: TracedRNumber + return convert(TracedRArray{T,N}, aos_to_soa(x)) + end return convert(TracedRArray{T,N}, Ops.constant(collect(x))) end @@ -460,6 +465,21 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) return dest end +function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted) + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + isempty(dest) && return dest + + bc = Broadcast.preprocess(dest, bc) + + args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + + res = TracedUtils.elem_apply(bc.f, args...) + for I in 1:length(dest) + dest[I] = Reactant.@allowscalar res[I] + end + return dest +end + dispatch_val(x) = x dispatch_val(::Val{D}) where {D} = D diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 2294a1cdec..7b2c5bbf5d 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -6,6 +6,7 @@ using ..Reactant: AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector, + unwrapped_eltype, Ops, MLIR @@ -190,12 +191,12 @@ function overloaded_mul!( end function overloaded_mul!( - @nospecialize(C::TracedRArray{T,2}), + @nospecialize(C::TracedRArray{T,2} where {T}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRMatrix), α::Number=true, β::Number=false, -) where {T} +) if size(C) != (size(A, 1), size(B, 2)) throw( DimensionMismatch( @@ -207,6 +208,7 @@ function overloaded_mul!( throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) end + T = unwrapped_eltype(C) tmp = Ops.dot_general( T.(materialize_traced_array(A)), T.(materialize_traced_array(B)); @@ -317,4 +319,32 @@ function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) return indices end +function LinearAlgebra.ldiv!( + B::Union{ + AbstractArray{<:TracedRNumber{T},1}, + AbstractArray{<:TracedRNumber{T},2}, + AnyTracedRArray{T,1}, + AnyTracedRArray{T,2}, + }, + D::Diagonal, + A::AbstractVecOrMat, +) where {T} + LinearAlgebra.require_one_based_indexing(A, B) + dd = D.diag + d = length(dd) + m, n = size(A, 1), size(A, 2) + m′, n′ = size(B, 1), size(B, 2) + m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d")) + (m, n) == (m′, n′) || + throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) + B .= dd .\ A + # OG implementation below, we don't currently support the conditional throw exception + #j = findfirst(iszero, D.diag) + #isnothing(j) || throw(SingularException(j)) + #@inbounds for j = 1:n, i = 1:m + # B[i, j] = dd[i] \ A[i, j] + #end + return B +end + end