Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions ext/ReactantNNlibExt/Implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,17 @@ function _stack_indices(idxs::AbstractArray{<:CartesianIndex})
end

function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::Int)
idxs = TracedUtils.promote_to(TracedRArray{Int,ndims(idxs)}, idxs)
n_idxs = size(idxs, 1)
idxs = TracedUtils.promote_to(
TracedRArray{Reactant.unwrapped_eltype(idxs),ndims(idxs)}, idxs
)
return @opcall gather(
src,
idxs;
offset_dims=collect(Int64, 1:n_dims),
collapsed_slice_dims=collect(Int64, (n_dims + 1):ndims(src)),
operand_batching_dims=Int64[],
start_indices_batching_dims=Int64[],
start_index_map=collect(Int64, (ndims(src) - n_idxs + 1):ndims(src)),
start_index_map=collect(Int64, (ndims(src) - size(idxs, 1) + 1):ndims(src)),
index_vector_dim=1,
slice_sizes=Int64[size(src)[1:n_dims]..., ones(Int64, ndims(src) - n_dims)...],
)
Expand Down Expand Up @@ -553,7 +554,9 @@ function _nnlib_scatter_impl(
idx::AbstractArray,
n_dims::Int,
) where {OP,T}
scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx)
scatter_indices = TracedUtils.promote_to(
TracedRArray{Reactant.unwrapped_eltype(idx),ndims(idx)}, idx
)
n_idxs = size(scatter_indices, 1)
return @opcall(
scatter(
Expand Down
22 changes: 10 additions & 12 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1819,10 +1819,10 @@ instead.
"""
@noinline function scatter_setindex(
dest::TracedRArray{T,N},
scatter_indices::TracedRArray{Int64,2},
scatter_indices::TracedRArray{T1,2},
updates::TracedRArray{T2,1};
location=mlir_stacktrace("scatter_setindex", @__FILE__, @__LINE__),
) where {T,N,T2}
) where {T,N,T1,T2}
@assert length(updates) == size(scatter_indices, 1)
@assert size(scatter_indices, 2) == N

Expand Down Expand Up @@ -1873,7 +1873,7 @@ end

@noinline function scatter(
dest::Vector{TracedRArray{T,N}},
scatter_indices::TracedRArray{Int64},
scatter_indices::TracedRArray{TI},
updates::Vector{<:TracedRArray{T}};
update_computation::MLIR.IR.Region,
update_window_dims::Vector{Int64},
Expand All @@ -1885,9 +1885,9 @@ end
unique_indices::Union{Bool,Nothing}=nothing,
indices_are_sorted::Union{Bool,Nothing}=nothing,
location=mlir_stacktrace("scatter", @__FILE__, @__LINE__),
) where {T,N}
) where {T,TI,N}
scatter_indices = subtract(
scatter_indices, fill(Int64(1), size(scatter_indices)); location
scatter_indices, fill(TI(1), size(scatter_indices)); location
)

update_window_dims = update_window_dims .- 1
Expand Down Expand Up @@ -1938,9 +1938,9 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
"""
@noinline function gather_getindex(
src::TracedRArray{T,N},
gather_indices::TracedRArray{Int64,2};
gather_indices::TracedRArray{TI,2};
location=mlir_stacktrace("gather_getindex", @__FILE__, @__LINE__),
) where {T,N}
) where {T,TI,N}
@assert size(gather_indices, 2) == N

if GATHER_GETINDEX_DISABLED[]
Expand Down Expand Up @@ -1969,7 +1969,7 @@ end

@noinline function gather(
src::TracedRArray{T,N},
gather_indices::TracedRArray{Int64};
gather_indices::TracedRArray{TI};
offset_dims::Vector{Int64},
collapsed_slice_dims::Vector{Int64},
operand_batching_dims::Vector{Int64},
Expand All @@ -1979,10 +1979,8 @@ end
slice_sizes::Vector{Int64},
indices_are_sorted::Bool=false,
location=mlir_stacktrace("gather", @__FILE__, @__LINE__),
) where {T,N}
gather_indices = subtract(
gather_indices, fill(Int64(1), size(gather_indices)); location
)
) where {T,TI,N}
gather_indices = subtract(gather_indices, fill(TI(1), size(gather_indices)); location)

offset_dims = offset_dims .- 1
start_indices_batching_dims = start_indices_batching_dims .- 1
Expand Down
12 changes: 12 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,15 @@ end
y_ra = @jit(logsumexp(x_ra))
@test Float32(y_ra) ≈ y
end

@testset "gather 32bit indexing" begin
x = rand(Float32, 10, 10)
x_ra = Reactant.to_rarray(x)

idxs = Int32.(rand(1:10, 32))
idxs_ra = Reactant.to_rarray(idxs)

@test @jit(NNlib.gather(x_ra, idxs_ra)) ≈ NNlib.gather(x, idxs)
hlo = repr(@code_hlo(NNlib.gather(x_ra, idxs_ra)))
@test !contains(hlo, "i64>")
end
Loading