Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[sources.ReactantCore]
path = "lib/ReactantCore"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
Expand Down Expand Up @@ -74,3 +71,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[sources.ReactantCore]
path = "lib/ReactantCore"
8 changes: 6 additions & 2 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ for (cT, aT, bT) in (
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
A, B = aos_to_soa(A), aos_to_soa(B)
if use_overlayed_version((C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
C2 = aos_to_soa(C)
if use_overlayed_version((C2, A, B))
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
if C2 !== C
C .= C2
end
else
LinearAlgebra.mul!(C, A, B, α, β)
end
Expand Down
20 changes: 20 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using ..Reactant:
Ops,
MLIR,
ancestor,
allowscalar,
aos_to_soa,
unwrapped_eltype
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

Expand All @@ -29,6 +31,9 @@ function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
end
x isa WrappedTracedRArray &&
return convert(TracedRArray{T,N}, materialize_traced_array(x))
if eltype(x) <: TracedRNumber
return convert(TracedRArray{T,N}, aos_to_soa(x))
end
return convert(TracedRArray{T,N}, Ops.constant(collect(x)))
end

Expand Down Expand Up @@ -460,6 +465,21 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
return dest
end

function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

bc = Broadcast.preprocess(dest, bc)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
for I in 1:length(dest)
dest[I] = Reactant.@allowscalar res[I]
end
return dest
end

dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D

Expand Down
34 changes: 32 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ..Reactant:
AnyTracedRArray,
AnyTracedRMatrix,
AnyTracedRVector,
unwrapped_eltype,
Ops,
MLIR

Expand Down Expand Up @@ -190,12 +191,12 @@ function overloaded_mul!(
end

function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(C::TracedRArray{T,2} where {T}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRMatrix),
α::Number=true,
β::Number=false,
) where {T}
)
if size(C) != (size(A, 1), size(B, 2))
throw(
DimensionMismatch(
Expand All @@ -207,6 +208,7 @@ function overloaded_mul!(
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
end

T = unwrapped_eltype(C)
tmp = Ops.dot_general(
T.(materialize_traced_array(A)),
T.(materialize_traced_array(B));
Expand Down Expand Up @@ -317,4 +319,32 @@ function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
return indices
end

function LinearAlgebra.ldiv!(
B::Union{
AbstractArray{<:TracedRNumber{T},1},
AbstractArray{<:TracedRNumber{T},2},
AnyTracedRArray{T,1},
AnyTracedRArray{T,2},
},
D::Diagonal,
A::AbstractVecOrMat,
) where {T}
LinearAlgebra.require_one_based_indexing(A, B)
dd = D.diag
d = length(dd)
m, n = size(A, 1), size(A, 2)
m′, n′ = size(B, 1), size(B, 2)
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
(m, n) == (m′, n′) ||
throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
B .= dd .\ A
# OG implementation below, we don't currently support the conditional throw exception
#j = findfirst(iszero, D.diag)
#isnothing(j) || throw(SingularException(j))
#@inbounds for j = 1:n, i = 1:m
# B[i, j] = dd[i] \ A[i, j]
#end
return B
end

end
Loading