diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 9c416fd133..834a12c77b 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -464,8 +464,9 @@ function _stack_indices(idxs::AbstractArray{<:CartesianIndex}) end function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::Int) - idxs = TracedUtils.promote_to(TracedRArray{Int,ndims(idxs)}, idxs) - n_idxs = size(idxs, 1) + idxs = TracedUtils.promote_to( + TracedRArray{Reactant.unwrapped_eltype(idxs),ndims(idxs)}, idxs + ) return @opcall gather( src, idxs; @@ -473,7 +474,7 @@ function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::I collapsed_slice_dims=collect(Int64, (n_dims + 1):ndims(src)), operand_batching_dims=Int64[], start_indices_batching_dims=Int64[], - start_index_map=collect(Int64, (ndims(src) - n_idxs + 1):ndims(src)), + start_index_map=collect(Int64, (ndims(src) - size(idxs, 1) + 1):ndims(src)), index_vector_dim=1, slice_sizes=Int64[size(src)[1:n_dims]..., ones(Int64, ndims(src) - n_dims)...], ) @@ -553,7 +554,9 @@ function _nnlib_scatter_impl( idx::AbstractArray, n_dims::Int, ) where {OP,T} - scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx) + scatter_indices = TracedUtils.promote_to( + TracedRArray{Reactant.unwrapped_eltype(idx),ndims(idx)}, idx + ) n_idxs = size(scatter_indices, 1) return @opcall( scatter( diff --git a/src/Ops.jl b/src/Ops.jl index 547623881f..7741889b35 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1819,10 +1819,10 @@ instead. """ @noinline function scatter_setindex( dest::TracedRArray{T,N}, - scatter_indices::TracedRArray{Int64,2}, + scatter_indices::TracedRArray{T1,2}, updates::TracedRArray{T2,1}; location=mlir_stacktrace("scatter_setindex", @__FILE__, @__LINE__), -) where {T,N,T2} +) where {T,N,T1,T2} @assert length(updates) == size(scatter_indices, 1) @assert size(scatter_indices, 2) == N @@ -1873,7 +1873,7 @@ end @noinline function scatter( dest::Vector{TracedRArray{T,N}}, - scatter_indices::TracedRArray{Int64}, + scatter_indices::TracedRArray{TI}, updates::Vector{<:TracedRArray{T}}; update_computation::MLIR.IR.Region, update_window_dims::Vector{Int64}, @@ -1885,9 +1885,9 @@ end unique_indices::Union{Bool,Nothing}=nothing, indices_are_sorted::Union{Bool,Nothing}=nothing, location=mlir_stacktrace("scatter", @__FILE__, @__LINE__), -) where {T,N} +) where {T,TI,N} scatter_indices = subtract( - scatter_indices, fill(Int64(1), size(scatter_indices)); location + scatter_indices, fill(TI(1), size(scatter_indices)); location ) update_window_dims = update_window_dims .- 1 @@ -1938,9 +1938,9 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. """ @noinline function gather_getindex( src::TracedRArray{T,N}, - gather_indices::TracedRArray{Int64,2}; + gather_indices::TracedRArray{TI,2}; location=mlir_stacktrace("gather_getindex", @__FILE__, @__LINE__), -) where {T,N} +) where {T,TI,N} @assert size(gather_indices, 2) == N if GATHER_GETINDEX_DISABLED[] @@ -1969,7 +1969,7 @@ end @noinline function gather( src::TracedRArray{T,N}, - gather_indices::TracedRArray{Int64}; + gather_indices::TracedRArray{TI}; offset_dims::Vector{Int64}, collapsed_slice_dims::Vector{Int64}, operand_batching_dims::Vector{Int64}, @@ -1979,10 +1979,8 @@ end slice_sizes::Vector{Int64}, indices_are_sorted::Bool=false, location=mlir_stacktrace("gather", @__FILE__, @__LINE__), -) where {T,N} - gather_indices = subtract( - gather_indices, fill(Int64(1), size(gather_indices)); location - ) +) where {T,TI,N} + gather_indices = subtract(gather_indices, fill(TI(1), size(gather_indices)); location) offset_dims = offset_dims .- 1 start_indices_batching_dims = start_indices_batching_dims .- 1 diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index c92cc300b1..27b1a01d0b 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -739,3 +739,15 @@ end y_ra = @jit(logsumexp(x_ra)) @test Float32(y_ra) ≈ y end + +@testset "gather 32bit indexing" begin + x = rand(Float32, 10, 10) + x_ra = Reactant.to_rarray(x) + + idxs = Int32.(rand(1:10, 32)) + idxs_ra = Reactant.to_rarray(idxs) + + @test @jit(NNlib.gather(x_ra, idxs_ra)) ≈ NNlib.gather(x, idxs) + hlo = repr(@code_hlo(NNlib.gather(x_ra, idxs_ra))) + @test !contains(hlo, "i64>") +end