diff --git a/src/Overlay.jl b/src/Overlay.jl index 43c24ae817..1f8d2a0c3d 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -165,6 +165,28 @@ end end end +@reactant_overlay @noinline function Base.map(f, x::AbstractArray, ys::AbstractArray...) + if use_overlayed_version(x) || any(use_overlayed_version, ys) + return TracedRArrayOverrides.overloaded_map(f, x, ys...) + else + return Base.inferencebarrier(Base.map)(f, x, ys...) + end +end + +@reactant_overlay @noinline function Base.map!( + f, y::AbstractArray, x::AbstractArray, xs::AbstractArray... +) + if ( + use_overlayed_version(y) || + use_overlayed_version(x) || + any(use_overlayed_version, xs) + ) + return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...) + else + return Base.inferencebarrier(Base.map!)(f, y, x, xs...) + end +end + @reactant_overlay @noinline function Base._all(f, x::AbstractArray, dims) if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index af0a7aac62..226e3994c2 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1249,10 +1249,25 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin return (values, linear_indices) end -Base.map(f, x::AnyTracedRArray) = f.(x) +function overloaded_map(f, x::AbstractArray, xs::AbstractArray...) + @assert allequal((axes(x), axes.(xs)...)) "Expected axes of all inputs to map to be \ + equal" + + inputs = () + for input in (x, xs...) + if input isa AnyTracedRArray + input = Reactant.materialize_traced_array(input) + else + input = TracedUtils.promote_to(TracedRArray{eltype(input),ndims(input)}, input) + end + inputs = (inputs..., input) + end + + return TracedUtils.elem_apply(f, inputs...) +end -function Base.map!(f, y::AnyTracedRArray, x::AbstractArray) - y .= f.(x) +function overloaded_map!(f, y::AnyTracedRArray, x::AbstractArray, xs::AbstractArray...) + copyto!(y, overloaded_map(f, x, xs...)) return y end diff --git a/src/utils.jl b/src/utils.jl index bed592af38..3aaa2a8576 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,3 @@ - function apply(f::F, args...; kwargs...) where {F} return f(args...; kwargs...) end diff --git a/test/basic.jl b/test/basic.jl index 8359b053de..11a2c2b91c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1532,3 +1532,23 @@ end @test res == mod1(xᵢ, y) end end + +map_test_1(i, xᵢ, yᵢ) = xᵢ + yᵢ + max(xᵢ, yᵢ) + +@testset "multi-argument map" begin + x = collect(Float32, 1:10) + y = collect(Float32, 31:40) + + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + gt = map(map_test_1, 1:length(x), x, y) + @test @jit(map(map_test_1, 1:length(x), x_ra, y_ra)) ≈ gt + + z = similar(x) + z_ra = Reactant.to_rarray(z) + map!(map_test_1, z, 1:length(x), x, y) + @jit map!(map_test_1, z_ra, 1:length(x), x_ra, y_ra) + @test z ≈ z_ra + @test z_ra ≈ gt +end