-
Notifications
You must be signed in to change notification settings - Fork 33
Description
I'm having some issues whenever I attempt to use Optimiser.jl's ClipNorm() in conjunction with training using Lux + Reactant + Enzyme (e.g. as in most tutorials in Lux's current documentation), such as the Conv-Mixer, Simple CNN, or ResNet20 examples (specifically, the same _norm error occurs at the Training Loop section of all three.
I believe this is due to Optimisers.jl's ClipNorm() calling _norm, which expects an Array and a Real type but will instead be passed both a TracedRArray{Float32, 2}, TracedRNumber{Float64}.
I wasn't completely sure if this should be an issue submitted to Optimisers.jl instead, where a more general _norm can be passed in Reactant.jl types properly, but I wanted to get the issue started.
Here is a MWE below:
using Lux, Enzyme, Optimisers, Reactant, Random
xdev = reactant_device(; force=true)
m = Dense(10 => 10)
ps, st = Lux.setup(Random.default_rng(), m) |> xdev
data = (randn(10, 32) |> xdev, randn(10, 32) |> xdev)
clip = Optimisers.ClipNorm(0.5)
opt = Optimisers.OptimiserChain(clip, Optimisers.AdamW(1e-3))
ts = Lux.Training.TrainState(m, ps, st, opt)Which is fine and will produce the trainstate correctly, but then the error starts when starting to train by e.g. calling Lux's single_train_step!:
_, loss, stats, ts = Training.single_train_step!(
AutoEnzyme(),
MSELoss(),
data,
ts;
return_gradients=Val(false),
)which produces the following error:
MethodError: no method matching _norm(::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRNumber{Float64})
The function `_norm` exists, but no method is defined for this combination of argument types.
Closest candidates are:
_norm(::AbstractArray, !Matched::Real)
@ Optimisers ~/.julia/packages/Optimisers/W5seC/src/rules.jl:729
_norm(!Matched::Base.Broadcast.Broadcasted, !Matched::Real)
@ Optimisers ~/.julia/packages/Optimisers/W5seC/src/rules.jl:730
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0 [inlined]
[2] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers._norm), ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRNumber{Float64})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:875
[3] apply!
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:720 [inlined]
[4] (::Nothing)(none::typeof(Optimisers.apply!), none::ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, none::Nothing, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2})
@ Reactant ./<missing>:0
[5] getproperty
@ ./Base.jl:49 [inlined]
[6] apply!
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:720 [inlined]
[7] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers.apply!), ::ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, ::Nothing, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[8] #168
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:790 [inlined]
[9] (::Nothing)(none::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing})
@ Reactant ./<missing>:0
[10] indexed_iterate
@ ./tuple.jl:159 [inlined]
[11] #168
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:787 [inlined]
[12] call_with_reactant(::Reactant.MustThrowError, ::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[13] BottomRF
@ ./reduce.jl:86 [inlined]
[14] afoldl
@ ./operators.jl:553 [inlined]
[15] (::Nothing)(none::typeof(Base.afoldl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[16] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base.afoldl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, ::Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:501
[17] _foldl_impl
@ ./reduce.jl:68 [inlined]
[18] (::Nothing)(none::typeof(Base._foldl_impl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[19] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base._foldl_impl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:501
[20] foldl_impl
@ ./reduce.jl:48 [inlined]
[21] (::Nothing)(none::typeof(Base.foldl_impl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[22] foldl_impl
@ ./reduce.jl:48 [inlined]
[23] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base.foldl_impl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[24] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[25] (::Nothing)(none::typeof(Base.mapfoldl_impl), none::typeof(identity), none::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[26] BottomRF
@ ./reduce.jl:82 [inlined]
[27] mapfoldl_impl
@ ./reduce.jl:43 [inlined]
[28] call_with_reactant(::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[29] #mapfoldl#337
@ ./reduce.jl:175 [inlined]
[30] (::Nothing)(none::Base.var"##mapfoldl#337", none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::typeof(mapfoldl), none::Function, none::Function, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[31] call_with_reactant(::Base.var"##mapfoldl#337", ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::typeof(mapfoldl), ::Function, ::Function, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant reduce.jl:175
[32] mapfoldl
@ ./reduce.jl:175 [inlined]
[33] foldl
@ ./reduce.jl:198 [inlined]
[34] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{init::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}}, none::typeof(foldl), none::Function, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant ./<missing>:0
[35] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{init::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}}, ::typeof(foldl), ::Function, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
@ Reactant reduce.jl:198
[36] apply!
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:786 [inlined]
[37] (::Nothing)(none::typeof(Optimisers.apply!), none::OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}, none::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2}, none::Tuple{})
@ Reactant ./<missing>:0
[38] apply!
@ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:786 [inlined]
[39] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers.apply!), ::OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}, ::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[40] apply!
@ ~/.julia/packages/Lux/9B1iJ/src/helpers/optimizers.jl:26 [inlined]
[41] (::Nothing)(none::typeof(Optimisers.apply!), none::Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, none::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2})
@ Reactant ./<missing>:0
[42] getproperty
@ ./Base.jl:49 [inlined]
[43] apply!
@ ~/.julia/packages/Lux/9B1iJ/src/helpers/optimizers.jl:26 [inlined]
[44] call_with_reactant(::typeof(Optimisers.apply!), ::Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, ::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[45] macro expansion
@ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:293 [inlined]
[46] applyiterate_with_reactant(::typeof(iterate), ::typeof(Optimisers.apply!), ::Tuple{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Reactant.TracedRArray{Float32, 2}})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:279
[47] #_update!#10
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:96 [inlined]
[48] (::Nothing)(none::Optimisers.var"##_update!#10", none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, none::Reactant.TracedRArray{Float32, 2})
@ Reactant ./<missing>:0
[49] #_update!#10
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:93 [inlined]
[50] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[51] _update!
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:92 [inlined]
[52] #8
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:85 [inlined]
[53] map
@ ./tuple.jl:383 [inlined]
[54] map
@ ./namedtuple.jl:266 [inlined]
[55] mapvalue
@ ~/.julia/packages/Optimisers/W5seC/src/utils.jl:2 [inlined]
[56] #_update!#7
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:85 [inlined]
[57] _update!
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:81 [inlined]
[58] update!
@ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:77 [inlined]
[59] compute_gradients_internal_and_step!
@ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:151 [inlined]
[60] (::Nothing)(none::typeof(LuxReactantExt.compute_gradients_internal_and_step!), none::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, none::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, none::Tuple{Reactant.TracedRArray{Float64, 2}, Reactant.TracedRArray{Float64, 2}}, none::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, none::@NamedTuple{}, none::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}}, none::Static.False)
@ Reactant ./<missing>:0
[61] Const
@ ~/.julia/packages/EnzymeCore/lmG5F/src/EnzymeCore.jl:30 [inlined]
[62] compute_gradients_internal
@ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:7 [inlined]
[63] compute_gradients_internal_and_step!
@ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:148 [inlined]
[64] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, ::Tuple{Reactant.TracedRArray{Float64, 2}, Reactant.TracedRArray{Float64, 2}}, ::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, ::@NamedTuple{}, ::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}}, ::Static.False)
@ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
[65] make_mlir_fn(f::typeof(LuxReactantExt.compute_gradients_internal_and_step!), args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/EsmaI/src/TracedUtils.jl:332
[66] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1732", Symbol} where var"#s1732", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1555
[67] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1522 [inlined]
[68] compile_xla(f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3433
[69] compile_xla
@ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3406 [inlined]
[70] compile(f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3505
[71] macro expansion
@ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:2586 [inlined]
[72] (::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}})()
@ LuxReactantExt ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:86
[73] with(f::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, pair::Pair{Base.ScopedValues.ScopedValue{Union{Nothing, Tuple{Reactant.PrecisionConfig.T, Reactant.PrecisionConfig.T}, Reactant.PrecisionConfig.T}}, Reactant.PrecisionConfig.T}, rest::Pair{Base.ScopedValues.ScopedValue{Union{Nothing, Tuple{Reactant.PrecisionConfig.T, Reactant.PrecisionConfig.T}, Reactant.PrecisionConfig.T}}, Reactant.PrecisionConfig.T})
@ Base.ScopedValues ./scopedvalues.jl:269
[74] #with_config#14
@ ~/.julia/packages/Reactant/EsmaI/src/Configuration.jl:62 [inlined]
[75] with_config
@ ~/.julia/packages/Reactant/EsmaI/src/Configuration.jl:34 [inlined]
[76] with_default_precision_config(f::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, ps::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}})
@ LuxReactantExt ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/LuxReactantExt.jl:67
[77] macro expansion
@ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:85 [inlined]
[78] (::LuxReactantExt.var"#14#16"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}})()
@ LuxReactantExt ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:105
[79] annotate(f::LuxReactantExt.var"#14#16"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, name::String, level::Int32)
@ Reactant.Profiler ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:80
[80] annotate
@ ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:76 [inlined]
[81] single_train_step_impl!
@ ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:105 [inlined]
[82] #single_train_step!#6
@ ~/.julia/packages/Lux/9B1iJ/src/helpers/training.jl:297 [inlined]There was also a separate trace that the Jupyter notebook produced for me below when running this same line:
"Inconsistent guaranteed error IR 198 1 ─ %1 = Base.mapfoldl::typeof(mapfoldl)\n │ %2 = Base.identity::typeof(identity)\n │ %3 = Core.getfield(_2, :init)::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}\n │ %4 = invoke Base.:(var\"#mapfoldl#337\")(%3::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, %1::typeof(mapfoldl), %2::Function, _4::Function, _5::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})::Any\n └── return %4\n "