diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 15ab869c6c..411542aec7 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -393,14 +393,6 @@ buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data) buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data) -function Ops.constant(x::AbstractConcreteArray; kwargs...) - return Ops.constant(Base.convert(Array, x); kwargs...) -end - -function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T} - return Ops.constant(Base.convert(T, x); kwargs...) -end - function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N} return ConcretePJRTArray( zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding @@ -464,3 +456,9 @@ function Base.mapreducedim!( fn(f, op, R, A) return R end + +function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray) + fn = compile(Base.map!, (f, R, A)) + fn(f, R, A) + return R +end diff --git a/src/Ops.jl b/src/Ops.jl index 2053f6647c..df1e42f27e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -119,6 +119,16 @@ end end end +@noinline function constant( + x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) +) where {T,N} + return constant(collect(x); location) +end + +@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...) + return constant(Base.convert(Array, x); kwargs...) +end + @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} @@ -127,6 +137,10 @@ end return TracedRNumber{T}((), res.mlir_data) end +@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T} + return constant(Base.convert(T, x); kwargs...) +end + function fill( v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) ) @@ -374,7 +388,7 @@ end end # shape ops -function reshape(x::TracedRArray, dims...; kwargs...) +function reshape(x::TracedRArray, dims::Integer...; kwargs...) return reshape(x, collect(dims); kwargs...) end @@ -2377,7 +2391,7 @@ end x::TracedRArray{T}, init_values::TracedRNumber{T}, dimensions::Vector{Int}, - fn::Function, + fn::Function; location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) @@ -2409,25 +2423,43 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, - **CPU version & Julia's `reduce`**: - Reduce along dimension 1 → `[(15) (21); (18) (24)]` - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]` - + - **GPU version**: - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]` - Reduce along dimension 3 → `[37 49]` """ @noinline function reduce( x::TracedRArray{T}, - init_values::TracedRNumber{T}, + init_values::Union{TracedRNumber{T},Nothing}, dimensions::Vector{Int}, - fn::Function, + fn::Function; location=mlir_stacktrace("reduce", @__FILE__, @__LINE__), ) where {T} + elT = T + if init_values === nothing + if fn === min || fn === Base.FastMath.min_fast + init = typemax(elT) + elseif fn === max || fn === Base.FastMath.max_fast + init = typemin(elT) + else + init = Base.reduce_empty(Base.BottomRF(fn), elT) + end + + initT = unwrapped_eltype(typeof(init)) + if initT != elT # Bool, etc. reductions + elT = promote_type(initT, elT) + x = elT.(x) + end + init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init) + end + reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions)) - result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) + result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape) sample_inputs = [ - Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), - Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0), ] func = @@ -2441,14 +2473,8 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, return_dialect=:stablehlo, ).f @assert MLIR.IR.nregions(func) == 1 - fn_name = String( - MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) - ) - ftype_attr = MLIR.IR.attr(func, "function_type") - ftype = MLIR.IR.Type(ftype_attr) - @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error ( - "$fn return type is not tensor" - ) + ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type")) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor" fn = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) @@ -2466,7 +2492,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, ), ) - return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape) + return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape) end end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 662b9d1e9f..507b44a8a7 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -468,100 +468,29 @@ function Base.mapreduce( dims=:, init=nothing, ) where {T,N} - A = materialize_traced_array(A) + inp = broadcast(f, materialize_traced_array(A)) - if dims isa Int - dims = [dims] - end - - op_in_T = Core.Compiler.return_type(f, Tuple{T}) - - if init === nothing - if op === min - init = typemax(op_in_T) - elseif op === max - init = typemin(op_in_T) - else - init = Base.reduce_empty(Base.BottomRF(op), op_in_T) - end + dims isa Number && (dims = (dims,)) - if typeof(init) != op_in_T - op_in_T = typeof(init) - A = typeof(init).(A) - end + if init !== nothing && typeof(init) != unwrapped_eltype(inp) + inp = typeof(init).(inp) end - init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] - - inp = [broadcast(f, A).mlir_data] + rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims) - rdims = Int64[] + reduction_result = Ops.reduce(inp, nothing, rdims, op) - if dims == (:) - for i in 0:(N - 1) - push!(rdims, i) - end + reduction_result = if dims != (:) + Ops.reshape(reduction_result, Int64[i ∈ rdims ? 1 : size(A, i) for i in 1:N]) else - for i in dims - push!(rdims, i - 1) - end + TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data) end - in_tys = [ - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), - ] - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) - - args = ( - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)), - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)), - ) - - resty = MLIR.IR.block!(fnbody) do - tmp = TracedUtils.broadcast_to_size(op(args...), ()) - Ops.return_(tmp) - return eltype(MLIR.IR.type(tmp.mlir_data)) - end - - toonedims = Int[] - outdims = Int[] - for i in 1:N - tmp = if in(i - 1, rdims) - 1 - else - sz = size(A, i) - push!(outdims, sz) - sz - end - push!(toonedims, tmp) - end - - TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] - - body = MLIR.IR.Region() - push!(body, fnbody) - red = MLIR.Dialects.stablehlo.reduce( - inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body - ) - - red = MLIR.IR.result(red, 1) - redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) - - if dims != (:) - red = Ops.reshape(TracedRArray(red), toonedims...) - else - if length(outdims) == 0 - red = TracedRNumber{redT}((), red) - else - red = TracedRArray{redT,length(outdims)}((), red, (outdims...,)) - end - end - return red + init === nothing && return reduction_result + return broadcast(op, reduction_result, init) end -function Base.mapreducedim!( +function Base._mapreducedim!( @nospecialize(f), @nospecialize(op), @nospecialize(R::AnyTracedRArray), @@ -573,9 +502,11 @@ function Base.mapreducedim!( @assert sR == 1 return i end + + isempty(A) && return R + tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) - # set_mlir_data!(R, get_mlir_data(tmp)) - R .= op.(R, tmp) # match native Julia's behavior + R .= op.(R, tmp) return R end @@ -1084,4 +1015,11 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin return (values, linear_indices) end +Base.map(f, x::AnyTracedRArray) = f.(x) + +function Base.map!(f, y::AnyTracedRArray, x::AbstractArray) + y .= f.(x) + return y +end + end diff --git a/test/basic.jl b/test/basic.jl index d94d16e4ad..de8a4109c3 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -938,3 +938,15 @@ end rv ) end + +@testset "mapreduce with init" begin + x = reshape(collect(Float32, 1:12), 3, 4) + x_ra = Reactant.to_rarray(x) + + init = 3.0 + init_ra = Reactant.to_rarray(init; track_numbers=Number) + + fn(x, init; kwargs...) = sum(x; init, kwargs...) + + @test @jit(fn(x_ra, init_ra; dims=2)) ≈ fn(x, init; dims=2) +end