From 10b568398b81a6cf16c104e9f8921483ea50fa50 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Jan 2025 17:55:12 -0500 Subject: [PATCH 1/3] feat: support dynamic indexing for reshaped arrays --- src/TracedRArray.jl | 11 +++++++++-- src/TracedUtils.jl | 30 +++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 282a5a6e78..7896325a5c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -26,6 +26,8 @@ ReactantCore.is_traced(::TracedRArray) = true Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) +Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear() + function Base.convert(::Type{TracedRArray}, x::AnyTracedRArray) return Base.convert(TracedRArray{unwrapped_eltype(x),ndims(x)}, x) end @@ -125,9 +127,14 @@ function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N} if !(indices isa TracedRArray) indices = collect(indices) eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) - indices = TracedUtils.promote_to(TracedRArray{Int,1}, indices) + indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) end - return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a))) + return materialize_traced_array( + reshape( + Ops.gather_getindex(a, scalar_index_to_cartesian(vec(indices), size(a))), + size(indices), + ), + ) end Base.getindex(a::TracedRArray{T,N}, ::Colon) where {T,N} = materialize_traced_array(vec(a)) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 8802ed0833..4aeecba695 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -16,7 +16,7 @@ using ..Reactant: OrderedIdDict, ReactantPrimitive, Ops -using ReactantCore: MissingTracedValue +using ReactantCore: MissingTracedValue, is_traced materialize_traced_array(x::TracedRArray) = x @@ -63,10 +63,30 @@ end function get_ancestor_indices( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices... ) where {T,N,M} - cartesian_indices = CartesianIndex.(indices...) - linear_indices = LinearIndices(size(x))[cartesian_indices] - parent_cartesian_indices = CartesianIndices(size(parent(x)))[linear_indices] - return (parent_cartesian_indices,) + @assert length(indices) == N "Expected $N indices, got $(length(indices))" + if any(is_traced, indices) + # XXX: scalars are not supported + final_size = Vector{Int64}(undef, N) + for (i, idx) in enumerate(indices) + @assert ndims(idx) == 1 "Unsupported feature. Please file an issue." + final_size[i] = length(idx) + end + @show Base.strides(x) + linear_indices = mapreduce(+, enumerate(indices)) do (i, idx) + Base.stride(x, i) .* (Ops.broadcast_in_dim(idx, Int64[i], final_size) .- 1) .+ 1 + end + parent_linear_indices_all = collect(LinearIndices(size(parent(x)))) + parent_linear_indices = TracedUtils.promote_to( + TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all + )[linear_indices] + return (parent_linear_indices,) + else + # Have this as a separate code-path since we can generate non-dynamic indexing + cartesian_indices = CartesianIndex.(Iterators.product(indices...)) + linear_indices = LinearIndices(size(x))[cartesian_indices] + parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices] + return (parent_linear_indices,) + end end function set_mlir_data!( From ca8780b011b6d06c627feb2cc1fb7379ea43aa88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Jan 2025 19:42:26 -0500 Subject: [PATCH 2/3] feat: support scalars --- src/Ops.jl | 19 +++++++++++++++++++ src/TracedUtils.jl | 16 ++++++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index b953e1c596..31d79ff5f0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -956,6 +956,25 @@ function broadcast_in_dim( return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) end +function broadcast_in_dim( + x::TracedRNumber{T}, + dims::Vector{Int}, + result_size::Vector{Int}; + location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__), +) where {T} + @assert length(dims) == 0 + + res = MLIR.IR.result( + stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + location, + ), + ) + return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) +end + @noinline function sort( xs::TracedRArray...; comparator, diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4aeecba695..b3e8be00c6 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -65,20 +65,28 @@ function get_ancestor_indices( ) where {T,N,M} @assert length(indices) == N "Expected $N indices, got $(length(indices))" if any(is_traced, indices) - # XXX: scalars are not supported final_size = Vector{Int64}(undef, N) + ddims = Int64[] for (i, idx) in enumerate(indices) - @assert ndims(idx) == 1 "Unsupported feature. Please file an issue." + @assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue." + ndims(idx) == 0 && push!(ddims, i) final_size[i] = length(idx) end - @show Base.strides(x) linear_indices = mapreduce(+, enumerate(indices)) do (i, idx) - Base.stride(x, i) .* (Ops.broadcast_in_dim(idx, Int64[i], final_size) .- 1) .+ 1 + bcasted_idxs = Ops.broadcast_in_dim( + idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size + ) + Base.stride(x, i) .* (bcasted_idxs .- 1) .+ 1 end parent_linear_indices_all = collect(LinearIndices(size(parent(x)))) parent_linear_indices = TracedUtils.promote_to( TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all )[linear_indices] + isempty(ddims) || ( + parent_linear_indices = materialize_traced_array( + dropdims(parent_linear_indices; dims=Tuple(ddims)) + ) + ) return (parent_linear_indices,) else # Have this as a separate code-path since we can generate non-dynamic indexing From c9f3b163e5c7f11a41ced42c834a6b48e948a6fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Jan 2025 19:56:42 -0500 Subject: [PATCH 3/3] test: reshaped arrays broadcasting --- src/TracedUtils.jl | 3 ++- test/wrapped_arrays.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b3e8be00c6..e42e2d2875 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -76,8 +76,9 @@ function get_ancestor_indices( bcasted_idxs = Ops.broadcast_in_dim( idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size ) - Base.stride(x, i) .* (bcasted_idxs .- 1) .+ 1 + Base.stride(x, i) .* (bcasted_idxs .- 1) end + linear_indices = linear_indices .+ 1 parent_linear_indices_all = collect(LinearIndices(size(parent(x)))) parent_linear_indices = TracedUtils.promote_to( TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index c522bcd171..5069e665d0 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -202,3 +202,29 @@ end @test @jit(fn(x_ra)) ≈ fn(x) end end + +function broadcast_reshaped_array(x, idx1, idx2) + y = reshape(x, 20, 2) + return y[idx1, idx2] .+ 1 +end + +function broadcast_reshaped_array(x, idx1, idx2::Number) + y = reshape(x, 20, 2) + return y[idx1, idx2] .+ 1 +end + +@testset "Broadcast reshaped array" begin + x_ra = Reactant.to_rarray(rand(5, 4, 2)) + idx1_ra = Reactant.to_rarray(rand(1:20, 4)) + idx2_ra = Reactant.to_rarray([2, 1]) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra), Array(idx2_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx2_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Array(idx2_ra))) + + idx3 = ConcreteRNumber(2) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra), Int64(idx3)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx3)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Int64(idx3))) +end