Skip to content

Commit

Permalink
Standardize positional datamodule and argument names (#7431)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people committed Jun 15, 2021
1 parent 0974d66 commit 560b197
Show file tree
Hide file tree
Showing 15 changed files with 184 additions and 128 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))


- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))


Expand Down Expand Up @@ -174,6 +177,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated


- Standardized the dataloaders arguments of `trainer.{fit,valdiate,test,tune}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))


- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/))


Expand Down
17 changes: 0 additions & 17 deletions docs/source/advanced/multiple_loaders.rst
Expand Up @@ -91,23 +91,6 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra
Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
be returned.

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(16), batch_size=2)

return {'a': loader_a, 'b': loader_b}

def training_step(self, batch, batch_idx):
# access a dictionnary with a batch from each dataloader
batch_a = batch["a"]
batch_b = batch["b"]


.. testcode::

class LitModel(LightningModule):
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/lightning_module.rst
Expand Up @@ -441,12 +441,12 @@ There are two ways to call `test()`:
trainer.fit(model)
# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)
trainer.test(dataloaders=test_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)
trainer.test(model, dataloaders=test_dataloader)
----------

Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/test_set.rst
Expand Up @@ -80,10 +80,10 @@ is not available at the time your model was declared.
.. code-block:: python
# setup your data loader
test = DataLoader(...)
test_dataloader = DataLoader(...)
# test (pass in the loader)
trainer.test(test_dataloaders=test)
trainer.test(dataloaders=test_dataloader)
You can either pass in a single dataloader or a list of them. This optional named
parameter can be used in conjunction with any of the above use cases. Additionally,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/trainer.rst
Expand Up @@ -159,7 +159,7 @@ or after it has already been trained.

.. code-block:: python
trainer.validate(val_dataloaders=val_dataloaders)
trainer.validate(dataloaders=val_dataloaders)
------------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/datamodules.rst
Expand Up @@ -53,7 +53,7 @@ Datamodules are for you if you ever asked the questions:

What is a DataModule
--------------------
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the
matching transforms and data processing/downloads steps required.

Here's a simple PyTorch example:
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
"""Various hooks to be used in the Lightning code."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS


class ModelHooks:
Expand Down Expand Up @@ -428,14 +427,13 @@ def teardown(self, stage: Optional[str] = None) -> None:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""
Implement one or more PyTorch DataLoaders for training.
Return:
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
this :ref:`page <multiple-training-dataloaders>`
A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
Expand Down Expand Up @@ -503,7 +501,7 @@ def train_dataloader(self):
"""
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def test_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for testing.
Expand Down Expand Up @@ -533,7 +531,7 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples.
Example::
Expand Down Expand Up @@ -563,7 +561,7 @@ def test_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def val_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for validation.
Expand All @@ -584,7 +582,7 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
Examples::
Expand Down Expand Up @@ -614,7 +612,7 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def predict_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for prediction.
Expand All @@ -632,7 +630,7 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
Note:
In the case where you return multiple prediction dataloaders, the :meth:`predict`
Expand Down
36 changes: 17 additions & 19 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

from torch.utils.data import DataLoader
from typing import Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


class DataConnector(object):
class DataConnector:

def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
Expand Down Expand Up @@ -71,16 +70,16 @@ def can_prepare_data(self):
def attach_data(
self,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(
model,
train_dataloader=train_dataloader,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
Expand All @@ -92,15 +91,15 @@ def attach_data(
def attach_dataloaders(
self,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)
if train_dataloaders is not None:
model.train_dataloader = _PatchDataLoader(train_dataloaders)

if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
Expand Down Expand Up @@ -140,23 +139,22 @@ def attach_datamodule(
model.data_pipeline = datamodule.data_pipeline


class _PatchDataLoader(object):
class _PatchDataLoader:
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
self.dataloader = dataloader

# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader

0 comments on commit 560b197

Please sign in to comment.