diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 9a2b20de3b..0f193ba700 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -105,6 +105,14 @@ function set_mlir_data!(x::Base.ReshapedArray{TracedRNumber{T}}, data) where {T} return x end +function get_ancestor_indices( + x::Base.ReshapedArray{TracedRNumber{T},N}, indices::Vector{CartesianIndex{N}} +) where {T,N} + linear_indices = LinearIndices(size(x))[indices] + parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices] + return (parent_linear_indices,) +end + function get_ancestor_indices( x::Base.ReshapedArray{TracedRNumber{T},N}, indices... ) where {T,N} @@ -190,15 +198,12 @@ function get_ancestor_indices_inner( end function _get_ancestor_indices_linear(x::AnyTracedRArray, indices::AbstractArray) - @show indices indices = CartesianIndices(x)[indices] - @show indices pidxs = parentindices(x) parent_indices = map(indices) do idx CartesianIndex(Base.reindex(pidxs, (idx.I...,))) end - @show parent_indices - return @show get_ancestor_indices(parent(x), parent_indices) + return get_ancestor_indices(parent(x), parent_indices) end Base.@nospecializeinfer function batch_ty( diff --git a/test/indexing.jl b/test/indexing.jl index e88d04872d..5937f92785 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -419,3 +419,25 @@ end @jit test_slice_copy!(x) @test all(Array(x)[1, :] .== 0) end + +@testset "reshaped view setindex!" begin + du = rand(Float32, 10) + u = rand(Float32, 10) + p = 16.0f0 + t = 12.0f0 + + du_ra = Reactant.to_rarray(du; track_numbers=Number) + u_ra = Reactant.to_rarray(u; track_numbers=Number) + p_ra = Reactant.to_rarray(p; track_numbers=Number) + t_ra = Reactant.to_rarray(t; track_numbers=Number) + + function odef(du, u, p, t) + u = reshape(u, 2, :) + du = reshape(du, 2, :) + @view(du[1, :]) .= @view(u[1, :]) .+ p .+ t + @view(du[2, :]) .= @view(u[2, :]) .* 2 + return reshape(du, :) + end + + @test @jit(odef(du_ra, u_ra, p_ra, t_ra)) ≈ odef(du, u, p, t) +end