Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
start_sizes = ntuple(Base.Fix1(size, src), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,)))
Expand Down
135 changes: 102 additions & 33 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,84 @@ end

Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)

function generate_index_list(i1, is...)
list = reshape(i1, :, 1) .- 1
for i in is
i = reshape(i, :, 1)
lorig = size(list, 1)
list = repeat(list, size(i, 1), 1)
i = repeat(i; inner=(lorig, 1)) .- 1
list = hcat(list, i)
end
return list
end

function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) where {T,N}
idx = idx .- 1
idxs = materialize_traced_array(reshape(idx .% T(sz[1]), :, 1))
idx = idx .÷ T(sz[1])
for i in 2:N
idxs = hcat(idxs, idx .% T(sz[i]))
idx = idx .÷ T(sz[i])
end
return idxs
end

function Base.getindex(
a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}}
) where {T,N}
if indices isa Int
indices = TracedUtils.promote_to(TracedRNumber{Int}, indices)
end
indices = TracedUtils.broadcast_to_size(indices, (1,))
return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))[1]
end

function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N}
if !(indices isa TracedRArray)
indices = TracedUtils.promote_to(TracedRArray{Int,1}, collect(indices))
end
return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))
end

Base.getindex(a::TracedRArray{T,N}, ::Colon) where {T,N} = materialize_traced_array(vec(a))

function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {T,N}
indices =
materialize_traced_array(
reshape(
TracedUtils.promote_to(TracedRArray{Int,1}, vcat(Tuple(indices)...)), 1, N
),
) .- 1
return Ops.gather_getindex(a, indices)[1]
end

function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

non_contiguous_getindex = false
use_gather_getindex = false
for idxs in indices
idxs isa Number && continue
if idxs isa Reactant.TracedType
use_gather_getindex = true
break
end
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_getindex = true
use_gather_getindex = true
break
end
end

if non_contiguous_getindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.gather_getindex(a, indices)
return Ops.reshape(res, size(indices_tuples)...)
if use_gather_getindex
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
indices_list = generate_index_list(indices_list...)
res = Ops.gather_getindex(a, indices_list)
return Ops.reshape(res, length.(indices)...)
end

start_indices = map(indices) do i
Expand All @@ -99,7 +148,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}

x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
ddims = findall(Base.Fix2(isa, Integer), indices)
isempty(ddims) || return dropdims(x; dims=Tuple(ddims))
isempty(ddims) || return materialize_traced_array(dropdims(x; dims=Tuple(ddims)))
return x
end

Expand All @@ -119,27 +168,24 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
return i
end

non_contiguous_setindex = false
use_scatter_setindex = false
for idxs in indices
idxs isa Number && continue
if idxs isa Reactant.TracedType
use_scatter_setindex = true
break
end
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_setindex = true
use_scatter_setindex = true
break
end
end

if non_contiguous_setindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
if use_scatter_setindex
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
indices_list = generate_index_list(indices_list...)
res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
a.mlir_data = res.mlir_data
return v
end
Expand Down Expand Up @@ -512,15 +558,16 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x)
Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x)

# outer repeat
# Overridden because we don't need to further recur into the definitions here
function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M}
function Base._RepeatInnerOuter.repeat_outer(
x::AnyTracedRArray{T,N}, counts::NTuple{M,Int}
) where {T,N,M}
P = max(N, M) # potentially padded

# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
interleaved_size = ones(Int, 2P)
interleaved_size[1:2:(2N)] .= size(x)

x_interleaved = reshape(x, interleaved_size...)
x_interleaved = reshape(materialize_traced_array(x), interleaved_size...)

# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
broadcast_target_size = interleaved_size
Expand All @@ -531,9 +578,31 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,
# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))

x_final = reshape(x_broadcasted, final_size...)
return materialize_traced_array(reshape(x_broadcasted, final_size...))
end

# inner repeat
function Base._RepeatInnerOuter.repeat_inner(
x::AnyTracedRArray{T,N}, counts::NTuple{M,Int}
) where {T,N,M}
P = max(N, M) # potentially padded

# (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP)
interleaved_size = ones(Int, 2P)
interleaved_size[2:2:(2N)] .= size(x)

x_interleaved = reshape(materialize_traced_array(x), interleaved_size...)

# (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP)
broadcast_target_size = interleaved_size
broadcast_target_size[1:2:(2N)] .= counts

x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size)

# (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP)
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))

return x_final
return materialize_traced_array(reshape(x_broadcasted, final_size...))
end

end
6 changes: 6 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ for (jlop, hloop) in (
(:(Base.:*), :multiply),
(:(Base.:/), :divide),
(:(Base.:^), :power),
(:(Base.mod), :remainder),
(:(Base.rem), :remainder),
)
@eval function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
Expand All @@ -92,6 +94,10 @@ for (jlop, hloop) in (
end
end

function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer}
return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
end

function Base.div(
@nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown)
) where {T<:Integer}
Expand Down
75 changes: 64 additions & 11 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,23 @@ end
end

@testset "repeat" begin
fn_inner(x, counts) = repeat(x; inner=counts)

@testset for (size, counts) in Iterators.product(
[(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)],
[(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)],
)
x = rand(size...)
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)

@testset "outer repeat" begin
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
end

length(counts) < length(size) && continue

@testset "inner repeat" begin
@test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts)
end
end
end

Expand Down Expand Up @@ -751,11 +762,11 @@ end
x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1(x) = x[[1, 3, 2], :]
non_contiguous_indexing2(x) = x[:, [1, 2, 2]]
non_contiguous_indexing3(x) = x[[1, 3, 2], :]
non_contiguous_indexing4(x) = x[:, [1, 2, 2]]

@test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x)
@test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x)
@test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x)
@test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x)

x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)
Expand All @@ -777,17 +788,59 @@ end
x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2
non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2
non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2
non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2

@jit(non_contiguous_indexing1!(x_ra))
non_contiguous_indexing1!(x)
@jit(non_contiguous_indexing3!(x_ra))
non_contiguous_indexing3!(x)
@test x_ra ≈ x

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

@jit(non_contiguous_indexing2!(x_ra))
non_contiguous_indexing2!(x)
@jit(non_contiguous_indexing4!(x_ra))
non_contiguous_indexing4!(x)
@test x_ra ≈ x
end

@testset "indexing with traced arrays" begin
x = rand(4, 4, 3)
idx1 = [1, 3, 2]
idx3 = [1, 2, 1, 3]

x_ra = Reactant.to_rarray(x)
idx1_ra = Reactant.to_rarray(idx1)
idx3_ra = Reactant.to_rarray(idx3)

getindex1(x, idx1) = x[idx1, :, :]
getindex2(x, idx1) = x[:, idx1, :]
getindex3(x, idx3) = x[:, :, idx3]
getindex4(x, idx1, idx3) = x[idx1, :, idx3]

@test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1)
@test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1)
@test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3)
@test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3)
end

@testset "linear indexing" begin
x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

getindex_linear_scalar(x, idx) = @allowscalar x[idx]

@testset for i in 1:length(x)
@test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i)
@test @jit(
getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=(Number,)))
) ≈ getindex_linear_scalar(x, i)
end

idx = rand(1:length(x), 8)
idx_ra = Reactant.to_rarray(idx)

getindex_linear_vector(x, idx) = x[idx]

@test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx)
@test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx)
end
Loading