Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
)
end

# force permutation of dims, because we can optimize it anyway
# TODO should we add a method for `PermutedDimsArray` with type params?
PermutedDimsArray(x::TracedRArray, perm) = permutedims(x, perm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think this is correct since this will create a new array. Thus if you have

a
b = PermuteDimsArray(a)

a change to b should result in a change to a, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because a must still reference the array before the tranposition

b is another array; it's just that the transposition is performed lazily on getindex/setindex!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


julia> x = ones(4,4)
4×4 Matrix{Float64}:
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0

julia> y = PermutedDimsArray(x, (2,1))
4×4 PermutedDimsArray(::Matrix{Float64}, (2, 1)) with eltype Float64:
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0

julia> y[3,4] = 2
2

julia> y
4×4 PermutedDimsArray(::Matrix{Float64}, (2, 1)) with eltype Float64:
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  2.0
 1.0  1.0  1.0  1.0

julia> x
4×4 Matrix{Float64}:
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0
 1.0  1.0  2.0  1.0

julia> 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, you're right

what would the solution be then? my problem is with these lines https://github.com/bsc-quantic/Tenet.jl/blob/c7dbf2513d2edd80829e978319e05fd720bf7cfc/src/Numerics.jl#L21-L31

should i specialize +(::TracedRArray, ::PermuteDimsArray{TracedRArray}) and -(::TracedRArray, ::PermuteDimsArray{TracedRArray}) to call Ops.transpose just inside those methods?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should specialize for AnyTracedRArray and call materialize_traced_array for the rhs arguments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah think so

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mofeing can you check if #342 handles your problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it fixes it!


Base.conj(A::TracedRArray) = A
function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N}
return TracedRArray{T,N}(
Expand Down
8 changes: 8 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,11 @@ end

@test @jit(f_row_major(x_ra)) ≈ f_row_major(x)
end

@testset "PermutedDimsArray" begin
x = randn(2, 3)
x_re = Reactant.to_rarray(x)

f(u) = PermutedDimsArray(u, (2, 1))
@test f(x) == @jit f(x_re)
end