-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
In a temporary environment
(jl_oIMZ8q) pkg> st
Status `/tmp/jl_oIMZ8q/Project.toml`
[872c559c] NNlib v0.9.31
⌃ [3c362404] Reactant v0.2.152
this works:
julia> using Reactant, NNlib
julia> x = Reactant.to_rarray(rand(10))
10-element ConcretePJRTArray{Float64,1}:
0.08903425793914821
0.7983627095290846
0.6869913407798728
0.21537352286537292
0.2539470134164943
0.07547134727519211
0.8528935829249477
0.9225987687818796
0.8782849123401902
0.20126944260945456
julia> @jit logsumexp(x)
ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.8566158486867264)
However, using v0.2.153 (and beyond) breaks the compilation of NNlib.logsumexp:
(jl_khbri3) pkg> st
Status `/tmp/jl_khbri3/Project.toml`
[872c559c] NNlib v0.9.31
⌃ [3c362404] Reactant v0.2.153
julia> using Reactant, NNlib
julia> x = Reactant.to_rarray(rand(10))
10-element ConcretePJRTArray{Float64,1}:
0.3737505265211689
0.8670935749296754
0.7975370385656305
0.8880806088880492
0.5321583149971186
0.5164878524049845
0.9024030516357548
0.5942285215483274
0.32937350425382883
0.45795255020414805
julia> @jit logsumexp(x)
ERROR: ArgumentError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer
Stacktrace:
[1] _empty_reduce_error()
@ Base ./reduce.jl:319
[2] reduce_empty(f::Function, T::Type)
@ Base ./reduce.jl:320
[3] reduce_empty(op::Base.BottomRF{typeof(Base.FastMath.max_fast)}, ::Type{Float64})
@ Base ./reduce.jl:357
[4] __default_init(::Type{Float64}, op::typeof(Base.FastMath.max_fast))
@ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:540
[5] overloaded_mapreduce(f::Any, op::Any, A::Reactant.TracedRArray{Float64, 1}; dims::Function, init::Reactant.TracedRNumber{Float64})
@ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:574
[6] overloaded_mapreduce
@ ~/.julia/packages/Reactant/hB4Fs/src/TracedRArray.jl:559 [inlined]
[7] mapreduce(f::Function, op::Function, A::Reactant.TracedRArray{Float64, 1}; kwargs::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}})
@ Reactant ~/.julia/packages/Reactant/hB4Fs/src/Overlay.jl:162
[8] #reduce#930
@ ./reducedim.jl:378 [inlined]
[9] (::Nothing)(none::Base.var"##reduce#930", none::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}}, none::typeof(reduce), none::Function, none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[10] merge
@ namedtuple.jl:349 [inlined]
[11] call_with_reactant(::Reactant.MustThrowError, ::Base.var"##reduce#930", ::@Kwargs{dims::Colon, init::Reactant.TracedRNumber{Float64}}, ::typeof(reduce), ::Function, ::Reactant.TracedRArray{Float64, 1})
@ Reactant reducedim.jl:378
[12] reduce
@ ./reducedim.jl:378 [inlined]
[13] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{dims::Colon, init::Reactant.TracedRNumber{Float64}}, none::typeof(reduce), none::typeof(Base.FastMath.max_fast), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[14] Pairs
@ ./essentials.jl:483 [inlined]
[15] pairs
@ ./iterators.jl:279 [inlined]
[16] reduce
@ ./reducedim.jl:378 [inlined]
[17] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{dims::Colon, init::Reactant.TracedRNumber{Float64}}, ::typeof(reduce), ::typeof(Base.FastMath.max_fast), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:0
[18] #fast_maximum#199
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92 [inlined]
[19] (::Nothing)(none::NNlib.var"##fast_maximum#199", none::Function, none::typeof(NNlib.fast_maximum), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[20] TracedRNumber
@ ~/.julia/packages/Reactant/hB4Fs/src/TracedRNumber.jl:100 [inlined]
[21] convert
@ number.jl:7 [inlined]
[22] zero
@ number.jl:309 [inlined]
[23] float
@ float.jl:391 [inlined]
[24] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##fast_maximum#199", ::Function, ::typeof(NNlib.fast_maximum), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92
[25] fast_maximum
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:92 [inlined]
[26] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{dims::Colon}, none::typeof(NNlib.fast_maximum), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[27] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{dims::Colon}, ::typeof(NNlib.fast_maximum), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:501
[28] #logsumexp#206
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:143 [inlined]
[29] (::Nothing)(none::NNlib.var"##logsumexp#206", none::Function, none::typeof(logsumexp), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[30] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##logsumexp#206", ::Function, ::typeof(logsumexp), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:143
[31] logsumexp
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:142 [inlined]
[32] (::Nothing)(none::typeof(logsumexp), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[33] logsumexp
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:142 [inlined]
[34] call_with_reactant(::typeof(logsumexp), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/Reactant/hB4Fs/src/utils.jl:0
[35] make_mlir_fn(f::typeof(logsumexp), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/hB4Fs/src/TracedUtils.jl:330
[36] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:1544
[37] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:1511 [inlined]
[38] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3420
[39] compile_xla
@ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3393 [inlined]
[40] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:3492
[41] top-level scope
@ ~/.julia/packages/Reactant/hB4Fs/src/Compiler.jl:2573
Some type information was truncated. Use `show(err)` to see complete types.
Metadata
Metadata
Assignees
Labels
No labels