Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save/load/resume from checkpoint for DeepSpeed Plugin #8397

Merged
merged 75 commits into from Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
24a3e50
wip
Jul 7, 2021
03a8769
Change trainer loading behaviour for validate/test/predict
Jul 9, 2021
a943e33
Fix
Jul 9, 2021
40a3446
Fix/add tests
Jul 9, 2021
8c24ffd
remove
Jul 9, 2021
1879be7
Cleanups
Jul 12, 2021
3162ff7
Space
Jul 12, 2021
6dd61d6
cleanups
Jul 12, 2021
5772e17
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
b072868
Add CHANGELOG.md
Jul 12, 2021
de2738d
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
6910e39
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
bf5afe3
Fix
Jul 12, 2021
f2ee8b5
Move after setup
Jul 12, 2021
b7c24d9
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
8659426
Cleanups on logic
Jul 12, 2021
84d20f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9e367fd
Remve
Jul 12, 2021
6ea8b44
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
3f8c3d3
Remve
Jul 12, 2021
b8ffc39
fix test
Jul 12, 2021
b02f35b
feedback
Jul 12, 2021
dbb03af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
1c7b9a1
Update pytorch_lightning/trainer/properties.py
Jul 12, 2021
444fb55
Feedback
Jul 12, 2021
4632bba
Same fix
Jul 12, 2021
e92b757
Same fix
Jul 12, 2021
66bea8e
Add test for behaviour, modify based on feedback
Jul 12, 2021
0139a19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
d48d916
Wording
Jul 12, 2021
100d73b
Apply suggestions from code review
Jul 12, 2021
f3f92a5
Cleanup docs
Jul 12, 2021
2849d0b
Update pytorch_lightning/trainer/trainer.py
Jul 12, 2021
f53c896
feedback
Jul 12, 2021
ebc713b
Fixes to test API
Jul 12, 2021
76e22c2
Add carlos description
Jul 12, 2021
9a62650
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
0b46226
Fixes
Jul 13, 2021
7a85b44
Merge branch 'master' into fix/ds_saving_2
Jul 13, 2021
8042fb4
Changes
Jul 13, 2021
203fd49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2021
8d0f260
Try delaying
Jul 14, 2021
d4e2295
Merge branch 'master' into fix/ds_saving_2
Jul 14, 2021
28d7575
Fixes
Jul 27, 2021
32f73e4
Merge branch 'master' into fix/ds_saving_2
Jul 27, 2021
a3c6009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2021
857a6aa
Merge branch 'master' into fix/ds_saving_2
Jul 27, 2021
e6c3bd1
Merge branch 'master' into fix/ds_saving_2
Jul 28, 2021
c51033a
fixes
Jul 28, 2021
4f5bd96
Add extra condition
Jul 28, 2021
e1fb2f0
Fix
Jul 28, 2021
77036a2
Fix
Jul 28, 2021
82e00be
Attempt to fix tests
Jul 28, 2021
57355aa
Add guard
Jul 28, 2021
3fc8f67
Fix test
Jul 29, 2021
6adb83d
Fix
Jul 29, 2021
607aef2
Add test
Jul 29, 2021
0c30656
Update pytorch_lightning/plugins/training_type/deepspeed.py
Jul 29, 2021
c9849e0
Fix description
Jul 29, 2021
0d3866c
Add test
Jul 29, 2021
fd7a168
Fix test
Jul 29, 2021
256b145
Refactors
Jul 29, 2021
c189595
add recursive
Jul 29, 2021
670810f
Merge branch 'master' into fix/ds_saving_2
Jul 30, 2021
64a4eba
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
0d2ec03
Fix dupe
Aug 2, 2021
ef33d90
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
5329c48
Force 0.4.3
Aug 2, 2021
95d1287
Address reviews
Aug 2, 2021
88ab306
Add todo
Aug 2, 2021
a15cd8d
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
Aug 2, 2021
9365cc0
Apply suggestions from code review
Aug 2, 2021
5f994c4
Add asserts for properties, address reviews
Aug 2, 2021
cdf8c25
Fix description
Aug 2, 2021
c47abf2
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))


SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
### Changed


Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/test_set.rst
Expand Up @@ -23,7 +23,7 @@ To run the test set after training completes, use this method.
trainer.test()

# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)
trainer.test(model)

# (3) test using a specific checkpoint
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
Expand Down
117 changes: 79 additions & 38 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Expand Up @@ -20,19 +20,21 @@
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union

import torch
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn

if _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down Expand Up @@ -631,9 +633,6 @@ def _create_default_config(
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

def _filepath_to_dir(self, filepath: str) -> str:
return os.path.dirname(filepath)

@property
def deepspeed_engine(self):
return self.model
Expand All @@ -645,55 +644,97 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
"""
if self.world_size > 1 and self.zero_stage_3:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if self.save_full_weights:
# todo: expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
checkpoint['state_dict'] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return

# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
else:
super().save_checkpoint(checkpoint, filepath)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
if self.save_full_weights or self.world_size == 1:
if self.save_full_weights and self.zero_stage_3:
# todo (sean): expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
checkpoint['state_dict'] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return

# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.save_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
return super().load_checkpoint_file(checkpoint_path)

# Rely on deepspeed to load the checkpoint and necessary information
# Rely on deepspeed completely to load the checkpoint and necessary information
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(checkpoint_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)
return client_state

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
pass
if self.save_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)

def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
Overrides the normal load_state_dict behaviour in PyTorch to ensure
we gather parameters that may be sharded across processes before loading
the state dictionary when using ZeRO stage 3.
This is then automatically synced across processes.
Args:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
ckpt: The ckpt file.
"""

def load(module: torch.nn.Module, prefix=""):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

missing_keys = []
unexpected_keys = []
error_msgs = []
state_dict = ckpt['state_dict']

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if self.is_global_zero:
module._load_from_state_dict(
state_dict=state_dict,
prefix=prefix,
local_metadata=local_metadata,
strict=True,
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
error_msgs=error_msgs
)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(self.lightning_module, prefix="")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
pass
if self.save_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
rank_zero_warn(
"A single checkpoint file was saved using ZeRO Stage 3. This means optimizer states and "
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"scheduler states can not be restored. If you'd like to restore these states, you must"
"set save_full_weights=False, i.e Trainer(plugins=DeepSpeedPlugin(save_full_weights=False)) "
"when training the model initially."
)

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
if self._original_accumulate_grad_batches is None:
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/properties.py
Expand Up @@ -68,6 +68,11 @@ class TrainerProperties(ABC):
validate_loop: EvaluationLoop
test_loop: EvaluationLoop
predict_loop: PredictionLoop

# .validate() and .test() set this when they load a checkpoint
validated_ckpt_path: Optional[str] = None
tested_ckpt_path: Optional[str] = None
predicted_ckpt_path: Optional[str] = None
"""
Accelerator properties
"""
Expand Down Expand Up @@ -548,6 +553,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down