diff --git a/src/Overlay.jl b/src/Overlay.jl index 1f8d2a0c3d..9091eaeb46 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -161,7 +161,9 @@ end if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) else - return Base.inferencebarrier(Base.mapreduce)(f, op, A; kwargs...) + return Base.inferencebarrier(Base.mapreduce)( + CallWithReactant(f), CallWithReactant(op), A; kwargs... + ) end end @@ -169,7 +171,7 @@ end if use_overlayed_version(x) || any(use_overlayed_version, ys) return TracedRArrayOverrides.overloaded_map(f, x, ys...) else - return Base.inferencebarrier(Base.map)(f, x, ys...) + return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...) end end @@ -183,7 +185,7 @@ end ) return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...) else - return Base.inferencebarrier(Base.map!)(f, y, x, xs...) + return Base.inferencebarrier(Base.map!)(CallWithReactant(f), y, x, xs...) end end @@ -191,7 +193,7 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims) else - return Base.inferencebarrier(Base._all)(f, x, dims) + return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims) end end @@ -199,6 +201,6 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims) else - return Base.inferencebarrier(Base._any)(f, x, dims) + return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims) end end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e0bf9fa383..28744ceb8b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -622,7 +622,7 @@ function overloaded_mapreduce( end reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init) - reduce_input = materialize_traced_array(broadcast(f, A)) + reduce_input = materialize_traced_array(TracedUtils.elem_apply(f, A)) res = @opcall reduce(reduce_input, reduce_init, dims, op) @@ -651,7 +651,7 @@ function Base.mapreducedim!( return i end tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) - R .= op.(R, tmp) # match native Julia's behavior + copyto!(R, op.(R, tmp)) return R end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 288d4e069f..63d3c73160 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1074,7 +1074,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} scalar_args = map(args) do arg return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg) end - return f(scalar_args...) + return Reactant.call_with_reactant(f, scalar_args...) end argprefix::Symbol = gensym("broadcastarg") diff --git a/src/utils.jl b/src/utils.jl index 3aaa2a8576..0081c4be4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,19 @@ +struct CallWithReactant{F} + f::F +end + +function Base.reducedim_init(f::F, op::CallWithReactant, A::AbstractArray, region) where {F} + return Base.reducedim_init(f, op.f, A, region) +end + +function (f::CallWithReactant{F})(args...; kwargs...) where {F} + if isempty(kwargs) + return call_with_reactant(f.f, args...) + else + return call_with_reactant(Core.kwcall, NamedTuple(kwargs), f.f, args...) + end +end + function apply(f::F, args...; kwargs...) where {F} return f(args...; kwargs...) end diff --git a/test/basic.jl b/test/basic.jl index 29c2e80a71..8c6f791e8b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1559,3 +1559,14 @@ end hlo = repr(@code_hlo(repeat(x_ra, 2, 3))) @test !contains(hlo, "stablehlo.dynamic_update_slice") end + +@testset "call through inference barrier" begin + points = [rand(Float32, 2), rand(Float32, 2)] + params = rand(Float32, 4, 2) + points_ra = Reactant.to_rarray(points) + params_ra = Reactant.to_rarray(params) + + f(params, points) = mapreduce(Base.Fix1(*, params), +, points) + + @test @jit(f(params_ra, points_ra)) ≈ f(params, points) +end