From 2f9e4f0587fd8863e4bb6dc0a85c84e1090a0ec7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 7 Dec 2024 16:54:04 +0100 Subject: [PATCH 1/2] implement `PermuteDimsArray` for `TracedRArray` --- src/TracedRArray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 91d8df0049..84dd311f48 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -252,6 +252,10 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} ) end +# force permutation of dims, because we can optimize it anyway +# TODO should we add a method for `PermutedDimsArray` with type params? +PermutedDimsArray(x::TracedRArray, perm) = permutedims(x, perm) + Base.conj(A::TracedRArray) = A function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N} return TracedRArray{T,N}( From f44b6013e8e46c3774b7c24b7bddd08d4093b424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 7 Dec 2024 16:56:54 +0100 Subject: [PATCH 2/2] test --- test/basic.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 3796b7275c..ba23c0e087 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -640,3 +640,11 @@ end @test @jit(f_row_major(x_ra)) ≈ f_row_major(x) end + +@testset "PermutedDimsArray" begin + x = randn(2, 3) + x_re = Reactant.to_rarray(x) + + f(u) = PermutedDimsArray(u, (2, 1)) + @test f(x) == @jit f(x_re) +end