diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f23799fd5b..5ffc60c165 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1109,6 +1109,19 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector) return accumulate!(op, A, B; dims=1) end +if isdefined(Base, :_accumulate_promote_op) + function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T} + if init !== nothing + init isa TracedRNumber && (init = zero(unwrapped_eltype(init))) + end + return TracedRNumber{ + unwrapped_eltype( + Base._accumulate_promote_op(op, Array{T,ndims(A)}(undef, size(A)); init) + ), + } + end +end + function Base._accumulate!( op, output::AnyTracedRArray, input::AnyTracedRVector, ::Nothing, ::Nothing ) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..23e8cb4390 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,26 +273,56 @@ function overloaded_mul!( return C end -function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} +if isdefined(LinearAlgebra, :_triu) + function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} + return overloaded_triu(materialize_traced_array(A), k) + end + function LinearAlgebra._triu( + A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer + ) where {T} + return overloaded_triu(materialize_traced_array(A), k) + end +end + +if isdefined(LinearAlgebra, :_tril) + function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} + return overloaded_tril(materialize_traced_array(A), k) + end + function LinearAlgebra._tril( + A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer + ) where {T} + return overloaded_tril(materialize_traced_array(A), k) + end +end + +function LinearAlgebra.triu!(X::AnyTracedRArray{T,2}, k::Integer) where {T} + set_mlir_data!(X, get_mlir_data(overloaded_triu(materialize_traced_array(X), k))) + return X +end + +function LinearAlgebra.tril!(X::AnyTracedRArray{T,2}, k::Integer) where {T} + set_mlir_data!(X, get_mlir_data(overloaded_tril(materialize_traced_array(X), k))) + return X +end + +function overloaded_triu(X::TracedRArray{T,2}, k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( @opcall(iota(Int64, [size(X)...]; iota_dimension=2)), Reactant.broadcast_to_size(k, size(X)), ) idxs = @opcall compare(iota_1, iota_2; comparison_direction="LE") - X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data - return X + return @opcall select(idxs, X, zero(X)) end -function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} +function overloaded_tril(X::TracedRArray{T,2}, k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( @opcall(iota(Int64, [size(X)...]; iota_dimension=2)), Reactant.broadcast_to_size(k, size(X)), ) idxs = @opcall compare(iota_1, iota_2; comparison_direction="GE") - X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data - return X + return @opcall select(idxs, X, zero(X)) end # LinearAlgebra defines norm with some conditionals which cannot be traced directly diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..bc10705e43 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,12 +53,16 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") - @safetestset "Zygote" include("integration/zygote.jl") @safetestset "MPI" begin using MPI nranks = 2 run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) end + + # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 + if VERSION < v"1.12-" + @safetestset "Zygote" include("integration/zygote.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"