From f709c157900e9307d4d3ded1ea71c6c6f2d0916f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Sep 2025 02:19:50 -0400 Subject: [PATCH 1/2] fix: set-indexing into subarray --- src/TracedUtils.jl | 4 ++-- test/indexing.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 63d3c73160..917f4b55f1 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -179,12 +179,12 @@ end function get_ancestor_indices_inner( x::AnyTracedRArray{T,N}, linear_indices::AbstractArray ) where {T,N} - return _get_ancestor_indices_linear(x, linear_indices) + return (_get_ancestor_indices_linear(x, linear_indices),) end function get_ancestor_indices_inner( x::AnyTracedRArray{T,1}, linear_indices::AbstractArray ) where {T} - return _get_ancestor_indices_linear(x, linear_indices) + return (_get_ancestor_indices_linear(x, linear_indices),) end function _get_ancestor_indices_linear(x::AnyTracedRArray, indices::AbstractArray) diff --git a/test/indexing.jl b/test/indexing.jl index 9e4f24b4d5..e88d04872d 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -408,3 +408,14 @@ end @test parent_ra ≈ parent end + +function test_slice_copy!(x) + @view(x[1, :]) .= 0 + return x +end + +@testset "slice copy" begin + x = Reactant.to_rarray(rand(2, 10)) + @jit test_slice_copy!(x) + @test all(Array(x)[1, :] .== 0) +end From a1279faf2584a5f28d0b3a158060126563cfda77 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Sep 2025 03:02:02 -0400 Subject: [PATCH 2/2] fix: old tests --- src/TracedRArray.jl | 4 ++-- src/TracedUtils.jl | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 28744ceb8b..af269adba8 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -295,11 +295,11 @@ function Base.getindex( end function Base.getindex(a::WrappedArray{TracedRNumber{T}}, linear_indices) where {T} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, linear_indices)) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, linear_indices)...) end function Base.getindex(a::WrappedArray{TracedRNumber{T},1}, indices) where {T} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices)) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices)...) end function Base.getindex( a::WrappedArray{TracedRNumber{T},N}, indices::Vararg{Any,N} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 917f4b55f1..fe5a26125c 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -179,21 +179,26 @@ end function get_ancestor_indices_inner( x::AnyTracedRArray{T,N}, linear_indices::AbstractArray ) where {T,N} - return (_get_ancestor_indices_linear(x, linear_indices),) + idxs = _get_ancestor_indices_linear(x, linear_indices) + return idxs isa Tuple ? idxs : (idxs,) end function get_ancestor_indices_inner( x::AnyTracedRArray{T,1}, linear_indices::AbstractArray ) where {T} - return (_get_ancestor_indices_linear(x, linear_indices),) + idxs = _get_ancestor_indices_linear(x, linear_indices) + return idxs isa Tuple ? idxs : (idxs,) end function _get_ancestor_indices_linear(x::AnyTracedRArray, indices::AbstractArray) + @show indices indices = CartesianIndices(x)[indices] + @show indices pidxs = parentindices(x) parent_indices = map(indices) do idx CartesianIndex(Base.reindex(pidxs, (idx.I...,))) end - return get_ancestor_indices(parent(x), parent_indices) + @show parent_indices + return @show get_ancestor_indices(parent(x), parent_indices) end Base.@nospecializeinfer function batch_ty(