Skip to content

Ambigous method for fill! #1601

@yolhan83

Description

@yolhan83

Hello, from discourse https://discourse.julialang.org/t/using-reactant-with-lux-and-enzyme-to-speed-up-training-in-physics-context/131898/17 here is the mwe :

using Reactant

function foo(X)
    A = zeros(eltype(X), size(X))
    return A
end
X = Reactant.to_rarray(rand(4,5))
@jit foo(X)

and the error

ERROR: LoadError: MethodError: fill!(::Matrix{Reactant.TracedRNumber{Float64}}, ::Reactant.TracedRNumber{Float64}) is ambiguous.

Candidates:
  fill!(dest::Array{T}, x) where T
    @ Base array.jl:326
  fill!(A::AbstractArray{Reactant.TracedRNumber{T}, N}, x::Reactant.TracedRNumber{T2}) where {T, N, T2}
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/gBXlB/src/TracedRArray.jl:658        

Possible fix, define
  fill!(::Array{Reactant.TracedRNumber{T}, N}, ::Reactant.TracedRNumber{T2}) where {T, N, T2}        

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/gBXlB/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(fill!), ::Matrix{Reactant.TracedRNumber{Float64}}, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/gBXlB/src/utils.jl:875
  [3] zeros
    @ ./array.jl:590 [inlined]
  [4] (::Nothing)(none::typeof(zeros), none::Type{Reactant.TracedRNumber{Float64}}, none::Tuple{Int64, Int64})
    @ Reactant ./<missing>:0
  [5] Array
    @ ./boot.jl:592 [inlined]
  [6] zeros
    @ ./array.jl:589 [inlined]
  [7] call_with_reactant(::Reactant.MustThrowError, ::typeof(zeros), ::Type{Reactant.TracedRNumber{Float64}}, ::Tuple{Int64, Int64})
    @ Reactant ~/.julia/packages/Reactant/gBXlB/src/utils.jl:0
  [8] foo
    @ /mnt/c/Users/yolha/Desktop/juju_tests/mini/test/main2.jl:4 [inlined]
  [9] (::Nothing)(none::typeof(foo), none::Reactant.TracedRArray{Float64, 2})
    @ Reactant ./<missing>:0
 [10] getproperty
    @ ./Base.jl:49 [inlined]
 [11] size
    @ ~/.julia/packages/Reactant/gBXlB/src/TracedRArray.jl:527 [inlined]
 [12] foo
    @ /mnt/c/Users/yolha/Desktop/juju_tests/mini/test/main2.jl:4 [inlined]
 [13] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float64, 2})
    @ Reactant ~/.julia/packages/Reactant/gBXlB/src/utils.jl:0
 [14] make_mlir_fn(f::typeof(foo), args::Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/gBXlB/src/TracedUtils.jl:332
 [15] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1734", Symbol} where var"#s1734", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:1555
 [16] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:1522 [inlined]
 [17] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:3433
 [18] compile_xla
    @ ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:3406 [inlined]
 [19] compile(f::Function, args::Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:3505
 [20] top-level scope
    @ ~/.julia/packages/Reactant/gBXlB/src/Compiler.jl:2586
in expression starting at /mnt/c/Users/yolha/Desktop/juju_tests/mini/test/main2.jl:8

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions