diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f78b0ba8f9..4c05527cb6 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -268,12 +268,6 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} return a end -function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} - v = TracedUtils.broadcast_to_size(v, size(a)) - set_mlir_data!(a, get_mlir_data(v)) - return a -end - function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N} GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") indices = @@ -293,6 +287,16 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) whe end function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} + if (N == 1) && (indices isa Colon) + # Remove ambiguity from the previous + # ```julia + # Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} + # ``` + # signature, which would be confused with this one for N=1. + v = TracedUtils.broadcast_to_size(v, size(a)) + set_mlir_data!(a, get_mlir_data(v)) + return a + end maybe_assert_scalar_setindexing(a, indices...) indices = TracedUtils.normalize_indices(a, indices...) diff --git a/test/indexing.jl b/test/indexing.jl index 16212b8a40..ca6a5ccb01 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -31,6 +31,23 @@ end # get_view_compiled = @compile get_view(x_concrete) end +function maskset!(y, x) + y[:] = x + return nothing +end + +@testset "setindex! with vectors & colon indexing" begin + x = Reactant.to_rarray([4.0]) + y = Reactant.to_rarray([2.0]) + @jit(maskset!(y, x)) + @test y ≈ x + + x = Reactant.to_rarray(ones(3)) + y = Reactant.to_rarray(2 * ones(3)) + @jit(maskset!(y, x)) + @test y ≈ x +end + function masking(x) y = similar(x) y[1:2, :] .= 0