Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,17 @@ end
if use_overlayed_version(A)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)
else
return Base.inferencebarrier(Base.mapreduce)(f, op, A; kwargs...)
return Base.inferencebarrier(Base.mapreduce)(
CallWithReactant(f), CallWithReactant(op), A; kwargs...
)
end
end

@reactant_overlay @noinline function Base.map(f, x::AbstractArray, ys::AbstractArray...)
if use_overlayed_version(x) || any(use_overlayed_version, ys)
return TracedRArrayOverrides.overloaded_map(f, x, ys...)
else
return Base.inferencebarrier(Base.map)(f, x, ys...)
return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...)
end
end

Expand All @@ -183,22 +185,22 @@ end
)
return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...)
else
return Base.inferencebarrier(Base.map!)(f, y, x, xs...)
return Base.inferencebarrier(Base.map!)(CallWithReactant(f), y, x, xs...)
end
end

@reactant_overlay @noinline function Base._all(f, x::AbstractArray, dims)
if use_overlayed_version(x)
return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims)
else
return Base.inferencebarrier(Base._all)(f, x, dims)
return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims)
end
end

@reactant_overlay @noinline function Base._any(f, x::AbstractArray, dims)
if use_overlayed_version(x)
return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims)
else
return Base.inferencebarrier(Base._any)(f, x, dims)
return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims)
end
end
4 changes: 2 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ function overloaded_mapreduce(
end
reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init)

reduce_input = materialize_traced_array(broadcast(f, A))
reduce_input = materialize_traced_array(TracedUtils.elem_apply(f, A))

res = @opcall reduce(reduce_input, reduce_init, dims, op)

Expand Down Expand Up @@ -651,7 +651,7 @@ function Base.mapreducedim!(
return i
end
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
R .= op.(R, tmp) # match native Julia's behavior
copyto!(R, op.(R, tmp))
return R
end

Expand Down
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
scalar_args = map(args) do arg
return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg)
end
return f(scalar_args...)
return Reactant.call_with_reactant(f, scalar_args...)
end

argprefix::Symbol = gensym("broadcastarg")
Expand Down
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
struct CallWithReactant{F}
f::F
end

function Base.reducedim_init(f::F, op::CallWithReactant, A::AbstractArray, region) where {F}
return Base.reducedim_init(f, op.f, A, region)
end

function (f::CallWithReactant{F})(args...; kwargs...) where {F}
if isempty(kwargs)
return call_with_reactant(f.f, args...)
else
return call_with_reactant(Core.kwcall, NamedTuple(kwargs), f.f, args...)
end
end

function apply(f::F, args...; kwargs...) where {F}
return f(args...; kwargs...)
end
Expand Down
11 changes: 11 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1559,3 +1559,14 @@ end
hlo = repr(@code_hlo(repeat(x_ra, 2, 3)))
@test !contains(hlo, "stablehlo.dynamic_update_slice")
end

@testset "call through inference barrier" begin
points = [rand(Float32, 2), rand(Float32, 2)]
params = rand(Float32, 4, 2)
points_ra = Reactant.to_rarray(points)
params_ra = Reactant.to_rarray(params)

f(params, points) = mapreduce(Base.Fix1(*, params), +, points)

@test @jit(f(params_ra, points_ra)) ≈ f(params, points)
end
Loading