From ee1f4565c12345f51c2c9c23281d2435ea134827 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Jan 2025 12:07:37 -0500 Subject: [PATCH 1/2] fix: generalize broadcast_in_dims for setindex --- src/Ops.jl | 36 ++++++++++++++++++------------------ src/TracedRArray.jl | 17 +++++++++++++++-- src/TracedUtils.jl | 25 +------------------------ 3 files changed, 34 insertions(+), 44 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index f673007876..0d94cb75d4 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -936,24 +936,24 @@ end end # broadcast ops -# function broadcast_in_dim( -# x::TracedRArray{T,N}, -# dims::Vector{Int}; -# location=mlir_stacktrace( -# "broadcast_in_dim", @__FILE__, @__LINE__ -# ), -# ) where {T,N} -# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x)) -# res = MLIR.IR.result( -# stablehlo.broadcast_in_dim( -# x.mlir_data; -# result_0=restype, -# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +function broadcast_in_dim( + x::TracedRArray{T,N}, + dims::Vector{Int}, + result_size::Vector{Int}; + location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__), +) where {T,N} + @assert length(dims) == N + + res = MLIR.IR.result( + stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + location, + ), + ) + return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) +end @noinline function sort( x::TracedRArray{T,N}; diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a7d72b3caa..8b7835bcd5 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { return v end - v = TracedUtils.broadcast_to_size(v, length.(indices)) - v = TracedUtils.promote_to(TracedRArray{T,N}, v) + if v isa Number + v = TracedUtils.broadcast_to_size(v, length.(indices)) + v = TracedUtils.promote_to(TracedRArray{T,N}, v) + else + v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v) + non_integer_indices = [!(idx isa Integer) for idx in indices] + broadcast_dims = findall(non_integer_indices) + if length(broadcast_dims) == N + v = TracedUtils.broadcast_to_size(v, length.(indices)) + else + v = Ops.broadcast_in_dim( + materialize_traced_array(v), broadcast_dims, Int64.(length.(indices)) + ) + end + end indices = [ ( diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index ee90875573..ab7a556432 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize) end @noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T} - dims = collect(Int64, 0:(length(size(x)) - 1)) - - if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims) - @show x - @show arg - @show rsize - @show rsize2 - @show dims - end - @assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims) - mlirty = MLIR.IR.type(get_mlir_data(x)) - - return TracedRArray{T,Int(length(rsize))}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - get_mlir_data(x); - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), - ), - 1, - ), - collect(rsize), - ) + return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) end end From ac0125cbaa655c5b0c46e8bf4372a2228204bd40 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Jan 2025 12:22:53 -0500 Subject: [PATCH 2/2] test: writing with less dims --- test/basic.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 531fec16e0..83336f4325 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -422,6 +422,39 @@ end # get_view_compiled = @compile get_view(x_concrete) end +function write_with_broadcast1!(x, y) + x[1, :, :] .= reshape(y, 4, 3) + return x +end +function write_with_broadcast2!(x, y) + x[:, 1, :] .= view(y, :, 1:3) + return x +end + +@testset "write_with_broadcast" begin + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast1!(x_ra, y_ra) + + @test res.data === x_ra.data + + res = Array(res) + y = Array(y_ra) + @test res[1, :, :] ≈ reshape(y, 4, 3) + + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast2!(x_ra, y_ra) + + @test res.data === x_ra.data + + res = Array(res) + y = Array(y_ra) + @test res[:, 1, :] ≈ view(y, :, 1:3) +end + function masking(x) y = similar(x) y[1:2, :] .= 0