You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is the most confusing for me when using Lux. There are multiple ways how to handle the state:
in a function local variable
in a let block
with a global variable
ADDED: With a Ref
not at all using first, [1], etc.
I am aware that in certain scenarios like NeuralODEs, it is super important that the state does not change from iteration to iteration, as this would change the analytical solution of the NeuralODE and hence invalidate the algorithms.
Still I want to suggest to unify the passing of the state everywhere to use a common pattern. And have a warning note in case it is ignored, saying why the state needs to be ignored in these cases.
I also read that the state can change types inbetween runs, which probably makes it hard to have a generic way of handling the state in a compiler friendly way, especially when interacting with Optimization.jl or Turing.jl. Maybe Lux.jl needs a dedicated Ref equivalent which is better than a Ref{Any} but still allows for the type changes of a state.
I am thinking about something like
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st =setup(rng, model)
functiongetloss(st) =
rst =Ref(st)
functionloss(X, y, ps)
ypred, rst[] =model(X, ps, rst[])
sum(abs2, y .- ypred)
end
rst, loss
end
rst, loss =getloss(st)
# Now use `loss` (or predict, etc.) inside Optimization, Turing, ...
I guess the Experimental Training state is/was supposed to fill such a need. I think before having an additional API, it would be great to standardize the handling of state st similar to above.
EDIT: Instead of using the same style everywhere, a dedicated Documentation for this topic inside Getting Started could also be very helpful.
The text was updated successfully, but these errors were encountered:
This is the most confusing for me when using Lux. There are multiple ways how to handle the state:
function
local variablelet
blockglobal
variableRef
first
,[1]
, etc.I am aware that in certain scenarios like NeuralODEs, it is super important that the state does not change from iteration to iteration, as this would change the analytical solution of the NeuralODE and hence invalidate the algorithms.
Still I want to suggest to unify the passing of the state everywhere to use a common pattern. And have a warning note in case it is ignored, saying why the state needs to be ignored in these cases.
I also read that the state can change types inbetween runs, which probably makes it hard to have a generic way of handling the state in a compiler friendly way, especially when interacting with Optimization.jl or Turing.jl. Maybe Lux.jl needs a dedicated Ref equivalent which is better than a
Ref{Any}
but still allows for the type changes of a state.I am thinking about something like
I guess the Experimental Training state is/was supposed to fill such a need. I think before having an additional API, it would be great to standardize the handling of state
st
similar to above.EDIT: Instead of using the same style everywhere, a dedicated Documentation for this topic inside
Getting Started
could also be very helpful.The text was updated successfully, but these errors were encountered: