-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
split restore_training_state
into logical parts [1 / 2]
#7901
Changes from 4 commits
10259d7
37a1d51
8a7e3d1
d93394e
8d32cb5
ee0867a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,6 +207,92 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) | |
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): | ||
scheduler['scheduler'].load_state_dict(lrs_state) | ||
|
||
def restore_callbacks(self) -> None: | ||
""" Restores all callbacks from the pre-loaded checkpoint. """ | ||
if not self._loaded_checkpoint: | ||
return | ||
|
||
if any([key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"The checkpoint you're attempting to load follows an" | ||
" outdated schema. You can upgrade to the current schema by running" | ||
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" | ||
" where `model.ckpt` is your checkpoint file." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this validation be done in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question. maybe we could. |
||
self.trainer.on_load_checkpoint(self._loaded_checkpoint) | ||
|
||
def restore_progress(self) -> None: | ||
""" | ||
Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step | ||
and current epoch. | ||
""" | ||
if not self._loaded_checkpoint: | ||
return | ||
|
||
self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step'] | ||
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch'] | ||
|
||
# crash if max_epochs is lower then the current epoch from the checkpoint | ||
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: | ||
raise MisconfigurationException( | ||
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch}," | ||
f" but the Trainer(max_epochs={self.trainer.max_epochs})" | ||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# Division deals with global step stepping once per accumulated batch | ||
# Inequality deals with different global step for odd vs even num_training_batches | ||
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches | ||
expected_steps = self.trainer.num_training_batches / n_accum | ||
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: | ||
rank_zero_warn( | ||
"You're resuming from a checkpoint that ended mid-epoch." | ||
" Training will start from the beginning of the next epoch." | ||
" This can cause unreliable results if further training is done," | ||
" consider using an end of epoch checkpoint." | ||
) | ||
|
||
def restore_optimizers_and_schedulers(self) -> None: | ||
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ | ||
if not self._load_optimizer_states or not self._loaded_checkpoint: | ||
return | ||
|
||
# validation | ||
if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint: | ||
raise KeyError( | ||
"Trying to restore training state but checkpoint contains only the model." | ||
" This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`." | ||
) | ||
Comment on lines
+260
to
+264
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do this chain of PRs plan to tackle the issue of restoring part of the checkpoint? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not the aim directly , but it will definitely help, and we can directly continue with it after these PRs. There will be nothing standing in the way as far as I can tell :) |
||
self.restore_optimizers() | ||
self.restore_lr_schedulers() | ||
|
||
def restore_optimizers(self) -> None: | ||
""" Restores the optimizer states from the pre-loaded checkpoint. """ | ||
if not self._load_optimizer_states or not self._loaded_checkpoint: | ||
return | ||
|
||
# restore the optimizers | ||
optimizer_states = self._loaded_checkpoint['optimizer_states'] | ||
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): | ||
optimizer.load_state_dict(opt_state) | ||
|
||
# move optimizer to GPU 1 weight at a time | ||
# avoids OOM | ||
if self.trainer.root_gpu is not None: | ||
for state in optimizer.state.values(): | ||
for k, v in state.items(): | ||
if isinstance(v, torch.Tensor): | ||
state[k] = v.cuda(self.trainer.root_gpu) | ||
|
||
def restore_lr_schedulers(self) -> None: | ||
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """ | ||
if not self._load_optimizer_states or not self._loaded_checkpoint: | ||
return | ||
|
||
# restore the lr schedulers | ||
lr_schedulers = self._loaded_checkpoint['lr_schedulers'] | ||
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): | ||
scheduler['scheduler'].load_state_dict(lrs_state) | ||
|
||
# ---------------------------------- | ||
# PRIVATE OPS | ||
# ---------------------------------- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few questions:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, good question.
If you look at the "end result" #7652 (open to discussion) you will see here in the Trainer file resume_start() and resume_end() are actually in very different places, so I can't make it into a context manager.
Yes, I think it's best to document the order. The order may be important. In the future, we will want to enable configuration of what is restored, so some of these functions get called on demand and some won't be called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, maybe a context manager could still work. I will investigate it in #7652
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Ananth's suggestion is good.
Also, could any other class want to call the
start
andend
methods?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think we would only want to call it for unit testing, or the context manager if that works.
So I know what you are saying, yes I will put the underscores everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, ctx manager could kinda work but I see an issue. can we move the conversation to #7652 so I can directly point to the code in trainer.py?