diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7014a71d9a3..740cf5c9c9c11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) + * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 3fa32bc72da5e..074090f10e3fe 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -394,52 +394,6 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str return iter_dataloader -def _dataloader_to_state_dict( - dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None -) -> List[Dict[str, Any]]: - """Convert a dataloader to its associated state dict.""" - out = {} - if iterator is not None: - out.update(_find_current_worker(iterator)) - - if not isinstance(dataloader.dataset, CaptureIterableDataset): - fast_forward_sampler = _find_fast_forward_samplers(dataloader) - if fast_forward_sampler is not None: - out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed)) - return out - - -def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader: - """Reload ``DataLoader`` fast-forward sampler state dict.""" - fast_forward_sampler = _find_fast_forward_samplers(dataloader) - - if isinstance(fast_forward_sampler, Sampler): - state_dict = {k: v for k, v in state_dict.items() if k not in ("num_workers", "previous_worker")} - fast_forward_sampler.load_state_dict(state_dict) - - return dataloader - - -def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: - """Find the current DataLoader Iterator worker if multiple workers were used.""" - # get the current number of workers - num_workers = getattr(iterator, "_num_workers", 0) - if isinstance(iterator, _MultiProcessingDataLoaderIter): - # fetch next worker - next_worker = (next(iterator._worker_queue_idx_cycle)) % num_workers - # get the current worker from next one - previous_worker = (next_worker - 1) % num_workers - # reset back the `worker_queue_idx` to current one, so we can keep - # going without perturbation. - while next(iterator._worker_queue_idx_cycle) != previous_worker: - pass - else: - previous_worker = None - - # return the captured metadata. - return {"num_workers": num_workers, "previous_worker": previous_worker} - - def _capture_metadata_collate( samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode ) -> Any: @@ -476,6 +430,52 @@ def _capture_metadata_collate( return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} +# TODO: Merge this code within stateful DataLoaderIter. +def _next_data_wrapper( + fn: Callable, + it: Iterator, + dl: DataLoader, + num_batches_fetched: int, + data_fetcher: "pl.utilities.fetching.AbstractDataFetcher", +) -> Callable: + @wraps(fn) + def wrapper() -> Any: + nonlocal num_batches_fetched + + dataset = dl.dataset + combined_batch = fn() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + num_batches_fetched += 1 + + if isinstance(dataset, CaptureIterableDataset): + state = [ + IteratorState( + num_workers=dl.num_workers, + sampler_state=iterator_state, + num_batches_fetched=num_batches_fetched, + worker_id=list(iterator_state.keys())[0], + name=sampler_iter_name, + ) + for sampler_iter_name, iterator_state in state.items() + ] + elif isinstance(dataset, CaptureMapDataset): + ff_sampler = _find_fast_forward_samplers(dl) + state = [ + IteratorState( + num_workers=dl.num_workers, + sampler_state=ff_sampler.state_dict(num_batches_fetched), + dataset_state=state, + worker_id=list(state.keys())[0], + num_batches_fetched=num_batches_fetched, + ) + ] + data_fetcher._store_dataloader_iter_state(it, state) + return batch + + return wrapper + + def patch_dataloader_iterator( dataloader: DataLoader, iterator: Iterator, @@ -506,48 +506,9 @@ def patch_dataloader_iterator( return assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) - - def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: - @wraps(fn) - def wrapper(): - nonlocal num_batches_fetched - nonlocal it - nonlocal dl - - dataset = dl.dataset - combined_batch = fn() - - batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] - num_batches_fetched += 1 - - if isinstance(dataset, CaptureIterableDataset): - state = [ - IteratorState( - num_workers=dataloader.num_workers, - sampler_state=iterator_state, - num_batches_fetched=num_batches_fetched, - worker_id=list(iterator_state.keys())[0], - name=sampler_iter_name, - ) - for sampler_iter_name, iterator_state in state.items() - ] - elif isinstance(dataset, CaptureMapDataset): - ff_sampler = _find_fast_forward_samplers(dl) - state = [ - IteratorState( - num_workers=dataloader.num_workers, - sampler_state=ff_sampler.state_dict(num_batches_fetched), - dataset_state=state, - worker_id=list(state.keys())[0], - num_batches_fetched=num_batches_fetched, - ) - ] - data_fetcher._store_dataloader_iter_state(it, state) - return batch - - return wrapper - - iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched) + iterator._next_data = _next_data_wrapper( + iterator._next_data, iterator, dataloader, num_batches_fetched, data_fetcher + ) def _add_capture_metadata_collate(dataloader: DataLoader) -> None: diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5b56940460ca4..9f725c37d3f23 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -24,8 +24,8 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler +from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache @@ -246,17 +246,8 @@ def _get_dataloader_init_kwargs( dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None - if _fault_tolerant_training(): - dataset = dl_kwargs["dataset"] - if isinstance(dataset, IterableDataset): - # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - elif get_len(dataset) != float("inf"): - dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) - else: - raise MisconfigurationException( - "This shouldn't happen, please open an issue on Lightning Github repository." - ) + if _FaultTolerantMode.detect_current_mode().is_automatic: + dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs) return dl_kwargs @@ -271,6 +262,7 @@ def _dataloader_init_kwargs_resolve_sampler( Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. """ + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. @@ -283,7 +275,7 @@ def _dataloader_init_kwargs_resolve_sampler( if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - if _fault_tolerant_training(): + if fault_tolerant_mode.is_automatic: fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) @@ -295,7 +287,7 @@ def _dataloader_init_kwargs_resolve_sampler( "drop_last": False, } - if _fault_tolerant_training(): + if fault_tolerant_mode.is_automatic: fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) @@ -305,3 +297,15 @@ def _dataloader_init_kwargs_resolve_sampler( def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) + + +def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict: + dataset = dl_kwargs["dataset"] + if isinstance(dataset, IterableDataset): + # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. + dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset) + elif get_len(dataset) != float("inf"): + dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset) + else: + raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.") + return dl_kwargs diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 1c27d582cc6a5..47f5deb344d91 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,8 +39,6 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, - _dataloader_load_state_dict, - _dataloader_to_state_dict, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, _reload_dataloader_state_dict, @@ -665,44 +663,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset -def test_dataloader_to_state_dict_and_reload(): - """ - Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset. - """ - - def create_dataloader(): - dataset = range(50) - batch_size = 8 - sampler = FastForwardSampler(SequentialSampler(dataset)) - sampler.setup(batch_size) - - return DataLoader(dataset, sampler=sampler, batch_size=batch_size) - - dataloader = create_dataloader() - iter_dataloader = iter(dataloader) - _ = next(iter_dataloader) - _ = next(iter_dataloader) - - state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == { - "num_workers": 0, - "previous_worker": None, - 0: {"current_iteration": 16}, - } - - dataloader = create_dataloader() - dataloader = _dataloader_load_state_dict(dataloader, state_dict) - iter_dataloader = iter(dataloader) - _ = next(iter_dataloader) - - state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == { - "num_workers": 0, - "previous_worker": None, - 0: {"current_iteration": 24}, - } - - @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""