Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split restore_training_state into logical parts [1 / 2] #7901

Merged
merged 6 commits into from
Jun 10, 2021
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,92 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

def restore_callbacks(self) -> None:
""" Restores all callbacks from the pre-loaded checkpoint. """
if not self._loaded_checkpoint:
return
Comment on lines +211 to +213
Copy link
Contributor

@ananthsub ananthsub Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions:

  • is there a risk of these restoration functions being called outside this context? should the start and end restore from checkpoint be replaced with a dedicated context manager?
  • in splitting these out, should we be prescriptive about the order they are loaded?

Copy link
Member Author

@awaelchli awaelchli Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, good question.
If you look at the "end result" #7652 (open to discussion) you will see here in the Trainer file resume_start() and resume_end() are actually in very different places, so I can't make it into a context manager.

Yes, I think it's best to document the order. The order may be important. In the future, we will want to enable configuration of what is restored, so some of these functions get called on demand and some won't be called.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, maybe a context manager could still work. I will investigate it in #7652

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Ananth's suggestion is good.

Also, could any other class want to call the start and end methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think we would only want to call it for unit testing, or the context manager if that works.
So I know what you are saying, yes I will put the underscores everywhere

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, ctx manager could kinda work but I see an issue. can we move the conversation to #7652 so I can directly point to the code in trainer.py?


if any([key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this validation be done in resume_start ?

Copy link
Member Author

@awaelchli awaelchli Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. maybe we could.
one thought though, in the future we will have a way to configure what to load, so if these functions get called individually we may want to have the validation together with the particular objects that are being restored.

self.trainer.on_load_checkpoint(self._loaded_checkpoint)

def restore_progress(self) -> None:
"""
Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step
and current epoch.
"""
if not self._loaded_checkpoint:
return

self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step']
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
raise MisconfigurationException(
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
f" but the Trainer(max_epochs={self.trainer.max_epochs})"
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
)

# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
)

def restore_optimizers_and_schedulers(self) -> None:
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# validation
if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint:
raise KeyError(
"Trying to restore training state but checkpoint contains only the model."
" This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
)
Comment on lines +260 to +264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this chain of PRs plan to tackle the issue of restoring part of the checkpoint?

#5339

Copy link
Member Author

@awaelchli awaelchli Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the aim directly , but it will definitely help, and we can directly continue with it after these PRs. There will be nothing standing in the way as far as I can tell :)

self.restore_optimizers()
self.restore_lr_schedulers()

def restore_optimizers(self) -> None:
""" Restores the optimizer states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# restore the optimizers
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)

def restore_lr_schedulers(self) -> None:
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# restore the lr schedulers
lr_schedulers = self._loaded_checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

# ----------------------------------
# PRIVATE OPS
# ----------------------------------
Expand Down