Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated prepare_data_per_node in Trainer #12536

Merged
merged 17 commits into from Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514))


-
- Remove deprecated `prepare_data_per_node` flag from `Trainer` ([#12536](https://github.com/PyTorchLightning/pytorch-lightning/pull/12536))
Borda marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved


-
Expand Down
33 changes: 0 additions & 33 deletions docs/source/common/trainer.rst
Expand Up @@ -1171,39 +1171,6 @@ To define your own behavior, subclass the relevant class and pass it in. Here's

trainer = Trainer(plugins=[MyCluster()], ...)


prepare_data_per_node
^^^^^^^^^^^^^^^^^^^^^
.. warning:: ``prepare_data_per_node`` has been deprecated in v1.5 and will be removed in v1.7.
Please set its value inside ``LightningDataModule`` and/or ``LightningModule`` directly described
in the following code:

.. testcode::

class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/prepare_data_per_node.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/prepare_data_per_node.mp4"></video>

|

If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.

.. testcode::

# default
Trainer(prepare_data_per_node=True)

# use only NODE_RANK=0, LOCAL_RANK=0
Trainer(prepare_data_per_node=False)

precision
^^^^^^^^^

Expand Down
3 changes: 1 addition & 2 deletions docs/source/starter/core_guide.rst
Expand Up @@ -596,8 +596,7 @@ will cause all sorts of issues.
To solve this problem, make sure your download code is in the ``prepare_data`` method in the DataModule.
In this method we do all the preparation we need to do once (instead of on every GPU).

``prepare_data`` can be called in two ways, once per node or only on the root node
(``Trainer(prepare_data_per_node=False)``).
``prepare_data`` can be called in two ways, once per node or only on the root node.

.. code-block:: python

Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -361,21 +361,27 @@ def prepare_data(self):
self.split = data_split
self.some_state = some_other_state()

In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)):
In DDP ``prepare_data`` can be called in two ways
Borda marked this conversation as resolved.
Show resolved Hide resolved
(using :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`)

1. Once per node. This is the default and is only called on LOCAL_RANK=0.
2. Once in total. Only called on GLOBAL_RANK=0.

See :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`.

Example::

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = False

This is called before requesting the dataloaders:

Expand All @@ -387,6 +393,7 @@ def prepare_data(self):
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
"""

def setup(self, stage: Optional[str] = None) -> None:
Expand Down
27 changes: 1 addition & 26 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache

Expand Down Expand Up @@ -73,18 +73,9 @@ def on_trainer_init(
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
prepare_data_per_node: Optional[bool] = None,
) -> None:
self.trainer.datamodule = None

if prepare_data_per_node is not None:
rank_zero_deprecation(
"Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0 and will be removed in"
" v1.7.0. Please set `prepare_data_per_node` in `LightningDataModule` and/or `LightningModule`"
" directly instead."
)
self.trainer.prepare_data_per_node = prepare_data_per_node

if not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
Expand Down Expand Up @@ -112,28 +103,12 @@ def prepare_data(self) -> None:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to DataModule property."
)
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
self.trainer.datamodule.prepare_data()
# handle lightning module prepare data:
# check for prepare_data_per_node before calling lightning_module.prepare_data
if lightning_module is not None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node
if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to LightningModule property."
)
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
self.trainer._call_lightning_module_hook("prepare_data")
self.trainer._is_data_prepared = True
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -181,7 +181,6 @@ def __init__(
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: Optional[str] = None,
Expand Down Expand Up @@ -314,14 +313,6 @@ def __init__(
log_every_n_steps: How often to log within steps.
Default: ``50``.

prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

.. deprecated:: v1.5
Deprecated in v1.5.0 and will be removed in v1.7.0
Please set ``prepare_data_per_node`` in ``LightningDataModule`` and/or
``LightningModule`` directly instead.

process_position: Orders the progress bar when running multiple models on same machine.

.. deprecated:: v1.5
Expand Down Expand Up @@ -542,7 +533,6 @@ def __init__(
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
prepare_data_per_node,
)

if terminate_on_nan is not None:
Expand Down
11 changes: 0 additions & 11 deletions tests/core/test_datamodules.py
Expand Up @@ -487,14 +487,3 @@ class BoringDataModule2(LightningDataModule):
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)


def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
with pytest.deprecated_call(match="prepare_data_per_node` with the trainer flag is deprecated"):
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer._data_connector.prepare_data()
5 changes: 0 additions & 5 deletions tests/deprecated_api/test_remove_1-7.py
Expand Up @@ -125,11 +125,6 @@ def get_progress_bar_dict(self):
_ = trainer.progress_bar_dict


def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"):
_ = Trainer(prepare_data_per_node=False)


@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
Expand Down