New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Call configure_model
when loading via LightningModule.load_from_checkpoint
#18102
Comments
Just ran into this. Are there any good workarounds y'all can recommend for now (or ways I can help push this along?) |
This should be possible if |
Our idea here was to call The only workaround I know currently is to override The memory-mapping is only half of the story but yes it can help load large checkpoints in a more memory efficient way in the future. The other half is the instantiation of the model. Until now, the intention of |
Hmmm that sounds right to me too, what's your concern?
Ah, this is clever. But you'd have to manually use the meta context too right (as iiuc |
errr sorry the |
@MF-FOOM
|
gotcha. if i was to try to make a PR to call model.configure_model() inside load_from_checkpoint before the .load_state_dict() call, should I just use |
My concern is that The proposal here would only address the issue where loading a checkpoint requires all layers to be defined, and so calling We will probably do what is proposed in the issue (?) here but I'm saying we probably need to think about a more seamless loading experience here across all the different ways user want to use the Trainer with pretrained checkpoints. |
Yes, we can't do this. My proposal in the issue was to just call the method directly. |
Description & Motivation
A LightningModule might have layer definitions in
configure_model
(previously namedconfigure_sharded_model
) which get created before training and then all parameters get saved to the checkpoint. However, if you load like this:in a new script, it will fail becaues the model has missing parameters (
configure_model
wasn't called yet).Pitch
Call
model.configure_model()
directly after instantiating the model in.load_from_checkpoint()
.Notes:
model.configure_model()
will be called again if you latertrainer.fit()
ortrainer.test()
..load_from_checkpoint
.Alternatives
Keep as is, but users will get keys missing errors, not being able to load their checkpoints conveniently.
Additional context
No response
cc @Borda @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: