Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Introduce .fit(ckpt_path="last") #11912

Closed
carmocca opened this issue Feb 14, 2022 · 3 comments · Fixed by #12816
Closed

[RFC] Introduce .fit(ckpt_path="last") #11912

carmocca opened this issue Feb 14, 2022 · 3 comments · Fixed by #12816
Assignees
Labels
checkpointing Related to checkpointing design Includes a design discussion fault tolerance feature Is an improvement or enhancement
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Feb 14, 2022

🚀 Feature

Add support for passing just "last" to trainer.{fit,validate,test,predict}(ckpt_path=...)

This would pick the latest model checkpoint saved from the set of:

  • "top-k" checkpoints
  • "last" checkpoints
  • "fault-tolerant" checkpoints.

The tracking logic would most likely be inside the checkpoint connector.

Also, just for trainer.fit, ckpt_path should default to "last" when fault-tolerance is enabled.

Motivation

Until now, we've recommended passing trainer.model_checkpoint.last_model_path, but with fault-tolerance enabled, the last model path might be one generated by fault-tolerance.

Fault-tolerance logic is meant to be hidden from the users so we don't want them to track this.

Alternatives

The best alternative (after #11862) would be:

ft_checkpoints = [cb for cb in self.callbacks if isinstance(cb, _FaultToleranceCheckpoint)]
if ft_checkpoints:
    last_ckpt_path = ft_checkpoints[0].ckpt_path
else:
    last_ckpt_path = trainer.model_checkpoint.last_model_path
trainer.fit(model, ckpt_path=last_ckpt_path)

but it does not consider that the model_checkpoint.last_model_path could have been saved after.

We should track this for the user.

Additional context

Proposed by @tchaton


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @ananthsub @ninginthecloud @rohitgr7 @carmocca

@carmocca carmocca added feature Is an improvement or enhancement design Includes a design discussion checkpointing Related to checkpointing fault tolerance labels Feb 14, 2022
@ananthsub
Copy link
Contributor

Some n00b questions:

  • Should Lightning use "last.ckpt" as the default path for both the model checkpoint and the fault tolerant checkpoint path? This way we don't have to even distinguish these.
  • Why is "last.ckpt" customizable by the end user? Doesn't that that make distinguishing the most recently saved checkpoint harder for the framework?
  • Does save_last still make sense on the ModelCheckpoint callback if the fault tolerant feature is enabled & rolled out? Is there redundancy if both are enabled?
  • What happens if multiple model checkpoint callbacks are configured with save_last ? From which callback do we pick last_model_path from ?

@tchaton
Copy link
Contributor

tchaton commented Feb 15, 2022

Hey @ananthsub .

Should Lightning use "last.ckpt" as the default path for both the model checkpoint and the fault tolerant checkpoint path? This way we don't have to even distinguish these.

Yes, this was my first idea. However, fault tolerant can be harder to get right when an epoch end checkpoint is quite stable. It is why we originally decided to keep them apart.

Why is "last.ckpt" customizable by the end user? Doesn't that that make distinguishing the most recently saved checkpoint harder for the framework?

Not 100 % sure to understand your question. However, the last checkpoint ideally should be the one with the most recent timestamp.

Does save_last still make sense on the ModelCheckpoint callback if the fault tolerant feature is enabled & rolled out? Is there redundancy if both are enabled?

Yes, it still makes sense. The concept of last doesn't change. When fault tolerant is triggered, its checkpoint is actually latest in date to be created. But when restarting, as soon the new ModelCheckpoint creates a new checkpoint, latest become this one.

What happens if multiple model checkpoint callbacks are configured with save_last ? From which callback do we pick last_model_path from ?

I think checkpoint should be timestamped and the latest should be used.

@carmocca
Copy link
Contributor Author

carmocca commented Feb 15, 2022

I'll answer too, adding to Thomas' comments:

Should Lightning use "last.ckpt" as the default path

No because the fault-tolerance checkpoint is saved on exception so we cannot really guarantee that it will be usable on reload, on the other hand, "last" checkpoints are saved normally.

Why is "last.ckpt" customizable by the end user? Doesn't that that make distinguishing the most recently saved checkpoint harder for the framework?

Because it's quite common to want to append the epoch number in case you start multiple runs using the same directory.

It would be harder if we go off the checkpoint name, but we should either use their timestamps or "listen" on the save_checkpoint procedure

Does save_last still make sense on the ModelCheckpoint callback if the fault tolerant feature is enabled & rolled out? Is there redundancy if both are enabled?

I don't think it's redundant for the reasons above.

What happens if multiple model checkpoint callbacks are configured with save_last ? From which callback do we pick last_model_path from ?

Currently, they are overwritten. There's a proposal to avoid this in #5030

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing design Includes a design discussion fault tolerance feature Is an improvement or enhancement
Projects
No open projects
Status: Done
Status: Accepted
4 participants