-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
I cannot compile a function that returns a scalar
julia> function test(x)
y = x^2
return y[1]
end
test (generic function with 1 method)
julia> x = Reactant.ConcreteRArray(rand(10, 10));
julia> test_comp = @compile test(x)ERROR: Scalar indexing is disallowed.
Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error
@ ./error.jl:35 [inlined]
[2] (::Nothing)(none::typeof(error), none::String)
@ Reactant ./<missing>:0
[3] ErrorException
@ ./boot.jl:323 [inlined]
[4] error
@ ./error.jl:35 [inlined]
[5] call_with_reactant(::Reactant.MustThrowError, ::typeof(error), ::String)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[6] errorscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151 [inlined]
[7] (::Nothing)(none::typeof(GPUArraysCore.errorscalar), none::String)
@ Reactant ./<missing>:0
[8] string
@ ./strings/substring.jl:236 [inlined]
[9] scalardesc
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:134 [inlined]
[10] errorscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:150 [inlined]
[11] call_with_reactant(::Reactant.MustThrowError, ::typeof(GPUArraysCore.errorscalar), ::String)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[12] _assertscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124 [inlined]
[13] (::Nothing)(none::typeof(GPUArraysCore._assertscalar), none::String, none::GPUArraysCore.ScalarIndexing)
@ Reactant ./<missing>:0
[14] _assertscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:123 [inlined]
[15] call_with_reactant(::typeof(GPUArraysCore._assertscalar), ::String, ::GPUArraysCore.ScalarIndexing)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[16] assertscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112 [inlined]
[17] (::Nothing)(none::typeof(GPUArraysCore.assertscalar), none::String)
@ Reactant ./<missing>:0
[18] current_task
@ ./task.jl:152 [inlined]
[19] task_local_storage
@ ./task.jl:280 [inlined]
[20] assertscalar
@ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:97 [inlined]
[21] call_with_reactant(::typeof(GPUArraysCore.assertscalar), ::String)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[22] getindex
@ ~/.julia/packages/Reactant/8rzTQ/src/Indexing.jl:40 [inlined]
[23] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 2}, none::Tuple{Int64, Int64})
@ Reactant ./<missing>:0
[24] getindex
@ ~/.julia/packages/Reactant/8rzTQ/src/Indexing.jl:40 [inlined]
[25] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Float64, 2}, ::Int64, ::Int64)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[26] getindex
@ ~/.julia/packages/Reactant/8rzTQ/src/Indexing.jl:48 [inlined]
[27] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 2}, none::Int64)
@ Reactant ./<missing>:0
[28] getproperty
@ ./Base.jl:49 [inlined]
[29] size
@ ~/.julia/packages/Reactant/8rzTQ/src/TracedRArray.jl:250 [inlined]
[30] getindex
@ ~/.julia/packages/Reactant/8rzTQ/src/Indexing.jl:48 [inlined]
[31] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Float64, 2}, ::Int64)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[32] test
@ ./REPL[11]:3 [inlined]
[33] (::Nothing)(none::typeof(test), none::Reactant.TracedRArray{Float64, 2})
@ Reactant ./<missing>:0
[34] test
@ ./REPL[11]:2 [inlined]
[35] call_with_reactant(::typeof(test), ::Reactant.TracedRArray{Float64, 2})
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/utils.jl:0
[36] make_mlir_fn(f::typeof(test), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/8rzTQ/src/TracedUtils.jl:345
[37] make_mlir_fn
@ ~/.julia/packages/Reactant/8rzTQ/src/TracedUtils.jl:275 [inlined]
[38] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:1605
[39] compile_mlir!
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:1567 [inlined]
[40] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3513
[41] compile_xla
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3485 [inlined]
[42] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3589
[43] top-level scope
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:2658
Some type information was truncated. Use `show(err)` to see complete types.Metadata
Metadata
Assignees
Labels
No labels