Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize positional datamodule and argument names #7431

Merged
merged 27 commits into from Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d55043f
Move trainer functions
carmocca Apr 30, 2021
4df5804
Progress
carmocca May 3, 2021
368e05d
Debugging
carmocca May 3, 2021
afebe88
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 3, 2021
de3e39d
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 7, 2021
1b8b28b
Progress
carmocca May 7, 2021
5a79985
Fixes
carmocca May 7, 2021
cddb652
Update CHANGELOG
carmocca May 7, 2021
ea564b8
Update tests
carmocca May 7, 2021
1de1792
Update docs
carmocca May 7, 2021
2b577da
Fix tests
carmocca May 7, 2021
31ccb44
Merge branch 'master' into feature/datamodule-as-arg
Borda May 7, 2021
2ff11bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2021
5859746
Merge branch 'master' into feature/datamodule-as-arg
awaelchli May 7, 2021
0e2559c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2021
a26a91a
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 10, 2021
d17f277
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2021
62c8fa5
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 13, 2021
bfabb36
Resolve TODO
carmocca May 13, 2021
f59e2f7
Resolve TODO
carmocca May 13, 2021
924cd33
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 14, 2021
866b097
Merge branch 'master' into feature/datamodule-as-arg
carmocca May 24, 2021
63a0364
Merge branch 'master' into feature/datamodule-as-arg
carmocca Jun 11, 2021
48521fa
Merge branch 'master' into feature/datamodule-as-arg
awaelchli Jun 14, 2021
5bec4de
Merge branch 'master' into feature/datamodule-as-arg
carmocca Jun 14, 2021
5d8aaca
Merge branch 'master' into feature/datamodule-as-arg
carmocca Jun 15, 2021
8950e11
Remove CombinedDataLoader
carmocca Jun 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))


- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))


Expand Down Expand Up @@ -160,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated


- Standardized the dataloaders arguments of `trainer.{fit,valdiate,test,tune}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))


- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/))


Expand Down
17 changes: 0 additions & 17 deletions docs/source/advanced/multiple_loaders.rst
Expand Up @@ -91,23 +91,6 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra
Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
be returned.

.. testcode::

class LitModel(LightningModule):
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(16), batch_size=2)

return {'a': loader_a, 'b': loader_b}

def training_step(self, batch, batch_idx):
# access a dictionnary with a batch from each dataloader
batch_a = batch["a"]
batch_b = batch["b"]


.. testcode::

class LitModel(LightningModule):
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/lightning_module.rst
Expand Up @@ -441,12 +441,12 @@ There are two ways to call `test()`:
trainer.fit(model)

# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)
trainer.test(dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)
trainer.test(model, dataloaders=test_dataloader)

----------

Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/test_set.rst
Expand Up @@ -80,10 +80,10 @@ is not available at the time your model was declared.
.. code-block:: python

# setup your data loader
test = DataLoader(...)
test_dataloader = DataLoader(...)

# test (pass in the loader)
trainer.test(test_dataloaders=test)
trainer.test(dataloaders=test_dataloader)

You can either pass in a single dataloader or a list of them. This optional named
parameter can be used in conjunction with any of the above use cases. Additionally,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/trainer.rst
Expand Up @@ -159,7 +159,7 @@ or after it has already been trained.

.. code-block:: python

trainer.validate(val_dataloaders=val_dataloaders)
trainer.validate(dataloaders=val_dataloaders)

------------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/datamodules.rst
Expand Up @@ -53,7 +53,7 @@ Datamodules are for you if you ever asked the questions:

What is a DataModule
--------------------
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the
matching transforms and data processing/downloads steps required.

Here's a simple PyTorch example:
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
"""Various hooks to be used in the Lightning code."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS


class ModelHooks:
Expand Down Expand Up @@ -428,14 +427,13 @@ def teardown(self, stage: Optional[str] = None) -> None:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""
Implement one or more PyTorch DataLoaders for training.

Return:
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
this :ref:`page <multiple-training-dataloaders>`
A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`.

The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
Expand Down Expand Up @@ -503,7 +501,7 @@ def train_dataloader(self):
"""
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def test_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for testing.

Expand Down Expand Up @@ -533,7 +531,7 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.

Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples.

Example::

Expand Down Expand Up @@ -563,7 +561,7 @@ def test_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def val_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for validation.

Expand All @@ -584,7 +582,7 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.

Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

Examples::

Expand Down Expand Up @@ -614,7 +612,7 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
def predict_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for prediction.

Expand All @@ -632,7 +630,7 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
There is no need to set it yourself.

Return:
Single or multiple PyTorch DataLoaders.
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.

Note:
In the case where you return multiple prediction dataloaders, the :meth:`predict`
Expand Down
36 changes: 17 additions & 19 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

from torch.utils.data import DataLoader
from typing import Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


class DataConnector(object):
class DataConnector:

def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
Expand Down Expand Up @@ -71,16 +70,16 @@ def can_prepare_data(self):
def attach_data(
self,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(
model,
train_dataloader=train_dataloader,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
Expand All @@ -92,15 +91,15 @@ def attach_data(
def attach_dataloaders(
self,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)
if train_dataloaders is not None:
model.train_dataloader = _PatchDataLoader(train_dataloaders)

if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
Expand Down Expand Up @@ -140,23 +139,22 @@ def attach_datamodule(
model.data_pipeline = datamodule.data_pipeline


class _PatchDataLoader(object):
class _PatchDataLoader:
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
self.dataloader = dataloader

# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader