From 3fdb2061dd9747c748c41ab97a1dd56501ec2769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Wed, 12 Mar 2025 16:54:05 +0000 Subject: [PATCH 1/2] Revert "fix: apply init values after reduction (#881)" This reverts commit 6936dbebfb2969167f6bdfae76a0a9773a1d72aa. --- src/ConcreteRArray.jl | 14 +++--- src/Ops.jl | 60 +++++++---------------- src/TracedRArray.jl | 108 +++++++++++++++++++++++++++++++++--------- test/basic.jl | 12 ----- 4 files changed, 110 insertions(+), 84 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 411542aec7..15ab869c6c 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -393,6 +393,14 @@ buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data) buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data) +function Ops.constant(x::AbstractConcreteArray; kwargs...) + return Ops.constant(Base.convert(Array, x); kwargs...) +end + +function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T} + return Ops.constant(Base.convert(T, x); kwargs...) +end + function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N} return ConcretePJRTArray( zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding @@ -456,9 +464,3 @@ function Base.mapreducedim!( fn(f, op, R, A) return R end - -function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray) - fn = compile(Base.map!, (f, R, A)) - fn(f, R, A) - return R -end diff --git a/src/Ops.jl b/src/Ops.jl index 8e0a58dc16..39a38f9505 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -122,16 +122,6 @@ end end end -@noinline function constant( - x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) -) where {T,N} - return constant(collect(x); location) -end - -@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...) - return constant(Base.convert(Array, x); kwargs...) -end - @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} @@ -140,10 +130,6 @@ end return TracedRNumber{T}((), res.mlir_data) end -@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T} - return constant(Base.convert(T, x); kwargs...) -end - function fill( v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) ) @@ -391,7 +377,7 @@ end end # shape ops -function reshape(x::TracedRArray, dims::Integer...; kwargs...) +function reshape(x::TracedRArray, dims...; kwargs...) return reshape(x, collect(dims); kwargs...) end @@ -2394,7 +2380,7 @@ end x::TracedRArray{T}, init_values::TracedRNumber{T}, dimensions::Vector{Int}, - fn::Function; + fn::Function, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) @@ -2426,43 +2412,25 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, - **CPU version & Julia's `reduce`**: - Reduce along dimension 1 → `[(15) (21); (18) (24)]` - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]` - + - **GPU version**: - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]` - Reduce along dimension 3 → `[37 49]` """ @noinline function reduce( x::TracedRArray{T}, - init_values::Union{TracedRNumber{T},Nothing}, + init_values::TracedRNumber{T}, dimensions::Vector{Int}, - fn::Function; + fn::Function, location=mlir_stacktrace("reduce", @__FILE__, @__LINE__), ) where {T} - elT = T - if init_values === nothing - if fn === min || fn === Base.FastMath.min_fast - init = typemax(elT) - elseif fn === max || fn === Base.FastMath.max_fast - init = typemin(elT) - else - init = Base.reduce_empty(Base.BottomRF(fn), elT) - end - - initT = unwrapped_eltype(typeof(init)) - if initT != elT # Bool, etc. reductions - elT = promote_type(initT, elT) - x = elT.(x) - end - init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init) - end - reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions)) - result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape) + result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) sample_inputs = [ - Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0), - Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), ] func = @@ -2476,8 +2444,14 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, return_dialect=:stablehlo, ).f @assert MLIR.IR.nregions(func) == 1 - ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type")) - @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor" + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error ( + "$fn return type is not tensor" + ) fn = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) @@ -2495,7 +2469,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, ), ) - return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape) + return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape) end end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 507b44a8a7..662b9d1e9f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -468,29 +468,100 @@ function Base.mapreduce( dims=:, init=nothing, ) where {T,N} - inp = broadcast(f, materialize_traced_array(A)) + A = materialize_traced_array(A) - dims isa Number && (dims = (dims,)) + if dims isa Int + dims = [dims] + end + + op_in_T = Core.Compiler.return_type(f, Tuple{T}) + + if init === nothing + if op === min + init = typemax(op_in_T) + elseif op === max + init = typemin(op_in_T) + else + init = Base.reduce_empty(Base.BottomRF(op), op_in_T) + end - if init !== nothing && typeof(init) != unwrapped_eltype(inp) - inp = typeof(init).(inp) + if typeof(init) != op_in_T + op_in_T = typeof(init) + A = typeof(init).(A) + end end - rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims) + init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] + + inp = [broadcast(f, A).mlir_data] - reduction_result = Ops.reduce(inp, nothing, rdims, op) + rdims = Int64[] - reduction_result = if dims != (:) - Ops.reshape(reduction_result, Int64[i ∈ rdims ? 1 : size(A, i) for i in 1:N]) + if dims == (:) + for i in 0:(N - 1) + push!(rdims, i) + end else - TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data) + for i in dims + push!(rdims, i - 1) + end end - init === nothing && return reduction_result - return broadcast(op, reduction_result, init) + in_tys = [ + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), + ] + + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) + + args = ( + TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)), + TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)), + ) + + resty = MLIR.IR.block!(fnbody) do + tmp = TracedUtils.broadcast_to_size(op(args...), ()) + Ops.return_(tmp) + return eltype(MLIR.IR.type(tmp.mlir_data)) + end + + toonedims = Int[] + outdims = Int[] + for i in 1:N + tmp = if in(i - 1, rdims) + 1 + else + sz = size(A, i) + push!(outdims, sz) + sz + end + push!(toonedims, tmp) + end + + TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] + + body = MLIR.IR.Region() + push!(body, fnbody) + red = MLIR.Dialects.stablehlo.reduce( + inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body + ) + + red = MLIR.IR.result(red, 1) + redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) + + if dims != (:) + red = Ops.reshape(TracedRArray(red), toonedims...) + else + if length(outdims) == 0 + red = TracedRNumber{redT}((), red) + else + red = TracedRArray{redT,length(outdims)}((), red, (outdims...,)) + end + end + return red end -function Base._mapreducedim!( +function Base.mapreducedim!( @nospecialize(f), @nospecialize(op), @nospecialize(R::AnyTracedRArray), @@ -502,11 +573,9 @@ function Base._mapreducedim!( @assert sR == 1 return i end - - isempty(A) && return R - tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) - R .= op.(R, tmp) + # set_mlir_data!(R, get_mlir_data(tmp)) + R .= op.(R, tmp) # match native Julia's behavior return R end @@ -1015,11 +1084,4 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin return (values, linear_indices) end -Base.map(f, x::AnyTracedRArray) = f.(x) - -function Base.map!(f, y::AnyTracedRArray, x::AbstractArray) - y .= f.(x) - return y -end - end diff --git a/test/basic.jl b/test/basic.jl index de8a4109c3..d94d16e4ad 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -938,15 +938,3 @@ end rv ) end - -@testset "mapreduce with init" begin - x = reshape(collect(Float32, 1:12), 3, 4) - x_ra = Reactant.to_rarray(x) - - init = 3.0 - init_ra = Reactant.to_rarray(init; track_numbers=Number) - - fn(x, init; kwargs...) = sum(x; init, kwargs...) - - @test @jit(fn(x_ra, init_ra; dims=2)) ≈ fn(x, init; dims=2) -end From a077d8165471c35eb0802fdfcbbc9399e873305e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Wed, 12 Mar 2025 17:15:18 +0000 Subject: [PATCH 2/2] Bump version to v0.2.40 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8895c3a132..426438ff00 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.2.39" +version = "0.2.40" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"