Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs note about saving/loading models with anonymous functions #2263

Open
darsnack opened this issue Jun 9, 2023 · 9 comments
Open

Add docs note about saving/loading models with anonymous functions #2263

darsnack opened this issue Jun 9, 2023 · 9 comments

Comments

@darsnack
Copy link
Member

darsnack commented Jun 9, 2023

The new save/load docs promote JLD2.jl which does not support saving/loading anonymous functions reliably. This most commonly occurs for activation functions. The solution is to use Flux.state + Flux.loadmodel! and set the desired anonymous function in the destination of loadmodel!. This avoids needing the serialization library to correctly handle the anonymous function.

This will be problematic when the anonymous function contains data (state) that actually should be restored. A possible solution here is to make the closure an explicit struct. Maybe there are better solutions.

Regardless, new users are unlikely to realize these edge cases. We should expand the saving/loading documentation to explain how to handle these cases with code examples.

@theabhirath
Copy link
Member

Given that activation functions have always been handled weirdly, despite the fact that it is a little against the Flux style, maybe it might not be a bad idea to have an Activation layer that just does this explicitly. We've had this discussion before, though (FluxML/NNlib.jl#423 (comment) and FluxML/NNlib.jl#423 (comment)), so I thought it may just be time to do this.

@darsnack
Copy link
Member Author

An Activation layer won't help if it wraps an anonymous function. It's a wrapper so it just pushes the issue one node deeper in the tree.

This kind of solution is both cleaner and correct by just naming the function (e.g. myact(x) = ...). If you are closing over some data that needs to be serialized, then define a callable.

@ToucheSir
Copy link
Member

I wonder if we could create a helper function which searches the model for these closures and warns the user if it finds them?

@tom-plaa
Copy link

Might my issue at #2339 be related to this? It contains anonymous functions that slice the input arrays, like x->x[begin:inputpoints, 1, :] for example. How would one go around correctly saving a model like this according to your advice?

@ToucheSir
Copy link
Member

Just as mentioned up top: extract the parameters with Flux.state and only save those. I suspect we'll be scrubbing any examples that use BSON to save the whole model from the docs soon because it's just too error-prone.

@tom-plaa
Copy link

After checking the docs, this implies that the model definition must be available in the session, right? Is it necessary to create a custom struct and apply the Flux.@functor macro to it before saving (like in the docs)? Must we also repeat these steps before loading it (creating the same struct and applying the macro)? I'm saying this because of this line in the docs:
model = MyModel(); # MyModel definition must be available

@ToucheSir
Copy link
Member

state strips out any custom container types and gives you a tree of plain old Julia objects (tuples, namedtuples, arrays) which should be easier to save and mostly immune to type-related breakages down the line. It is necessary to have any layer types with parameters/non-trainable state support functor for it to work, but you'll need those declarations anyhow because loadmodel! takes in an already constructed model to stuff the aforementioned tree of plain old Julia objects back into.

@tom-plaa
Copy link

tom-plaa commented Sep 20, 2023

Thank you, I managed to make it work with the loadmodel! function. I will update my other issue accordingly. It might be clearer to expand this on the documentation to mention that you need to rebuild the custom struct all over again and apply the functor macro when loading as well.

@ToucheSir
Copy link
Member

The reason we don't mention that in the docs is the same reason PyTorch doesn't mention that you need to define all the layer types for a model before calling model.load_state_dict(...): if you have model already to load into, that means all of the custom layer structs, @functor definitions, etc must already be present! That said, this issue exists in the first place because the saving and loading docs could use some work, so any suggestions (ideally in the form of PRs) is appreciated :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants