From 24a3e50a0d019de916457cb1b5627eb8f0bdeab2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Jul 2021 11:54:37 +0100 Subject: [PATCH 01/60] wip --- .../plugins/training_type/deepspeed.py | 53 ++++++++++--------- tests/plugins/test_deepspeed_plugin.py | 35 ++++++------ 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 4d229e4bff43a..868de2b4f2a1a 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -27,12 +27,13 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning +from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn if _DEEPSPEED_AVAILABLE: import deepspeed @@ -647,29 +648,26 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: checkpoint: The checkpoint state dictionary filepath: write-target file's path """ - if self.world_size > 1 and self.zero_stage_3: - if self.save_full_weights: - # todo: expose this as general function in deepspeed - state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict() - if self.is_global_zero: - # State dict keys will include reference to wrapper LightningDeepSpeedModule - # Delete `module` prefix before saving. - state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} - checkpoint['state_dict'] = state_dict - return super().save_checkpoint(checkpoint, filepath) - return - - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes - # dump states as a checkpoint dictionary object - save_dir = self._filepath_to_dir(filepath) - _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers'] - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint) - else: - super().save_checkpoint(checkpoint, filepath) + if self.save_full_weights and self.zero_stage_3: + # todo (sean): expose this as general function in deepspeed + state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict() + if self.is_global_zero: + # State dict keys will include reference to wrapper LightningDeepSpeedModule + # Delete `module` prefix before saving. + state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} + checkpoint['state_dict'] = state_dict + return super().save_checkpoint(checkpoint, filepath) + return + + # Use deepspeed's internal checkpointing function to handle partitioned weights across processes + # dump states as a checkpoint dictionary object + save_dir = self._filepath_to_dir(filepath) + _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers'] + checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} + self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint) def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - if self.save_full_weights or self.world_size == 1: + if self.save_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) @@ -691,11 +689,18 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` - pass + if self.save_full_weights and self.zero_stage_3: + self.lightning_module.load_state_dict(checkpoint['state_dict']) def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` - pass + if self.save_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + rank_zero_warn( + "A single checkpoint file was saved using ZeRO Stage 3. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must" + "set save_full_weights=False, i.e Trainer(plugins=DeepSpeedPlugin(save_full_weights=False)) " + "when training the model initially." + ) def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: if self._original_accumulate_grad_batches is None: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index dcb4ff00b219b..f9fba56bb4849 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -507,6 +507,10 @@ def configure_optimizers(self): 'interval': 'step', }] + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if not hasattr(self, 'model'): + self.configure_sharded_model() + class ManualModelParallelClassificationModel(ModelParallelClassificationModel): @@ -589,27 +593,24 @@ def run_checkpoint_test( results = trainer.test(model, datamodule=dm) assert results[0]['test_acc'] > 0.7 - saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]['test_acc'] > 0.7 assert saved_results == results - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=10, - plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], - gpus=2, - precision=16, - accumulate_grad_batches=2, - callbacks=[ck], - resume_from_checkpoint=ck.best_model_path - ) - results = trainer.test(model, datamodule=dm) - assert results[0]['test_acc'] > 0.7 + if automatic_optimization: + model_cls = ModelParallelClassificationModel() + else: + model_cls = ManualModelParallelClassificationModel() + if trainer.is_global_zero: + trainer = Trainer(default_root_dir=tmpdir, gpus=1, precision=16) + saved_model = model_cls.load_from_checkpoint(ck.best_model_path) - dm.predict_dataloader = dm.test_dataloader - results = trainer.predict(datamodule=dm) - assert results[-1] > 0.7 + results = trainer.test(saved_model, datamodule=dm) + assert results[0]['test_acc'] > 0.7 + + dm.predict_dataloader = dm.test_dataloader + results = trainer.predict(datamodule=dm) + assert results[-1] > 0.7 @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -621,7 +622,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir, save_full_weights=False) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, special=False) def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): """ Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, From 03a87699a1d10fa5dbc664a280ab1e6ef7ef65d1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Jul 2021 14:23:49 +0100 Subject: [PATCH 02/60] Change trainer loading behaviour for validate/test/predict --- pytorch_lightning/trainer/trainer.py | 72 +++++++++++++++------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7475cd9c81326..970c375e03561 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -528,7 +528,7 @@ def validate( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, val_dataloaders=None, # noqa TODO: remove with 1.6 @@ -542,9 +542,9 @@ def validate( dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + ckpt_path: Path to the checkpoint you wish to use to validate. + If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. + If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. verbose: If True, prints the validation results. @@ -579,7 +579,6 @@ def validate( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -589,8 +588,7 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) # run validate results = self._run(model) @@ -604,7 +602,7 @@ def test( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, test_dataloaders=None, # noqa TODO: remove with 1.6 @@ -619,9 +617,9 @@ def test( dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + ckpt_path: Path to the checkpoint you wish to use to test. + If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. + If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. verbose: If True, prints the test results. @@ -654,7 +652,6 @@ def test( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -664,8 +661,7 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) # run test results = self._run(model) @@ -681,7 +677,7 @@ def predict( dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: r""" @@ -699,9 +695,9 @@ def predict( return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). - ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + ckpt_path: Path to the checkpoint you wish to use to predict. + If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. + If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. @@ -725,7 +721,6 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -735,8 +730,7 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) results = self._run(model) @@ -807,6 +801,15 @@ def tune( return result + @property + def ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -828,13 +831,16 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.connect(model) self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - # restore modules after setup self.checkpoint_connector.restore_datamodule() self.checkpoint_connector.restore_model() # restore callback states self.checkpoint_connector.restore_callbacks() + if self.ckpt_path: + rank_zero_info(f"Loading checkpoint from {self.ckpt_path}") + self.checkpoint_connector.restore_model_weights(self.ckpt_path) + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -1059,13 +1065,18 @@ def _run_sanity_check(self, ref_model): # restore the previous stage when the sanity check if finished self.state.stage = stage - def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: - if ckpt_path is None: + def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Optional[str]: + if model_provided and ckpt_path is None: return fn = self.state.fn.value - if ckpt_path == 'best': + if not model_provided and ckpt_path == 'best': + rank_zero_deprecation( + f'`.{fn}(ckpt_path="best")` was deprecated in v1.4 and will be removed in v1.6. Please use `.{fn}()`' + ) + + if not model_provided and ckpt_path is None: # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: @@ -1074,23 +1085,16 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' ) raise MisconfigurationException( - f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' + f'`.{fn}(ckpt_path=None)` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights ckpt_path = self.checkpoint_callback.best_model_path - if not ckpt_path: + if not ckpt_path and not model_provided: raise MisconfigurationException( f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - - self.checkpoint_connector.restore_model_weights(ckpt_path) return ckpt_path def _call_setup_hook(self, model: 'pl.LightningModule') -> None: From a943e33cd0d617d076fa0bee23d596cd5a11622f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Jul 2021 16:51:16 +0100 Subject: [PATCH 03/60] Fix --- pytorch_lightning/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 970c375e03561..d19b6d0bae759 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1073,8 +1073,10 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Opt if not model_provided and ckpt_path == 'best': rank_zero_deprecation( - f'`.{fn}(ckpt_path="best")` was deprecated in v1.4 and will be removed in v1.6. Please use `.{fn}()`' + f'`.{fn}(ckpt_path="best")` was deprecated in v1.4 and will be removed in v1.6. Do not provide ' + f'ckpt_path, i.e `.{fn}()` ' ) + ckpt_path = None if not model_provided and ckpt_path is None: # if user requests the best checkpoint but we don't have it, error From 40a3446fe5cc16e4a913063ac471e1fd0b2fb256 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Jul 2021 17:03:45 +0100 Subject: [PATCH 04/60] Fix/add tests --- pytorch_lightning/trainer/trainer.py | 9 ++++++--- tests/deprecated_api/test_remove_1-6.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d19b6d0bae759..163de94180de6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -579,6 +579,7 @@ def validate( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -588,7 +589,7 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) # run validate results = self._run(model) @@ -652,6 +653,7 @@ def test( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -661,7 +663,7 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) # run test results = self._run(model) @@ -721,6 +723,7 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -730,7 +733,7 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) results = self._run(model) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 69d2a45530607..fd3009cf20b31 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,6 +15,7 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin @@ -303,3 +304,13 @@ def test_v1_6_0_deprecated_disable_validation(): trainer = Trainer() with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"): _ = trainer.disable_validation + + +def test_v1_6_0_deprecated_ckpt_best(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + + for fn in (trainer.test, trainer.validate, trainer.predict): + with pytest.deprecated_call(match="deprecated in v1.4 and will be removed in v1.6. Do not provide"): + fn(ckpt_path='best') From 8c24ffd28436b065b5233795ba85a25ba44c022d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Jul 2021 17:06:44 +0100 Subject: [PATCH 05/60] remove --- tests/deprecated_api/test_remove_1-6.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index fd3009cf20b31..e1934169244bd 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,7 +15,6 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin From 1879be7eb9a82e547301c1feb2c0cb1cbb01414e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 10:27:11 +0100 Subject: [PATCH 06/60] Cleanups --- pytorch_lightning/trainer/trainer.py | 44 ++++++++++--------------- tests/deprecated_api/test_remove_1-6.py | 10 ------ tests/trainer/test_trainer.py | 3 ++ 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 163de94180de6..bcd96017fa06e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -528,7 +528,7 @@ def validate( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, val_dataloaders=None, # noqa TODO: remove with 1.6 @@ -542,9 +542,9 @@ def validate( dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - ckpt_path: Path to the checkpoint you wish to use to validate. - If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. - If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + When the model is given as argument, we load the ckpt path. verbose: If True, prints the validation results. @@ -579,7 +579,6 @@ def validate( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -589,7 +588,7 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) + self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) # run validate results = self._run(model) @@ -603,7 +602,7 @@ def test( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, test_dataloaders=None, # noqa TODO: remove with 1.6 @@ -618,9 +617,9 @@ def test( dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. - ckpt_path: Path to the checkpoint you wish to use to test. - If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. - If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + When the model is given as argument, we load the ckpt path. verbose: If True, prints the test results. @@ -653,7 +652,6 @@ def test( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -663,7 +661,7 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) + self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) # run test results = self._run(model) @@ -679,7 +677,7 @@ def predict( dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = None, + ckpt_path: Optional[str] = 'best', ) -> Optional[_PREDICT_OUTPUT]: r""" @@ -697,9 +695,9 @@ def predict( return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). - ckpt_path: Path to the checkpoint you wish to use to predict. - If ``None``, use the best weights based on the checkpoint callback if the model isn't provided. - If model is provided and ``ckpt_path`` is ``None``, this parameter does not apply. + ckpt_path: Either ``best`` or path to the checkpoint you wish to predict. + If ``None``, use the current weights of the model. + When the model is given as argument, we load the ckpt path. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. @@ -723,7 +721,6 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') - model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -733,7 +730,7 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided) + self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) results = self._run(model) @@ -1074,14 +1071,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Opt fn = self.state.fn.value - if not model_provided and ckpt_path == 'best': - rank_zero_deprecation( - f'`.{fn}(ckpt_path="best")` was deprecated in v1.4 and will be removed in v1.6. Do not provide ' - f'ckpt_path, i.e `.{fn}()` ' - ) - ckpt_path = None - - if not model_provided and ckpt_path is None: + if model_provided and ckpt_path is 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: @@ -1095,7 +1085,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Opt # load best weights ckpt_path = self.checkpoint_callback.best_model_path - if not ckpt_path and not model_provided: + if ckpt_path is None: raise MisconfigurationException( f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index e1934169244bd..69d2a45530607 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -303,13 +303,3 @@ def test_v1_6_0_deprecated_disable_validation(): trainer = Trainer() with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"): _ = trainer.disable_validation - - -def test_v1_6_0_deprecated_ckpt_best(tmpdir): - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - trainer.fit(model) - - for fn in (trainer.test, trainer.validate, trainer.predict): - with pytest.deprecated_call(match="deprecated in v1.4 and will be removed in v1.6. Do not provide"): - fn(ckpt_path='best') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4f2043d80c805..40a0334a6cdfc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -701,6 +701,9 @@ def predict_step(self, batch, *_): trainer_fn(ckpt_path=ckpt_path) assert getattr(trainer, path_attr) == ckpt_path + trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == ckpt_path + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" From 3162ff7e9e630f7798bc71492fd410de99fbddca Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 10:29:30 +0100 Subject: [PATCH 07/60] Space --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bcd96017fa06e..54346f2d797e1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -831,6 +831,7 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.connect(model) self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + # restore modules after setup self.checkpoint_connector.restore_datamodule() self.checkpoint_connector.restore_model() From 6dd61d6d75e9670433329b4c79f00c2f9116d1eb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 10:32:27 +0100 Subject: [PATCH 08/60] cleanups --- pytorch_lightning/trainer/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 54346f2d797e1..18f1650e3c123 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -839,6 +839,11 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.checkpoint_connector.restore_callbacks() if self.ckpt_path: + # only one process running at this point for TPUs, as spawn isn't triggered yet + # todo: move this logic internally within the barrier. + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() + rank_zero_info(f"Loading checkpoint from {self.ckpt_path}") self.checkpoint_connector.restore_model_weights(self.ckpt_path) @@ -1072,7 +1077,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Opt fn = self.state.fn.value - if model_provided and ckpt_path is 'best': + if model_provided and ckpt_path == 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: @@ -1081,12 +1086,12 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Opt f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' ) raise MisconfigurationException( - f'`.{fn}(ckpt_path=None)` is set but `ModelCheckpoint` is not configured to save the best model.' + f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights ckpt_path = self.checkpoint_callback.best_model_path - if ckpt_path is None: + if not ckpt_path: raise MisconfigurationException( f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' From b07286831eee550de4cb6388d166133f2f37d609 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 10:35:13 +0100 Subject: [PATCH 09/60] Add CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14326c781b118..e3e586e35af7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808)) +- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) + + ### Changed From bf5afe3707b1aa44544e4b2f2ffb1668555fbf58 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 11:04:46 +0100 Subject: [PATCH 10/60] Fix --- tests/plugins/test_deepspeed_plugin.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 5cac991b4baec..85fb3a4bfc86c 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -598,18 +598,17 @@ def run_checkpoint_test( assert saved_results == results if automatic_optimization: - model_cls = ModelParallelClassificationModel() + model = ModelParallelClassificationModel() else: - model_cls = ManualModelParallelClassificationModel() + model = ManualModelParallelClassificationModel() if trainer.is_global_zero: trainer = Trainer(default_root_dir=tmpdir, gpus=1, precision=16) - saved_model = model_cls.load_from_checkpoint(ck.best_model_path) - results = trainer.test(saved_model, datamodule=dm) + results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) assert results[0]['test_acc'] > 0.7 dm.predict_dataloader = dm.test_dataloader - results = trainer.predict(datamodule=dm) + results = trainer.predict(model, datamodule=dm, ckpt_path=ck.best_model_path) assert results[-1] > 0.7 From f2ee8b58391d58208e3ed58c6f2753582019a665 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 11:05:19 +0100 Subject: [PATCH 11/60] Move after setup --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a83b2c3edd2fa..39959e0d4ac25 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -838,6 +838,9 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # restore callback states self.checkpoint_connector.restore_callbacks() + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment + self.accelerator.setup(self, model) # note: this sets up self.lightning_module + if self.ckpt_path: # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. @@ -847,9 +850,6 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, rank_zero_info(f"Loading checkpoint from {self.ckpt_path}") self.checkpoint_connector.restore_model_weights(self.ckpt_path) - self._call_configure_sharded_model(model) # allow user to setup in model sharded environment - self.accelerator.setup(self, model) # note: this sets up self.lightning_module - # ---------------------------- # INSPECT THE CORE LOOPS # ---------------------------- From 86594263a0f1d9044104dbba3039ad52408c54df Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 11:59:37 +0100 Subject: [PATCH 12/60] Cleanups on logic --- pytorch_lightning/trainer/trainer.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39959e0d4ac25..78e5c3ace4cb4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -579,6 +579,7 @@ def validate( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -588,7 +589,7 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) # run validate results = self._run(model) @@ -652,6 +653,7 @@ def test( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -661,7 +663,7 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) # run test results = self._run(model) @@ -721,6 +723,7 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') + model_provided = model is not None model = model or self.lightning_module if model is None: raise MisconfigurationException( @@ -730,7 +733,7 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None) + self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) results = self._run(model) @@ -1084,13 +1087,22 @@ def _run_sanity_check(self, ref_model): # restore the previous stage when the sanity check if finished self.state.stage = stage - def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Optional[str]: - if model_provided and ckpt_path is None: - return + def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected:bool) -> Optional[str]: + """ + If a user passes the model, we want to use this regardless. + If a user passes the model, with best weights ckpt_path being something other than best or None, we load the weight + + """ + + if model_provided and (ckpt_path in ('best', None)): + return # use passed model to function without loading weights + + if model_connected and ckpt_path is None: + return # use connected model without loading weights fn = self.state.fn.value - if model_provided and ckpt_path == 'best': + if model_connected and ckpt_path == 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: From 84d20f5aca5230221ec6c981b393435a6d194263 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jul 2021 11:00:52 +0000 Subject: [PATCH 13/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 78e5c3ace4cb4..b19551369be28 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -589,7 +589,9 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) + self.validated_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) # run validate results = self._run(model) @@ -663,7 +665,9 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) + self.tested_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) # run test results = self._run(model) @@ -733,7 +737,9 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None) + self.predicted_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) results = self._run(model) @@ -1087,7 +1093,7 @@ def _run_sanity_check(self, ref_model): # restore the previous stage when the sanity check if finished self.state.stage = stage - def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected:bool) -> Optional[str]: + def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: """ If a user passes the model, we want to use this regardless. If a user passes the model, with best weights ckpt_path being something other than best or None, we load the weight From 9e367fd6be63cade8772b0574297e18539c39b46 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 12:02:55 +0100 Subject: [PATCH 14/60] Remve --- pytorch_lightning/trainer/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b19551369be28..6db59e7f34554 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1094,12 +1094,6 @@ def _run_sanity_check(self, ref_model): self.state.stage = stage def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: - """ - If a user passes the model, we want to use this regardless. - If a user passes the model, with best weights ckpt_path being something other than best or None, we load the weight - - """ - if model_provided and (ckpt_path in ('best', None)): return # use passed model to function without loading weights From 3f8c3d344ae79803b24b3f5203fc230886fc1906 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 12:21:28 +0100 Subject: [PATCH 15/60] Remve --- tests/plugins/test_deepspeed_plugin.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 85fb3a4bfc86c..6782d39e80052 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -601,15 +601,10 @@ def run_checkpoint_test( model = ModelParallelClassificationModel() else: model = ManualModelParallelClassificationModel() - if trainer.is_global_zero: - trainer = Trainer(default_root_dir=tmpdir, gpus=1, precision=16) + trainer = Trainer(default_root_dir=tmpdir, gpus=1, precision=16) - results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) - assert results[0]['test_acc'] > 0.7 - - dm.predict_dataloader = dm.test_dataloader - results = trainer.predict(model, datamodule=dm, ckpt_path=ck.best_model_path) - assert results[-1] > 0.7 + results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) + assert results[0]['test_acc'] > 0.7 @RunIf(min_gpus=2, deepspeed=True, special=True) From b8ffc39c41b96bd044625210b5121b467632e979 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 12:50:20 +0100 Subject: [PATCH 16/60] fix test --- tests/trainer/test_dataloaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index c4044935f4bd3..b4017a0e344cb 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -782,6 +782,8 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': + if ckpt_path == 'specific': + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) else: From b02f35bf29df3e503e20188dee2d05564330408f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 12:57:32 +0100 Subject: [PATCH 17/60] feedback --- pytorch_lightning/trainer/properties.py | 15 ++++++++++++++ pytorch_lightning/trainer/trainer.py | 26 ++++++------------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 54d0079b9255e..a1b7f3326d17d 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -68,6 +68,12 @@ class TrainerProperties(ABC): validate_loop: EvaluationLoop test_loop: EvaluationLoop predict_loop: PredictionLoop + + # .validate() and .test() set this when they load a checkpoint + validated_ckpt_path: str = None + tested_ckpt_path: str = None + predicted_ckpt_path: str = None + """ Accelerator properties """ @@ -570,6 +576,15 @@ def _results(self) -> Optional[ResultCollection]: if active_loop is not None: return active_loop._results + @property + def ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + """ Other """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6db59e7f34554..8a5d391d15fc8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -458,15 +458,10 @@ def _setup_on_init( self.test_dataloaders = None self.val_dataloaders = None - # .validate() and .test() set this when they load a checkpoint - self.validated_ckpt_path = None - self.tested_ckpt_path = None - # when true, print evaluation results in .validate() and .test() self.verbose_evaluate = True self.num_predict_batches = [] - self.predicted_ckpt_path = None def fit( self, @@ -810,15 +805,6 @@ def tune( return result - @property - def ckpt_path(self) -> Optional[str]: - if self.state.fn == TrainerFn.VALIDATING: - return self.validated_ckpt_path - if self.state.fn == TrainerFn.TESTING: - return self.tested_ckpt_path - if self.state.fn == TrainerFn.PREDICTING: - return self.predicted_ckpt_path - def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -850,14 +836,14 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module - if self.ckpt_path: + if self._ckpt_path: # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. if not self._device_type == DeviceType.TPU: self.training_type_plugin.barrier() - rank_zero_info(f"Loading checkpoint from {self.ckpt_path}") - self.checkpoint_connector.restore_model_weights(self.ckpt_path) + rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") + self.checkpoint_connector.restore_model_weights(self._ckpt_path) # ---------------------------- # INSPECT THE CORE LOOPS @@ -1094,12 +1080,12 @@ def _run_sanity_check(self, ref_model): self.state.stage = stage def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: - if model_provided and (ckpt_path in ('best', None)): - return # use passed model to function without loading weights - if model_connected and ckpt_path is None: return # use connected model without loading weights + if model_provided and ckpt_path in ('best', None): + return # use passed model to function without loading weights + fn = self.state.fn.value if model_connected and ckpt_path == 'best': From dbb03afd257c1bd3da337bccfea5e4204bd97f81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jul 2021 11:59:17 +0000 Subject: [PATCH 18/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/properties.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index a1b7f3326d17d..ae8f4cc2064f6 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -73,7 +73,6 @@ class TrainerProperties(ABC): validated_ckpt_path: str = None tested_ckpt_path: str = None predicted_ckpt_path: str = None - """ Accelerator properties """ From 1c7b9a10fdf69b9cac837451b4c9956a031eb116 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 12 Jul 2021 13:48:56 +0100 Subject: [PATCH 19/60] Update pytorch_lightning/trainer/properties.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/properties.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ae8f4cc2064f6..e159bb241fa17 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -70,9 +70,9 @@ class TrainerProperties(ABC): predict_loop: PredictionLoop # .validate() and .test() set this when they load a checkpoint - validated_ckpt_path: str = None - tested_ckpt_path: str = None - predicted_ckpt_path: str = None + validated_ckpt_path: Optional[str] = None + tested_ckpt_path: Optional[str] = None + predicted_ckpt_path: Optional[str] = None """ Accelerator properties """ From 444fb5582e0927cffda5a4e848313d4d263b0547 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 13:49:23 +0100 Subject: [PATCH 20/60] Feedback --- pytorch_lightning/trainer/properties.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e159bb241fa17..ca929e71158dd 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -553,6 +553,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop if self.predicting: return self.predict_loop + @property + def _ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + """ Logging properties """ @@ -575,15 +584,6 @@ def _results(self) -> Optional[ResultCollection]: if active_loop is not None: return active_loop._results - @property - def ckpt_path(self) -> Optional[str]: - if self.state.fn == TrainerFn.VALIDATING: - return self.validated_ckpt_path - if self.state.fn == TrainerFn.TESTING: - return self.tested_ckpt_path - if self.state.fn == TrainerFn.PREDICTING: - return self.predicted_ckpt_path - """ Other """ From 4632bbafc93355c9b4c9318c86da532d0af2c8b7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 14:16:06 +0100 Subject: [PATCH 21/60] Same fix --- tests/trainer/test_dataloaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b4017a0e344cb..ef634b0e2a55c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -740,6 +740,8 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': + if ckpt_path == 'specific': + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) else: From e92b7571042deef966f934e59e7ed760c89e6595 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 14:16:17 +0100 Subject: [PATCH 22/60] Same fix --- tests/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ef634b0e2a55c..39e3feb7a4fd2 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -741,7 +741,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): ): if stage == 'test': if ckpt_path == 'specific': - trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) else: From 66bea8e2d9baf36e4366f8d2b1ebea12e4c1b132 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 15:03:57 +0100 Subject: [PATCH 23/60] Add test for behaviour, modify based on feedback --- pytorch_lightning/trainer/trainer.py | 212 ++++++++++++++------------- tests/trainer/test_trainer.py | 15 +- 2 files changed, 123 insertions(+), 104 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8a5d391d15fc8..e623351d93bdc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -85,7 +85,7 @@ # warnings to ignore in trainer warnings.filterwarnings( 'ignore', message='torch.distributed.reduce_op is deprecated, ' - 'please use torch.distributed.ReduceOp instead' + 'please use torch.distributed.ReduceOp instead' ) @@ -102,64 +102,65 @@ class Trainer( @_defaults_from_env_vars def __init__( - self, - logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, - checkpoint_callback: bool = True, - callbacks: Optional[Union[List[Callback], Callback]] = None, - default_root_dir: Optional[str] = None, - gradient_clip_val: float = 0.0, - gradient_clip_algorithm: str = 'norm', - process_position: int = 0, - num_nodes: int = 1, - num_processes: int = 1, - gpus: Optional[Union[List[int], str, int]] = None, - auto_select_gpus: bool = False, - tpu_cores: Optional[Union[List[int], str, int]] = None, - ipus: Optional[int] = None, - log_gpu_memory: Optional[str] = None, - progress_bar_refresh_rate: Optional[int] = None, - overfit_batches: Union[int, float] = 0.0, - track_grad_norm: Union[int, float, str] = -1, - check_val_every_n_epoch: int = 1, - fast_dev_run: Union[int, bool] = False, - accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, - max_steps: Optional[int] = None, - min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, - limit_train_batches: Union[int, float] = 1.0, - limit_val_batches: Union[int, float] = 1.0, - limit_test_batches: Union[int, float] = 1.0, - limit_predict_batches: Union[int, float] = 1.0, - val_check_interval: Union[int, float] = 1.0, - flush_logs_every_n_steps: int = 100, - log_every_n_steps: int = 50, - accelerator: Optional[Union[str, Accelerator]] = None, - sync_batchnorm: bool = False, - precision: int = 32, - weights_summary: Optional[str] = 'top', - weights_save_path: Optional[str] = None, - num_sanity_val_steps: int = 2, - truncated_bptt_steps: Optional[int] = None, - resume_from_checkpoint: Optional[Union[Path, str]] = None, - profiler: Optional[Union[BaseProfiler, str]] = None, - benchmark: bool = False, - deterministic: bool = False, - reload_dataloaders_every_n_epochs: int = 0, - reload_dataloaders_every_epoch: bool = False, - auto_lr_find: Union[bool, str] = False, - replace_sampler_ddp: bool = True, - terminate_on_nan: bool = False, - auto_scale_batch_size: Union[str, bool] = False, - prepare_data_per_node: bool = True, - plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, - amp_backend: str = 'native', - amp_level: str = 'O2', - distributed_backend: Optional[str] = None, - move_metrics_to_cpu: bool = False, - multiple_trainloader_mode: str = 'max_size_cycle', - stochastic_weight_avg: bool = False + self, + logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, + checkpoint_callback: bool = True, + callbacks: Optional[Union[List[Callback], Callback]] = None, + default_root_dir: Optional[str] = None, + gradient_clip_val: float = 0.0, + gradient_clip_algorithm: str = 'norm', + process_position: int = 0, + num_nodes: int = 1, + num_processes: int = 1, + gpus: Optional[Union[List[int], str, int]] = None, + auto_select_gpus: bool = False, + tpu_cores: Optional[Union[List[int], str, int]] = None, + ipus: Optional[int] = None, + log_gpu_memory: Optional[str] = None, + progress_bar_refresh_rate: Optional[int] = None, + overfit_batches: Union[int, float] = 0.0, + track_grad_norm: Union[int, float, str] = -1, + check_val_every_n_epoch: int = 1, + fast_dev_run: Union[int, bool] = False, + accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, + max_epochs: Optional[int] = None, + min_epochs: Optional[int] = None, + max_steps: Optional[int] = None, + min_steps: Optional[int] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + limit_train_batches: Union[int, float] = 1.0, + limit_val_batches: Union[int, float] = 1.0, + limit_test_batches: Union[int, float] = 1.0, + limit_predict_batches: Union[int, float] = 1.0, + val_check_interval: Union[int, float] = 1.0, + flush_logs_every_n_steps: int = 100, + log_every_n_steps: int = 50, + accelerator: Optional[Union[str, Accelerator]] = None, + sync_batchnorm: bool = False, + precision: int = 32, + weights_summary: Optional[str] = 'top', + weights_save_path: Optional[str] = None, + num_sanity_val_steps: int = 2, + truncated_bptt_steps: Optional[int] = None, + resume_from_checkpoint: Optional[Union[Path, str]] = None, + profiler: Optional[Union[BaseProfiler, str]] = None, + benchmark: bool = False, + deterministic: bool = False, + reload_dataloaders_every_n_epochs: int = 0, + reload_dataloaders_every_epoch: bool = False, + auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, + terminate_on_nan: bool = False, + auto_scale_batch_size: Union[str, bool] = False, + prepare_data_per_node: bool = True, + plugins: Optional[ + Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, + amp_backend: str = 'native', + amp_level: str = 'O2', + distributed_backend: Optional[str] = None, + move_metrics_to_cpu: bool = False, + multiple_trainloader_mode: str = 'max_size_cycle', + stochastic_weight_avg: bool = False ): r""" Customize every aspect of training via flags @@ -437,8 +438,8 @@ def __init__( self.on_init_end() def _setup_on_init( - self, - num_sanity_val_steps: int, + self, + num_sanity_val_steps: int, ) -> None: self._log_device_info() @@ -464,12 +465,12 @@ def _setup_on_init( self.num_predict_batches = [] def fit( - self, - model: 'pl.LightningModule', - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - train_dataloader=None, # noqa TODO: remove with 1.6 + self, + model: 'pl.LightningModule', + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> None: r""" Runs the full optimization routine. @@ -520,13 +521,13 @@ def fit( self.training = False def validate( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - val_dataloaders=None, # noqa TODO: remove with 1.6 + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + val_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -597,13 +598,13 @@ def validate( return results def test( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - test_dataloaders=None, # noqa TODO: remove with 1.6 + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + test_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from @@ -673,12 +674,12 @@ def test( return results def predict( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = 'best', + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + datamodule: Optional[LightningDataModule] = None, + return_predictions: Optional[bool] = None, + ckpt_path: Optional[str] = 'best', ) -> Optional[_PREDICT_OUTPUT]: r""" @@ -744,14 +745,14 @@ def predict( return results def tune( - self, - model: 'pl.LightningModule', - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, - lr_find_kwargs: Optional[Dict[str, Any]] = None, - train_dataloader=None, # noqa TODO: remove with 1.6 + self, + model: 'pl.LightningModule', + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -1080,15 +1081,21 @@ def _run_sanity_check(self, ref_model): self.state.stage = stage def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: - if model_connected and ckpt_path is None: - return # use connected model without loading weights - - if model_provided and ckpt_path in ('best', None): + if model_provided and ckpt_path is None: return # use passed model to function without loading weights fn = self.state.fn.value - if model_connected and ckpt_path == 'best': + if model_connected and ckpt_path is None: + rank_zero_warn( + f"`.{fn}(ckpt_path=None)` was called without a model. " + f"The best model of the previous `fit` call will be used. " + f"You can pass `ckpt_path='best'` to avoid this warning " + f"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model." + ) + ckpt_path = 'best' + + if (model_connected or model_provided) and ckpt_path == 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: @@ -1188,8 +1195,9 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: return output def _parse_devices( - self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int], - str, int]] + self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, + tpu_cores: Optional[Union[List[int], + str, int]] ) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: if auto_select_gpus and isinstance(gpus, int): gpus = pick_multiple_gpus(gpus) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 40a0334a6cdfc..06df07b271d6b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -680,14 +680,25 @@ def predict_step(self, batch, *_): if save_top_k == 0: with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): trainer_fn(ckpt_path=ckpt_path) + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + trainer_fn(model, ckpt_path=ckpt_path) else: trainer_fn(ckpt_path=ckpt_path) assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + + trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and - # use the weights from the end of training - trainer_fn(ckpt_path=ckpt_path) + # use the model + trainer_fn(model, ckpt_path=ckpt_path) assert getattr(trainer, path_attr) is None + + if save_top_k > 0: + # ckpt_path is None with no model provided means load the best weights + with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"): + trainer_fn(ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path else: # specific checkpoint, pick one from saved ones if save_top_k == 0: From 0139a19889d72e58110beb11f22c3d147ae36029 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jul 2021 14:05:13 +0000 Subject: [PATCH 24/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 196 +++++++++++++-------------- 1 file changed, 97 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e623351d93bdc..dc9fdaf85e088 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -85,7 +85,7 @@ # warnings to ignore in trainer warnings.filterwarnings( 'ignore', message='torch.distributed.reduce_op is deprecated, ' - 'please use torch.distributed.ReduceOp instead' + 'please use torch.distributed.ReduceOp instead' ) @@ -102,65 +102,64 @@ class Trainer( @_defaults_from_env_vars def __init__( - self, - logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, - checkpoint_callback: bool = True, - callbacks: Optional[Union[List[Callback], Callback]] = None, - default_root_dir: Optional[str] = None, - gradient_clip_val: float = 0.0, - gradient_clip_algorithm: str = 'norm', - process_position: int = 0, - num_nodes: int = 1, - num_processes: int = 1, - gpus: Optional[Union[List[int], str, int]] = None, - auto_select_gpus: bool = False, - tpu_cores: Optional[Union[List[int], str, int]] = None, - ipus: Optional[int] = None, - log_gpu_memory: Optional[str] = None, - progress_bar_refresh_rate: Optional[int] = None, - overfit_batches: Union[int, float] = 0.0, - track_grad_norm: Union[int, float, str] = -1, - check_val_every_n_epoch: int = 1, - fast_dev_run: Union[int, bool] = False, - accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, - max_steps: Optional[int] = None, - min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, - limit_train_batches: Union[int, float] = 1.0, - limit_val_batches: Union[int, float] = 1.0, - limit_test_batches: Union[int, float] = 1.0, - limit_predict_batches: Union[int, float] = 1.0, - val_check_interval: Union[int, float] = 1.0, - flush_logs_every_n_steps: int = 100, - log_every_n_steps: int = 50, - accelerator: Optional[Union[str, Accelerator]] = None, - sync_batchnorm: bool = False, - precision: int = 32, - weights_summary: Optional[str] = 'top', - weights_save_path: Optional[str] = None, - num_sanity_val_steps: int = 2, - truncated_bptt_steps: Optional[int] = None, - resume_from_checkpoint: Optional[Union[Path, str]] = None, - profiler: Optional[Union[BaseProfiler, str]] = None, - benchmark: bool = False, - deterministic: bool = False, - reload_dataloaders_every_n_epochs: int = 0, - reload_dataloaders_every_epoch: bool = False, - auto_lr_find: Union[bool, str] = False, - replace_sampler_ddp: bool = True, - terminate_on_nan: bool = False, - auto_scale_batch_size: Union[str, bool] = False, - prepare_data_per_node: bool = True, - plugins: Optional[ - Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, - amp_backend: str = 'native', - amp_level: str = 'O2', - distributed_backend: Optional[str] = None, - move_metrics_to_cpu: bool = False, - multiple_trainloader_mode: str = 'max_size_cycle', - stochastic_weight_avg: bool = False + self, + logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, + checkpoint_callback: bool = True, + callbacks: Optional[Union[List[Callback], Callback]] = None, + default_root_dir: Optional[str] = None, + gradient_clip_val: float = 0.0, + gradient_clip_algorithm: str = 'norm', + process_position: int = 0, + num_nodes: int = 1, + num_processes: int = 1, + gpus: Optional[Union[List[int], str, int]] = None, + auto_select_gpus: bool = False, + tpu_cores: Optional[Union[List[int], str, int]] = None, + ipus: Optional[int] = None, + log_gpu_memory: Optional[str] = None, + progress_bar_refresh_rate: Optional[int] = None, + overfit_batches: Union[int, float] = 0.0, + track_grad_norm: Union[int, float, str] = -1, + check_val_every_n_epoch: int = 1, + fast_dev_run: Union[int, bool] = False, + accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, + max_epochs: Optional[int] = None, + min_epochs: Optional[int] = None, + max_steps: Optional[int] = None, + min_steps: Optional[int] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + limit_train_batches: Union[int, float] = 1.0, + limit_val_batches: Union[int, float] = 1.0, + limit_test_batches: Union[int, float] = 1.0, + limit_predict_batches: Union[int, float] = 1.0, + val_check_interval: Union[int, float] = 1.0, + flush_logs_every_n_steps: int = 100, + log_every_n_steps: int = 50, + accelerator: Optional[Union[str, Accelerator]] = None, + sync_batchnorm: bool = False, + precision: int = 32, + weights_summary: Optional[str] = 'top', + weights_save_path: Optional[str] = None, + num_sanity_val_steps: int = 2, + truncated_bptt_steps: Optional[int] = None, + resume_from_checkpoint: Optional[Union[Path, str]] = None, + profiler: Optional[Union[BaseProfiler, str]] = None, + benchmark: bool = False, + deterministic: bool = False, + reload_dataloaders_every_n_epochs: int = 0, + reload_dataloaders_every_epoch: bool = False, + auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, + terminate_on_nan: bool = False, + auto_scale_batch_size: Union[str, bool] = False, + prepare_data_per_node: bool = True, + plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, + amp_backend: str = 'native', + amp_level: str = 'O2', + distributed_backend: Optional[str] = None, + move_metrics_to_cpu: bool = False, + multiple_trainloader_mode: str = 'max_size_cycle', + stochastic_weight_avg: bool = False ): r""" Customize every aspect of training via flags @@ -438,8 +437,8 @@ def __init__( self.on_init_end() def _setup_on_init( - self, - num_sanity_val_steps: int, + self, + num_sanity_val_steps: int, ) -> None: self._log_device_info() @@ -465,12 +464,12 @@ def _setup_on_init( self.num_predict_batches = [] def fit( - self, - model: 'pl.LightningModule', - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - train_dataloader=None, # noqa TODO: remove with 1.6 + self, + model: 'pl.LightningModule', + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> None: r""" Runs the full optimization routine. @@ -521,13 +520,13 @@ def fit( self.training = False def validate( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - val_dataloaders=None, # noqa TODO: remove with 1.6 + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + val_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -598,13 +597,13 @@ def validate( return results def test( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - test_dataloaders=None, # noqa TODO: remove with 1.6 + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + test_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from @@ -674,12 +673,12 @@ def test( return results def predict( - self, - model: Optional['pl.LightningModule'] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = 'best', + self, + model: Optional['pl.LightningModule'] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, + datamodule: Optional[LightningDataModule] = None, + return_predictions: Optional[bool] = None, + ckpt_path: Optional[str] = 'best', ) -> Optional[_PREDICT_OUTPUT]: r""" @@ -745,14 +744,14 @@ def predict( return results def tune( - self, - model: 'pl.LightningModule', - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, - lr_find_kwargs: Optional[Dict[str, Any]] = None, - train_dataloader=None, # noqa TODO: remove with 1.6 + self, + model: 'pl.LightningModule', + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -1195,9 +1194,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: return output def _parse_devices( - self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, - tpu_cores: Optional[Union[List[int], - str, int]] + self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int], + str, int]] ) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: if auto_select_gpus and isinstance(gpus, int): gpus = pick_multiple_gpus(gpus) From d48d9167d64f0d1fe27d7965b11128b00bc9e293 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 15:07:38 +0100 Subject: [PATCH 25/60] Wording --- tests/trainer/test_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 06df07b271d6b..3eca8bd64ebdc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -689,8 +689,7 @@ def predict_step(self, batch, *_): trainer_fn(model, ckpt_path=ckpt_path) assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: - # ckpt_path is None, meaning we don't load any checkpoints and - # use the model + # ckpt_path is None, meaning we don't load any checkpoints and use the provided model trainer_fn(model, ckpt_path=ckpt_path) assert getattr(trainer, path_attr) is None From 100d73b2eaba5acb2036485d3a8c4c7ec7d4002c Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 12 Jul 2021 15:32:13 +0100 Subject: [PATCH 26/60] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dc9fdaf85e088..cc79fe0532f2e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -539,7 +539,7 @@ def validate( ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. - When the model is given as argument, we load the ckpt path. + When the model and the ckpt path are passed as arguments, we load the ckpt path. verbose: If True, prints the validation results. @@ -1094,7 +1094,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ ) ckpt_path = 'best' - if (model_connected or model_provided) and ckpt_path == 'best': + if ckpt_path == 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: From f3f92a50c56711a0cc69591e0166249aa394f5ab Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 15:42:44 +0100 Subject: [PATCH 27/60] Cleanup docs --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc79fe0532f2e..54a6ef64dc176 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -523,7 +523,7 @@ def validate( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, val_dataloaders=None, # noqa TODO: remove with 1.6 @@ -538,7 +538,7 @@ def validate( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. - If ``None``, use the current weights of the model. + If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``. When the model and the ckpt path are passed as arguments, we load the ckpt path. verbose: If True, prints the validation results. @@ -600,7 +600,7 @@ def test( self, model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, test_dataloaders=None, # noqa TODO: remove with 1.6 @@ -616,8 +616,8 @@ def test( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. - When the model is given as argument, we load the ckpt path. + If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``. + When the model and the ckpt path are passed as arguments, we load the ckpt path. verbose: If True, prints the test results. @@ -678,7 +678,7 @@ def predict( dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: r""" From 2849d0b2ce1a6da32144791dd8ba148799db490f Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 12 Jul 2021 15:49:28 +0100 Subject: [PATCH 28/60] Update pytorch_lightning/trainer/trainer.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 54a6ef64dc176..28f1419bd4c99 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1088,9 +1088,9 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ if model_connected and ckpt_path is None: rank_zero_warn( f"`.{fn}(ckpt_path=None)` was called without a model. " - f"The best model of the previous `fit` call will be used. " - f"You can pass `ckpt_path='best'` to avoid this warning " - f"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model." + "The best model of the previous `fit` call will be used. " + f"You can pass `{fn}(ckpt_path='best')` to avoid this warning " + "or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model." ) ckpt_path = 'best' From f53c896ea633f2a62768e0f860af79ed6891a955 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 15:49:45 +0100 Subject: [PATCH 29/60] feedback --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 28f1419bd4c99..7d1e1aae7c889 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1081,7 +1081,8 @@ def _run_sanity_check(self, ref_model): def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: if model_provided and ckpt_path is None: - return # use passed model to function without loading weights + # use passed model to function without loading weights + return fn = self.state.fn.value From ebc713b41398a9fc90d4587cb81b739129ebe1fc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 16:31:28 +0100 Subject: [PATCH 30/60] Fixes to test API --- docs/source/common/test_set.rst | 2 +- tests/callbacks/test_callbacks.py | 2 +- tests/models/test_hooks.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/flags/test_fast_dev_run.py | 4 ++-- tests/trainer/logging_/test_logger_connector.py | 2 +- tests/trainer/test_dataloaders.py | 14 +++++++------- tests/trainer/test_trainer.py | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 5703d71d956de..54f7f161a205b 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -23,7 +23,7 @@ To run the test set after training completes, use this method. trainer.test() # (2) don't load a checkpoint, instead use the model with the latest weights - trainer.test(ckpt_path=None) + trainer.test(model) # (3) test using a specific checkpoint trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt') diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 57fdd1bf66322..d92441f07ceb0 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -109,6 +109,6 @@ def configure_callbacks(self): callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit - trainer_fn(ckpt_path=None) + trainer_fn(model) callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d89fc090c401f..1c1940abb2796 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -830,7 +830,7 @@ def predict_dataloader(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - trainer.test(ckpt_path=None) + trainer.test(model) preds = trainer.predict(model) assert len(preds) == 2 diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index b1b8e73861ef1..e4f350b4e1c5e 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -364,7 +364,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template): # fit model trainer = Trainer(**trainer_options) trainer.fit(model) - trainer.test(ckpt_path=None) + trainer.test(model) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 8320134058c4e..3669fc9ef9499 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -108,7 +108,7 @@ def _make_fast_dev_run_assertions(trainer, model): train_val_step_model = FastDevRunModel() trainer = Trainer(**trainer_config) trainer.fit(train_val_step_model) - trainer.test(ckpt_path=None) + trainer.test(train_val_step_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_val_step_model) @@ -121,7 +121,7 @@ def _make_fast_dev_run_assertions(trainer, model): trainer = Trainer(**trainer_config) trainer.fit(train_step_only_model) - trainer.test(ckpt_path=None) + trainer.test(train_step_only_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_step_only_model) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 27598b40fbd31..5a353b52e7e74 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -202,7 +202,7 @@ def test_dataloader(self): max_epochs=1, ) trainer.fit(model) - trainer.test(model, ckpt_path=None) + trainer.test(model) def test_can_return_tensor_with_more_than_one_element(tmpdir): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 39e3feb7a4fd2..6502859976a20 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -470,7 +470,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim assert trainer.num_training_batches == expected_train_batches assert trainer.num_val_batches == expected_val_batches - trainer.test(ckpt_path=None) + trainer.test(model) expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders] assert trainer.num_test_batches == expected_test_batches @@ -507,7 +507,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v # ------------------------------------------- assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) - trainer.test(ckpt_path=None) + trainer.test(model) # when the limit is greater than the number of test batches it should be the num in loaders test_dataloader_lengths = [len(x) for x in model.test_dataloader()] @@ -586,7 +586,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): assert trainer.num_training_batches == fast_dev_run assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) - trainer.test(ckpt_path=None) + trainer.test(model) assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) # verify sanity check batches match as expected @@ -740,7 +740,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': - if ckpt_path == 'specific': + if ckpt_path in ('specific', 'best'): trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) @@ -784,7 +784,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': - if ckpt_path == 'specific': + if ckpt_path in ('specific', 'best'): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) @@ -1097,7 +1097,7 @@ def test_dataloader_distributed_sampler(tmpdir): callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))], ) trainer.fit(model) - trainer.test(ckpt_path=None) + trainer.test(model) class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): @@ -1617,7 +1617,7 @@ def predict_dataloader(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - trainer.test(ckpt_path=None) + trainer.test(model) preds = trainer.predict(model) assert len(preds) == 2 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3eca8bd64ebdc..595b4302b55fc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1314,9 +1314,9 @@ def setup(self, model, stage): if stage == "fit": trainer.fit(model) elif stage == "validate": - trainer.validate(model, ckpt_path=None) + trainer.validate(model) else: - trainer.test(model, ckpt_path=None) + trainer.test(model) assert trainer.stage == stage assert trainer.lightning_module.stage == stage From 76e22c268bb0756b3d932a78de6867fd99a2b4cc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 12 Jul 2021 16:40:20 +0100 Subject: [PATCH 31/60] Add carlos description --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d1e1aae7c889..3c4d276ff5734 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -538,8 +538,8 @@ def validate( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. - If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``. - When the model and the ckpt path are passed as arguments, we load the ckpt path. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. verbose: If True, prints the validation results. @@ -616,8 +616,8 @@ def test( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``. - When the model and the ckpt path are passed as arguments, we load the ckpt path. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. verbose: If True, prints the test results. @@ -697,8 +697,8 @@ def predict( ``True`` by default except when an accelerator that spawns processes is used (not supported). ckpt_path: Either ``best`` or path to the checkpoint you wish to predict. - If ``None``, use the current weights of the model. - When the model is given as argument, we load the ckpt path. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. From 0b4622676236789289020e021c939671cff6031b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 13 Jul 2021 10:32:33 +0100 Subject: [PATCH 32/60] Fixes --- .../plugins/training_type/deepspeed.py | 51 +++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 18 +++---- tests/plugins/test_deepspeed_plugin.py | 19 ++++--- 3 files changed, 67 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 1c9d3e4704f1b..7951825885155 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -676,10 +676,6 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING save_dir = self._filepath_to_dir(checkpoint_path) - if self.zero_stage_3: - # TODO: Currently required as this call is missing within the deepspeed engine. - self.deepspeed_engine.optimizer._partition_all_parameters() - _, client_state = self.deepspeed_engine.load_checkpoint( save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting ) @@ -688,7 +684,52 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` if self.save_full_weights and self.zero_stage_3: - self.lightning_module.load_state_dict(checkpoint['state_dict']) + self.model_to_device() + self._restore_zero_state(checkpoint) + + def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: + """ + Overrides the normal load_state_dict behaviour in PyTorch to ensure + we gather parameters that may be sharded across processes before loading + the state dictionary when using ZeRO stage 3. + This is then automatically synced across processes. + Args: + ckpt: The ckpt file. + """ + + def load(module: torch.nn.Module, prefix=""): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + state_dict = ckpt['state_dict'] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=True, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + load(self.lightning_module, prefix="") def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3c4d276ff5734..cf050dab5ea6c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -836,15 +836,6 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module - if self._ckpt_path: - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - - rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") - self.checkpoint_connector.restore_model_weights(self._ckpt_path) - # ---------------------------- # INSPECT THE CORE LOOPS # ---------------------------- @@ -886,6 +877,15 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # restore optimizers, etc. self.checkpoint_connector.restore_training_state() + if self._ckpt_path: + # only one process running at this point for TPUs, as spawn isn't triggered yet + # todo: move this logic internally within the barrier. + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() + + rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") + self.checkpoint_connector.restore_model_weights(self._ckpt_path) + # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 6782d39e80052..4a3ede2af2be8 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -196,7 +196,7 @@ def test_deepspeed_with_invalid_config_path(tmpdir): """ with pytest.raises( - MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" ): DeepSpeedPlugin(config='invalid_path.json') @@ -234,7 +234,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): plugins='deepspeed', ) with pytest.raises( - MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' + MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' ): trainer.fit(model) @@ -570,7 +570,7 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config def run_checkpoint_test( - tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 + tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 ): seed_everything(1) if automatic_optimization: @@ -591,7 +591,7 @@ def run_checkpoint_test( ) trainer.fit(model, datamodule=dm) - results = trainer.test(model, datamodule=dm) + results = trainer.test(ckpt_path='best', datamodule=dm) assert results[0]['test_acc'] > 0.7 saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]['test_acc'] > 0.7 @@ -601,7 +601,12 @@ def run_checkpoint_test( model = ModelParallelClassificationModel() else: model = ManualModelParallelClassificationModel() - trainer = Trainer(default_root_dir=tmpdir, gpus=1, precision=16) + trainer = Trainer( + default_root_dir=tmpdir, + gpus=2, + plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], + precision=16 + ) results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) assert results[0]['test_acc'] > 0.7 @@ -616,7 +621,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir, save_full_weights=False) -@RunIf(min_gpus=2, deepspeed=True, special=False) +@RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): """ Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, @@ -656,7 +661,7 @@ def __init__(self): self.on_train_batch_start_called = False def on_train_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps From 8042fb4670b3349ea993816b92191864fc6a9efc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 13 Jul 2021 12:47:58 +0100 Subject: [PATCH 33/60] Changes --- .../plugins/training_type/deepspeed.py | 16 +--- pytorch_lightning/trainer/trainer.py | 15 ++- tests/plugins/test_deepspeed_plugin.py | 95 ++++++++++++++++++- 3 files changed, 106 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 7951825885155..0b2a09e4d3295 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch +from torch.nn import Module import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler @@ -632,9 +633,6 @@ def _create_default_config( cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg - def _filepath_to_dir(self, filepath: str) -> str: - return os.path.dirname(filepath) - @property def deepspeed_engine(self): return self.model @@ -659,27 +657,23 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object - save_dir = self._filepath_to_dir(filepath) _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers'] checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint) + self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: if self.save_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) return super().load_checkpoint_file(checkpoint_path) - # Rely on deepspeed to load the checkpoint and necessary information + # Rely on deepspeed completely to load the checkpoint and necessary information from pytorch_lightning.trainer.states import TrainerFn is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING - save_dir = self._filepath_to_dir(checkpoint_path) - _, client_state = self.deepspeed_engine.load_checkpoint( - save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting + checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting ) - return client_state def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf050dab5ea6c..ddf45b4319733 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -512,8 +512,6 @@ def fit( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self.checkpoint_connector.resume_start() - self._run(model) assert self.state.stopped @@ -827,12 +825,6 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - # restore modules after setup - self.checkpoint_connector.restore_datamodule() - self.checkpoint_connector.restore_model() - # restore callback states - self.checkpoint_connector.restore_callbacks() - self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -874,6 +866,13 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() + # restore modules after setup + self.checkpoint_connector.resume_start() + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + # restore optimizers, etc. self.checkpoint_connector.restore_training_state() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 4a3ede2af2be8..aec62ba168c88 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -581,7 +581,6 @@ def run_checkpoint_test( ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, - progress_bar_refresh_rate=0, max_epochs=10, plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], gpus=2, @@ -630,6 +629,100 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): run_checkpoint_test(tmpdir, save_full_weights=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): + """ + Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning + that the optimizer state and scheduler states cannot be restored. + """ + model = ModelParallelClassificationModel() + dm = ClassifDataModule() + + ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + plugins=DeepSpeedPlugin(stage=3), + gpus=1, + precision=16, + callbacks=[ck] + ) + trainer.fit(model, datamodule=dm) + model = ModelParallelClassificationModel() + with pytest.warns(UserWarning, match="A single checkpoint file was saved using ZeRO Stage 3. " + "This means optimizer states and scheduler states can not be restored"): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins='deepspeed_stage_3', + gpus=1, + precision=16, + resume_from_checkpoint=ck.best_model_path + ) + trainer.fit(model, datamodule=dm) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): + """ + Test to ensure with Stage 3 and multiple GPUs that we can resume training if save_full_weights is false. + """ + initial_model = ModelParallelClassificationModel() + dm = ClassifDataModule() + + ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) + initial_trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + plugins=DeepSpeedPlugin(stage=3, save_full_weights=False), + gpus=1, + precision=16, + callbacks=[ck] + ) + initial_trainer.fit(initial_model, datamodule=dm) + + class TestCallback(Callback): + def on_train_batch_start( + self, + trainer: 'pl.Trainer', + pl_module: 'pl.LightningModule', + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin + current_deepspeed_plugin = trainer.accelerator.training_type_plugin + + assert isinstance(original_deepspeed_plugin, DeepSpeedPlugin) + assert isinstance(current_deepspeed_plugin, DeepSpeedPlugin) + # assert optimizer states are the correctly loaded + original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() + current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() + for orig_tensor, current_tensor in zip(original_optimizer_dict['fp32_flat_groups'], current_optimizer_dict['fp32_flat_groups']): + assert torch.all(orig_tensor.eq(current_tensor)) + # assert model state is loaded correctly + for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()): + assert torch.equal(current_param.cpu(), initial_param.cpu()) + + model = ModelParallelClassificationModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=DeepSpeedPlugin(stage=3, save_full_weights=False), + gpus=1, + precision=16, + resume_from_checkpoint=ck.best_model_path, + callbacks=TestCallback() + ) + trainer.fit(model, datamodule=dm) + + @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): """ From 203fd4935a467be0c83c05f98e39256565b41bed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jul 2021 11:52:08 +0000 Subject: [PATCH 34/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/deepspeed.py | 1 + tests/plugins/test_deepspeed_plugin.py | 52 +++++++++++-------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 0b2a09e4d3295..4645a0e1d0d98 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -723,6 +723,7 @@ def load(module: torch.nn.Module, prefix=""): for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") + load(self.lightning_module, prefix="") def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index aec62ba168c88..02d259a1b06a9 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -196,7 +196,7 @@ def test_deepspeed_with_invalid_config_path(tmpdir): """ with pytest.raises( - MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" ): DeepSpeedPlugin(config='invalid_path.json') @@ -234,7 +234,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): plugins='deepspeed', ) with pytest.raises( - MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' + MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' ): trainer.fit(model) @@ -570,7 +570,7 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config def run_checkpoint_test( - tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 + tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 ): seed_everything(1) if automatic_optimization: @@ -652,17 +652,20 @@ def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): ) trainer.fit(model, datamodule=dm) model = ModelParallelClassificationModel() - with pytest.warns(UserWarning, match="A single checkpoint file was saved using ZeRO Stage 3. " - "This means optimizer states and scheduler states can not be restored"): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - plugins='deepspeed_stage_3', - gpus=1, - precision=16, - resume_from_checkpoint=ck.best_model_path - ) - trainer.fit(model, datamodule=dm) + with pytest.warns( + UserWarning, + match="A single checkpoint file was saved using ZeRO Stage 3. " + "This means optimizer states and scheduler states can not be restored" + ): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins='deepspeed_stage_3', + gpus=1, + precision=16, + resume_from_checkpoint=ck.best_model_path + ) + trainer.fit(model, datamodule=dm) @RunIf(min_gpus=1, deepspeed=True, special=True) @@ -688,14 +691,15 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): initial_trainer.fit(initial_model, datamodule=dm) class TestCallback(Callback): + def on_train_batch_start( - self, - trainer: 'pl.Trainer', - pl_module: 'pl.LightningModule', - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: + self, + trainer: 'pl.Trainer', + pl_module: 'pl.LightningModule', + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin current_deepspeed_plugin = trainer.accelerator.training_type_plugin @@ -704,7 +708,9 @@ def on_train_batch_start( # assert optimizer states are the correctly loaded original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() - for orig_tensor, current_tensor in zip(original_optimizer_dict['fp32_flat_groups'], current_optimizer_dict['fp32_flat_groups']): + for orig_tensor, current_tensor in zip( + original_optimizer_dict['fp32_flat_groups'], current_optimizer_dict['fp32_flat_groups'] + ): assert torch.all(orig_tensor.eq(current_tensor)) # assert model state is loaded correctly for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()): @@ -754,7 +760,7 @@ def __init__(self): self.on_train_batch_start_called = False def on_train_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps From 8d0f260ebe69e8e8cbfdd625dd825b752a446ce0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Jul 2021 10:12:35 +0100 Subject: [PATCH 35/60] Try delaying --- pytorch_lightning/accelerators/accelerator.py | 9 +++++++ .../plugins/training_type/deepspeed.py | 4 +++ .../training_type/training_type_plugin.py | 9 +++++++ pytorch_lightning/trainer/trainer.py | 25 ++++++++++++------- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5d86c54028b6e..1d93e7f9b49c1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -529,6 +529,15 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: """ return self.training_type_plugin.setup_optimizers_in_pre_dispatch + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + """ + Override to delay restoring from checkpoint till after predispatch. + This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + Returns: If true, restore checkpoint after pre_dispatch. + """ + return self.training_type_plugin.restore_checkpoint_after_pre_dispatch + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 4645a0e1d0d98..2986672b7de7e 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -342,6 +342,10 @@ def setup_distributed(self): self._format_config() self._config_initialized = True + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + return True + def pre_dispatch(self): self.init_deepspeed() self.barrier() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e49d170a93d66..4407e8a020226 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -235,6 +235,15 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: """ return False + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + """ + Override to delay restoring from checkpoint till after predispatch. + This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + Returns: If true, restore checkpoint after pre_dispatch. + """ + return False + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: """ Provide a hook to count optimizer step calls. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ddf45b4319733..a6dfbe3373f9b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -803,6 +803,14 @@ def tune( return result + def _restore_training(self) -> None: + # restore modules after setup + self.checkpoint_connector.resume_start() + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -825,6 +833,12 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + if not self.accelerator.restore_checkpoint_after_pre_dispatch: + self._restore_training() + + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -866,15 +880,8 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - # restore modules after setup - self.checkpoint_connector.resume_start() - self.checkpoint_connector.restore_datamodule() - self.checkpoint_connector.restore_model() - # restore callback states - self.checkpoint_connector.restore_callbacks() - - # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() + if self.accelerator.restore_checkpoint_after_pre_dispatch: + self._restore_training() if self._ckpt_path: # only one process running at this point for TPUs, as spawn isn't triggered yet From 28d7575e7fbcd389c70b4c742bd404878e49e8be Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Jul 2021 14:53:13 +0100 Subject: [PATCH 36/60] Fixes --- .../plugins/training_type/deepspeed.py | 43 +++++++------ .../training_type/training_type_plugin.py | 5 ++ .../connectors/checkpoint_connector.py | 3 +- pytorch_lightning/trainer/trainer.py | 5 +- tests/plugins/test_deepspeed_plugin.py | 63 +++++++++++-------- 5 files changed, 69 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 2986672b7de7e..70d3578ffb244 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -20,7 +20,6 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch -from torch.nn import Module import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler @@ -34,7 +33,9 @@ from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn +from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache + +warning_cache = WarningCache() if _DEEPSPEED_AVAILABLE: import deepspeed @@ -119,7 +120,7 @@ def __init__( cpu_checkpointing: bool = False, contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, - save_full_weights: bool = True, + load_full_weights: bool = False, cpu_offload: bool = False, cpu_offload_params: bool = False, cpu_offload_use_pin_memory: bool = False, @@ -250,7 +251,7 @@ def __init__( synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. - save_full_weights: Gathers weights across all processes before saving to disk + load_full_weights: Gathers weights across all processes before saving to disk when using ZeRO Stage 3. This allows a single weight file to contain the entire model, rather than individual sharded weight files. Disable to save sharded states individually. @@ -314,7 +315,7 @@ def __init__( deepspeed.utils.logging.logger.setLevel(logging_level) self.remote_device = remote_device - self.save_full_weights = save_full_weights + self.load_full_weights = load_full_weights # default FP16 parameters. self.loss_scale = loss_scale @@ -641,6 +642,10 @@ def _create_default_config( def deepspeed_engine(self): return self.model + @property + def _multi_device(self) -> bool: + return self.num_processes > 1 or self.num_nodes > 1 + def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -648,17 +653,12 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: checkpoint: The checkpoint state dictionary filepath: write-target file's path """ - if self.save_full_weights and self.zero_stage_3: - # todo (sean): expose this as general function in deepspeed - state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict() - if self.is_global_zero: - # State dict keys will include reference to wrapper LightningDeepSpeedModule - # Delete `module` prefix before saving. - state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} - checkpoint['state_dict'] = state_dict - return super().save_checkpoint(checkpoint, filepath) - return - + if self.zero_stage_3 and self._multi_device: + warning_cache.warn( + 'When saving the DeepSpeed Stage 3 checkpoint, ' + 'each worker will save a shard of the checkpoint within a directory.' + 'If a single file is required after training, see for instructions.' + ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers'] @@ -666,7 +666,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: - if self.save_full_weights and self.zero_stage_3: + if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) @@ -678,10 +678,15 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Di _, client_state = self.deepspeed_engine.load_checkpoint( checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting ) + return client_state + + @property + def lightning_restore_optimizer_and_schedulers(self) -> bool: + return False # managed by DeepSpeed def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` - if self.save_full_weights and self.zero_stage_3: + if self.load_full_weights and self.zero_stage_3: self.model_to_device() self._restore_zero_state(checkpoint) @@ -732,7 +737,7 @@ def load(module: torch.nn.Module, prefix=""): def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` - if self.save_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: rank_zero_warn( "A single checkpoint file was saved using ZeRO Stage 3. This means optimizer states and " "scheduler states can not be restored. If you'd like to restore these states, you must" diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 4407e8a020226..af79520aed260 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -134,6 +134,11 @@ def lightning_module(self) -> 'pl.LightningModule': """Returns the pure LightningModule without potential wrappers""" return unwrap_lightning_module(self._model) + @property + def lightning_restore_optimizer_and_schedulers(self) -> bool: + """Whether to allow Lightning to restore optimizers/schedulers.""" + return True + @property def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab74c3bccfc8d..c6a6a0afab938 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -211,7 +211,8 @@ def restore_progress(self) -> None: def restore_optimizers_and_schedulers(self) -> None: """ Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ - if not self._loaded_checkpoint: + if not self._loaded_checkpoint or \ + not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers: return # validation diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b3348b1e123e4..a20d966297fb2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -810,6 +810,8 @@ def _restore_training(self) -> None: self.checkpoint_connector.restore_model() # restore callback states self.checkpoint_connector.restore_callbacks() + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams @@ -836,9 +838,6 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, if not self.accelerator.restore_checkpoint_after_pre_dispatch: self._restore_training() - # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() - self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 02d259a1b06a9..8d80a407ecb68 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -569,9 +569,7 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim) -def run_checkpoint_test( - tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 -): +def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2): seed_everything(1) if automatic_optimization: model = ModelParallelClassificationModel() @@ -582,7 +580,7 @@ def run_checkpoint_test( trainer = Trainer( default_root_dir=tmpdir, max_epochs=10, - plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], + plugins=[DeepSpeedPlugin(stage=3)], gpus=2, precision=16, accumulate_grad_batches=accumulate_grad_batches, @@ -600,12 +598,7 @@ def run_checkpoint_test( model = ModelParallelClassificationModel() else: model = ManualModelParallelClassificationModel() - trainer = Trainer( - default_root_dir=tmpdir, - gpus=2, - plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], - precision=16 - ) + trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16) results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) assert results[0]['test_acc'] > 0.7 @@ -617,16 +610,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and see convergence. """ - run_checkpoint_test(tmpdir, save_full_weights=False) - - -@RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, - where we save the full weights to one file. - """ - run_checkpoint_test(tmpdir, save_full_weights=True) + run_checkpoint_test(tmpdir) @RunIf(min_gpus=1, deepspeed=True, special=True) @@ -668,10 +652,33 @@ def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): trainer.fit(model, datamodule=dm) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, special=False) +def test_deepspeed_multigpu_stage_3_save_warning(tmpdir): + """ + Test to ensure with Stage 3 and multiple GPUs that we recieve a warning that we're saving sharded checkpoints. + """ + initial_model = ModelParallelClassificationModel() + dm = ClassifDataModule() + + ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) + initial_trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + plugins=DeepSpeedPlugin(stage=3), + gpus=1, + precision=16, + callbacks=[ck] + ) + initial_trainer.fit(initial_model, datamodule=dm) + + +@RunIf(min_gpus=1, deepspeed=True, special=False) def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): """ - Test to ensure with Stage 3 and multiple GPUs that we can resume training if save_full_weights is false. + Test to ensure with Stage 3 and multiple GPUs that we can resume training. """ initial_model = ModelParallelClassificationModel() dm = ClassifDataModule() @@ -683,7 +690,7 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, - plugins=DeepSpeedPlugin(stage=3, save_full_weights=False), + plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, callbacks=[ck] @@ -694,8 +701,8 @@ class TestCallback(Callback): def on_train_batch_start( self, - trainer: 'pl.Trainer', - pl_module: 'pl.LightningModule', + trainer: Trainer, + pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int, @@ -715,12 +722,14 @@ def on_train_batch_start( # assert model state is loaded correctly for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()): assert torch.equal(current_param.cpu(), initial_param.cpu()) + # assert epoch has correctly been restored + assert trainer.current_epoch == 1 model = ModelParallelClassificationModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, - plugins=DeepSpeedPlugin(stage=3, save_full_weights=False), + plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, resume_from_checkpoint=ck.best_model_path, @@ -735,7 +744,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, where we save the full weights to one file. """ - run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1) + run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1) @RunIf(min_gpus=2, deepspeed=True, special=True) From a3c600909da18825dc44606a9a7dbc0a20946243 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jul 2021 13:59:13 +0000 Subject: [PATCH 37/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/deepspeed.py | 15 +++++----- .../connectors/checkpoint_connector.py | 6 ++-- pytorch_lightning/trainer/trainer.py | 4 +-- tests/plugins/test_deepspeed_plugin.py | 30 ++++++++----------- tests/trainer/test_dataloaders.py | 4 +-- 5 files changed, 27 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 39ac39a8df614..d86781b19a15f 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -16,7 +16,6 @@ import logging import os import platform - from collections import OrderedDict from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union @@ -680,13 +679,13 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: """ if self.zero_stage_3 and self._multi_device: warning_cache.warn( - 'When saving the DeepSpeed Stage 3 checkpoint, ' - 'each worker will save a shard of the checkpoint within a directory.' - 'If a single file is required after training, see for instructions.' + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory." + "If a single file is required after training, see for instructions." ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object - _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers'] + _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) @@ -731,10 +730,10 @@ def load(module: torch.nn.Module, prefix=""): missing_keys = [] unexpected_keys = [] error_msgs = [] - state_dict = ckpt['state_dict'] + state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata @@ -752,7 +751,7 @@ def load(module: torch.nn.Module, prefix=""): strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, - error_msgs=error_msgs + error_msgs=error_msgs, ) for name, child in module._modules.items(): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 51db7cd1eca8e..8365f589aa8b3 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -221,8 +221,10 @@ def restore_loops(self) -> None: def restore_optimizers_and_schedulers(self) -> None: """Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" - if not self._loaded_checkpoint or \ - not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers: + if ( + not self._loaded_checkpoint + or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers + ): return # validation diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a65d9c3726a1..931070451cc62 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1152,9 +1152,9 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f"You can pass `{fn}(ckpt_path='best')` to avoid this warning " "or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model." ) - ckpt_path = 'best' + ckpt_path = "best" - if ckpt_path == 'best': + if ckpt_path == "best": # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 74e63fbbf374e..01f74f7c2bcf2 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -470,7 +470,7 @@ def configure_optimizers(self): return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if not hasattr(self, 'model'): + if not hasattr(self, "model"): self.configure_sharded_model() @@ -542,7 +542,7 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu trainer.fit(model, datamodule=dm) results = trainer.test(model, datamodule=dm) - assert results[0]['test_acc'] > 0.7 + assert results[0]["test_acc"] > 0.7 saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]["test_acc"] > 0.7 assert saved_results == results @@ -554,7 +554,7 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16) results = trainer.test(model, datamodule=dm) - assert results[0]['test_acc'] > 0.7 + assert results[0]["test_acc"] > 0.7 @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -585,22 +585,22 @@ def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, - callbacks=[ck] + callbacks=[ck], ) trainer.fit(model, datamodule=dm) model = ModelParallelClassificationModel() with pytest.warns( UserWarning, match="A single checkpoint file was saved using ZeRO Stage 3. " - "This means optimizer states and scheduler states can not be restored" + "This means optimizer states and scheduler states can not be restored", ): trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, - plugins='deepspeed_stage_3', + plugins="deepspeed_stage_3", gpus=1, precision=16, - resume_from_checkpoint=ck.best_model_path + resume_from_checkpoint=ck.best_model_path, ) trainer.fit(model, datamodule=dm) @@ -623,7 +623,7 @@ def test_deepspeed_multigpu_stage_3_save_warning(tmpdir): plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, - callbacks=[ck] + callbacks=[ck], ) initial_trainer.fit(initial_model, datamodule=dm) @@ -646,19 +646,13 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, - callbacks=[ck] + callbacks=[ck], ) initial_trainer.fit(initial_model, datamodule=dm) class TestCallback(Callback): - def on_train_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - batch_idx: int, - dataloader_idx: int, + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin current_deepspeed_plugin = trainer.accelerator.training_type_plugin @@ -669,7 +663,7 @@ def on_train_batch_start( original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() for orig_tensor, current_tensor in zip( - original_optimizer_dict['fp32_flat_groups'], current_optimizer_dict['fp32_flat_groups'] + original_optimizer_dict["fp32_flat_groups"], current_optimizer_dict["fp32_flat_groups"] ): assert torch.all(orig_tensor.eq(current_tensor)) # assert model state is loaded correctly @@ -686,7 +680,7 @@ def on_train_batch_start( gpus=1, precision=16, resume_from_checkpoint=ck.best_model_path, - callbacks=TestCallback() + callbacks=TestCallback(), ) trainer.fit(model, datamodule=dm) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 91ad7f94e6827..93b6a1a9288c9 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -685,7 +685,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": - if ckpt_path in ('specific', 'best'): + if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) @@ -724,7 +724,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": - if ckpt_path in ('specific', 'best'): + if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) From c51033aa23b3b4ce467f6ced503ef1a9c8636427 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 11:36:16 +0100 Subject: [PATCH 38/60] fixes --- pytorch_lightning/trainer/trainer.py | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d9d8a2ef9383..c4d9874bf9b10 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -838,15 +838,13 @@ def tune( return result - def _restore_training(self) -> None: + def _restore_checkpoint(self) -> None: # restore modules after setup self.checkpoint_connector.resume_start() self.checkpoint_connector.restore_datamodule() self.checkpoint_connector.restore_model() # restore callback states self.checkpoint_connector.restore_callbacks() - # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams @@ -862,14 +860,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) - if self._ckpt_path: - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - - rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") - self.checkpoint_connector.restore_model_weights(self._ckpt_path) + if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: + self._load_checkpoint_weights() # ---------------------------- # SET UP TRAINING @@ -880,7 +872,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment if not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._restore_training() + self._restore_checkpoint() self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -922,6 +914,11 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() + + if self.accelerator.restore_checkpoint_after_pre_dispatch: + self._load_checkpoint_weights() + self._restore_checkpoint() + # restore optimizers, etc. self.checkpoint_connector.restore_training_state() @@ -947,6 +944,14 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, return self.accelerator.results + def _load_checkpoint_weights(self): + # only one process running at this point for TPUs, as spawn isn't triggered yet + # todo: move this logic internally within the barrier. + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() + rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") + self.checkpoint_connector.restore_model_weights(self._ckpt_path) + def _pre_dispatch(self): self.accelerator.pre_dispatch(self) self._log_hyperparams() From 4f5bd960fddd161f0753d8796d16db72c7bfd67e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 12:28:44 +0100 Subject: [PATCH 39/60] Add extra condition --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c4d9874bf9b10..4a91b17ac13a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -871,7 +871,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - if not self.accelerator.restore_checkpoint_after_pre_dispatch: + if not self.accelerator.restore_checkpoint_after_pre_dispatch and self.state.fn == TrainerFn.FITTING: self._restore_checkpoint() self._call_configure_sharded_model(model) # allow user to setup in model sharded environment @@ -915,7 +915,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - if self.accelerator.restore_checkpoint_after_pre_dispatch: + if self.accelerator.restore_checkpoint_after_pre_dispatch and self.state.fn == TrainerFn.FITTING: self._load_checkpoint_weights() self._restore_checkpoint() From e1fb2f04f350f751011aeedbfa4932ff8e9fdad2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 12:55:44 +0100 Subject: [PATCH 40/60] Fix --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4a91b17ac13a3..9842e1c923204 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -916,7 +916,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self._pre_dispatch() if self.accelerator.restore_checkpoint_after_pre_dispatch and self.state.fn == TrainerFn.FITTING: - self._load_checkpoint_weights() + if self._ckpt_path: + self._load_checkpoint_weights() self._restore_checkpoint() # restore optimizers, etc. From 77036a2f3b2039ca520b0f22386c34cbb51c8e6d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 15:21:44 +0100 Subject: [PATCH 41/60] Fix --- tests/plugins/test_deepspeed_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 334ead4b24f2d..b80088ebf7579 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -610,7 +610,7 @@ def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): trainer.fit(model, datamodule=dm) -@RunIf(min_gpus=1, deepspeed=True, special=False) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_save_warning(tmpdir): """ Test to ensure with Stage 3 and multiple GPUs that we recieve a warning that we're saving sharded checkpoints. @@ -633,7 +633,7 @@ def test_deepspeed_multigpu_stage_3_save_warning(tmpdir): initial_trainer.fit(initial_model, datamodule=dm) -@RunIf(min_gpus=1, deepspeed=True, special=False) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): """ Test to ensure with Stage 3 and multiple GPUs that we can resume training. From 82e00be99c676f08313875b6fe54e3267c592edf Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 17:27:53 +0100 Subject: [PATCH 42/60] Attempt to fix tests --- .../plugins/training_type/deepspeed.py | 7 ++--- tests/plugins/test_deepspeed_plugin.py | 30 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 948f8eb7ffda8..f5d11ee71e2b8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -766,10 +766,9 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: rank_zero_warn( - "A single checkpoint file was saved using ZeRO Stage 3. This means optimizer states and " - "scheduler states can not be restored. If you'd like to restore these states, you must" - "set save_full_weights=False, i.e Trainer(plugins=DeepSpeedPlugin(save_full_weights=False)) " - "when training the model initially." + "A single checkpoint file has been given. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must " + "provide a path to the originally saved DeepSpeed checkpoint." ) def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index b80088ebf7579..e506d67a3f8ac 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -6,6 +6,7 @@ import pytest import torch import torch.nn.functional as F +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from torch import nn, Tensor from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -329,8 +330,6 @@ def on_train_start(self, trainer, pl_module) -> None: trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer) - @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_custom_precision_params(tmpdir): @@ -396,15 +395,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu(tmpdir, deepspeed_config): """ - Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation. + Test to ensure that DeepSpeed with multiple GPUs works. """ model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, - plugins=[DeepSpeedPlugin(zero_optimization=False, stage=2)], - gpus=2, - fast_dev_run=True, - precision=16, + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=2)], gpus=2, fast_dev_run=True, precision=16 ) trainer.fit(model) trainer.test(model) @@ -508,7 +503,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel) + _assert_save_model_is_equal(model, tmpdir, trainer) @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -524,7 +519,7 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim) + _assert_save_model_is_equal(model, tmpdir, trainer) def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2): @@ -782,17 +777,22 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank) -def _assert_save_model_is_equal(model, tmpdir, trainer, cls=BoringModel): +def _assert_save_model_is_equal(model, tmpdir, trainer): + checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) + + single_ckpt_path = os.path.join(tmpdir, "single_model.pt") + # carry out the check only on rank 0 if trainer.global_rank == 0: - saved_model = cls.load_from_checkpoint(checkpoint_path) + convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path) + state_dict = torch.load(single_ckpt_path) if model.dtype == torch.half: - saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16 + model = model.float() # moved model to float32 for comparison with single fp32 saved weights model = model.cpu() # Assert model parameters are identical after loading - for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + for orig_param, trained_model_param in zip(model.parameters(), state_dict.values()): assert torch.equal(orig_param, trained_model_param) @@ -807,4 +807,4 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): ) trainer.fit(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelNoSchedulers) + _assert_save_model_is_equal(model, tmpdir, trainer) From 57355aafab8dc0dd245b9e299d733b3c13cf375f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 28 Jul 2021 17:53:48 +0100 Subject: [PATCH 43/60] Add guard --- tests/plugins/test_deepspeed_plugin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e506d67a3f8ac..b6ff0ed9e0949 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -6,7 +6,6 @@ import pytest import torch import torch.nn.functional as F -from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from torch import nn, Tensor from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -17,10 +16,14 @@ from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf +if _DEEPSPEED_AVAILABLE: + from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + class ModelParallelBoringModel(BoringModel): def __init__(self): From 3fc8f67a01cb46ce54949869352c51f585fc6e60 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 10:56:22 +0100 Subject: [PATCH 44/60] Fix test --- tests/plugins/test_deepspeed_plugin.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index b6ff0ed9e0949..f673c0d6b9d0c 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -781,22 +781,24 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat def _assert_save_model_is_equal(model, tmpdir, trainer): - checkpoint_path = os.path.join(tmpdir, "model.pt") + checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - - single_ckpt_path = os.path.join(tmpdir, "single_model.pt") + trainer.accelerator.barrier() # carry out the check only on rank 0 - if trainer.global_rank == 0: + if trainer.is_global_zero: + single_ckpt_path = os.path.join(tmpdir, "single_model.pt") convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path) state_dict = torch.load(single_ckpt_path) - if model.dtype == torch.half: - model = model.float() # moved model to float32 for comparison with single fp32 saved weights + model = model.cpu() # Assert model parameters are identical after loading - for orig_param, trained_model_param in zip(model.parameters(), state_dict.values()): - assert torch.equal(orig_param, trained_model_param) + for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()): + if model.dtype == torch.half: + # moved model to float32 for comparison with single fp32 saved weights + saved_model_param = saved_model_param.half() + assert torch.equal(orig_param, saved_model_param) @RunIf(min_gpus=2, deepspeed=True, special=True) From 6adb83dec57a21357c41673d651ed52628be9674 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 12:27:06 +0100 Subject: [PATCH 45/60] Fix --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index f673c0d6b9d0c..8b23a14dcdfe0 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -402,7 +402,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): """ model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=2)], gpus=2, fast_dev_run=True, precision=16 + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 ) trainer.fit(model) trainer.test(model) From 607aef2e01955e7b0e20e25955bbed4325242ecd Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 12:47:33 +0100 Subject: [PATCH 46/60] Add test --- .../plugins/training_type/deepspeed.py | 5 ++++ pytorch_lightning/trainer/trainer.py | 5 ++-- tests/plugins/test_deepspeed_plugin.py | 29 ++++++++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f5d11ee71e2b8..469c57e9e18cf 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -705,6 +705,11 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Di _, client_state = self.deepspeed_engine.load_checkpoint( checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting ) + if client_state is None: + raise MisconfigurationException( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " + "or a single checkpoint file with Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))." + ) return client_state @property diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9842e1c923204..af25c634bb6fe 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -915,10 +915,11 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - if self.accelerator.restore_checkpoint_after_pre_dispatch and self.state.fn == TrainerFn.FITTING: + if self.accelerator.restore_checkpoint_after_pre_dispatch: if self._ckpt_path: self._load_checkpoint_weights() - self._restore_checkpoint() + if self.state.fn == TrainerFn.FITTING: + self._restore_checkpoint() # restore optimizers, etc. self.checkpoint_connector.restore_training_state() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 8b23a14dcdfe0..d50d1fe64be89 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -396,7 +396,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: @RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu(tmpdir, deepspeed_config): +def test_deepspeed_multigpu(tmpdir): """ Test to ensure that DeepSpeed with multiple GPUs works. """ @@ -417,6 +417,33 @@ def test_deepspeed_fp32_works(tmpdir): trainer.fit(model) +@RunIf(min_gpus=1, deepspeed=True, special=False) +def test_deepspeed_multigpu_single_file(tmpdir): + """ + Test to ensure that DeepSpeed with multiple GPUs works, loading from a single file checkpoint. + """ + model = BoringModel() + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + trainer.save_checkpoint(checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16 + ) + with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."): + trainer.test(model, ckpt_path=checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(stage=3, load_full_weights=True)], + gpus=1, + fast_dev_run=True, + precision=16, + ) + trainer.test(model, ckpt_path=checkpoint_path) + + class ModelParallelClassificationModel(LightningModule): def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() From 0c30656589ca3dc9778b68a05b4aff64335327d3 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 29 Jul 2021 12:47:59 +0100 Subject: [PATCH 47/60] Update pytorch_lightning/plugins/training_type/deepspeed.py --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 469c57e9e18cf..70d262d24922d 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -698,7 +698,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Di checkpoint_path = self.broadcast(checkpoint_path) return super().load_checkpoint_file(checkpoint_path) - # Rely on deepspeed completely to load the checkpoint and necessary information + # Rely on deepspeed to load the checkpoint and necessary information from pytorch_lightning.trainer.states import TrainerFn is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING From c9849e0f203308a1e0c59b7f854066578bbf623a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 12:49:20 +0100 Subject: [PATCH 48/60] Fix description --- pytorch_lightning/plugins/training_type/deepspeed.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 70d262d24922d..f5b9218f36f90 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -253,10 +253,9 @@ def __init__( synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. - load_full_weights: Gathers weights across all processes before saving to disk - when using ZeRO Stage 3. This allows a single weight file to contain the entire model, - rather than individual sharded weight files. - Disable to save sharded states individually. + load_full_weights: True when loading a single checkpoint file containing the model state dict + when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards + per worker. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( From 0d3866c47b98b948c353082683dd90b2d67af615 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 12:56:25 +0100 Subject: [PATCH 49/60] Add test --- .../plugins/training_type/deepspeed.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f5b9218f36f90..e8b5f60ee8d6c 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -678,7 +678,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: checkpoint: The checkpoint state dictionary filepath: write-target file's path """ - if self.zero_stage_3 and self._multi_device: + if self.zero_stage_3 and self._multi_device and self.is_global_zero: warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory." diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index d50d1fe64be89..8cbceec23c502 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -417,7 +417,22 @@ def test_deepspeed_fp32_works(tmpdir): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=False) +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_stage_3_save_warning(tmpdir): + """ + Test to ensure that DeepSpeed Stage 3 gives a warning when saving. + """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + checkpoint_path = os.path.join(tmpdir, "model.pt") + with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."): + trainer.save_checkpoint(checkpoint_path) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_single_file(tmpdir): """ Test to ensure that DeepSpeed with multiple GPUs works, loading from a single file checkpoint. From fd7a16835a9620792fb1fea3a5cb6f794e0a6b2d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 13:31:59 +0100 Subject: [PATCH 50/60] Fix test --- tests/plugins/test_deepspeed_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 8cbceec23c502..fbb7d358210c6 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -586,7 +586,7 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu ) trainer.fit(model, datamodule=dm) - results = trainer.test(model, datamodule=dm) + results = trainer.test(datamodule=dm) assert results[0]["test_acc"] > 0.7 saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]["test_acc"] > 0.7 @@ -598,7 +598,7 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu model = ManualModelParallelClassificationModel() trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16) - results = trainer.test(model, datamodule=dm) + results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) assert results[0]["test_acc"] > 0.7 From 256b145765d23a6982e85b98f56a353498d20c3e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 14:43:14 +0100 Subject: [PATCH 51/60] Refactors --- .../plugins/training_type/deepspeed.py | 16 +++++---- tests/plugins/test_deepspeed_plugin.py | 35 +++++++------------ 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e8b5f60ee8d6c..84030f532c5d9 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -713,7 +713,14 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Di @property def lightning_restore_optimizer_and_schedulers(self) -> bool: - return False # managed by DeepSpeed + # managed by DeepSpeed + if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + rank_zero_warn( + "A single checkpoint file has been given. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must " + "provide a path to the originally saved DeepSpeed checkpoint." + ) + return False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` @@ -768,12 +775,7 @@ def load(module: torch.nn.Module, prefix=""): def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` - if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - rank_zero_warn( - "A single checkpoint file has been given. This means optimizer states and " - "scheduler states can not be restored. If you'd like to restore these states, you must " - "provide a path to the originally saved DeepSpeed checkpoint." - ) + pass def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: if self._original_accumulate_grad_batches is None: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index fbb7d358210c6..b6d4bb79efe2b 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -611,42 +611,33 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir) -@RunIf(min_gpus=1, deepspeed=True, special=True) -def test_deepspeed_multigpu_stage_3_full_weights_warns_resume_training(tmpdir): +@RunIf(min_gpus=1, deepspeed=True, special=False) +def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """ Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored. """ - model = ModelParallelClassificationModel() dm = ClassifDataModule() + model = BoringModel() + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + trainer.save_checkpoint(checkpoint_path) - ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - plugins=DeepSpeedPlugin(stage=3), + fast_dev_run=True, + plugins=DeepSpeedPlugin(stage=3, load_full_weights=True), gpus=1, precision=16, - callbacks=[ck], + resume_from_checkpoint=checkpoint_path, ) - trainer.fit(model, datamodule=dm) - model = ModelParallelClassificationModel() with pytest.warns( UserWarning, - match="A single checkpoint file was saved using ZeRO Stage 3. " - "This means optimizer states and scheduler states can not be restored", + match="A single checkpoint file has been given. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must " + "provide a path to the originally saved DeepSpeed checkpoint.", ): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - plugins="deepspeed_stage_3", - gpus=1, - precision=16, - resume_from_checkpoint=ck.best_model_path, - ) trainer.fit(model, datamodule=dm) From c18959537ac22fb8e4b54af5849bcffb4c8f5908 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 29 Jul 2021 15:15:40 +0100 Subject: [PATCH 52/60] add recursive --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0928410e01107..1c8f3ba24857e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -525,7 +525,7 @@ def save_function(self, value: Optional[Callable]) -> None: def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None: if trainer.should_rank_save_checkpoint and self._fs.exists(filepath): - self._fs.rm(filepath) + self._fs.rm(filepath, recursive=True) log.debug(f"Removed checkpoint: {filepath}") def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: From 0d2ec03752f538afa764c12f8095972b9c86c1ea Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 10:19:07 +0100 Subject: [PATCH 53/60] Fix dupe --- .../plugins/training_type/training_type_plugin.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0b9ae25825f0c..66d098223464d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -134,11 +134,6 @@ def lightning_module(self) -> "pl.LightningModule": """Returns the pure LightningModule without potential wrappers""" return unwrap_lightning_module(self._model) - @property - def lightning_restore_optimizer_and_schedulers(self) -> bool: - """Whether to allow Lightning to restore optimizers/schedulers.""" - return True - @property def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ From 5329c48f7211b81d541a0634a8ac4dc54b8c9c56 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 12:13:19 +0100 Subject: [PATCH 54/60] Force 0.4.3 --- .azure-pipelines/gpu-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 5a7bcff3bb69f..31555ea1db503 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -51,7 +51,7 @@ jobs: - bash: | python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install fairscale>=0.3.4 - pip install "deepspeed>=0.4.0, !=0.4.4" # FIXME: bug with 0.4.4 + pip install "deepspeed>=0.4.3, !=0.4.4" # FIXME: bug with 0.4.4 pip install . --requirement requirements/devel.txt pip list displayName: 'Install dependencies' From 95d1287020e322c00c55c4cff493b918964ef1ec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 14:07:16 +0100 Subject: [PATCH 55/60] Address reviews --- CHANGELOG.md | 8 ++++---- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- .../plugins/training_type/training_type_plugin.py | 8 ++++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a9fed5b154d9..fec1997f721c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,7 +88,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613)) -- +- Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( + [#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397), + [#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644), + [#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627)) ## [1.4.0] - 2021-07-27 @@ -168,9 +171,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for providing callables to the Lightning CLI instead of types ([#8400](https://github.com/PyTorchLightning/pytorch-lightning/pull/8400)) -- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) - - ### Changed - Decoupled device parsing logic from Accelerator connector to Trainer ([#8180](https://github.com/PyTorchLightning/pytorch-lightning/pull/8180)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 84030f532c5d9..78535eb400ece 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -707,7 +707,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Di if client_state is None: raise MisconfigurationException( "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " - "or a single checkpoint file with Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))." + "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." ) return client_state diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 66d098223464d..04a529fab0f9c 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -231,7 +231,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: Override to delay setting optimizers and schedulers till after dispatch. This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. - Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + + Returns: + If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False @@ -240,7 +242,9 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: """ Override to delay restoring from checkpoint till after predispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. - Returns: If true, restore checkpoint after pre_dispatch. + + Returns: + If true, restore checkpoint after pre_dispatch. """ return False From 88ab3068ef6572bf815fae4f1dbd88ec4349d0b5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 14:11:31 +0100 Subject: [PATCH 56/60] Add todo --- pytorch_lightning/plugins/training_type/deepspeed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 78535eb400ece..f633d792fb5e7 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -679,6 +679,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: filepath: write-target file's path """ if self.zero_stage_3 and self._multi_device and self.is_global_zero: + # todo (sean): Add link to docs once docs are merged. warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory." From a15cd8d04a4353a996c3acda31c3814e52491180 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 2 Aug 2021 14:12:44 +0100 Subject: [PATCH 57/60] Update pytorch_lightning/plugins/training_type/training_type_plugin.py --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 04a529fab0f9c..a8b444de0bd27 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -240,7 +240,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: @property def restore_checkpoint_after_pre_dispatch(self) -> bool: """ - Override to delay restoring from checkpoint till after predispatch. + Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. Returns: From 9365cc0260d4766818b7322c154f211950bac956 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 2 Aug 2021 16:11:35 +0100 Subject: [PATCH 58/60] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/training_type/deepspeed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f633d792fb5e7..0cfa5ad8aabe3 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -682,7 +682,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: # todo (sean): Add link to docs once docs are merged. warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " - "each worker will save a shard of the checkpoint within a directory." + "each worker will save a shard of the checkpoint within a directory. " "If a single file is required after training, see for instructions." ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes @@ -735,6 +735,7 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: we gather parameters that may be sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced across processes. + Args: ckpt: The ckpt file. """ From 5f994c4d4af3a8a044081bfbc9ef7af46842a71b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 15:25:24 +0100 Subject: [PATCH 59/60] Add asserts for properties, address reviews --- .../plugins/training_type/deepspeed.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 29 ++++--------------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 0cfa5ad8aabe3..ee023e4d40b3b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -735,7 +735,7 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: we gather parameters that may be sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced across processes. - + Args: ckpt: The ckpt file. """ diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index eeb3f6b1e93f7..3a1f93cae7bb4 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -446,6 +446,9 @@ def test_deepspeed_multigpu_single_file(tmpdir): trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16 ) + plugin = trainer.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert not plugin.load_full_weights with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."): trainer.test(model, ckpt_path=checkpoint_path) @@ -456,6 +459,9 @@ def test_deepspeed_multigpu_single_file(tmpdir): fast_dev_run=True, precision=16, ) + plugin = trainer.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.load_full_weights trainer.test(model, ckpt_path=checkpoint_path) @@ -641,29 +647,6 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): trainer.fit(model, datamodule=dm) -@RunIf(min_gpus=1, deepspeed=True, special=True) -def test_deepspeed_multigpu_stage_3_save_warning(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we recieve a warning that we're saving sharded checkpoints. - """ - initial_model = ModelParallelClassificationModel() - dm = ClassifDataModule() - - ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) - initial_trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - plugins=DeepSpeedPlugin(stage=3), - gpus=1, - precision=16, - callbacks=[ck], - ) - initial_trainer.fit(initial_model, datamodule=dm) - - @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): """ From cdf8c25c8d28f770d3f371bded0b5247553e5bed Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 2 Aug 2021 16:54:58 +0100 Subject: [PATCH 60/60] Fix description --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 3a1f93cae7bb4..f0c1d7d49b586 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -435,7 +435,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_single_file(tmpdir): """ - Test to ensure that DeepSpeed with multiple GPUs works, loading from a single file checkpoint. + Test to ensure that DeepSpeed loads from a single file checkpoint. """ model = BoringModel() checkpoint_path = os.path.join(tmpdir, "model.pt")