Skip to content

Commit

Permalink
[Feat] Add utilities for CombinedLoader state dict and dataloader sta…
Browse files Browse the repository at this point in the history
…te dict 1/n (#8364)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <justus.schock@posteo.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
6 people committed Jul 19, 2021
1 parent 257fabd commit 374fae5
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 87 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -95,6 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Added `state_dict` and `load_state_dict` utilities for `CombinedLoader` + utilities for dataloader ([#8364](https://github.com/PyTorchLightning/pytorch-lightning/pull/8364))


- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
Expand Down
28 changes: 18 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -30,9 +30,11 @@
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _sampler_metadata_collate
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function

Expand Down Expand Up @@ -259,6 +261,10 @@ def reset_train_dataloader(self, model: 'pl.LightningModule') -> None:
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_enabled():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)

Expand Down Expand Up @@ -460,9 +466,6 @@ def reset_train_val_dataloaders(self, model) -> None:
def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
dataloader_fx: The bound dataloader getter
Returns:
The dataloader
"""
Expand All @@ -474,11 +477,16 @@ def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoa

def _flatten_dl_only(self, dataloaders):
# handles user error when they return:
# return dl1, dl2 vs return (dl1, dl2)
if isinstance(dataloaders, tuple):
all_dls = [isinstance(x, Iterable) for x in dataloaders]
all_dls = all(all_dls)
if all_dls:
dataloaders = list(dataloaders)

# `return dl1, dl2` vs `return (dl1, dl2)`
if isinstance(dataloaders, tuple) and all(isinstance(x, Iterable) for x in dataloaders):
return list(dataloaders)
return dataloaders

@staticmethod
def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
"""
Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled.
"""
dataloader.collate_fn = partial(
_sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)
144 changes: 121 additions & 23 deletions pytorch_lightning/trainer/supporters.py
Expand Up @@ -14,18 +14,25 @@

import os
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Callable, Generator, Optional, Tuple, Union
from functools import partial
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_cycle_to_next_worker_and_reset,
_find_current_worker,
CaptureIterableDataset,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled


class TensorRunningAccum(object):
Expand Down Expand Up @@ -172,12 +179,10 @@ class CycleIterator(object):

def __init__(self, loader: Any, length: Optional[int] = None):
"""
Args:
loader: the loader to restart for cyclic (and optionally infinite) sampling
length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration
if None: infinite
"""
if length is None:
length = float('inf')
Expand All @@ -193,7 +198,6 @@ def __iter__(self) -> Any:
Returns:
CycleIterator: self
"""
self.counter = 0
self._loader_iter = iter(self.loader)
Expand All @@ -209,7 +213,6 @@ def __next__(self) -> Any:
Raises:
StopIteration: if more then :attr:`length` batches have been returned
"""
# Note: if self.length is `inf`, then the iterator will never stop
if self.counter >= self.__len__():
Expand Down Expand Up @@ -237,13 +240,11 @@ class CombinedDataset(object):

def __init__(self, datasets: Union[Sequence, Mapping], mode: str = 'min_size'):
"""
Args:
datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset,
Iterable or even None.
mode: whether to use the minimum number of batches in all samples or the maximum
number of batches in all samples.
"""
self.datasets = datasets
if mode not in self.COMPUTE_FUNCS.keys():
Expand Down Expand Up @@ -273,7 +274,6 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
Returns:
length: the length of `CombinedDataset`
"""
if mode not in CombinedDataset.COMPUTE_FUNCS.keys():
raise MisconfigurationException(f"Invalid Mode: {mode}")
Expand Down Expand Up @@ -319,10 +319,14 @@ def __len__(self) -> int:
return self._calc_num_data(self.datasets, self.mode)


class DataLoaderDict(Dict):
# behaves exactly like a dict, this is used to simplify apply_to_collection.
pass


class CombinedLoader(object):
"""
Combines different dataloaders and allows sampling in parallel.
Supported modes are 'min_size', which raises StopIteration after the shortest loader
(the one with the lowest number of batches) is done, and 'max_size_cycle` which raises
StopIteration after the longest loader (the one with most batches) is done, while cycling
Expand All @@ -342,18 +346,15 @@ class CombinedLoader(object):
... print(item)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
"""
SUPPORTED_MODES = ('min_size', 'max_size_cycle')

def __init__(self, loaders: Any, mode: str = 'min_size'):
"""
Args:
loaders: the loaders to sample from. Can be all kind of collection
mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
"""
if mode not in self.SUPPORTED_MODES:
raise MisconfigurationException(f"Invalid Mode: {mode}")
Expand All @@ -371,6 +372,84 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
if self.mode == 'max_size_cycle':
self._wrap_loaders_max_size_cycle()

self._loaders_iter_state_dict = None
self._iterator = None # assigned in __iter__

@staticmethod
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
# find next worker if multiple workers were used
state = _find_current_worker(iterator)
if isinstance(dataloader.dataset, CaptureIterableDataset):
# the sampler state dict are extracted in `CombinedLoaderIterator`
if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
state.update(iterator._sampler_state_dict[0])
else:
# fetch directly from fast forward sampler
state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
return DataLoaderDict(state)

def state_dict(self, num_batches_processed: int) -> Dict:
"""
The state dict includes all states from wrapped dataloaders and their samplers through the
``CaptureIterableDataset`` and fast-forward samplers.
Args:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
"""
if not _fault_tolerant_enabled():
return DataLoaderDict()

state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)

return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)

def load_state_dict(self, state_dict):
# store the samplers state.
# They would be reloaded once the ``CombinedIterator`` as been created
# and the workers are created.
self._loaders_iter_state_dict = state_dict

def mock_reset_fn(self, *_, **__):
pass

# mock reset call, so we can rotate the ``_worker_queue_idx_cycle`` to failed worker
# and get the first batch from it
_MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
_MultiProcessingDataLoaderIter._reset = mock_reset_fn

def on_restart(self, iterator: Iterator):
if not self._loaders_iter_state_dict:
return

# this happen inside the workers if any were specificied.

def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
if isinstance(dataloader.dataset, CaptureIterableDataset):
# provide the ``state_dict`` to the ``CaptureIterableDataset``
# as it is responsible for passing down the state to associated ``FastForwardSampler``
dataloader.dataset.load_state_dict(state_dict)
else:
# for ``Mapping-based`` dataset, the ``fast_forward_sampler`` was attached
# on the dataloader for simplicity
dataloader.fast_forward_sampler.load_state_dict(state_dict)

# cycle back the iterator to the failed worker if multiple workers were provided
iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)

if isinstance(dataloader.dataset, CaptureIterableDataset):
# remove keys related to iterator
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
# need to re-attach the state dict into the iterator for future collection.
iterator._sampler_state_dict = [state_dict]
return iterator

# apply the ``create_loader_iters`` on the collection of ``DataLoader / Iterator``.
# each ``Iterator``` was created from the ``DataLoader``.
iterator._loader_iters = apply_to_collections(
self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
)

@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
Expand All @@ -382,7 +461,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
Returns:
the wrapped loaders
"""
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))

Expand All @@ -398,7 +476,18 @@ def __iter__(self) -> Any:
"""
Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.
"""
return CombinedLoaderIterator(self.loaders)

# prevent ``NotImplementedError`` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}

_BaseDataLoaderIter.__getstate__ = __getstate__patch__
iterator = CombinedLoaderIterator(self.loaders)
# handle fault tolerant restart logic.
self.on_restart(iterator)
self._iterator = iterator
return iterator

@staticmethod
def _calc_num_batches(loaders: Any) -> Union[int, float]:
Expand All @@ -410,7 +499,6 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
Returns:
length: the minimum length of loaders
"""
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))

Expand All @@ -429,10 +517,8 @@ class CombinedLoaderIterator(object):

def __init__(self, loaders: Any):
"""
Args:
loaders: the loaders to sample from. Can be all kind of collection
"""
self.loaders = loaders
self._loader_iters = None
Expand All @@ -456,7 +542,6 @@ def __next__(self) -> Any:
Returns:
a collections of batch data
"""
return self.request_next_batch(self.loader_iters)

Expand All @@ -470,9 +555,23 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:
Returns
Any: a collections of batch data
"""
return apply_to_collection(loader_iters, Iterator, next)

def next_fn(iterator: Iterator):
batch = next(iterator)
if not _fault_tolerant_enabled():
return batch
# when fault tolerant is enabled, the iterator will return
# ``FastForwardSampler`` state_dict metadata
# along side with the user data.
# the metadata are extracted and store directly on the iterator
# to simplify the collection on ``state_dict`` call.
batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
# store the ``sampler_state_dict`` on the iterator
CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict)
return batch

return apply_to_collection(loader_iters, Iterator, next_fn)

@staticmethod
def create_loader_iters(
Expand All @@ -486,7 +585,6 @@ def create_loader_iters(
Returns
a collections of iterators
"""
# dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))
Expand Down

0 comments on commit 374fae5

Please sign in to comment.