diff --git a/src/Ops.jl b/src/Ops.jl index 203749ee2d..33f12615e7 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2906,19 +2906,42 @@ end ] end +function standardize_start_index( + sz::Int, start_index::Union{Integer,TracedRNumber{<:Integer}} +) + if (start_index isa Integer && start_index ≤ typemax(Int32)) || sz ≤ typemax(Int32) + start_index = Reactant.TracedUtils.promote_to(TracedRNumber{Int32}, start_index) + elseif start_index isa Integer + start_index = Reactant.TracedUtils.promote_to( + TracedRNumber{eltype(start_index)}, start_index + ) + end + + start_index = start_index - Reactant.unwrapped_eltype(start_index)(1) + return start_index +end + +function standardize_start_indices( + operand::TracedRArray{T,N}, start_indices::Vector +) where {T,N} + @assert length(start_indices) == N + return [ + standardize_start_index(size(operand, i), start_indices[i]).mlir_data for i in 1:N + ] +end + @noinline function dynamic_update_slice( operand::TracedRArray{T,N}, update::TracedRArray{T}, start_indices::Vector; location=mlir_stacktrace("dynamic_update_slice", @__FILE__, @__LINE__), ) where {T,N} - start_indices = [ - Reactant.TracedUtils.promote_to(TracedRNumber{Int32}, index - 1).mlir_data for - index in start_indices - ] res = MLIR.IR.result( stablehlo.dynamic_update_slice( - operand.mlir_data, update.mlir_data, start_indices; location + operand.mlir_data, + update.mlir_data, + standardize_start_indices(operand, start_indices); + location, ), 1, ) @@ -2931,15 +2954,10 @@ end slice_sizes::Vector; location=mlir_stacktrace("dynamic_slice", @__FILE__, @__LINE__), ) where {T,N} - start_indices = [ - Reactant.TracedUtils.promote_to( - TracedRNumber{Int32}, index - Reactant.unwrapped_eltype(index)(1) - ).mlir_data for index in start_indices - ] res = MLIR.IR.result( stablehlo.dynamic_slice( operand.mlir_data, - start_indices; + standardize_start_indices(operand, start_indices); slice_sizes=collect(Int64, slice_sizes), location, ),