Skip to content

Reactant's traced scalar types causing issues with Optimisers.jl's ClipNorm() #1645

@vulcan-spacemage

Description

@vulcan-spacemage

I'm having some issues whenever I attempt to use Optimiser.jl's ClipNorm() in conjunction with training using Lux + Reactant + Enzyme (e.g. as in most tutorials in Lux's current documentation), such as the Conv-Mixer, Simple CNN, or ResNet20 examples (specifically, the same _norm error occurs at the Training Loop section of all three.

I believe this is due to Optimisers.jl's ClipNorm() calling _norm, which expects an Array and a Real type but will instead be passed both a TracedRArray{Float32, 2}, TracedRNumber{Float64}.

I wasn't completely sure if this should be an issue submitted to Optimisers.jl instead, where a more general _norm can be passed in Reactant.jl types properly, but I wanted to get the issue started.

Here is a MWE below:

using Lux, Enzyme, Optimisers, Reactant, Random

xdev = reactant_device(; force=true)

m = Dense(10 => 10)
ps, st = Lux.setup(Random.default_rng(), m) |> xdev
data = (randn(10, 32) |> xdev, randn(10, 32) |> xdev)

clip = Optimisers.ClipNorm(0.5)
opt = Optimisers.OptimiserChain(clip, Optimisers.AdamW(1e-3))
ts = Lux.Training.TrainState(m, ps, st, opt)

Which is fine and will produce the trainstate correctly, but then the error starts when starting to train by e.g. calling Lux's single_train_step!:

_, loss, stats, ts = Training.single_train_step!(
  AutoEnzyme(),
  MSELoss(),
  data,
  ts;
  return_gradients=Val(false),
)

which produces the following error:

MethodError: no method matching _norm(::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRNumber{Float64})
The function `_norm` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  _norm(::AbstractArray, !Matched::Real)
   @ Optimisers ~/.julia/packages/Optimisers/W5seC/src/rules.jl:729
  _norm(!Matched::Base.Broadcast.Broadcasted, !Matched::Real)
   @ Optimisers ~/.julia/packages/Optimisers/W5seC/src/rules.jl:730


Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers._norm), ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:875
  [3] apply!
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:720 [inlined]
  [4] (::Nothing)(none::typeof(Optimisers.apply!), none::ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, none::Nothing, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2})
    @ Reactant ./<missing>:0
  [5] getproperty
    @ ./Base.jl:49 [inlined]
  [6] apply!
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:720 [inlined]
  [7] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers.apply!), ::ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, ::Nothing, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
  [8] #168
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:790 [inlined]
  [9] (::Nothing)(none::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing})
    @ Reactant ./<missing>:0
 [10] indexed_iterate
    @ ./tuple.jl:159 [inlined]
 [11] #168
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:787 [inlined]
 [12] call_with_reactant(::Reactant.MustThrowError, ::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [13] BottomRF
    @ ./reduce.jl:86 [inlined]
 [14] afoldl
    @ ./operators.jl:553 [inlined]
 [15] (::Nothing)(none::typeof(Base.afoldl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [16] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base.afoldl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, ::Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:501
 [17] _foldl_impl
    @ ./reduce.jl:68 [inlined]
 [18] (::Nothing)(none::typeof(Base._foldl_impl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [19] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base._foldl_impl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:501
 [20] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [21] (::Nothing)(none::typeof(Base.foldl_impl), none::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [22] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [23] call_with_reactant(::Reactant.MustThrowError, ::typeof(Base.foldl_impl), ::Base.BottomRF{Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [24] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [25] (::Nothing)(none::typeof(Base.mapfoldl_impl), none::typeof(identity), none::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [26] BottomRF
    @ ./reduce.jl:82 [inlined]
 [27] mapfoldl_impl
    @ ./reduce.jl:43 [inlined]
 [28] call_with_reactant(::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Optimisers.var"#168#169"{Reactant.TracedRArray{Float32, 2}, Tuple{}}, ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [29] #mapfoldl#337
    @ ./reduce.jl:175 [inlined]
 [30] (::Nothing)(none::Base.var"##mapfoldl#337", none::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, none::typeof(mapfoldl), none::Function, none::Function, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [31] call_with_reactant(::Base.var"##mapfoldl#337", ::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, ::typeof(mapfoldl), ::Function, ::Function, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant reduce.jl:175
 [32] mapfoldl
    @ ./reduce.jl:175 [inlined]
 [33] foldl
    @ ./reduce.jl:198 [inlined]
 [34] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{init::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}}, none::typeof(foldl), none::Function, none::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant ./<missing>:0
 [35] call_with_reactant(::Reactant.MustThrowError, ::typeof(Core.kwcall), ::@NamedTuple{init::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}}, ::typeof(foldl), ::Function, ::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})
    @ Reactant reduce.jl:198
 [36] apply!
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:786 [inlined]
 [37] (::Nothing)(none::typeof(Optimisers.apply!), none::OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}, none::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2}, none::Tuple{})
    @ Reactant ./<missing>:0
 [38] apply!
    @ ~/.julia/packages/Optimisers/W5seC/src/rules.jl:786 [inlined]
 [39] call_with_reactant(::Reactant.MustThrowError, ::typeof(Optimisers.apply!), ::OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}, ::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [40] apply!
    @ ~/.julia/packages/Lux/9B1iJ/src/helpers/optimizers.jl:26 [inlined]
 [41] (::Nothing)(none::typeof(Optimisers.apply!), none::Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, none::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, none::Reactant.TracedRArray{Float32, 2}, none::Reactant.TracedRArray{Float32, 2})
    @ Reactant ./<missing>:0
 [42] getproperty
    @ ./Base.jl:49 [inlined]
 [43] apply!
    @ ~/.julia/packages/Lux/9B1iJ/src/helpers/optimizers.jl:26 [inlined]
 [44] call_with_reactant(::typeof(Optimisers.apply!), ::Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, ::Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [45] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:293 [inlined]
 [46] applyiterate_with_reactant(::typeof(iterate), ::typeof(Optimisers.apply!), ::Tuple{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, Reactant.TracedRArray{Float32, 2}}, ::Tuple{Reactant.TracedRArray{Float32, 2}})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:279
 [47] #_update!#10
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:96 [inlined]
 [48] (::Nothing)(none::Optimisers.var"##_update!#10", none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, none::Reactant.TracedRArray{Float32, 2})
    @ Reactant ./<missing>:0
 [49] #_update!#10
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:93 [inlined]
 [50] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, ::Reactant.TracedRArray{Float32, 2})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [51] _update!
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:92 [inlined]
 [52] #8
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:85 [inlined]
 [53] map
    @ ./tuple.jl:383 [inlined]
 [54] map
    @ ./namedtuple.jl:266 [inlined]
 [55] mapvalue
    @ ~/.julia/packages/Optimisers/W5seC/src/utils.jl:2 [inlined]
 [56] #_update!#7
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:85 [inlined]
 [57] _update!
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:81 [inlined]
 [58] update!
    @ ~/.julia/packages/Optimisers/W5seC/src/interface.jl:77 [inlined]
 [59] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:151 [inlined]
 [60] (::Nothing)(none::typeof(LuxReactantExt.compute_gradients_internal_and_step!), none::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, none::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, none::Tuple{Reactant.TracedRArray{Float64, 2}, Reactant.TracedRArray{Float64, 2}}, none::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, none::@NamedTuple{}, none::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}}, none::Static.False)
    @ Reactant ./<missing>:0
 [61] Const
    @ ~/.julia/packages/EnzymeCore/lmG5F/src/EnzymeCore.jl:30 [inlined]
 [62] compute_gradients_internal
    @ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:7 [inlined]
 [63] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:148 [inlined]
 [64] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, ::Tuple{Reactant.TracedRArray{Float64, 2}, Reactant.TracedRArray{Float64, 2}}, ::@NamedTuple{weight::Reactant.TracedRArray{Float32, 2}, bias::Reactant.TracedRArray{Float32, 1}}, ::@NamedTuple{}, ::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}}}}, Tuple{Nothing, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}}}, ::Static.False)
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [65] make_mlir_fn(f::typeof(LuxReactantExt.compute_gradients_internal_and_step!), args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/EsmaI/src/TracedUtils.jl:332
 [66] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1732", Symbol} where var"#s1732", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1555
 [67] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1522 [inlined]
 [68] compile_xla(f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3433
 [69] compile_xla
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3406 [inlined]
 [70] compile(f::Function, args::Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, @NamedTuple{}, @NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{OptimiserChain{Tuple{ClipNorm{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, AdamW{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}, Tuple{Nothing, Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}}}}, Static.False}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3505
 [71] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:2586 [inlined]
 [72] (::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}})()
    @ LuxReactantExt ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:86
 [73] with(f::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, pair::Pair{Base.ScopedValues.ScopedValue{Union{Nothing, Tuple{Reactant.PrecisionConfig.T, Reactant.PrecisionConfig.T}, Reactant.PrecisionConfig.T}}, Reactant.PrecisionConfig.T}, rest::Pair{Base.ScopedValues.ScopedValue{Union{Nothing, Tuple{Reactant.PrecisionConfig.T, Reactant.PrecisionConfig.T}, Reactant.PrecisionConfig.T}}, Reactant.PrecisionConfig.T})
    @ Base.ScopedValues ./scopedvalues.jl:269
 [74] #with_config#14
    @ ~/.julia/packages/Reactant/EsmaI/src/Configuration.jl:62 [inlined]
 [75] with_config
    @ ~/.julia/packages/Reactant/EsmaI/src/Configuration.jl:34 [inlined]
 [76] with_default_precision_config(f::LuxReactantExt.var"#15#17"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, ps::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}})
    @ LuxReactantExt ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/LuxReactantExt.jl:67
 [77] macro expansion
    @ ~/.julia/packages/Lux/9B1iJ/ext/LuxReactantExt/training.jl:85 [inlined]
 [78] (::LuxReactantExt.var"#14#16"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}})()
    @ LuxReactantExt ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:105
 [79] annotate(f::LuxReactantExt.var"#14#16"{Lux.Training.ReactantBackend{Static.False}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, name::String, level::Int32)
    @ Reactant.Profiler ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:80
 [80] annotate
    @ ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:76 [inlined]
 [81] single_train_step_impl!
    @ ~/.julia/packages/Reactant/EsmaI/src/Profiler.jl:105 [inlined]
 [82] #single_train_step!#6
    @ ~/.julia/packages/Lux/9B1iJ/src/helpers/training.jl:297 [inlined]

There was also a separate trace that the Jupyter notebook produced for me below when running this same line:
"Inconsistent guaranteed error IR 198 1 ─ %1 = Base.mapfoldl::typeof(mapfoldl)\n │ %2 = Base.identity::typeof(identity)\n │ %3 = Core.getfield(_2, :init)::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}\n │ %4 = invoke Base.:(var\"#mapfoldl#337\")(%3::Tuple{Tuple{}, Reactant.TracedRArray{Float32, 2}}, %1::typeof(mapfoldl), %2::Function, _4::Function, _5::Tuple{Tuple{ClipNorm{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Nothing}, Tuple{AdamW{Reactant.TracedRNumber{Float64}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}})::Any\n └── return %4\n "

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions