From 36a07bdad982ba0e29949ceffa8e9a2c5fbc7edc Mon Sep 17 00:00:00 2001 From: Florian VINCENT Date: Thu, 30 Jan 2025 21:37:32 +0100 Subject: [PATCH 1/2] Feature: allow colon indexing of traced vectors --- src/TracedRArray.jl | 16 ++++++++++------ test/indexing.jl | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f78b0ba8f9..4c05527cb6 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -268,12 +268,6 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} return a end -function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} - v = TracedUtils.broadcast_to_size(v, size(a)) - set_mlir_data!(a, get_mlir_data(v)) - return a -end - function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N} GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") indices = @@ -293,6 +287,16 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) whe end function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} + if (N == 1) && (indices isa Colon) + # Remove ambiguity from the previous + # ```julia + # Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} + # ``` + # signature, which would be confused with this one for N=1. + v = TracedUtils.broadcast_to_size(v, size(a)) + set_mlir_data!(a, get_mlir_data(v)) + return a + end maybe_assert_scalar_setindexing(a, indices...) indices = TracedUtils.normalize_indices(a, indices...) diff --git a/test/indexing.jl b/test/indexing.jl index 16212b8a40..576b9fcfe3 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -31,6 +31,23 @@ end # get_view_compiled = @compile get_view(x_concrete) end +function maskset!(y, x) + y[:] = x + return nothing +end + +@testset "setindex! with vectors & colon indexing" begin + x = Reactant.to_rarray([4.0]) + y = Reactant.to_rarray([2.0]) + @jit(maskset!(y, x)) + @test y ≈ x + + x = Reactant.to_rarray(ones(3)) + y = Reactant.to_rarray(2*ones(3)) + @jit(maskset!(y, x)) + @test y ≈ x +end + function masking(x) y = similar(x) y[1:2, :] .= 0 From ac20df68e553bb791d17ab80eeb51bcf1efbca95 Mon Sep 17 00:00:00 2001 From: Florian VINCENT Date: Fri, 31 Jan 2025 10:22:12 +0100 Subject: [PATCH 2/2] Style: fix space in mult op --- test/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/indexing.jl b/test/indexing.jl index 576b9fcfe3..ca6a5ccb01 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -43,7 +43,7 @@ end @test y ≈ x x = Reactant.to_rarray(ones(3)) - y = Reactant.to_rarray(2*ones(3)) + y = Reactant.to_rarray(2 * ones(3)) @jit(maskset!(y, x)) @test y ≈ x end