diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 91d8df0049..84dd311f48 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -252,6 +252,10 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} ) end +# force permutation of dims, because we can optimize it anyway +# TODO should we add a method for `PermutedDimsArray` with type params? +PermutedDimsArray(x::TracedRArray, perm) = permutedims(x, perm) + Base.conj(A::TracedRArray) = A function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N} return TracedRArray{T,N}( diff --git a/test/basic.jl b/test/basic.jl index 3796b7275c..ba23c0e087 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -640,3 +640,11 @@ end @test @jit(f_row_major(x_ra)) ≈ f_row_major(x) end + +@testset "PermutedDimsArray" begin + x = randn(2, 3) + x_re = Reactant.to_rarray(x) + + f(u) = PermutedDimsArray(u, (2, 1)) + @test f(x) == @jit f(x_re) +end