Skip to content

Going from Reactant v0.2.152 to v0.2.153 breaks compilation of NNlib.logsumexp #1593

@Sleort

Description

@Sleort

In a temporary environment

(jl_oIMZ8q) pkg> st
Status `/tmp/jl_oIMZ8q/Project.toml`
  [872c559c] NNlib v0.9.31
⌃ [3c362404] Reactant v0.2.152

this works:

julia> using Reactant, NNlib

julia> x = Reactant.to_rarray(rand(10))
10-element ConcretePJRTArray{Float64,1}:
 0.08903425793914821
 0.7983627095290846
 0.6869913407798728
 0.21537352286537292
 0.2539470134164943
 0.07547134727519211
 0.8528935829249477
 0.9225987687818796
 0.8782849123401902
 0.20126944260945456

julia> @jit logsumexp(x)
ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.8566158486867264)

However, using v0.2.153 (and beyond) breaks the compilation of NNlib.logsumexp:

(jl_khbri3) pkg> st
Status `/tmp/jl_khbri3/Project.toml`
  [872c559c] NNlib v0.9.31
⌃ [3c362404] Reactant v0.2.153

julia> using Reactant, NNlib

julia> x = Reactant.to_rarray(rand(10))
10-element ConcretePJRTArray{Float64,1}:
 0.3737505265211689
 0.8670935749296754
 0.7975370385656305
 0.8880806088880492
 0.5321583149971186
 0.5164878524049845
 0.9024030516357548
 0.5942285215483274
 0.32937350425382883
 0.45795255020414805

julia> @jit logsumexp(x)
ERROR: ArgumentError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer
Stacktrace:
  [1] _empty_reduce_error()
    @ Base ./reduce.jl:319
  [2] reduce_empty(f::Function, T::Type)
    @ Base ./reduce.jl:320
  [3] reduce_empty(op::Base.BottomRF{typeof(Base.FastMath.max_fast)}, ::Type{Float64})
    @ Base ./reduce.jl:357
  [4] __default_init(::Type{Float64}, op::typeof(Base.FastMath.max_fast))
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:540
  [5] overloaded_mapreduce(f::Any, op::Any, A::Reactant.TracedRArray{Float64, 1}; dims::Function, init::Reactant.TracedRNumber{Float64})
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:574
  [6] overloaded_mapreduce
    @ ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:559 [inlined]
  [7] mapreduce(f::Function, op::Function, A::Reactant.TracedRArray{Float64, 1}; kwargs::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}})
    @ Reactant ~/.julia/packages/Reactant/hB4Fs/src/Overlay.jl:162
  [8] #reduce#930
    @ ./reducedim.jl:378 [inlined]
  [9] (::Nothing)(none::Base.var"##reduce#930", none::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}}, none::typeof(reduce), none::Function, none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [10] merge
    @ namedtuple.jl:349 [inlined]
 [11] call_with_reactant(::Reactant.MustThrowError, ::Base.var"##reduce#930", ::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}}, ::typeof(reduce), ::Function, ::Reactant.TracedRArray{Float64, 1})
    @ Reactant reducedim.jl:378
 [12] reduce
    @ ./reducedim.jl:378 [inlined]
 [13] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{dims::Colon, init::Reactant.TracedRNumber{Float64}}, none::typeof(reduce), none::typeof(Base.FastMath.max_fast), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [14] Pairs
    @ ./essentials.jl:483 [inlined]
 [15] pairs
    @ ./iterators.jl:279 [inlined]
 [16] reduce
    @ ./reducedim.jl:378 [inlined]
 [17] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{dims::Colon, init::Reactant.TracedRNumber{Float64}}, ::typeof(reduce), ::typeof(Base.FastMath.max_fast), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:0
 [18] #fast_maximum#199
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92 [inlined]
 [19] (::Nothing)(none::NNlib.var"##fast_maximum#199", none::Function, none::typeof(NNlib.fast_maximum), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [20] TracedRNumber
    @ ~/.julia/packages/Reactant/hB4Fs/src/TracedRNumber.jl:100 [inlined]
 [21] convert
    @ number.jl:7 [inlined]
 [22] zero
    @ number.jl:309 [inlined]
 [23] float
    @ float.jl:391 [inlined]
 [24] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##fast_maximum#199", ::Function, ::typeof(NNlib.fast_maximum), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92
 [25] fast_maximum
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92 [inlined]
 [26] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{dims::Colon}, none::typeof(NNlib.fast_maximum), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [27] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{dims::Colon}, ::typeof(NNlib.fast_maximum), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:501
 [28] #logsumexp#206
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:143 [inlined]
 [29] (::Nothing)(none::NNlib.var"##logsumexp#206", none::Function, none::typeof(logsumexp), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [30] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##logsumexp#206", ::Function, ::typeof(logsumexp), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:143
 [31] logsumexp
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:142 [inlined]
 [32] (::Nothing)(none::typeof(logsumexp), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [33] logsumexp
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:142 [inlined]
 [34] call_with_reactant(::typeof(logsumexp), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:0
 [35] make_mlir_fn(f::typeof(logsumexp), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/hB4Fs/src/TracedUtils.jl:330
 [36] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:1544
 [37] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:1511 [inlined]
 [38] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3420
 [39] compile_xla
    @ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3393 [inlined]
 [40] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3492
 [41] top-level scope
    @ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:2573
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions