Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call configure_model when loading via LightningModule.load_from_checkpoint #18102

Closed
awaelchli opened this issue Jul 17, 2023 · 9 comments · Fixed by #19036
Closed

Call configure_model when loading via LightningModule.load_from_checkpoint #18102

awaelchli opened this issue Jul 17, 2023 · 9 comments · Fixed by #19036
Assignees
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement strategy: deepspeed strategy: fsdp Fully Sharded Data Parallel
Milestone

Comments

@awaelchli
Copy link
Member

awaelchli commented Jul 17, 2023

Description & Motivation

A LightningModule might have layer definitions in configure_model (previously named configure_sharded_model) which get created before training and then all parameters get saved to the checkpoint. However, if you load like this:

model = MyModel.load_from_checkpoint("path")

in a new script, it will fail becaues the model has missing parameters (configure_model wasn't called yet).

Pitch

Call model.configure_model() directly after instantiating the model in .load_from_checkpoint().

Notes:

  • It is important that the user makes the hook idempotent. The model.configure_model() will be called again if you later trainer.fit() or trainer.test().
  • This won't help if your model is so large that the checkpoint or model doesn't fit in ram. For supermassive models, you shouldn't load the weights using .load_from_checkpoint.

Alternatives

Keep as is, but users will get keys missing errors, not being able to load their checkpoints conveniently.

Additional context

No response

cc @Borda @awaelchli @carmocca

@awaelchli awaelchli added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers checkpointing Related to checkpointing strategy: deepspeed strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Jul 17, 2023
@awaelchli awaelchli self-assigned this Jul 17, 2023
@awaelchli awaelchli added this to the 2.1 milestone Jul 17, 2023
@MF-FOOM
Copy link
Contributor

MF-FOOM commented Sep 26, 2023

Just ran into this. Are there any good workarounds y'all can recommend for now (or ways I can help push this along?)

@MF-FOOM
Copy link
Contributor

MF-FOOM commented Sep 26, 2023

For supermassive models, you shouldn't load the weights using .load_from_checkpoint.

This should be possible if load_from_checkpoint just supported using torch.load(... mmap=True) no?

@awaelchli
Copy link
Member Author

awaelchli commented Sep 26, 2023

Our idea here was to call model.configure_model() inside the load_from_checkpoint function before the .load_state_dict() call. We're not sure if this is the smartest but we currently have no better idea. It's not hard to implement.

The only workaround I know currently is to override load_state_dict() inside your LightningModule to call self.configure_model and assume it is idempotent.

The memory-mapping is only half of the story but yes it can help load large checkpoints in a more memory efficient way in the future. The other half is the instantiation of the model. Until now, the intention of load_from_checkpoint was always to fully instantiate and load the model in memory, e.g., for inference, but it is not practical for massive models.

@MF-FOOM
Copy link
Contributor

MF-FOOM commented Sep 26, 2023

We're not sure if this is the smartest but we currently have no better idea. It's not hard to implement.

Hmmm that sounds right to me too, what's your concern?

The only workaround I know currently is to override load_state_dict() inside your LightningModule to call self.configure_model and assume it is idempotent.

Ah, this is clever. But you'd have to manually use the meta context too right (as iiuc configure_model assumes it's being called within an empty init context)?

@MF-FOOM
Copy link
Contributor

MF-FOOM commented Sep 26, 2023

errr sorry the init_module context? i'm actually a bit confused which context manager gets used

@carmocca
Copy link
Member

carmocca commented Sep 26, 2023

@MF-FOOM configure_model is just a hook run under these context managers https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/call.py#L108C9-L108C95
and is equivalent to fabric.init_module().

trainer.init_module() doesn't use the same context managers https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/trainer.py#L1115 because processes have not been launched at the time users will call it

@MF-FOOM
Copy link
Contributor

MF-FOOM commented Sep 26, 2023

@MF-FOOM configure_model is just a hook run under these context managers https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/call.py#L108C9-L108C95 and is equivalent to fabric.init_module().

trainer.init_module() doesn't use the same context managers https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/trainer.py#L1115 because processes have not been launched at the time users will call it

gotcha. if i was to try to make a PR to call model.configure_model() inside load_from_checkpoint before the .load_state_dict() call, should I just use fabric.init_module? because the _call_configure_model implementation depends on accessing the trainer's strategy, which we won't have in load_from_checkpoint no?

@awaelchli
Copy link
Member Author

Hmmm that sounds right to me too, what's your concern?

My concern is that load_from_checkpoint isn't the right tool for the job. Loading these large models efficiently and correctly would only be possible under certain conditions that most likely aren't fulfilled at the time the user calls this function (processes launched?, distributed initialized?), and at the same time load_from_checkpoint would have to handle things that normally the Trainer/strategy would do (meta device initialization?, model sharding?). All of this currently doesn't really fit into this function IMO.

The proposal here would only address the issue where loading a checkpoint requires all layers to be defined, and so calling configure_model() would solve that issue. That won't work in some (many?) cases though because a) the model may not fit in memory as a whole and 2) the user may do manual wrapping with fsdp.wrap() which requires the distributed backend to be initialized.

We will probably do what is proposed in the issue (?) here but I'm saying we probably need to think about a more seamless loading experience here across all the different ways user want to use the Trainer with pretrained checkpoints.

@awaelchli
Copy link
Member Author

because the _call_configure_model implementation depends on accessing the trainer's strategy, which we won't have in load_from_checkpoint no

Yes, we can't do this. My proposal in the issue was to just call the method directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement strategy: deepspeed strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants