From d067f2fbfaa5dddc8d5a9817c51d98680be27db2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Sep 2025 15:56:50 -0400 Subject: [PATCH 1/3] fix: force call with reactant if using inferencebarrier --- src/Overlay.jl | 12 +++++++----- src/TracedRArray.jl | 8 ++++---- src/TracedUtils.jl | 2 +- src/utils.jl | 12 ++++++++++++ test/basic.jl | 11 +++++++++++ 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 1f8d2a0c3d..9091eaeb46 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -161,7 +161,9 @@ end if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) else - return Base.inferencebarrier(Base.mapreduce)(f, op, A; kwargs...) + return Base.inferencebarrier(Base.mapreduce)( + CallWithReactant(f), CallWithReactant(op), A; kwargs... + ) end end @@ -169,7 +171,7 @@ end if use_overlayed_version(x) || any(use_overlayed_version, ys) return TracedRArrayOverrides.overloaded_map(f, x, ys...) else - return Base.inferencebarrier(Base.map)(f, x, ys...) + return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...) end end @@ -183,7 +185,7 @@ end ) return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...) else - return Base.inferencebarrier(Base.map!)(f, y, x, xs...) + return Base.inferencebarrier(Base.map!)(CallWithReactant(f), y, x, xs...) end end @@ -191,7 +193,7 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims) else - return Base.inferencebarrier(Base._all)(f, x, dims) + return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims) end end @@ -199,6 +201,6 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims) else - return Base.inferencebarrier(Base._any)(f, x, dims) + return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims) end end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e0bf9fa383..3ed3803e58 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -622,11 +622,11 @@ function overloaded_mapreduce( end reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init) - reduce_input = materialize_traced_array(broadcast(f, A)) + reduce_input = materialize_traced_array(TracedUtils.elem_apply(f, A)) res = @opcall reduce(reduce_input, reduce_init, dims, op) - init !== nothing && (res = op.(res, init)) + init !== nothing && (res = Reactant.call_with_reactant(broadcast, op, res, init)) if original_dims isa Colon @assert size(res) == () "expected size of result to be (), got $(size(res))" @@ -650,8 +650,8 @@ function Base.mapreducedim!( @assert sR == 1 return i end - tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) - R .= op.(R, tmp) # match native Julia's behavior + tmp = overloaded_mapreduce(f, op, A; dims=filter(!isnothing, dims), init=R) + copyto!(R, tmp) return R end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 288d4e069f..63d3c73160 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1074,7 +1074,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} scalar_args = map(args) do arg return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg) end - return f(scalar_args...) + return Reactant.call_with_reactant(f, scalar_args...) end argprefix::Symbol = gensym("broadcastarg") diff --git a/src/utils.jl b/src/utils.jl index 3aaa2a8576..b4dcaac87a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,15 @@ +struct CallWithReactant{F} + f::F +end + +function (f::CallWithReactant{F})(args...; kwargs...) where {F} + if isempty(kwargs) + return call_with_reactant(f.f, args...) + else + return call_with_reactant(Core.kwcall, NamedTuple(kwargs), f.f, args...) + end +end + function apply(f::F, args...; kwargs...) where {F} return f(args...; kwargs...) end diff --git a/test/basic.jl b/test/basic.jl index 29c2e80a71..8c6f791e8b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1559,3 +1559,14 @@ end hlo = repr(@code_hlo(repeat(x_ra, 2, 3))) @test !contains(hlo, "stablehlo.dynamic_update_slice") end + +@testset "call through inference barrier" begin + points = [rand(Float32, 2), rand(Float32, 2)] + params = rand(Float32, 4, 2) + points_ra = Reactant.to_rarray(points) + params_ra = Reactant.to_rarray(params) + + f(params, points) = mapreduce(Base.Fix1(*, params), +, points) + + @test @jit(f(params_ra, points_ra)) ≈ f(params, points) +end From bffdf48b9bfced497f15d289213c8231097b9e78 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Sep 2025 17:53:30 -0400 Subject: [PATCH 2/3] fix: broadcast --- src/TracedRArray.jl | 2 +- src/utils.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 3ed3803e58..a99a68fe03 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -626,7 +626,7 @@ function overloaded_mapreduce( res = @opcall reduce(reduce_input, reduce_init, dims, op) - init !== nothing && (res = Reactant.call_with_reactant(broadcast, op, res, init)) + init !== nothing && (res = op.(res, init)) if original_dims isa Colon @assert size(res) == () "expected size of result to be (), got $(size(res))" diff --git a/src/utils.jl b/src/utils.jl index b4dcaac87a..0081c4be4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,10 @@ struct CallWithReactant{F} f::F end +function Base.reducedim_init(f::F, op::CallWithReactant, A::AbstractArray, region) where {F} + return Base.reducedim_init(f, op.f, A, region) +end + function (f::CallWithReactant{F})(args...; kwargs...) where {F} if isempty(kwargs) return call_with_reactant(f.f, args...) From 47888964de41d58a98c00a896ec5a5ecac448c68 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Sep 2025 19:20:23 -0400 Subject: [PATCH 3/3] fix: mapreducedim! --- src/TracedRArray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a99a68fe03..28744ceb8b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -650,8 +650,8 @@ function Base.mapreducedim!( @assert sR == 1 return i end - tmp = overloaded_mapreduce(f, op, A; dims=filter(!isnothing, dims), init=R) - copyto!(R, tmp) + tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) + copyto!(R, op.(R, tmp)) return R end