-
Notifications
You must be signed in to change notification settings - Fork 38
implement PermuteDimsArray for TracedRArray
#340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| # force permutation of dims, because we can optimize it anyway | ||
| # TODO should we add a method for `PermutedDimsArray` with type params? | ||
| PermutedDimsArray(x::TracedRArray, perm) = permutedims(x, perm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think this is correct since this will create a new array. Thus if you have
a
b = PermuteDimsArray(a)
a change to b should result in a change to a, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, because a must still reference the array before the tranposition
b is another array; it's just that the transposition is performed lazily on getindex/setindex!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> x = ones(4,4)
4×4 Matrix{Float64}:
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
julia> y = PermutedDimsArray(x, (2,1))
4×4 PermutedDimsArray(::Matrix{Float64}, (2, 1)) with eltype Float64:
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
julia> y[3,4] = 2
2
julia> y
4×4 PermutedDimsArray(::Matrix{Float64}, (2, 1)) with eltype Float64:
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 2.0
1.0 1.0 1.0 1.0
julia> x
4×4 Matrix{Float64}:
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
1.0 1.0 2.0 1.0
julia>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, you're right
what would the solution be then? my problem is with these lines https://github.com/bsc-quantic/Tenet.jl/blob/c7dbf2513d2edd80829e978319e05fd720bf7cfc/src/Numerics.jl#L21-L31
should i specialize +(::TracedRArray, ::PermuteDimsArray{TracedRArray}) and -(::TracedRArray, ::PermuteDimsArray{TracedRArray}) to call Ops.transpose just inside those methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should specialize for AnyTracedRArray and call materialize_traced_array for the rhs arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might need something like https://github.com/EnzymeAD/Reactant.jl/pull/342/files#diff-8282ff437f9070db415d2444bd8aed4b3b09fdff70a505f92709deb76c5403c3R65 for PermutedDimsArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it fixes it!
|
Closing in favor of #342 |
the implementation forces to perform the dimension permutation (unlike the semantics of
PermuteDimsArraywhich just stores the unaffected array and an affine map) because MLIRtensordoesn't work with MLIR affine maps and we can optimize thestablehlo.transposeanyway