-
Notifications
You must be signed in to change notification settings - Fork 33
Closed
Description
using LinearAlgebra, Reactant, Lux, Random
const xdev = reactant_device()
model = Dense(2 => 1)
points = [rand(2),rand(2)] |> xdev
parameters, states = Lux.setup(Random.default_rng(), model) |> xdev
function f(parameters, points)
return mapreduce(p -> Lux.LuxCore.stateless_apply(model, p, parameters), +, points)
end
function g(parameters, points)
return reduce(+, map(p -> Lux.LuxCore.stateless_apply(model, p, parameters), points))
end
# works without issues
g_compiled = @compile g(parameters, points)
# raises error
f_compiled = Reactant.allowscalar() do
@compile f(parameters, points)
endThe above code raises the following error (this appears seperate to me to the issue from #1554, which seems to necessitate the allowscalar() call here) :
1-element ExceptionStack:
MethodError: no method matching Float64(::Reactant.TracedRNumber{Float64})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float64(::Irrational{:SQRT_HALF})
@ Random irrationals.jl:251
Float64(::Irrational{:catalan})
@ Base irrationals.jl:251
...
Stacktrace
Stacktrace:
[1] setindex!(a::Reactant.TracedRArray{Float64, 2}, v::Reactant.TracedRNumber{Float64}, index::CartesianIndex{2})
@ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/ZEral/src/TracedRArray.jl:370
[2] _modify!
@ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:91 [inlined]
[3] _generic_matmatmul!(C::Reactant.TracedRArray{Float64, 2}, A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:926
[4] generic_matmatmul!
@ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
[5] _mul!
@ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[6] mul!
@ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[7] mul!(C::Reactant.TracedRArray{Float64, 2}, A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}})
@ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253
[8] *(A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}})
@ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114
[9] muladd(A::Reactant.TracedRArray{Float32, 2}, y::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}}, z::Reactant.TracedRArray{Float32, 1})
@ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
[10] matmuladd
@ ~/.julia/packages/LuxLib/Kx6MR/src/impl/matmul.jl:13 [inlined]
[11] matmuladd
@ ~/.julia/packages/LuxLib/Kx6MR/src/impl/matmul.jl:7 [inlined]
[12] fused_dense
@ ~/.julia/packages/LuxLib/Kx6MR/src/impl/dense.jl:10 [inlined]
[13] fused_dense_bias_activation
@ ~/.julia/packages/LuxLib/Kx6MR/src/api/dense.jl:36 [inlined]
[14] (::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True})(x::Reactant.TracedRArray{Float64, 1}, ps::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{})
@ Lux ~/.julia/packages/Lux/sgU3g/src/layers/basic.jl:363
[15] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[16] stateless_apply(model::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, x::Reactant.TracedRArray{Float64, 1}, ps::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}})
@ LuxCore ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:166
[17] (::var"#74#75"{@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}})(p::Reactant.TracedRArray{Float64, 1})
@ Main ./REPL[19]:2
[18] _mapreduce(f::var"#74#75"{@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}}, op::typeof(+), ::IndexLinear, A::Vector{Reactant.TracedRArray{Float64, 1}})
@ Base ./reduce.jl:437
[19] _mapreduce_dim(f::Function, op::Function, ::Base._InitialValue, A::Vector{Reactant.TracedRArray{Float64, 1}}, ::Colon)
@ Base ./reducedim.jl:337
[20] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}})
@ Base ./reducedim.jl:329
[21] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}}; kwargs::@Kwargs{})
@ Reactant ~/.julia/packages/Reactant/ZEral/src/Overlay.jl:164
[22] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}})
@ Reactant ~/.julia/packages/Reactant/ZEral/src/Overlay.jl:158
[23] f
@ ./REPL[19]:2 [inlined]
[24] (::Nothing)(none::typeof(f), none::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, none::Vector{Reactant.TracedRArray{Float64, 1}})
@ Reactant ./<missing>:0
[25] f
@ ./REPL[19]:2 [inlined]
[26] call_with_reactant(::typeof(f), ::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, ::Vector{Reactant.TracedRArray{Float64, 1}})
@ Reactant ~/.julia/packages/Reactant/ZEral/src/utils.jl:0
[27] make_mlir_fn(f::typeof(f), args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/ZEral/src/TracedUtils.jl:332
[28] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1734", Symbol} where var"#s1734", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:1549
[29] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:1516 [inlined]
[30] compile_xla(f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3427
[31] compile_xla
@ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3400 [inlined]
[32] compile(f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3499
[33] macro expansion
@ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:2580 [inlined]
[34] (::var"#78#79")()
@ Main ./REPL[19]:13
[35] task_local_storage(body::var"#78#79", key::Symbol, val::GPUArraysCore.ScalarIndexing)
@ Base ./task.jl:315
[36] allowscalar(f::Function)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:179
[37] top-level scope
@ REPL[19]:12
[38] top-level scope
@ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/REPL/src/REPL.jl:1694
[39] eval
@ ./boot.jl:430 [inlined]
[40] eval
@ ./Base.jl:130 [inlined]
[41] repleval(m::Module, code::Expr, ::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:229
[42] #112
@ ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:192 [inlined]
[43] with_logstate(f::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt}, logstate::Base.CoreLogging.LogState)
@ Base.CoreLogging ./logging/logging.jl:524
[44] with_logger
@ ./logging/logging.jl:635 [inlined]
[45] (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:193
[46] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[47] invokelatest(::Any)
@ Base ./essentials.jl:1052
[48] (::VSCodeServer.var"#64#65")()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:34
Status `/tmp/tmp.0D9Ovk1t8u/Project.toml`
[b2108857] Lux v1.17.0
[3c362404] Reactant v0.2.156
Metadata
Metadata
Assignees
Labels
No labels