Skip to content

Overlayed functions not detected inside another overlayed function #1589

@unit-of-inductance

Description

@unit-of-inductance
using LinearAlgebra, Reactant, Lux, Random

const xdev = reactant_device()

model = Dense(2 => 1)
points = [rand(2),rand(2)] |> xdev
parameters, states = Lux.setup(Random.default_rng(), model) |> xdev

function f(parameters, points)
    return mapreduce(p -> Lux.LuxCore.stateless_apply(model, p, parameters), +, points)
end

function g(parameters, points)
    return reduce(+, map(p -> Lux.LuxCore.stateless_apply(model, p, parameters), points))
end

# works without issues
g_compiled = @compile g(parameters, points)

# raises error
f_compiled = Reactant.allowscalar() do
   @compile f(parameters, points)
end

The above code raises the following error (this appears seperate to me to the issue from #1554, which seems to necessitate the allowscalar() call here) :

1-element ExceptionStack:
MethodError: no method matching Float64(::Reactant.TracedRNumber{Float64})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float64(::Irrational{:SQRT_HALF})
   @ Random irrationals.jl:251
  Float64(::Irrational{:catalan})
   @ Base irrationals.jl:251
  ...
Stacktrace

Stacktrace:
  [1] setindex!(a::Reactant.TracedRArray{Float64, 2}, v::Reactant.TracedRNumber{Float64}, index::CartesianIndex{2})
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/ZEral/src/TracedRArray.jl:370
  [2] _modify!
    @ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:91 [inlined]
  [3] _generic_matmatmul!(C::Reactant.TracedRArray{Float64, 2}, A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:926
  [4] generic_matmatmul!
    @ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
  [5] _mul!
    @ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [6] mul!
    @ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
  [7] mul!(C::Reactant.TracedRArray{Float64, 2}, A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}})
    @ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253
  [8] *(A::Reactant.TracedRArray{Float32, 2}, B::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}})
    @ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114
  [9] muladd(A::Reactant.TracedRArray{Float32, 2}, y::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 2, Reactant.TracedRArray{Float64, 1}, Tuple{}}, z::Reactant.TracedRArray{Float32, 1})
    @ LinearAlgebra /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
 [10] matmuladd
    @ ~/.julia/packages/LuxLib/Kx6MR/src/impl/matmul.jl:13 [inlined]
 [11] matmuladd
    @ ~/.julia/packages/LuxLib/Kx6MR/src/impl/matmul.jl:7 [inlined]
 [12] fused_dense
    @ ~/.julia/packages/LuxLib/Kx6MR/src/impl/dense.jl:10 [inlined]
 [13] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/Kx6MR/src/api/dense.jl:36 [inlined]
 [14] (::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True})(x::Reactant.TracedRArray{Float64, 1}, ps::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{})
    @ Lux ~/.julia/packages/Lux/sgU3g/src/layers/basic.jl:363
 [15] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [16] stateless_apply(model::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, x::Reactant.TracedRArray{Float64, 1}, ps::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}})
    @ LuxCore ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:166
 [17] (::var"#74#75"{@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}})(p::Reactant.TracedRArray{Float64, 1})
    @ Main ./REPL[19]:2
 [18] _mapreduce(f::var"#74#75"{@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}}, op::typeof(+), ::IndexLinear, A::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Base ./reduce.jl:437
 [19] _mapreduce_dim(f::Function, op::Function, ::Base._InitialValue, A::Vector{Reactant.TracedRArray{Float64, 1}}, ::Colon)
    @ Base ./reducedim.jl:337
 [20] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Base ./reducedim.jl:329
 [21] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}}; kwargs::@Kwargs{})
    @ Reactant ~/.julia/packages/Reactant/ZEral/src/Overlay.jl:164
 [22] mapreduce(f::Function, op::Function, A::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Reactant ~/.julia/packages/Reactant/ZEral/src/Overlay.jl:158
 [23] f
    @ ./REPL[19]:2 [inlined]
 [24] (::Nothing)(none::typeof(f), none::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, none::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Reactant ./<missing>:0
 [25] f
    @ ./REPL[19]:2 [inlined]
 [26] call_with_reactant(::typeof(f), ::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, ::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Reactant ~/.julia/packages/Reactant/ZEral/src/utils.jl:0
 [27] make_mlir_fn(f::typeof(f), args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/ZEral/src/TracedUtils.jl:332
 [28] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1734", Symbol} where var"#s1734", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:1549
 [29] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:1516 [inlined]
 [30] compile_xla(f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3427
 [31] compile_xla
    @ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3400 [inlined]
 [32] compile(f::Function, args::Tuple{@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:3499
 [33] macro expansion
    @ ~/.julia/packages/Reactant/ZEral/src/Compiler.jl:2580 [inlined]
 [34] (::var"#78#79")()
    @ Main ./REPL[19]:13
 [35] task_local_storage(body::var"#78#79", key::Symbol, val::GPUArraysCore.ScalarIndexing)
    @ Base ./task.jl:315
 [36] allowscalar(f::Function)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:179
 [37] top-level scope
    @ REPL[19]:12
 [38] top-level scope
    @ /nix/store/0f8r9l0fn26mjj96j08p92ak8gmbgwas-julia-bin-1.11.6/share/julia/stdlib/v1.11/REPL/src/REPL.jl:1694
 [39] eval
    @ ./boot.jl:430 [inlined]
 [40] eval
    @ ./Base.jl:130 [inlined]
 [41] repleval(m::Module, code::Expr, ::String)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:229
 [42] #112
    @ ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:192 [inlined]
 [43] with_logstate(f::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt}, logstate::Base.CoreLogging.LogState)
    @ Base.CoreLogging ./logging/logging.jl:524
 [44] with_logger
    @ ./logging/logging.jl:635 [inlined]
 [45] (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:193
 [46] #invokelatest#2
    @ ./essentials.jl:1055 [inlined]
 [47] invokelatest(::Any)
    @ Base ./essentials.jl:1052
 [48] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:34

Reproduced in an empty environment with the following `Pkg.status()`:
Status `/tmp/tmp.0D9Ovk1t8u/Project.toml`
  [b2108857] Lux v1.17.0
  [3c362404] Reactant v0.2.156

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions