diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index e90c000763..ee00463e2e 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -323,7 +323,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr This case is not optimized and will be slow." maxlog = 1 dims = NNlib.scatter_dims(src, dst, idxs) colons = ntuple(Returns(Colon()), dims) - start_sizes = ntuple(i -> size(src, i), dims) + start_sizes = ntuple(Base.Fix1(size, src), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,))) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 275f6dd923..4f9ccdaa13 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -58,6 +58,58 @@ end Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data) +function generate_index_list(i1, is...) + list = reshape(i1, :, 1) .- 1 + for i in is + i = reshape(i, :, 1) + lorig = size(list, 1) + list = repeat(list, size(i, 1), 1) + i = repeat(i; inner=(lorig, 1)) .- 1 + list = hcat(list, i) + end + return list +end + +function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) where {T,N} + idx = idx .- 1 + idxs = materialize_traced_array(reshape(idx .% T(sz[1]), :, 1)) + idx = idx .÷ T(sz[1]) + for i in 2:N + idxs = hcat(idxs, idx .% T(sz[i])) + idx = idx .÷ T(sz[i]) + end + return idxs +end + +function Base.getindex( + a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}} +) where {T,N} + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = TracedUtils.broadcast_to_size(indices, (1,)) + return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))[1] +end + +function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N} + if !(indices isa TracedRArray) + indices = TracedUtils.promote_to(TracedRArray{Int,1}, collect(indices)) + end + return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a))) +end + +Base.getindex(a::TracedRArray{T,N}, ::Colon) where {T,N} = materialize_traced_array(vec(a)) + +function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {T,N} + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to(TracedRArray{Int,1}, vcat(Tuple(indices)...)), 1, N + ), + ) .- 1 + return Ops.gather_getindex(a, indices)[1] +end + function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) @@ -65,28 +117,25 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} return i end - non_contiguous_getindex = false + use_gather_getindex = false for idxs in indices idxs isa Number && continue + if idxs isa Reactant.TracedType + use_gather_getindex = true + break + end contiguous = all(isone, diff(idxs)) - # XXX: We want to throw error even for dynamic indexing if typeof(contiguous) <: Bool && !contiguous - non_contiguous_getindex = true + use_gather_getindex = true break end end - if non_contiguous_getindex - indices_tuples = collect(Iterators.product(indices...)) - indices = Matrix{Int}( - undef, (length(indices_tuples), length(first(indices_tuples))) - ) - for (i, idx) in enumerate(indices_tuples) - indices[i, :] .= idx .- 1 - end - indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) - res = Ops.gather_getindex(a, indices) - return Ops.reshape(res, size(indices_tuples)...) + if use_gather_getindex + indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) + indices_list = generate_index_list(indices_list...) + res = Ops.gather_getindex(a, indices_list) + return Ops.reshape(res, length.(indices)...) end start_indices = map(indices) do i @@ -99,7 +148,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} x = TracedRArray{T,N}((), res, Tuple(length.(indices))) ddims = findall(Base.Fix2(isa, Integer), indices) - isempty(ddims) || return dropdims(x; dims=Tuple(ddims)) + isempty(ddims) || return materialize_traced_array(dropdims(x; dims=Tuple(ddims))) return x end @@ -119,27 +168,24 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { return i end - non_contiguous_setindex = false + use_scatter_setindex = false for idxs in indices idxs isa Number && continue + if idxs isa Reactant.TracedType + use_scatter_setindex = true + break + end contiguous = all(isone, diff(idxs)) - # XXX: We want to throw error even for dynamic indexing if typeof(contiguous) <: Bool && !contiguous - non_contiguous_setindex = true + use_scatter_setindex = true break end end - if non_contiguous_setindex - indices_tuples = collect(Iterators.product(indices...)) - indices = Matrix{Int}( - undef, (length(indices_tuples), length(first(indices_tuples))) - ) - for (i, idx) in enumerate(indices_tuples) - indices[i, :] .= idx .- 1 - end - indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) - res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v))) + if use_scatter_setindex + indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) + indices_list = generate_index_list(indices_list...) + res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v))) a.mlir_data = res.mlir_data return v end @@ -512,15 +558,16 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) # outer repeat -# Overridden because we don't need to further recur into the definitions here -function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M} +function Base._RepeatInnerOuter.repeat_outer( + x::AnyTracedRArray{T,N}, counts::NTuple{M,Int} +) where {T,N,M} P = max(N, M) # potentially padded # (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1) interleaved_size = ones(Int, 2P) interleaved_size[1:2:(2N)] .= size(x) - x_interleaved = reshape(x, interleaved_size...) + x_interleaved = reshape(materialize_traced_array(x), interleaved_size...) # (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP) broadcast_target_size = interleaved_size @@ -531,9 +578,31 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, # (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP) final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) - x_final = reshape(x_broadcasted, final_size...) + return materialize_traced_array(reshape(x_broadcasted, final_size...)) +end + +# inner repeat +function Base._RepeatInnerOuter.repeat_inner( + x::AnyTracedRArray{T,N}, counts::NTuple{M,Int} +) where {T,N,M} + P = max(N, M) # potentially padded + + # (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP) + interleaved_size = ones(Int, 2P) + interleaved_size[2:2:(2N)] .= size(x) + + x_interleaved = reshape(materialize_traced_array(x), interleaved_size...) + + # (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP) + broadcast_target_size = interleaved_size + broadcast_target_size[1:2:(2N)] .= counts + + x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size) + + # (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP) + final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) - return x_final + return materialize_traced_array(reshape(x_broadcasted, final_size...)) end end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 8df88f8113..1e5cfde557 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -84,6 +84,8 @@ for (jlop, hloop) in ( (:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^), :power), + (:(Base.mod), :remainder), + (:(Base.rem), :remainder), ) @eval function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) @@ -92,6 +94,10 @@ for (jlop, hloop) in ( end end +function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer} + return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end + function Base.div( @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) ) where {T<:Integer} diff --git a/test/basic.jl b/test/basic.jl index 3522cd59e2..620118d376 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -365,12 +365,23 @@ end end @testset "repeat" begin + fn_inner(x, counts) = repeat(x; inner=counts) + @testset for (size, counts) in Iterators.product( [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)], [(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)], ) x = rand(size...) - @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) + + @testset "outer repeat" begin + @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) + end + + length(counts) < length(size) && continue + + @testset "inner repeat" begin + @test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts) + end end end @@ -751,11 +762,11 @@ end x = rand(4, 2) x_ra = Reactant.to_rarray(x) - non_contiguous_indexing1(x) = x[[1, 3, 2], :] - non_contiguous_indexing2(x) = x[:, [1, 2, 2]] + non_contiguous_indexing3(x) = x[[1, 3, 2], :] + non_contiguous_indexing4(x) = x[:, [1, 2, 2]] - @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) - @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) + @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) x = rand(4, 4, 3) x_ra = Reactant.to_rarray(x) @@ -777,17 +788,59 @@ end x = rand(4, 2) x_ra = Reactant.to_rarray(x) - non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2 - non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2 + non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 - @jit(non_contiguous_indexing1!(x_ra)) - non_contiguous_indexing1!(x) + @jit(non_contiguous_indexing3!(x_ra)) + non_contiguous_indexing3!(x) @test x_ra ≈ x x = rand(4, 2) x_ra = Reactant.to_rarray(x) - @jit(non_contiguous_indexing2!(x_ra)) - non_contiguous_indexing2!(x) + @jit(non_contiguous_indexing4!(x_ra)) + non_contiguous_indexing4!(x) @test x_ra ≈ x end + +@testset "indexing with traced arrays" begin + x = rand(4, 4, 3) + idx1 = [1, 3, 2] + idx3 = [1, 2, 1, 3] + + x_ra = Reactant.to_rarray(x) + idx1_ra = Reactant.to_rarray(idx1) + idx3_ra = Reactant.to_rarray(idx3) + + getindex1(x, idx1) = x[idx1, :, :] + getindex2(x, idx1) = x[:, idx1, :] + getindex3(x, idx3) = x[:, :, idx3] + getindex4(x, idx1, idx3) = x[idx1, :, idx3] + + @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) + @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) + @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) + @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) +end + +@testset "linear indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + getindex_linear_scalar(x, idx) = @allowscalar x[idx] + + @testset for i in 1:length(x) + @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) + @test @jit( + getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=(Number,))) + ) ≈ getindex_linear_scalar(x, i) + end + + idx = rand(1:length(x), 8) + idx_ra = Reactant.to_rarray(idx) + + getindex_linear_vector(x, idx) = x[idx] + + @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) + @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) +end