From 4699e5a8b1c574b31a9c1462753314ca270bf43d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 12:14:47 -0500 Subject: [PATCH 1/4] fix: add _accumulate_promote_op --- src/TracedRArray.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f23799fd5b..5ffc60c165 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1109,6 +1109,19 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector) return accumulate!(op, A, B; dims=1) end +if isdefined(Base, :_accumulate_promote_op) + function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T} + if init !== nothing + init isa TracedRNumber && (init = zero(unwrapped_eltype(init))) + end + return TracedRNumber{ + unwrapped_eltype( + Base._accumulate_promote_op(op, Array{T,ndims(A)}(undef, size(A)); init) + ), + } + end +end + function Base._accumulate!( op, output::AnyTracedRArray, input::AnyTracedRVector, ::Nothing, ::Nothing ) From 93378540a20087b8eb39189291b456b09f2b2bc2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 12:15:30 -0500 Subject: [PATCH 2/4] test: disable Zygote on 1.12 --- test/runtests.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..bc10705e43 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,12 +53,16 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") - @safetestset "Zygote" include("integration/zygote.jl") @safetestset "MPI" begin using MPI nranks = 2 run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) end + + # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 + if VERSION < v"1.12-" + @safetestset "Zygote" include("integration/zygote.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From ebc9d1a363c248d526e04fa5ece857285a622248 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 12:37:39 -0500 Subject: [PATCH 3/4] fix: tril/triu working again --- src/stdlibs/LinearAlgebra.jl | 42 ++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..393e77abca 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,26 +273,56 @@ function overloaded_mul!( return C end -function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} +if isdefined(LinearAlgebra, :_triu) + function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} + return overloaded_triu(materialize_traced_array(A), k) + end + function LinearAlgebra._triu( + A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer + ) where {T} + return overloaded_triu(materialize_traced_array(A), k) + end +end + +if isdefined(LinearAlgebra, :_tril) + function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} + return overloaded_tril(materialize_traced_array(A), k) + end + function LinearAlgebra._tril( + A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer + ) where {T} + return overloaded_tril(materialize_traced_array(A), k) + end +end + +function LinearAlgebra.triu!(X::AnyTracedRArray{T,2}, k::Integer) where {T} + set_mlir_data!(X, overloaded_triu(materialize_traced_array(X), k)) + return X +end + +function LinearAlgebra.tril!(X::AnyTracedRArray{T,2}, k::Integer) where {T} + set_mlir_data!(X, overloaded_tril(materialize_traced_array(X), k)) + return X +end + +function overloaded_triu(X::TracedRArray{T,2}, k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( @opcall(iota(Int64, [size(X)...]; iota_dimension=2)), Reactant.broadcast_to_size(k, size(X)), ) idxs = @opcall compare(iota_1, iota_2; comparison_direction="LE") - X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data - return X + return @opcall select(idxs, X, zero(X)) end -function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} +function overloaded_tril(X::TracedRArray{T,2}, k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( @opcall(iota(Int64, [size(X)...]; iota_dimension=2)), Reactant.broadcast_to_size(k, size(X)), ) idxs = @opcall compare(iota_1, iota_2; comparison_direction="GE") - X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data - return X + return @opcall select(idxs, X, zero(X)) end # LinearAlgebra defines norm with some conditionals which cannot be traced directly From 18372a10bfddbeb6717ec345dd61f37d58207d89 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 12:53:21 -0500 Subject: [PATCH 4/4] fix: inplace versions --- src/stdlibs/LinearAlgebra.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 393e77abca..23e8cb4390 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -296,12 +296,12 @@ if isdefined(LinearAlgebra, :_tril) end function LinearAlgebra.triu!(X::AnyTracedRArray{T,2}, k::Integer) where {T} - set_mlir_data!(X, overloaded_triu(materialize_traced_array(X), k)) + set_mlir_data!(X, get_mlir_data(overloaded_triu(materialize_traced_array(X), k))) return X end function LinearAlgebra.tril!(X::AnyTracedRArray{T,2}, k::Integer) where {T} - set_mlir_data!(X, overloaded_tril(materialize_traced_array(X), k)) + set_mlir_data!(X, get_mlir_data(overloaded_tril(materialize_traced_array(X), k))) return X end