-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
Sometimes I need a non-lazy transpose, e.g. to change data layout in memory depending on the access patterns used by some algorithms. I tried copy of a lazy-transposed array, but that currently runs into scalar indexing trouble (same with collect instead of copy:
julia> using Reactant; Reactant.set_default_backend("cpu")
julia> A = rand(Float32, 100, 10); rA = Reactant.to_rarray(A);
julia> foo(A) = copy(transpose(A))
foo (generic function with 1 method)
julia> typeof(foo(A))
Matrix{Float32} (alias for Array{Float32, 2})
julia> @code_hlo foo(rA)
ERROR: Scalar indexing is disallowed.
Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error
@ ./error.jl:35 [inlined]
[2] (::Nothing)(none::typeof(error), none::String)
@ Reactant ./<missing>:0
[3] ErrorException
@ ./boot.jl:323 [inlined]
[4] error
@ ./error.jl:35 [inlined]
[5] call_with_reactant(::Reactant.MustThrowError, ::typeof(error), ::String)
@ Reactant /user/.julia/packages/Reactant/gBXlB/src/utils.jl:0
[...]
This works, though:
julia> function bar(A)
A_transposed = transpose(A)
B = similar(A_transposed)
B .= A_transposed
end
julia> typeof(bar(A))
Matrix{Float32} (alias for Array{Float32, 2})
julia> @code_hlo bar(rA)
module @reactant_bar attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<10x100xf32>) -> tensor<100x10xf32> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x100xf32>) -> tensor<100x10xf32>
return %0 : tensor<100x10xf32>
}
}Under the hood, copy(transpose(A)) does something like this:
julia> using LinearAlgebra
julia> function baz(A)
A_transposed = transpose(A)
B = transpose!(similar(A, reverse(axes(A))), A)
transpose!(B, A)
end
baz (generic function with 1 method)
julia> typeof(baz(A))
Matrix{Float32} (alias for Array{Float32, 2})but that fails with the same error
julia> @code_hlo baz(rA)
ERROR: Scalar indexing is disallowed.I guess we need specializations of transpose! and adjoint! for Reactant?
Metadata
Metadata
Assignees
Labels
No labels