Skip to content

Commit

Permalink
Remove deprecated prepare_data_per_node in Trainer (#12536)
Browse files Browse the repository at this point in the history
* remove deprecated prepare_data_per_node from Trainer
* remove deprecated test for prepare_data_per_node
* remove doc for deprecated prepare_data_per_node
* remove inconsistency test
* remove deprecated prepare_data_per_node
* remove doc mentioning Trainer(prepare_data_per_node)
* update changelog
* remove unused code

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
  • Loading branch information
4 people committed Apr 2, 2022
1 parent 63f97e2 commit 56a3485
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 93 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514))


-
- Removed the deprecated `prepare_data_per_node` argument from the `Trainer` constructor ([#12536](https://github.com/PyTorchLightning/pytorch-lightning/pull/12536))


-
Expand Down
33 changes: 0 additions & 33 deletions docs/source/common/trainer.rst
Expand Up @@ -1171,39 +1171,6 @@ To define your own behavior, subclass the relevant class and pass it in. Here's
trainer = Trainer(plugins=[MyCluster()], ...)
prepare_data_per_node
^^^^^^^^^^^^^^^^^^^^^
.. warning:: ``prepare_data_per_node`` has been deprecated in v1.5 and will be removed in v1.7.
Please set its value inside ``LightningDataModule`` and/or ``LightningModule`` directly described
in the following code:

.. testcode::

class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/prepare_data_per_node.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/prepare_data_per_node.mp4"></video>

|
If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.

.. testcode::

# default
Trainer(prepare_data_per_node=True)

# use only NODE_RANK=0, LOCAL_RANK=0
Trainer(prepare_data_per_node=False)

precision
^^^^^^^^^

Expand Down
3 changes: 1 addition & 2 deletions docs/source/starter/core_guide.rst
Expand Up @@ -596,8 +596,7 @@ will cause all sorts of issues.
To solve this problem, make sure your download code is in the ``prepare_data`` method in the DataModule.
In this method we do all the preparation we need to do once (instead of on every GPU).

``prepare_data`` can be called in two ways, once per node or only on the root node
(``Trainer(prepare_data_per_node=False)``).
``prepare_data`` can be called in two ways, once per node or only on the root node.

.. code-block:: python
Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -361,21 +361,27 @@ def prepare_data(self):
self.split = data_split
self.some_state = some_other_state()
In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)):
In a distributed environment, ``prepare_data`` can be called in two ways
(using :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`)
1. Once per node. This is the default and is only called on LOCAL_RANK=0.
2. Once in total. Only called on GLOBAL_RANK=0.
See :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`.
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True
# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = False
This is called before requesting the dataloaders:
Expand All @@ -387,6 +393,7 @@ def prepare_data(self):
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
"""

def setup(self, stage: Optional[str] = None) -> None:
Expand Down
27 changes: 1 addition & 26 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache

Expand Down Expand Up @@ -73,18 +73,9 @@ def on_trainer_init(
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
prepare_data_per_node: Optional[bool] = None,
) -> None:
self.trainer.datamodule = None

if prepare_data_per_node is not None:
rank_zero_deprecation(
"Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0 and will be removed in"
" v1.7.0. Please set `prepare_data_per_node` in `LightningDataModule` and/or `LightningModule`"
" directly instead."
)
self.trainer.prepare_data_per_node = prepare_data_per_node

if not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
Expand Down Expand Up @@ -112,28 +103,12 @@ def prepare_data(self) -> None:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to DataModule property."
)
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
self.trainer.datamodule.prepare_data()
# handle lightning module prepare data:
# check for prepare_data_per_node before calling lightning_module.prepare_data
if lightning_module is not None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node
if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to LightningModule property."
)
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
self.trainer._call_lightning_module_hook("prepare_data")
self.trainer._is_data_prepared = True
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -181,7 +181,6 @@ def __init__(
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: Optional[str] = None,
Expand Down Expand Up @@ -314,14 +313,6 @@ def __init__(
log_every_n_steps: How often to log within steps.
Default: ``50``.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
.. deprecated:: v1.5
Deprecated in v1.5.0 and will be removed in v1.7.0
Please set ``prepare_data_per_node`` in ``LightningDataModule`` and/or
``LightningModule`` directly instead.
process_position: Orders the progress bar when running multiple models on same machine.
.. deprecated:: v1.5
Expand Down Expand Up @@ -542,7 +533,6 @@ def __init__(
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
prepare_data_per_node,
)

if terminate_on_nan is not None:
Expand Down
11 changes: 0 additions & 11 deletions tests/core/test_datamodules.py
Expand Up @@ -487,14 +487,3 @@ class BoringDataModule2(LightningDataModule):
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)


def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
with pytest.deprecated_call(match="prepare_data_per_node` with the trainer flag is deprecated"):
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer._data_connector.prepare_data()
5 changes: 0 additions & 5 deletions tests/deprecated_api/test_remove_1-7.py
Expand Up @@ -125,11 +125,6 @@ def get_progress_bar_dict(self):
_ = trainer.progress_bar_dict


def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"):
_ = Trainer(prepare_data_per_node=False)


@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
Expand Down

0 comments on commit 56a3485

Please sign in to comment.