Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation Request: Standardize the handling of the state st #515

Closed
schlichtanders opened this issue Feb 28, 2024 · 0 comments · Fixed by #553
Closed

Documentation Request: Standardize the handling of the state st #515

schlichtanders opened this issue Feb 28, 2024 · 0 comments · Fixed by #553
Labels
documentation Improvements or additions to documentation

Comments

@schlichtanders
Copy link

schlichtanders commented Feb 28, 2024

This is the most confusing for me when using Lux. There are multiple ways how to handle the state:

  • in a function local variable
  • in a let block
  • with a global variable
  • ADDED: With a Ref
  • not at all using first, [1], etc.

I am aware that in certain scenarios like NeuralODEs, it is super important that the state does not change from iteration to iteration, as this would change the analytical solution of the NeuralODE and hence invalidate the algorithms.

Still I want to suggest to unify the passing of the state everywhere to use a common pattern. And have a warning note in case it is ignored, saying why the state needs to be ignored in these cases.

I also read that the state can change types inbetween runs, which probably makes it hard to have a generic way of handling the state in a compiler friendly way, especially when interacting with Optimization.jl or Turing.jl. Maybe Lux.jl needs a dedicated Ref equivalent which is better than a Ref{Any} but still allows for the type changes of a state.

I am thinking about something like

rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = setup(rng, model)
function getloss(st) =
    rst = Ref(st) 
    function loss(X, y, ps)
        ypred, rst[] = model(X, ps, rst[])
        sum(abs2, y .- ypred)
    end
    rst, loss
end
rst, loss = getloss(st)
# Now use `loss` (or predict, etc.) inside Optimization, Turing, ...

I guess the Experimental Training state is/was supposed to fill such a need. I think before having an additional API, it would be great to standardize the handling of state st similar to above.

EDIT: Instead of using the same style everywhere, a dedicated Documentation for this topic inside Getting Started could also be very helpful.

@avik-pal avik-pal added the documentation Improvements or additions to documentation label Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants