Skip to content

Commit

Permalink
Refactor loop.setup_data with utility functions (#16918)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 9, 2023
1 parent aa7f252 commit bdd9b12
Show file tree
Hide file tree
Showing 17 changed files with 316 additions and 282 deletions.
11 changes: 6 additions & 5 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,23 +416,24 @@ def _replace_value_in_saved_args(
return False, args, kwargs


def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
def _set_sampler_epoch(dataloader: object, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
objects = set()
# cannot use a set because samplers might be unhashable: use a dict based on the id to drop duplicates
objects: Dict[int, Any] = {}
# check dataloader.sampler
if (sampler := getattr(dataloader, "sampler", None)) is not None:
objects.add(sampler)
objects[id(sampler)] = sampler
# check dataloader.batch_sampler.sampler
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
sampler := getattr(batch_sampler, "sampler", None)
) is not None:
objects.add(sampler)
for obj in objects:
objects[id(sampler)] = sampler
for obj in objects.values():
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)
46 changes: 38 additions & 8 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
from lightning.pytorch.loops.progress import BatchProgress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.data_connector import (
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
_resolve_overfit_batches,
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -115,8 +122,9 @@ def run(self) -> List[_OUT_DICT]:

def setup_data(self) -> None:
trainer = self.trainer
trainer_fn = trainer.state.fn

if self._combined_loader is not None and trainer.state.fn == "fit" and not self._should_reload_val_dl:
if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl:
return

source = self._data_source
Expand All @@ -128,20 +136,42 @@ def setup_data(self) -> None:

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
# it should not reload again if it has already reloaded during sanity_check
if trainer.state.fn == "fit" and (
if trainer_fn == "fit" and (
(trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch())
or not trainer.sanity_checking
):
self._last_val_dl_reload_epoch = trainer.current_epoch

stage = trainer.state.stage
assert stage is not None
self._max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)

if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop
for dl in combined_loader.flattened:
# some users want validation shuffling based on the training progress
_set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed)
dataloaders = _request_dataloader(source)
trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()")

if not isinstance(dataloaders, CombinedLoader):
combined_loader = CombinedLoader(dataloaders, "sequential")
else:
combined_loader = dataloaders

if trainer_fn == "fit" and trainer.overfit_batches > 0:
_resolve_overfit_batches(combined_loader, stage)

allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

dataloaders = []
self._max_batches = []
for dl in combined_loader.flattened:
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
limit_batches = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
num_batches = _parse_num_batches(stage, length, limit_batches)
self._max_batches.append(num_batches)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

# this depends on the data used, so reset it too
Expand Down
74 changes: 25 additions & 49 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
from typing import Optional, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _auto_add_worker_init_fn, _set_sampler_epoch
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.data_connector import (
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
_resolve_overfit_batches,
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.combined_loader import CombinedLoader
Expand Down Expand Up @@ -197,7 +203,7 @@ def run(self) -> None:
self._restarting = False
self.on_run_end()

def setup_data(self, shuffle: bool = True) -> None:
def setup_data(self) -> None:
if self._combined_loader is not None and not self._should_reload_train_dl:
return

Expand All @@ -209,52 +215,36 @@ def setup_data(self, shuffle: bool = True) -> None:

log.debug(f"{self.__class__.__name__}: resetting train dataloader")

train_dataloader = trainer._data_connector._request_dataloader()
train_dataloader = _request_dataloader(source)
trainer.strategy.barrier("train_dataloader()")

if not isinstance(train_dataloader, CombinedLoader):
combined_loader = CombinedLoader(train_dataloader, "max_size_cycle")
else:
combined_loader = train_dataloader

if trainer.overfit_batches > 0:
trainer._data_connector._resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

dataloaders = []
for i, dl in enumerate(combined_loader.flattened):
# automatically add samplers
dl = trainer._data_connector._prepare_dataloader(dl, shuffle=shuffle, mode=RunningStage.TRAINING)
# let the strategy inject its logic
dl = trainer.strategy.process_dataloader(dl)
# check the workers
trainer._data_connector._worker_check(dl, "train_dataloader")
# add worker_init_fn for correct seeding in worker processes
_auto_add_worker_init_fn(dl, trainer.global_rank)
dataloaders.append(dl)
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

dataloaders = [_process_dataloader(trainer, dl) for dl in combined_loader.flattened]
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

module = pl_module or trainer.datamodule
orig_train_batches = self.max_batches = (
len(self._combined_loader)
if has_len_all_ranks(self._combined_loader, trainer.strategy, module)
else float("inf")
)
if orig_train_batches == 0:
allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)
self.max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf")
if self.max_batches == 0:
return

stage = RunningStage.TRAINING
self.max_batches = _parse_num_batches(stage, self.max_batches, trainer.limit_train_batches)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = trainer.current_epoch

if isinstance(trainer.limit_train_batches, int):
self.max_batches = min(orig_train_batches, trainer.limit_train_batches)
elif self.max_batches != float("inf"):
self.max_batches = int(orig_train_batches * trainer.limit_train_batches)
elif trainer.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
"An int specifies `num_training_batches` to use."
)

if isinstance(trainer.val_check_interval, int):
trainer.val_check_batch = trainer.val_check_interval
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
Expand All @@ -265,7 +255,7 @@ def setup_data(self, shuffle: bool = True) -> None:
" If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
else:
if not has_len_all_ranks(self._combined_loader, trainer.strategy, module):
if not has_len_all_ranks_:
if trainer.val_check_interval == 1.0:
trainer.val_check_batch = float("inf")
else:
Expand All @@ -286,20 +276,6 @@ def setup_data(self, shuffle: bool = True) -> None:
category=PossibleUserWarning,
)

if (
self.max_batches == 0
and trainer.limit_train_batches > 0.0
and isinstance(trainer.limit_train_batches, float)
and orig_train_batches != float("inf")
):
min_percentage = 1.0 / orig_train_batches
raise MisconfigurationException(
f"You requested to check {trainer.limit_train_batches} of the `train_dataloader` but"
f" {trainer.limit_train_batches} * {orig_train_batches} < 1. Please increase the"
f" `limit_train_batches` argument. Try at least"
f" `limit_train_batches={min_percentage}`"
)

def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
Expand Down
38 changes: 30 additions & 8 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import lightning.pytorch as pl
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.callbacks import BasePredictionWriter
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
Expand All @@ -28,9 +27,15 @@
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.data_connector import (
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
)
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT

Expand Down Expand Up @@ -114,17 +119,34 @@ def run(self) -> Optional[_PREDICT_OUTPUT]:
def setup_data(self) -> None:
trainer = self.trainer
source = self._data_source
pl_module = trainer.lightning_module
# a dfault `predict_step` exists in the LightningModule, so no need to check if it's overridden
if not source.is_defined() or trainer.limit_predict_batches == 0:
return

self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
dataloaders = _request_dataloader(source)
trainer.strategy.barrier("predict_dataloader()")

if not isinstance(dataloaders, CombinedLoader):
combined_loader = CombinedLoader(dataloaders, "sequential")
else:
combined_loader = dataloaders

allow_zero_length = trainer.lightning_module.allow_zero_length_dataloader_with_multiple_devices
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

stage = RunningStage.PREDICTING
dataloaders = []
self.max_batches = []
for dl in combined_loader.flattened:
# some users want prediction shuffling based on the training progress
_set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
num_batches = _parse_num_batches(stage, length, trainer.limit_predict_batches)
self.max_batches.append(num_batches)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

def reset(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import contextlib
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -395,7 +395,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert isinstance(self.model, PredictStep)
return self.model.predict_step(*args, **kwargs)

def process_dataloader(self, dataloader: Iterable) -> Iterable:
def process_dataloader(self, dataloader: object) -> object:
"""Wraps the dataloader if necessary.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import io
import os
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -146,7 +146,7 @@ def is_distributed(self) -> bool:

return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader":
def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
XLAStrategy._validate_dataloader(dataloader)
from torch_xla.distributed.parallel_loader import MpDeviceLoader

Expand Down
Loading

0 comments on commit bdd9b12

Please sign in to comment.