diff --git a/CHANGELOG.md b/CHANGELOG.md index 194be8e9d6d6b..00aa84706b7e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking - * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) + * Added dataclasses for progress tracking([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574)) + * Integrate progress tracking with the training loops ([#7976](https://github.com/PyTorchLightning/pytorch-lightning/pull/7976)) * Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244)) @@ -137,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247)) +- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) + + ### Changed diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 3572a79b9bd84..9019d1d695556 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -211,6 +211,7 @@ def closure_dis(): profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) + self._trainer.fit_loop.epoch_loop.batch_loop.optim_progress.optimizer.step.increment_processed() self._total_optimizer_step_calls += 1 def __repr__(self): diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 3293b3eba29ab..bbf5d61e0d48a 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,11 +13,12 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, OrderedDict from deprecate import void import pytorch_lightning as pl +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,7 +47,44 @@ class Loop(ABC): def __init__(self) -> None: self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None + self._cached_state: Optional[Dict] = None self.restarting = False + self._loops = OrderedDict() + self._progress = OrderedDict() + + def __setattr__(self, name: str, value: Any) -> None: + if isinstance(value, Loop): + self._loops[name] = value + elif isinstance(value, BaseProgress): + self._progress[name] = value + else: + object.__setattr__(self, name, value) + + def __getattr__(self, name) -> Any: + loops = self.__dict__.get('_loops') + if loops is None: + raise MisconfigurationException("The Loop wasn't called parent `__init__` function.") + + if name in loops: + return loops[name] + + progress = self.__dict__.get('_progress') + + if name in progress: + return progress[name] + + if name not in self.__dict__: + raise AttributeError(f"{self.__class__.__name__} Loop doesn't have attribute {name}.") + + return self.__dict__[name] + + def __delattr__(self, name) -> None: + if name in self._loops: + del self._loops[name] + elif name in self._progress: + del self._progress[name] + else: + object.__delattr__(self, name) @property @abstractmethod @@ -89,7 +127,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: return self.on_skip() if self.restarting: - self.restore() + self.restore(self._cached_state) + self._cached_state = None self.restarting = False else: self.reset() @@ -108,7 +147,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: output = self.on_run_end() return output - def restore(self) -> None: + @abstractmethod + def restore(self, state: Optional[Dict] = None) -> None: """Restore the internal state of the loop the beginning of run if restarting is ``True``.""" @abstractmethod @@ -142,9 +182,43 @@ def on_run_end(self) -> Any: def teardown(self) -> None: """Use to release memory etc.""" - def load_state_dict(self, state_dict: Dict) -> None: - """Restore the loop state from the provided state_dict.""" - + @abstractmethod def state_dict(self) -> Dict: - """Return the loop current states.""" - return {} + """Current Loop state""" + + def get_state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> OrderedDict: + if destination is None: + destination = OrderedDict() + + destination[prefix + "state_dict"] = self.state_dict() + + for name, progress in self._progress.items(): + destination[prefix + name] = progress.state_dict() + + for name, loop in self._loops.items(): + loop.get_state_dict(destination, prefix + name + '.') + return destination + + def _load_from_state_dict(self, state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs): + self._cached_state = state_dict[prefix + "state_dict"] + + for name, progress in self._progress.items(): + progress.load_state_dict(state_dict[prefix + name]) + + def load_state_dict(self, state_dict: Dict, strict: bool = True): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + state_dict = state_dict.copy() + + def load(loop, prefix=''): + loop._load_from_state_dict(state_dict, prefix, True, missing_keys, unexpected_keys, error_msgs) + loop.restarting = True + for name, loop_children in loop._loops.items(): + if loop_children is not None: + load(loop_children, prefix + name + '.') + + load(self) + load = None # break load->load reference cycle diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 9b803a2790d9d..938e19f2f4bae 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -69,10 +69,8 @@ def connect( ) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - if optim_progress is not None: - self.optim_progress = optim_progress + self.progress = progress or self.progress + self.optim_progress = optim_progress or self.optim_progress @property def done(self) -> bool: @@ -98,6 +96,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict(signal=0, training_step_output=[[]]) + self.progress.increment_ready() + # hook self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") @@ -114,12 +114,23 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: self.batch_outputs = None # free memory return output - def reset(self) -> None: + def _initialize(self): """Resets the loop state""" self._hiddens = None self.batch_idx = 0 self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + def restore(self) -> None: + """Restore the loop state""" + self._initialize() + + def reset(self) -> None: + """Resets the loop state""" + self._initialize() + + # reset tracking + self.optim_progress.optimizer.reset_on_epoch() + def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): """Splits the data into tbptt splits @@ -131,6 +142,14 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): void(batch_idx, dataloader_idx) self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + self.progress.increment_started() + return super().on_advance_start(*args, **kwargs) + + def on_advance_end(self) -> None: + self.progress.increment_completed() + return super().on_advance_end() + def advance(self, batch, batch_idx, dataloader_idx): """Runs the train step together with optimization (if necessary) on the current batch split @@ -148,7 +167,17 @@ def advance(self, batch, batch_idx, dataloader_idx): self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) if self.trainer.lightning_module.automatic_optimization: - for opt_idx, optimizer in self.get_active_optimizers(batch_idx): + active_optimizers = self.get_active_optimizers(batch_idx) + for opt_idx, optimizer in active_optimizers: + + # handle optimization restart + if self.trainer.is_restarting: + if len(active_optimizers) > 1 and opt_idx < self.progress.current.completed: + continue + + # track optimizer_idx + self.optim_progress.optimizer_idx = opt_idx + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: self.batch_outputs[opt_idx].append(result.training_step_output) @@ -158,6 +187,8 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) + self.progress.increment_processed() + def teardown(self) -> None: # release memory self._remaining_splits = None @@ -238,8 +269,14 @@ def _training_step_and_backward_closure( """ result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + if result is not None: return_result.update(result) + + # this should be done only if result.loss exists + if not self.should_accumulate(): + self.optim_progress.optimizer.step.increment_started() + return return_result.loss def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable: @@ -409,6 +446,8 @@ def _optimizer_step( # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + self.optim_progress.optimizer.step.increment_ready() + # model hook model_ref.optimizer_step( self.trainer.current_epoch, @@ -421,13 +460,17 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) + self.optim_progress.optimizer.step.increment_completed() + def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. Args: optimizer: the current optimizer """ + self.optim_progress.optimizer.zero_grad.increment_started() self.trainer.call_hook('on_before_zero_grad', optimizer) + self.optim_progress.optimizer.zero_grad.increment_ready() def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: """Zeroes out all gradients of parameters optimized by the current optimizer. @@ -437,8 +480,11 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, optimizer: the current optimizer opt_idx: the index of the current optimizer """ + self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.optim_progress.optimizer.zero_grad.increment_completed() + def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. @@ -700,3 +746,12 @@ def _truncated_bptt_steps(self) -> int: if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 + + def state_dict(self) -> Dict: + return {"progress": self.progress.state_dict(), "optim_progress": self.optim_progress.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + if "progress" in state_dict: + self.progress.load_state_dict(state_dict['progress']) + if "optim_progress" in state_dict: + self.optim_progress.load_state_dict(state_dict['optim_progress']) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2f6e14b93b767..353dc4f7b9178 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import EpochLoopProgress +from pytorch_lightning.trainer.progress import EvaluationEpochLoopProgress, Tracker from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.auto_restart import dataloader_load_state_dict, dataloader_to_state_dict from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -33,13 +36,23 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self): super().__init__() self.outputs = [] - self.progress = EpochLoopProgress() + self.progress = EvaluationEpochLoopProgress() self.epoch_loop = EvaluationEpochLoop() - self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False + self.current_dataloader_iter: Optional[Iterator] = None + self.has_called = False + + @property + def state(self): + state = {"should_check_val": self.progress.should_check_val} + if self.current_dataloader_iter is not None: + dataloader = self.dataloaders[self.progress.epoch.dataloader_idx] + state.update(dataloader=dataloader_to_state_dict(dataloader, self.current_dataloader_iter)) + + return state @property def num_dataloaders(self) -> int: @@ -67,12 +80,17 @@ def predictions(self): return self.epoch_loop.predictions def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any + self, + trainer: "pl.Trainer", + *args: Any, + epoch_loop: Optional[Loop] = None, + progress: Optional[EvaluationEpochLoopProgress] = None, + **kwargs: Any ) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + self.progress = progress or self.progress + self.epoch_loop = epoch_loop or self.epoch_loop self.epoch_loop.connect(trainer, progress=self.progress.epoch) @property @@ -86,8 +104,7 @@ def skip(self) -> bool: max_batches = self.get_max_batches() return sum(max_batches) == 0 - def reset(self) -> None: - """Resets the internal state of the loop""" + def _initialize(self): self.iteration_count = 0 self._max_batches = self.get_max_batches() # bookkeeping @@ -96,27 +113,44 @@ def reset(self) -> None: if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) + def restore(self) -> None: + self._initialize() + + self.iteration_count = self.progress.epoch.dataloader_idx + + breakpoint() + + def reset(self) -> None: + """Resets the internal state of the loop""" + self._initialize() + + # reset batch / epoch progress tracking + self.progress.reset_on_epoch() + def on_skip(self) -> List: return [] def on_run_start(self, *args: Any, **kwargs: Any) -> None: """Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks""" void(*args, **kwargs) + self.progress.epoch.increment_started() + # hook self.on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() self.on_evaluation_start() self.on_evaluation_epoch_start() + self.progress.epoch.increment_ready() + def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader_iter = enumerate(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( - dataloader_iter, + self.enumerate(dataloader), self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, @@ -141,6 +175,8 @@ def on_run_end(self) -> Any: if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] + self.progress.epoch.increment_processed() + # lightning module method self.evaluation_epoch_end(outputs) @@ -159,6 +195,8 @@ def on_run_end(self) -> Any: # enable train mode again self.on_evaluation_model_train() + self.progress.epoch.increment_completed() + return eval_loop_results def get_max_batches(self) -> List[Union[int, float]]: @@ -273,3 +311,33 @@ def on_evaluation_epoch_end(self) -> None: def teardown(self) -> None: self._results.cpu() self.epoch_loop.teardown() + + def enumerate(self, dataloader: DataLoader) -> Any: + self.current_dataloader_iter = iter(dataloader) + if self.has_called: + breakpoint() + for batch_idx, batch in enumerate(self.current_dataloader_iter, self.progress.epoch.batch.current.completed): + yield batch_idx, batch + + def state_dict(self) -> Dict: + return self.state + + def load_state_dict(self, state_dict: Dict) -> None: + if "should_check_val" in state_dict: + self.progress.should_check_val = state_dict["should_check_val"] + + if "epoch_loop" in state_dict: + self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) + + if "dataloader" in state_dict: + self.reload_evaluation_dataloaders() + current_dataloader = self.dataloaders[self.progress.epoch.dataloader_idx] + dataloader_load_state_dict(current_dataloader, state_dict["dataloader"]) + + def fn(v: Tracker): + v.reset_on_restart() + + apply_to_collection(self.progress, Tracker, fn) + + self.has_called = True + breakpoint() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c01b20a5f84e2..d36533e308950 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -43,6 +43,10 @@ def __init__(self) -> None: self.outputs: List[STEP_OUTPUT] = [] self.progress = EpochProgress() + @property + def state(self): + return {} + def connect( self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any ) -> None: @@ -56,8 +60,7 @@ def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.iteration_count >= self.dl_max_batches - def reset(self) -> None: - """Resets the loop's internal state.""" + def _initialize(self): self.iteration_count = 0 self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self.dl_max_batches = None @@ -65,6 +68,18 @@ def reset(self) -> None: self.num_dataloaders = None self.outputs = [] + def restore(self): + """Restore the loop's internal state.""" + self._initialize() + + self.iteration_count = self.progress.batch.current.completed + + def reset(self) -> None: + """Resets the loop's internal state.""" + self._initialize() + + self.progress.batch.current.reset() + def on_run_start( self, dataloader_iter: Iterator, @@ -85,6 +100,7 @@ def on_run_start( self.dl_max_batches = dl_max_batches self.dataloader_idx = dataloader_idx self.num_dataloaders = num_dataloaders + self.progress.dataloader_idx = dataloader_idx def advance( self, @@ -108,20 +124,29 @@ def advance( batch_idx, batch = next(dataloader_iter) + print() + print("BATCH_IDX", dataloader_idx, batch_idx) + if batch is None: raise StopIteration + self.progress.batch.increment_started() + with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + self.progress.batch.increment_ready() + # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_step_end(output) + self.progress.batch.increment_processed() + # hook + store predictions self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) @@ -131,6 +156,8 @@ def advance( # track epoch level outputs self.outputs = self._track_output_for_epoch_end(self.outputs, output) + self.progress.batch.increment_completed() + def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" outputs = self.outputs @@ -267,3 +294,10 @@ def _track_output_for_epoch_end( output = output.cpu() outputs.append(output) return outputs + + def state_dict(self) -> Dict: + return {"progress": self.progress.state_dict()} + + def load_state_dict(self, state_dict: Dict) -> None: + if "progress" in state_dict: + self.progress.load_state_dict(state_dict["progress"]) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bc378c6bed0fb..821bc9d6d73e3 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -21,6 +21,7 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import TrainingEpochProgress +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -54,19 +55,28 @@ def __init__(self, min_steps: int, max_steps: int): self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None + self._map_dl_idx_sampler_states = {} @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" return self.iteration_count + @property + def total_optimizer_step(self) -> int: + return self.progress.optim.optimizer.step.total.completed + + @property + def current_batch_seen(self) -> int: + return self.progress.batch.current.completed + @property def done(self) -> bool: """Returns whether the training should be stopped. The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer signals to stop (e.g. by early stopping). """ - max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps + max_steps_reached = self.max_steps is not None and (self.total_optimizer_step) >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) def connect( @@ -78,13 +88,11 @@ def connect( ) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + self.progress = progress or self.progress self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) self.val_loop.connect(trainer, progress=self.progress.val) - def reset(self) -> None: - """Resets the internal state of the loop for a new run""" + def _initialize(self) -> None: self.iteration_count = 0 self.batches_seen = 0 self.is_last_batch = False @@ -93,12 +101,30 @@ def reset(self) -> None: # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] + def restore(self) -> None: + """Restore the internal state of the loop for a new run""" + self._initialize() + + self.iteration_count = self.progress.total.completed + self.batches_seen = self.progress.total.completed + + def reset(self) -> None: + """Resets the internal state of the loop for a new run""" + self._initialize() + + # reset tracking + self.progress.reset_on_epoch() + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + self.progress.increment_ready() + # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") + self.progress.increment_started() + def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """Runs a single training batch. @@ -109,6 +135,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: StopIteration: When the epoch is canceled by the user returning -1 """ _, (batch, is_last) = next(dataloader_iter) + batch = self._sanetize_batch(batch) + + print() + print(self.trainer.current_epoch, self.batches_seen) + print() + self.is_last_batch = is_last # ------------------------------------ @@ -159,11 +191,20 @@ def on_advance_end(self): # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch) + + if self.trainer.is_restarting: + should_check_val = self.progress.val.should_check_val + else: + self.progress.val.should_check_val = should_check_val + if should_check_val: self.trainer.validating = True self._run_validation() self.trainer.training = True + # inform trainer that restart is completed + self.trainer.is_restarting = False + # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- @@ -216,11 +257,15 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: 'HINT: remove the return statement in training_epoch_end' ) + self.progress.increment_processed() + # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') self.trainer.logger_connector.on_epoch_end() + self.progress.increment_completed() + epoch_output = self._epoch_output # free memory self._epoch_output = None @@ -431,9 +476,41 @@ def _save_loggers_on_train_batch_end(self) -> None: if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() + def _sanetize_batch(self, batch: Any) -> Any: + if isinstance(batch, Dict) and AutoRestartBatchKeys.PL_SAMPLERS in batch: + current_iterations = { + k: { + batch[AutoRestartBatchKeys.PL_SAMPLERS]["id"][-1].item(): { + "current_iteration": v["current_iteration"][-1].item(), + "rng_state": None + } + } + for k, v in batch[AutoRestartBatchKeys.PL_SAMPLERS].items() if k != "id" + } + if self._dataloader_idx not in self._map_dl_idx_sampler_states: + self._map_dl_idx_sampler_states[self._dataloader_idx] = current_iterations + + for k in current_iterations.keys(): + self._map_dl_idx_sampler_states[self._dataloader_idx][k].update(current_iterations[k]) + return batch["data"] + + return batch + def state_dict(self) -> Dict: - return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()} + progress = { + "total": self.progress.total.state_dict(), + "current": self.progress.current.state_dict(), + "dataloader_idx": self.progress.dataloader_idx, + } + return { + "batch_loop": self.batch_loop.state_dict(), + "val_loop": self.val_loop.state_dict(), + "progress": progress, + } def load_state_dict(self, state_dict: Dict) -> None: self.batch_loop.load_state_dict(state_dict["batch_loop"]) self.val_loop.load_state_dict(state_dict["val_loop"]) + self.progress.total.load_state_dict(state_dict["progress"]["total"]) + self.progress.current.load_state_dict(state_dict["progress"]["current"]) + self.progress.dataloader_idx = state_dict["progress"]["dataloader_idx"] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a8eb44923a241..1272c5103ebe2 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -20,9 +20,10 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import FitLoopProgress +from pytorch_lightning.trainer.progress import FitLoopProgress, Tracker from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.apply_func import apply_to_collection log = logging.getLogger(__name__) @@ -52,7 +53,6 @@ def __init__( self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.progress = FitLoopProgress() - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) @property @@ -113,6 +113,16 @@ def max_steps(self, value: int) -> None: # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.max_steps = value + @property + def total_epoch_completed(self) -> int: + """Returns the total number of epoch completed""" + return self.progress.epoch.total.completed + + @property + def total_optimizer_step_completed(self) -> int: + """Returns the total number of optimizer step completed""" + return self.progress.epoch.optim.optimizer.step.total.completed + @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss""" @@ -144,14 +154,14 @@ def done(self) -> bool: or if the maximum number of steps or epochs is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop - stop_steps = self.max_steps is not None and self.global_step >= self.max_steps - stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs + stop_steps = self.max_steps is not None and self.total_optimizer_step_completed >= self.max_steps + stop_epochs = self.max_epochs is not None and self.total_epoch_completed >= self.max_epochs should_stop = False if self.trainer.should_stop: # early stopping - met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + met_min_epochs = self.total_epoch_completed >= self.min_epochs if self.min_epochs else True + met_min_steps = self.total_optimizer_step_completed >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: @@ -161,7 +171,6 @@ def done(self) -> bool: ' not been met. Training will continue...' ) self.trainer.should_stop = should_stop - return stop_steps or should_stop or stop_epochs @property @@ -174,8 +183,7 @@ def connect( ) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + self.progress = progress or self.progress self.epoch_loop.connect(trainer, progress=self.progress.epoch) def reset(self) -> None: @@ -183,6 +191,10 @@ def reset(self) -> None: def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" + + # reset current epoch counter to 0 + self.progress.epoch.current.reset() + self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") @@ -290,10 +302,18 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) cb.on_validation_end(self.trainer, model) def state_dict(self) -> Dict: - return {"epoch_loop": self.epoch_loop.state_dict()} + return {"epoch_loop": self.epoch_loop.state_dict(), "dataloader": self.trainer.train_dataloader.state_dict()} def load_state_dict(self, state_dict: Dict) -> None: self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) + # todo (tchaton) Can we avoid creating the dataloader there ? + self.trainer.reset_train_dataloader(self.trainer.lightning_module) + self.trainer.train_dataloader.load_state_dict(state_dict["dataloader"]) + + def fn(v: Tracker): + v.reset_on_restart() + + apply_to_collection(self.progress, Tracker, fn) def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 8e4f4c0694e67..c8cc09c1f605e 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -165,6 +165,17 @@ def setup_environment(self) -> None: self.setup_distributed() + # share ddp pids to all processes + self.share_pids() + + def share_pids(self): + self.barrier() + pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device)) + pids = ','.join(str(pid) for pid in pids.cpu().numpy().tolist()) + os.environ["PL_INTERACTIVE_DDP_PROCS"] = pids + print(os.environ["PL_INTERACTIVE_DDP_PROCS"]) + self.barrier() + def _call_children_scripts(self): # bookkeeping of spawned processes assert self.local_rank == 0 @@ -178,6 +189,7 @@ def _call_children_scripts(self): # allow the user to pass the node rank os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) + os.environ["PL_TMPDIR"] = tempfile.mkdtemp() # create a temporary directory used to synchronize processes on deadlock. os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab74c3bccfc8d..21b32da974968 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,9 +20,12 @@ import torch import pytorch_lightning as pl +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.trainer.supporters import CombinedLoaderIterator from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import fault_tolerant_enabled from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -141,6 +144,13 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + gradients = self._loaded_checkpoint.get("gradients", None) + if gradients: + for name, param in model.named_parameters(): + grad = gradients.pop(name, None) + if grad is not None: + param.grad = grad + def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """ Restore only the model weights. """ checkpoint = self._loaded_checkpoint @@ -160,9 +170,12 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) + # restore progress (loops etc.) self.restore_progress() + self.restore_loops() + self.restore_optimizers_and_schedulers() def restore_callbacks(self) -> None: @@ -187,9 +200,6 @@ def restore_progress(self) -> None: if not self._loaded_checkpoint: return - self.trainer.fit_loop.global_step = self._loaded_checkpoint['global_step'] - self.trainer.fit_loop.current_epoch = self._loaded_checkpoint['epoch'] - # crash if max_epochs is lower then the current epoch from the checkpoint if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: raise MisconfigurationException( @@ -209,6 +219,26 @@ def restore_progress(self) -> None: " consider using an end of epoch checkpoint." ) + def restore_loops(self): + if not self._loaded_checkpoint: + return + + state_dict = self._loaded_checkpoint.get("loops", None) + if state_dict: + self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) + self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) + self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + + self.trainer.fit_loop.restarting = True + self.trainer.fit_loop.epoch_loop.restarting = True + self.trainer.fit_loop.epoch_loop.batch_loop.restarting = True + self.trainer.fit_loop.epoch_loop.val_loop.restarting = True + self.trainer.validate_loop.restarting = True + self.trainer.test_loop.restarting = True + self.trainer.predict_loop.restarting = True + self.trainer.is_restarting = True + def restore_optimizers_and_schedulers(self) -> None: """ Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -317,12 +347,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump epoch/global_step/pytorch-lightning_version current_epoch = self.trainer.current_epoch global_step = self.trainer.global_step - has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step - - global_step += 1 - if not has_reached_max_steps: - current_epoch += 1 - model = self.trainer.lightning_module checkpoint = { @@ -332,12 +356,21 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } + if fault_tolerant_enabled(): + checkpoint["loops"] = self.get_loops_state_dict() + # checkpoint.update({ + # 'progress': self.get_progress_state_dict(), + # 'samplers': self.get_samplers_state_dict(), + # 'gradients': self.get_gradients_state_dict(), + # 'current_workers': self.get_current_worker(), + # }) + if not weights_only: # dump callbacks checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) optimizer_states = [] - for i, optimizer in enumerate(self.trainer.optimizers): + for _, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) optimizer_states.append(optimizer_state) @@ -370,6 +403,27 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def get_gradients_state_dict(self): + return {n: p.grad for n, p in self.trainer.lightning_module.named_parameters()} + + def get_current_worker(self): + iter = self.trainer.current_iterator + if isinstance(iter, CombinedLoaderIterator): + return self.trainer.train_dataloader.state_dict() + else: + raise NotImplementedError + + def get_loops_state_dict(self): + return { + "fit_loop": self.trainer.fit_loop.state_dict(), + "validate_loop": self.trainer.validate_loop.state_dict(), + "test_loop": self.trainer.test_loop.state_dict(), + "predict_loop": self.trainer.predict_loop.state_dict(), + } + + def get_progress_state_dict(self): + return {TrainerFn.FITTING.value: self.trainer.fit_loop.progress.state_dict()} + def hpc_load(self, checkpoint_path: str) -> None: """ Attempts to restore the full training and model state from a HPC checkpoint file. diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 92019edbeff56..e5144eb08877a 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -63,7 +63,7 @@ def on_trainer_init( def get_profiled_train_dataloader(self, train_dataloader): profiled_dl = self.trainer.profiler.profile_iterable( - enumerate(prefetch_iterator(train_dataloader)), "get_train_batch" + enumerate(prefetch_iterator(train_dataloader, trainer=self.trainer)), "get_train_batch" ) return profiled_dl diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 06ae55a1ca672..a71356710b5a7 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -15,6 +15,7 @@ from weakref import proxy import pytorch_lightning as pl +from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -48,6 +49,8 @@ def update_learning_rates( if opt_indices is None: opt_indices = [] + progress: OptimizationProgress = self.trainer.fit_loop.epoch_loop.batch_loop.optim_progress + for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue @@ -83,11 +86,15 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + progress.scheduler.increment_ready() + if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step() + progress.scheduler.increment_completed() + new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] if self.trainer.dev_debugger.enabled: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ce6caa4e2f330..61909a765f453 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -19,6 +19,7 @@ from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +import torch from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -30,6 +31,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, FastForwardSampler from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -111,6 +113,41 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) + def add_samplers_to_iterable_dataset(self, dataloader: DataLoader): + skip_keys = ('sampler', 'batch_sampler', 'dataset_kind') + + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + + params = set(inspect.signature(dataloader.__init__).parameters) + contains_dataset = True + + if type(dataloader) is not DataLoader: + contains_dataset = "dataset" in params + params.update(inspect.signature(DataLoader.__init__).parameters) + + dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} + + multiprocessing_context = dataloader.multiprocessing_context + dl_args['multiprocessing_context'] = multiprocessing_context + + if not contains_dataset: + dl_args.pop('dataset') + + seed = int(os.getenv("PL_GLOBAL_SEED", 0)) + self.current_epoch + + if dl_args.get("generator") is None: + dl_args["generator"] = torch.Generator().manual_seed(seed) + + if 'dataset' in dl_args: + dl_args["dataset"] = CaptureIterableDataset( + dataset=dl_args["dataset"], + initial_seed=dl_args["generator"].initial_seed(), + ) + + dataloader = type(dataloader)(**dl_args) + dataloader.multiprocessing_context = multiprocessing_context + return dataloader + def auto_add_sampler( self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None ) -> DataLoader: @@ -123,12 +160,16 @@ def auto_add_sampler( dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle) return dataloader - if not is_dataloader or is_iterable_ds: + if is_iterable_ds: + return self.add_samplers_to_iterable_dataset(dataloader) + + if not is_dataloader: return dataloader need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance( dataloader.sampler, DistributedSampler ) + if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( @@ -137,10 +178,12 @@ def auto_add_sampler( ' distributed training. Either remove the sampler from your DataLoader or set' ' `replace_sampler_ddp`=False if you want to use your custom sampler.' ) - - # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode) - dataloader = self.replace_sampler(dataloader, sampler, mode=mode) + else: + # use current sampler + sampler = dataloader.sampler + + dataloader = self.replace_sampler(dataloader, sampler, mode=mode) return dataloader @@ -149,6 +192,7 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningS batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. + if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting: batch_sampler = type(batch_sampler)( sampler, @@ -157,17 +201,23 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningS ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - dl_args['batch_sampler'] = batch_sampler + fast_forward_sampler = FastForwardSampler(batch_sampler) + dl_args['batch_sampler'] = fast_forward_sampler dl_args['batch_size'] = 1 dl_args['shuffle'] = False dl_args['sampler'] = None dl_args['drop_last'] = False else: - dl_args['sampler'] = sampler + fast_forward_sampler = FastForwardSampler(sampler) + dl_args['sampler'] = fast_forward_sampler dl_args['shuffle'] = False dl_args['batch_sampler'] = None - return dl_args + batch_size = dl_args["batch_size"] + + fast_forward_sampler.setup(batch_size) + + return dl_args, fast_forward_sampler def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader: skip_keys = ('sampler', 'batch_sampler', 'dataset_kind') @@ -184,7 +234,7 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} - dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode) + dl_args, fast_forward_sampler = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode) multiprocessing_context = dataloader.multiprocessing_context dl_args['multiprocessing_context'] = multiprocessing_context @@ -213,6 +263,7 @@ def __init__(self, num_features, dataset, *args, **kwargs): dataloader = type(dataloader)(**dl_args) dataloader.multiprocessing_context = multiprocessing_context + dataloader.fast_forward_sampler = fast_forward_sampler return dataloader def _get_distributed_sampler( diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 25f76ad085cc6..0ea2541ad7fba 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -16,7 +16,7 @@ @dataclass -class _DataclassStateDictMixin: +class BaseProgress: def state_dict(self) -> dict: return asdict(self) @@ -25,14 +25,14 @@ def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) @classmethod - def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin": + def from_state_dict(cls, state_dict: dict) -> "BaseProgress": obj = cls() obj.load_state_dict(state_dict) return obj @dataclass -class Tracker(_DataclassStateDictMixin): +class Tracker(BaseProgress): """ Track an event's progress. @@ -60,6 +60,18 @@ def reset(self) -> None: if self.completed is not None: self.completed = 0 + def reset_on_restart(self): + value = self.completed if self.processed is None else self.processed + + if self.ready is not None: + self.ready = value + if self.started is not None: + self.started = value + if self.processed is not None: + self.processed = value + if self.completed is not None: + self.completed = value + def __setattr__(self, key: str, value: int) -> None: if getattr(self, key, 0) is None: raise AttributeError(f"The '{key}' attribute is meant to be unused") @@ -72,7 +84,7 @@ def __repr__(self): @dataclass -class Progress(_DataclassStateDictMixin): +class Progress(BaseProgress): """ Track aggregated and current progress. @@ -125,6 +137,7 @@ class BatchProgress(Progress): total: Tracks the total epoch progress current: Tracks the current epoch progress """ + should_check_val: bool = False @dataclass @@ -138,7 +151,7 @@ class EpochProgress(Progress): current: Tracks the current epoch progress batch: Tracks batch progress. """ - + dataloader_idx: int = 0 batch: BatchProgress = field(default_factory=BatchProgress) def reset_on_epoch(self) -> None: @@ -147,10 +160,11 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) self.batch.load_state_dict(state_dict["batch"]) + self.dataloader_idx = state_dict["dataloader_idx"] @dataclass -class OptimizerProgress(_DataclassStateDictMixin): +class OptimizerProgress(BaseProgress): """ Track optimizer progress. @@ -172,7 +186,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class OptimizationProgress(_DataclassStateDictMixin): +class OptimizationProgress(BaseProgress): """ Track optimization progress. @@ -182,6 +196,7 @@ class OptimizationProgress(_DataclassStateDictMixin): """ # TODO: support for multiple optimizers + optimizer_idx: int = 0 optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) @@ -200,10 +215,11 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.scheduler.load_state_dict(state_dict["scheduler"]) + self.optimizer_idx = state_dict["optimizer_idx"] @dataclass -class EpochLoopProgress(_DataclassStateDictMixin): +class EpochLoopProgress(BaseProgress): """ Tracks epoch loop progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. @@ -211,7 +227,6 @@ class EpochLoopProgress(_DataclassStateDictMixin): Args: epoch: Tracks epochs progress. """ - epoch: EpochProgress = field(default_factory=EpochProgress) def increment_epoch_completed(self) -> None: @@ -226,6 +241,16 @@ def load_state_dict(self, state_dict: dict) -> None: self.epoch.load_state_dict(state_dict["epoch"]) +@dataclass +class EvaluationEpochLoopProgress(EpochLoopProgress): + + should_check_val: bool = False + + def load_state_dict(self, state_dict: dict) -> None: + self.epoch.load_state_dict(state_dict["epoch"]) + self.should_check_val = state_dict["should_check_val"] + + @dataclass class TrainingEpochProgress(EpochProgress): """ @@ -240,7 +265,7 @@ class TrainingEpochProgress(EpochProgress): """ optim: OptimizationProgress = field(default_factory=OptimizationProgress) - val: EpochLoopProgress = field(default_factory=EpochLoopProgress) + val: EvaluationEpochLoopProgress = field(default_factory=EvaluationEpochLoopProgress) def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 54d0079b9255e..7666f509d5285 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -67,6 +67,7 @@ class TrainerProperties(ABC): fit_loop: FitLoop validate_loop: EvaluationLoop test_loop: EvaluationLoop + is_restarting: bool = False predict_loop: PredictionLoop """ Accelerator properties diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index e93d87291193d..34d1be6126c34 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -14,14 +14,16 @@ import os from collections.abc import Iterable, Iterator, Mapping, Sequence +from functools import partial from typing import Any, Callable, Generator, Optional, Tuple, Union import torch from torch import Tensor from torch.utils.data import Dataset -from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset +import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import get_len @@ -371,6 +373,55 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): if self.mode == 'max_size_cycle': self._wrap_loaders_max_size_cycle() + self.loaders_iter_state_dict = None + + def state_dict(self): + from pytorch_lightning.utilities.auto_restart import ( + fetch_fast_forward_samplers_state_dict, + fetch_previous_worker_state_dict, + ) + + out = [] + apply_to_collection(self._iterator.loader_iters, Iterator, partial(fetch_previous_worker_state_dict, out=out)) + + count = 0 + apply_to_collection( + self.loaders, DataLoader, partial(fetch_fast_forward_samplers_state_dict, out=out, count=count) + ) + return out + + def load_state_dict(self, state_dict): + count = 0 + + def mock_reset_fn(self, loader, **__): + nonlocal count + state_dict[count]["loader"] = loader + count += 1 + + self.loaders_iter_state_dict = state_dict + # delay reset call. + _MultiProcessingDataLoaderIter._ori_reset = _MultiProcessingDataLoaderIter._reset + _MultiProcessingDataLoaderIter._reset = mock_reset_fn + + def on_restart(self): + if self.loaders_iter_state_dict: + from pytorch_lightning.utilities.auto_restart import ( + cycle_to_next_worker, + fast_forward_sampler_load_state_dict, + ) + + count = 0 + apply_to_collection( + self._iterator.loader_iters, Iterator, + partial(cycle_to_next_worker, state_dict=self.loaders_iter_state_dict, count=count) + ) + + count = 0 + apply_to_collection( + self.loaders, DataLoader, + partial(fast_forward_sampler_load_state_dict, state_dict=self.loaders_iter_state_dict, count=count) + ) + @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" @@ -398,7 +449,9 @@ def __iter__(self) -> Any: """ Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. """ - return CombinedLoaderIterator(self.loaders) + self._iterator = CombinedLoaderIterator(self.loaders) + self.on_restart() + return self._iterator @staticmethod def _calc_num_batches(loaders: Any) -> Union[int, float]: @@ -514,7 +567,8 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable return compute_func(new_data) -def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]: +def prefetch_iterator(iterable: Iterable, + trainer: Optional['pl.Trainer'] = None) -> Generator[Tuple[Any, bool], None, None]: """ Returns an iterator that pre-fetches and caches the next item. The values are passed through from the given iterable with an added boolean indicating if this is the last item. @@ -522,6 +576,9 @@ def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, N """ it = iter(iterable) + if trainer: + trainer.current_iterator = it + try: # the iterator may be empty from the beginning last = next(it) @@ -534,3 +591,6 @@ def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, N last = val # yield last, no longer has next yield last, True + + if trainer: + trainer.current_iterator = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7475cd9c81326..a5cf49c9a5b76 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -13,6 +13,10 @@ # limitations under the License. """Trainer to automate the training.""" import logging +import os +import signal +import sys +import time import traceback import warnings from datetime import timedelta @@ -57,7 +61,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.progress import EpochLoopProgress, FitLoopProgress +from pytorch_lightning.trainer.progress import EpochLoopProgress, EvaluationEpochLoopProgress, FitLoopProgress from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -360,8 +364,8 @@ def __init__( self.test_loop = EvaluationLoop() self.predict_loop = PredictionLoop() self.fit_loop.connect(self, progress=FitLoopProgress()) - self.validate_loop.connect(self, progress=EpochLoopProgress()) - self.test_loop.connect(self, progress=EpochLoopProgress()) + self.validate_loop.connect(self, progress=EvaluationEpochLoopProgress()) + self.test_loop.connect(self, progress=EvaluationEpochLoopProgress()) self.predict_loop.connect(self, progress=EpochLoopProgress()) # training state @@ -995,12 +999,41 @@ def _run_train(self) -> None: if distributed_available() and self.world_size > 1: # try syncing remaing processes, kill otherwise self.training_type_plugin.reconciliate_processes(traceback.format_exc()) + # save a checkpoint for fault tolerant training + self.fit_loop._check_checkpoint_callback(True) # give accelerators a chance to finish self.accelerator.on_train_end() # reset bookkeeping self.state.stage = None raise + def _syncing_processes(self): + if distributed_available(): + sync_dir = os.path.join(os.getenv("PL_TMPDIR"), ".sync") + + if not os.path.exists(sync_dir): + try: + os.makedirs(sync_dir) + except FileExistsError: + pass + + torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.p")) + + time.sleep(1) + + if len(os.listdir(sync_dir)) == self.world_size: + return + + pids = os.getenv("PL_INTERACTIVE_DDP_PROCS", None) + if pids: + print("Detected deadlock, Lightning will terminate the processes.") + for pid in pids.split(','): + pid = int(pid) + if pid != os.getpid(): + os.kill(pid, signal.SIGKILL) + del os.environ["PL_INTERACTIVE_DDP_PROCS"] + sys.exit(0) + def _run_evaluate(self) -> _EVALUATE_OUTPUT: if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py new file mode 100644 index 0000000000000..3769de516ed76 --- /dev/null +++ b/pytorch_lightning/utilities/auto_restart.py @@ -0,0 +1,310 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from typing import Any, Dict, Generator, Iterator, List, Optional, Union + +from torch.utils.data import get_worker_info, Sampler +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset + +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys + + +class FastForwardSampler(Sampler): + """ + This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations + performed during an epoch. It maintains a state, saved with :meth:`state_dict`, that can be reloaded with + :meth:`load_state_dict`. If the sampler is used in a multiprocessing context, the ``FastForwardSampler`` will record + the state of the current worker. + + When reloading, the ``FastForwardSampler`` will "fast-forward" the wrapped sampler by iterating through all the + samples seen in the last iterations (for the current worker). + """ + + def __init__(self, sampler: Union[Sampler, Generator]) -> None: + super().__init__(data_source=None) + self._sampler = sampler + self.restarting: bool = False + self._current_iteration = 0 + self._dataloader_batch_size: Optional[int] = None + self._cached_state_dict: Optional[Dict[str, Any]] = None + + def __getattr__(self, key: str) -> Any: + if key in self.__dict__: + return self.__dict__[key] + return getattr(self._sampler, key, None) + + def setup(self, dataloader_batch_size: Optional[int] = None) -> None: + """ + Setup the ``FastForwardSampler``. + This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. + """ + self._dataloader_batch_size = dataloader_batch_size + + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + def __iter__(self) -> Iterator[Any]: + # split restart logic to avoid user with tempering with "fast-forwarding" + + if not self.restarting: + for batch in self._sampler: + self._current_iteration += 1 + yield batch + + else: + for i, batch in enumerate(self._sampler, 1): + + print(batch) + + # the `state dict` was cached as workers were available before. + if self._cached_state_dict is not None and self.worker_id in self._cached_state_dict: + + # reload the current state dict + self.load_state_dict(self._cached_state_dict, workers_initialized=True) + self._cached_state_dict = None + self.restarting = False + + # when the current index is higher than the current_iteration, we have "fast forwarded" the sampler. + if self._current_iteration < i: + self._current_iteration += 1 + yield batch + + self._current_iteration = 0 + + def __len__(self) -> int: + return len(self.sampler) + + def _compute_current_iteration(self, num_batches_processed: Optional[int] = None) -> int: + """ + This function is used to compute the effective iteration. + As DataLoader can perform ``prefecthing`` or training can fail while processing a batch, + the current iteration needs to be computed using the ``num_batches_processed`` processed information. + """ + if num_batches_processed is not None: + current_iteration = num_batches_processed + else: + current_iteration = self._current_iteration + + if self._dataloader_batch_size: + current_iteration *= self._dataloader_batch_size + + return current_iteration + + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: + """ Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" + return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} + + def load_state_dict(self, state_dict: Dict[int, Any], workers_initialized: bool = False) -> None: + """ + Loads the saved state for the wrapped sampler. + If the ``state_dict`` contains multiple states, it means there were multiple workers. + The state will be cached and fully reloaded (fast-forward) the first time :meth:`__iter__` is called. + """ + # as workers aren't available, the ``state_dict``` is cached until workers are made available. + if len(state_dict) > 1 and not workers_initialized: + self._cached_state_dict = deepcopy(state_dict) + self.restarting = self._cached_state_dict[self.worker_id]["current_iteration"] > 0 + return + self._current_iteration = state_dict[self.worker_id]["current_iteration"] + self.restarting = self._current_iteration > 0 + + +class CaptureIterableDataset(IterableDataset): + """ + The ``CaptureIterableDataset`` is used to wrap an :class:`torch.utils.data.IterableDataset`. + On ``__iter__`` function call, the ``CaptureIterableDataset`` will wrap the wrapped dataset + generators into ``FastForwardSampler`` to keep track of progress. + On ``__next__`` function call, the ``CaptureIterableDataset`` will return a dictionary containing + user data and metadata containing the ``FastForwardSampler`` samplers state_dict. + """ + + def __init__(self, dataset: IterableDataset, initial_seed: Optional[int] = None) -> None: + super().__init__() + self.dataset = deepcopy(dataset) + self.state_dict: Optional[Dict[int, Any]] = None + self.initial_seed = initial_seed + self.samplers: Optional[Dict[str, FastForwardSampler]] = None + + @property + def sampler(self) -> Sampler: + return self.dataset.sampler + + def load_state_dict(self, state_dict: Dict[int, Any]) -> None: + self.state_dict = deepcopy(state_dict) + + def _wrap_generator_samplers(self) -> None: + if self.samplers is not None: + return + + self.samplers = {} + + # access wrapped dataset attributes + dataset_dict = self.dataset.__dict__ + + # create a tuple of sampler names + samplers_names = tuple(v.__class__.__name__ for k, v in dataset_dict.items() if isinstance(v, Sampler)) + + # create a dictionary of generator present within the dataset attributes + dataset_sampler_generators = {k: v for k, v in dataset_dict.items() if isinstance(v, Generator)} + + # iterate over the generator. If a generator was created from a ``Sampler```, + # it will be wrapped into a ``FastForwardSampler``. + for (generator_attr_name, generator) in dataset_sampler_generators.items(): + + # Generator name have the the form `SamplerName.__iter__` + generator_name = generator.__qualname__.split('.')[0] + + # validate the base generator name matches a sampler name. + if any(sampler_name == generator_name for sampler_name in samplers_names): + + # wrap the generator into a ``FastForwardSampler`` + sampler = FastForwardSampler(generator) + + # if ``CaptureIterableDataset`` was available, the sampler should reload its own state. + if self.state_dict is not None: + sampler.load_state_dict(self.state_dict[generator_attr_name]) + + # store the samplers + self.samplers[generator_attr_name] = sampler + + # replace generator with the generator from the ``FastForwardSampler``. + dataset_dict[generator_attr_name] = iter(sampler) + + # reset state dict. + self.state_dict = None + + def reset_on_epoch(self) -> None: + self.state_dict = None + + def __iter__(self) -> Iterator: + # create a generator from the wrapped Iterative Dataset + # if the dataset contained samplers, they will be transformers into generators + self.iter_data = iter(self.dataset) + + # wrap any generator associated to a Sampler into a ``FastForwardSampler``. + self._wrap_generator_samplers() + return self + + def __next__(self) -> Dict[str, Any]: + # fetch next data + data = next(self.iter_data) + + # create current samplers state_dict + worker_info = get_worker_info() + state_dicts = {"id": worker_info.id if worker_info is not None else 0} + state_dicts.update({k: v.state_dict() for k, v in self.samplers.items()}) + + # return both current data and samplers ``state_dict``. + return {"data": data, AutoRestartBatchKeys.PL_SAMPLERS: state_dicts} + + @staticmethod + def convert_batch_into_state_dict(batch) -> Dict[str, Dict[int, Any]]: + """ + This function is used to convert a batch into a state_dict + """ + state_dict = {} + batch_worker_id = batch[AutoRestartBatchKeys.PL_SAMPLERS].pop("id") + worker_id = batch_worker_id[-1].item() + for sampler_name, sampler_state_dict in batch[AutoRestartBatchKeys.PL_SAMPLERS].items(): + state_dict[sampler_name] = { + worker_id: { + "current_iteration": sampler_state_dict[worker_id]["current_iteration"][-1].item() + } + } + return state_dict + + +def _find_next_worker_id(iter, state_dict: Dict[str, Any], num_workers: int): + if isinstance(iter, _MultiProcessingDataLoaderIter): + next_worker = (next(iter._worker_queue_idx_cycle)) % num_workers + previous_worker = (next_worker - 1) % num_workers + while next(iter._worker_queue_idx_cycle) != previous_worker: + pass + else: + previous_worker = None + + state_dict.update({"num_workers": iter._num_workers, "previous_worker": previous_worker}) + + +def find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardSampler]: + if isinstance(dataloader.sampler, FastForwardSampler): + return dataloader.sampler + + elif isinstance(dataloader.batch_sampler, FastForwardSampler): + return dataloader.batch_sampler + + +def fetch_previous_worker_state_dict(iter: Iterator, out: List): + num_workers = iter._num_workers + if isinstance(iter, _MultiProcessingDataLoaderIter): + next_worker = (next(iter._worker_queue_idx_cycle)) % num_workers + previous_worker = (next_worker - 1) % num_workers + while next(iter._worker_queue_idx_cycle) != previous_worker: + pass + else: + previous_worker = None + + out.append({"num_workers": iter._num_workers, "previous_worker": previous_worker}) + + +def fetch_fast_forward_samplers_state_dict(dataloader: DataLoader, out: List, count: int): + fast_forward_samplers = find_fast_forward_samplers(dataloader) + + if fast_forward_samplers is not None: + out[count]["sampler"] = fast_forward_samplers.state_dict() + count += 1 + + +def cycle_to_next_worker(iter: Iterator, state_dict: List[Dict[str, Any]], count: int): + current = state_dict[count] + num_workers = iter._num_workers + assert current["num_workers"] == num_workers + if isinstance(iter, _MultiProcessingDataLoaderIter): + # move back to 0 + while next(iter._worker_queue_idx_cycle) != 0: + pass + # increment previous worker + for _ in range(current["previous_worker"] - 1): + next(iter._worker_queue_idx_cycle) + iter._reset = iter._ori_reset + iter._reset(current["loader"], first_iter=True) + + count += 1 + + +def fast_forward_sampler_load_state_dict(dataloader, state_dict: List[Dict[str, Any]], count: int): + current_state_dict = state_dict[count]["sampler"] + fast_forward_samplers = find_fast_forward_samplers(dataloader) + + if fast_forward_samplers is not None: + fast_forward_samplers.load_state_dict(current_state_dict) + count += 1 + + +def dataloader_to_state_dict(dataloader: DataLoader, iter: Iterator) -> List[Dict[str, Any]]: + out = [] + fetch_previous_worker_state_dict(iter, out) + + count = 0 + fetch_fast_forward_samplers_state_dict(dataloader, out, count) + return out + + +def dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> None: + fast_forward_sampler = find_fast_forward_samplers(dataloader) + + if isinstance(fast_forward_sampler, Sampler): + fast_forward_sampler.load_state_dict(state_dict[0]["sampler"]) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 98f2770d03cf9..dd71a7dc72415 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -111,3 +111,11 @@ class GradClipAlgorithmType(LightningEnum): """ VALUE = 'value' NORM = 'norm' + + +class AutoRestartBatchKeys(LightningEnum): + """ + Defines special dictionary keys used to track sampler progress with multiple workers. + """ + + PL_SAMPLERS = "__pl_samplers__" diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 3125a2d38f15e..fdd5382ca751d 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,6 +14,7 @@ """General utilities""" import importlib import operator +import os import platform import sys from importlib.util import find_spec @@ -101,3 +102,7 @@ def _compare_version(package: str, op, version) -> bool: _IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable() else: _IPU_AVAILABLE = False + + +def fault_tolerant_enabled(): + return os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") == "1" diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index af5801d2b4552..ac3aa65586b47 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Dict, Iterator +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, Iterator, Optional from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.progress import BaseProgress def test_loop_restore(): @@ -72,3 +74,95 @@ def load_state_dict(self, state_dict: Dict) -> None: assert not loop.restarting assert loop.outputs == list(range(10)) + + +def test_loop_recursivity(): + + @dataclass + class SimpleProgress(BaseProgress): + + increment: int = 0 + + def state_dict(self): + return {"increment": self.increment} + + def load_state_dict(self, state_dict): + self.increment = state_dict["increment"] + + class Simple(Loop): + + def __init__(self, a): + super().__init__() + self.a = a + self.progress = SimpleProgress() + + def advance(self, *args: Any, **kwargs: Any) -> None: + for loop in self._loops.values(): + loop.run() + self.progress.increment += 1 + self.progress.increment += 1 + + @property + def done(self) -> bool: + return self.iteration_count > 0 + + def reset(self) -> None: + pass + + def restore(self, state: Optional[Dict]) -> None: + assert state is not None + if self.a == 1: + assert state["a"] == self.a + else: + assert state["a"] != self.a + + def state_dict(self) -> Dict: + return {"a": self.a} + + loop_parent = Simple(1) + loop_child = Simple(2) + loop_parent.loop_child = loop_child + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', { + 'a': 1 + }), ('progress', { + 'increment': 0 + }), ('loop_child.state_dict', { + 'a': 2 + }), ('loop_child.progress', { + 'increment': 0 + })]) + + state_dict["loop_child.state_dict"]["a"] = 3 + loop_parent.load_state_dict(state_dict) + cached_state = loop_parent.loop_child._cached_state + assert cached_state == state_dict["loop_child.state_dict"] + assert loop_parent.restarting + + loop_parent.run() + + assert loop_parent._cached_state is None + assert loop_parent.loop_child._cached_state is None + assert not loop_parent.restarting + + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', { + 'a': 1 + }), ('progress', { + 'increment': 2 + }), ('loop_child.state_dict', { + 'a': 2 + }), ('loop_child.progress', { + 'increment': 1 + })]) + + loop_parent = Simple(1) + loop_child = Simple(2) + loop_parent.loop_child = loop_child + loop_parent.load_state_dict(state_dict) + assert loop_parent.progress.increment == 2 + assert loop_parent.loop_child.progress.increment == 1 + + del loop_parent.loop_child + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index c8d982bd733fe..ec3cbdbab2718 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -11,14 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math +import os from collections.abc import Iterable +from unittest import mock import pytest -from torch.utils.data import BatchSampler, SequentialSampler +import torch +from torch.optim import Adam +from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler +from torch.utils.data._utils.worker import get_worker_info +from torch.utils.data.dataloader import _InfiniteConstantSampler, DataLoader +from torch.utils.data.dataset import Dataset, IterableDataset -from pytorch_lightning import seed_everything -from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler +from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.overrides.distributed import ( + CaptureIterativeDataset, + FastForwardSampler, + IndexBatchSamplerWrapper, + UnrepeatedDistributedSampler, +) from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.enums import BatchKeys +from tests.helpers.boring_model import BoringModel, RandomDataset @pytest.mark.parametrize("shuffle", [False, True]) @@ -67,3 +83,466 @@ def test_index_batch_sampler_methods(): assert isinstance(index_batch_sampler, Iterable) assert has_len(index_batch_sampler) + + +def test_fast_forward_on_batch_sampler(): + dataset = range(15) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, 3, False) + index_batch_sampler = FastForwardSampler(batch_sampler) + index_batch_sampler.setup(1, 1, False) + + assert isinstance(index_batch_sampler, Iterable) + assert has_len(index_batch_sampler) + + index_batch_sampler_iter = iter(index_batch_sampler) + + assert next(index_batch_sampler_iter) == [0, 1, 2] + assert next(index_batch_sampler_iter) == [3, 4, 5] + + state_dict = index_batch_sampler.state_dict(2) + + dataset = range(15) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, 3, False) + index_batch_sampler = FastForwardSampler(batch_sampler) + index_batch_sampler.setup(1, 1, False) + index_batch_sampler.load_state_dict(state_dict) + + index_batch_sampler_iter = iter(index_batch_sampler) + assert next(index_batch_sampler_iter) == [6, 7, 8] + + +def test_fast_forward_on_sequential_sampler(): + dataset = range(15) + sampler = FastForwardSampler(SequentialSampler(dataset)) + sampler.setup(1, 3, False) + batch_sampler = BatchSampler(sampler, 3, False) + + batch_sampler_iter = iter(batch_sampler) + + assert next(batch_sampler_iter) == [0, 1, 2] + assert next(batch_sampler_iter) == [3, 4, 5] + + state_dict = sampler.state_dict(2) + assert state_dict["current_iteration"] == 6 + + dataset = range(15) + sampler = FastForwardSampler(SequentialSampler(dataset)) + sampler.setup(1, 3, False) + batch_sampler = BatchSampler(sampler, 3, False) + sampler.load_state_dict(state_dict) + + batch_sampler_iter = iter(batch_sampler) + assert next(batch_sampler_iter) == [6, 7, 8] + + +def test_fast_forward_on_random_sampler(): + seed_everything(42) + + dataset = range(15) + sampler = FastForwardSampler(RandomSampler(dataset)) + sampler.setup(1, 3, False) + batch_sampler = BatchSampler(sampler, 3, False) + + batch_sampler_iter = iter(batch_sampler) + + assert next(batch_sampler_iter) == [14, 9, 1] + assert next(batch_sampler_iter) == [7, 11, 3] + assert next(batch_sampler_iter) == [12, 8, 2] + + state_dict = sampler.state_dict(3) + assert state_dict["current_iteration"] == 9 + state_dict["current_iteration"] = 6 + + dataset = range(15) + sampler = FastForwardSampler(RandomSampler(dataset)) + sampler.setup(1, 3, False) + batch_sampler = BatchSampler(sampler, 3, False) + sampler.load_state_dict(state_dict) + + batch_sampler_iter = iter(batch_sampler) + assert next(batch_sampler_iter) == [12, 8, 2] + has_raised = False + try: + for _ in range(5): + next(batch_sampler_iter) + except StopIteration: + has_raised = True + assert sampler.rng_state is None + assert sampler.current_iteration == 0 + sampler.load_state_dict(sampler.state_dict(0)) + assert has_raised + + +def test_fast_forward_sampler_replacement(tmpdir): + + seed_everything(42) + + class CustomBatchSampler(BatchSampler): + pass + + class CheckFastForwardSamplerInjection(Callback): + + def __init__(self): + self.has_called = False + + def on_train_batch_end(self, trainer, *args) -> None: + sampler = trainer.train_dataloader.loaders.fast_forward_sampler + wrapping_batch_sampler = isinstance(sampler.sampler, BatchSampler) + + num_batches = 2 + + if wrapping_batch_sampler: + assert isinstance(sampler, FastForwardSampler) + current_iteration = num_batches + else: + assert isinstance(sampler, FastForwardSampler) + current_iteration = 2 * 3 + + if trainer.fit_loop.batch_idx == 1: + assert sampler.state_dict(num_batches)["current_iteration"] == current_iteration + self.has_called = True + + dataset = RandomDataset(32, 64) + train_dataloader = DataLoader(dataset, batch_size=3, num_workers=1) + trainer_kwargs = dict( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, num_sanity_val_steps=0, limit_val_batches=0 + ) + model = BoringModel() + callback = CheckFastForwardSamplerInjection() + trainer = Trainer(**trainer_kwargs, callbacks=callback) + trainer.fit(model, train_dataloader=train_dataloader) + + train_dataloader = DataLoader( + dataset, + batch_sampler=CustomBatchSampler(batch_size=8, sampler=SequentialSampler(dataset), drop_last=True), + num_workers=2, + ) + + trainer = Trainer(**trainer_kwargs, callbacks=callback) + trainer.fit(model, train_dataloader=train_dataloader) + + +class RangeIterativeDataset(IterableDataset): + + def __init__(self, data, num_workers: int, batch_size: int, is_in_workers: bool, state_dict=None): + self.data = list(data) + self.batch_size = batch_size + self.num_workers = num_workers + self.is_in_workers = is_in_workers + self.state_dict = state_dict + + def __iter__(self): + worker_info = get_worker_info() + if worker_info and self.num_workers == 2: + id = worker_info.id + num_samples = len(self.data) + if id == 0: + self.data = list(self.data)[:num_samples // 2] + else: + self.data = list(self.data)[num_samples // 2:] + self.user_sampler = RandomSampler(self.data) + else: + self.user_sampler = RandomSampler(self.data) + + sampler = FastForwardSampler(self.user_sampler) + sampler.setup(self.batch_size, self.num_workers, self.is_in_workers) + if self.state_dict is not None: + sampler.load_state_dict(self.state_dict[0]["iter_sampler"]) + self.state_dict = None + self.sampler = sampler + self.iter_sampler = iter(self.sampler) + return self + + def __next__(self): + return self.data[next(self.iter_sampler)] + + +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +@pytest.mark.parametrize("batch_size", [3]) +def test_fast_forward_sampler_over_iterative_dataset(num_workers, batch_size, tmpdir): + + initial_seed = seed_everything(42) + generator = torch.Generator() + generator.manual_seed(initial_seed) + dataset = RangeIterativeDataset(range(20), num_workers, batch_size, True) + dataset = CaptureIterativeDataset(dataset, num_workers, batch_size, True) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) + iter_dataloader = iter(dataloader) + batches = [] + for _ in range(5): + batches.append(next(iter_dataloader)) + + if num_workers == 0: + batch_0_expected = torch.tensor([4, 3, 12]) + batch_1_expected = torch.tensor([18, 0, 7]) + batch_2_expected = torch.tensor([8, 14, 11]) + batch_3_expected = torch.tensor([5, 10, 13]) + batch_4_expected = torch.tensor([17, 19, 15]) + elif num_workers == 1: + batch_0_expected = torch.tensor([3, 18, 17]) + batch_1_expected = torch.tensor([13, 2, 19]) + batch_2_expected = torch.tensor([6, 4, 7]) + batch_3_expected = torch.tensor([1, 14, 5]) + batch_4_expected = torch.tensor([12, 8, 16]) + else: + batch_0_expected = torch.tensor([3, 4, 5]) + batch_1_expected = torch.tensor([10, 12, 14]) + batch_2_expected = torch.tensor([7, 0, 9]) + batch_3_expected = torch.tensor([16, 18, 17]) + batch_4_expected = torch.tensor([8, 1, 2]) + + assert torch.equal(batches[0]["data"], batch_0_expected) + assert torch.equal(batches[1]["data"], batch_1_expected) + assert torch.equal(batches[2]["data"], batch_2_expected) + assert torch.equal(batches[3]["data"], batch_3_expected) + assert torch.equal(batches[4]["data"], batch_4_expected) + + # restarting on batch_1 and getting 3 extra batches + + def parse_metadata(batch): + return { + k: { + batch[BatchKeys.PL_SAMPLERS]["id"][-1].item(): { + "current_iteration": v["current_iteration"][-1].item(), + "rng_state": v["rng_state"][-1] + } + } + for k, v in batch[BatchKeys.PL_SAMPLERS].items() if k != "id" + } + + state_dict = {0: {'iter_sampler': {}}} + for batch in batches[:2]: + metadata = parse_metadata(batch) + for k, v in metadata.items(): + state_dict[0][k].update(v) + + if num_workers == 2: + assert len(state_dict[0]["iter_sampler"]) == 2 + + initial_seed = seed_everything(42) + generator.manual_seed(initial_seed) + dataset = RangeIterativeDataset(range(20), num_workers, batch_size, True, state_dict=state_dict) + dataset = CaptureIterativeDataset(dataset, num_workers, batch_size, True) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) + iter_dataloader = iter(dataloader) + batches = [] + for _ in range(3): + batches.append(next(iter_dataloader)) + + assert torch.equal(batches[0]["data"], batch_2_expected) + assert torch.equal(batches[1]["data"], batch_3_expected) + assert torch.equal(batches[2]["data"], batch_4_expected) + + +class CustomIterativeDataset(IterableDataset): + + def __init__(self, dataset, num_workers: int, drop_last: bool = True): + self.dataset = list(dataset) + self.num_workers = num_workers + self.drop_last = drop_last + + if self.drop_last and len(self.dataset) % self.num_workers != 0: + self.num_samples = math.ceil((len(self.dataset) - self.num_workers) / self.num_workers) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_workers) + + self.total_size = self.num_samples * self.num_workers + + @property + def rank(self) -> int: + info = get_worker_info() + return info.id if info else 0 + + def __iter__(self): + indices = list(range(len(self.dataset))) + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_workers] + assert len(indices) == self.num_samples + + self.indices = indices + self.sampler = RandomSampler(indices) + self.iter_sampler = iter(self.sampler) + + return self + + def __next__(self): + return self.indices[next(self.iter_sampler)] + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_fast_forward_sampler_iterative_dataset(tmpdir): + + seed_everything(42) + + class CustomException(Exception): + pass + + class CheckFastForwardSamplerInjection(Callback): + + def __init__(self): + self.has_called = False + self.restarting = False + + def _validate_map_dl_idx_sampler_states(self, trainer, num_dataloaders, worker_iterations): + map_dl_idx_sampler_states = trainer.fit_loop.epoch_loop._map_dl_idx_sampler_states + assert len(map_dl_idx_sampler_states) == num_dataloaders + assert len(map_dl_idx_sampler_states[0]["iter_sampler"]) == len([i for i in worker_iterations if i > 0]) + if len(worker_iterations) == 1 and worker_iterations[0] > 0: + assert map_dl_idx_sampler_states[0]["iter_sampler"][0]["current_iteration"] == worker_iterations[0] + if len(worker_iterations) == 2 and worker_iterations[1] > 0: + assert map_dl_idx_sampler_states[0]["iter_sampler"][1]["current_iteration"] == worker_iterations[1] + if len(worker_iterations) == 3 and worker_iterations[2] > 0: + assert map_dl_idx_sampler_states[0]["iter_sampler"][2]["current_iteration"] == worker_iterations[2] + if len(worker_iterations) == 4 and worker_iterations[3] > 0: + assert map_dl_idx_sampler_states[0]["iter_sampler"][3]["current_iteration"] == worker_iterations[3] + + def on_train_batch_end( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx, + ) -> None: + assert isinstance(trainer.train_dataloader.loaders.sampler, _InfiniteConstantSampler) + assert isinstance(trainer.train_dataloader.loaders.dataset, CaptureIterativeDataset) + assert trainer.train_dataloader.loaders.generator.initial_seed() == 42 + assert trainer.train_dataloader.loaders.dataset.initial_seed == 42 + if not self.restarting: + if trainer.fit_loop.batch_idx == 0: + t = torch.tensor([20, 16, 24]) + self._validate_map_dl_idx_sampler_states(trainer, 1, [3]) + assert torch.equal(batch, t) + assert torch.equal(t % 4, torch.tensor([0, 0, 0])) + elif trainer.fit_loop.batch_idx == 1: + t = torch.tensor([1, 9, 5]) + self._validate_map_dl_idx_sampler_states(trainer, 1, [3, 3]) + assert torch.equal(batch, t) + assert torch.equal(t % 4, torch.tensor([1, 1, 1])) + raise CustomException + else: + if trainer.fit_loop.batch_idx == 2: + t = torch.tensor([2, 14, 22]) + self._validate_map_dl_idx_sampler_states(trainer, 1, [0, 0, 3]) + assert torch.equal(batch, t) + assert torch.equal(t % 4, torch.tensor([2, 2, 2])) + elif trainer.fit_loop.batch_idx == 3: + t = torch.tensor([7, 11, 15]) + self._validate_map_dl_idx_sampler_states(trainer, 1, [0, 0, 3, 3]) + assert torch.equal(batch, t) + assert torch.equal(t % 4, torch.tensor([3, 3, 3])) + elif trainer.fit_loop.batch_idx == 4: + t = torch.tensor([8, 4, 0]) + self._validate_map_dl_idx_sampler_states(trainer, 1, [6, 0, 3, 3]) + assert torch.equal(batch, t) + assert torch.equal(t % 4, torch.tensor([0, 0, 0])) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + pass + + model = TestModel() + model.training_epoch_end = None + + num_workers = 4 + + dataset = CustomIterativeDataset(range(30), num_workers) + train_dataloader = DataLoader(dataset, batch_size=3, num_workers=num_workers) + trainer_kwargs = dict( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, num_sanity_val_steps=0, limit_val_batches=0 + ) + ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) + cb = CheckFastForwardSamplerInjection() + callbacks = [cb, ck] + trainer = Trainer(**trainer_kwargs, callbacks=callbacks) + try: + trainer.fit(model, train_dataloader=train_dataloader) + except CustomException: + pass + + cb.restarting = True + + dataset = CustomIterativeDataset(range(30), num_workers) + train_dataloader = DataLoader(dataset, batch_size=3, num_workers=num_workers) + trainer = Trainer(**trainer_kwargs, resume_from_checkpoint=ck.last_model_path, callbacks=callbacks) + trainer.fit(model, train_dataloader=train_dataloader) + + +class MonotonicRandomDataset(Dataset): + + def __getitem__(self, index): + # 0.{random digits} + # 1.{random digits} + # 2.{random digits} + # ... + return torch.rand(1) + index + + def __len__(self): + return 64 + + +class RandomLightningModule(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.recorded_samples = [] + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + # print(batch_idx, batch) + self.recorded_samples.append(batch) + return {"loss": self(batch).sum()} + + def train_dataloader(self): + dataset = MonotonicRandomDataset() + dataloader = DataLoader(dataset, batch_size=2) + return dataloader + + def configure_optimizers(self): + return Adam(self.parameters()) + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_fastforward_sampler_and_dataset(tmpdir): + print("initial training") + seed_everything(1) + model = RandomLightningModule() + trainer = Trainer(max_steps=3, progress_bar_refresh_rate=0, weights_summary=None) + trainer.fit(model) + + print(torch.cat(model.recorded_samples)) + indices = [int(x) for x in torch.cat(model.recorded_samples).floor()] + assert indices == [0, 1, 2, 3, 4, 5] + + ckpt_file = os.path.join(tmpdir, "one.ckpt") + trainer.save_checkpoint(ckpt_file) + + print("resuming") + seed_everything(1) + model = RandomLightningModule() + trainer = Trainer(max_steps=6, progress_bar_refresh_rate=0, weights_summary=None, resume_from_checkpoint=ckpt_file) + trainer.fit(model) + + print(torch.cat(model.recorded_samples)) + indices = [int(x) for x in torch.cat(model.recorded_samples).floor()] + assert indices == [6, 7, 8, 9] diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index a3bbd5a36a2c1..1054bc5033dd8 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -11,19 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from copy import deepcopy +from unittest import mock import pytest +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.progress import ( BatchProgress, EpochLoopProgress, EpochProgress, FitLoopProgress, + OptimizationProgress, OptimizerProgress, Progress, Tracker, + TrainingEpochProgress, ) +from tests.helpers import BoringModel + + +class CustomException(BaseException): + pass def test_progress_geattr_setattr(): @@ -137,12 +149,15 @@ def test_optimizer_progress_default_factory(): def test_fit_loop_progress_serialization(): fit_loop = FitLoopProgress() + fit_loop.epoch.optim.optimizer_idx = 1 + fit_loop.epoch.dataloader_idx = 2 + fit_loop.epoch.val.should_check_val = True + fit_loop.epoch.increment_completed() _ = deepcopy(fit_loop) - fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super` state_dict = fit_loop.state_dict() # yapf: disable - assert state_dict == { + expected = { 'epoch': { # number of epochs across `fit` calls 'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, @@ -154,6 +169,7 @@ def test_fit_loop_progress_serialization(): # number of batches this epoch 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, }, + 'dataloader_idx': 2, # `fit` optimization progress 'optim': { # optimizers progress @@ -177,6 +193,7 @@ def test_fit_loop_progress_serialization(): # `scheduler.step` calls this epoch 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': None}, }, + "optimizer_idx": 1, }, # `fit` validation progress 'val': { @@ -191,10 +208,13 @@ def test_fit_loop_progress_serialization(): # number of batches this `fit` `validation` call 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, }, - } + 'dataloader_idx': 0 + }, + 'should_check_val': True, }, } } + assert expected == state_dict # yapf: enable new_loop = FitLoopProgress.from_state_dict(state_dict) @@ -203,6 +223,7 @@ def test_fit_loop_progress_serialization(): def test_epoch_loop_progress_serialization(): loop = EpochLoopProgress() + loop.epoch.dataloader_idx = 1 _ = deepcopy(loop) state_dict = loop.state_dict() @@ -219,9 +240,257 @@ def test_epoch_loop_progress_serialization(): # number of batches this `validate` call 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, }, + 'dataloader_idx': 1, } } # yapf: enable new_loop = EpochLoopProgress.from_state_dict(state_dict) assert loop == new_loop + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("use_multiple_optimizers", [False, True]) +@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) +def test_progress_tracking(use_multiple_optimizers, accumulate_grad_batches, tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + if use_multiple_optimizers: + self.configure_optimizers = self.configure_optimizers_3 + self.should_fail = True + + def training_step(self, batch, batch_idx, optimizer_idx: int = None): + # breaking on global_step 4 + if self.should_fail and self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( + 1 if use_multiple_optimizers else None + ): + raise CustomException + return super().training_step(batch, batch_idx) + + def configure_optimizers_3(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_1, optimizer_2], \ + [lr_scheduler, {"scheduler": lr_scheduler_1, "interval": "step"}] + + model = TestModel() + model.training_epoch_end = None + + chk = ModelCheckpoint(dirpath=tmpdir, filename=str(use_multiple_optimizers), save_last=True) + chk.last_model_path = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=0, + callbacks=chk, + accumulate_grad_batches=accumulate_grad_batches, + resume_from_checkpoint=None, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + assert isinstance(trainer.fit_loop.progress, FitLoopProgress) + assert isinstance(trainer.fit_loop.epoch_loop.progress, TrainingEpochProgress) + assert isinstance(trainer.fit_loop.epoch_loop.batch_loop.progress, BatchProgress) + assert isinstance(trainer.fit_loop.epoch_loop.batch_loop.optim_progress, OptimizationProgress) + assert isinstance(trainer.fit_loop.epoch_loop.val_loop.progress, EpochLoopProgress) + assert isinstance(trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress, EpochProgress) + + assert trainer.fit_loop.progress.epoch == trainer.fit_loop.epoch_loop.progress + + pr = trainer.fit_loop.epoch_loop.progress + + assert pr.total == Tracker(ready=2, started=2, processed=1, completed=1) + assert pr.current == Tracker(ready=2, started=2, processed=1, completed=1) + + assert pr.batch.total == Tracker(ready=5, started=5, processed=4, completed=4) + assert pr.batch.current == Tracker(ready=2, started=2, processed=1, completed=1) + + num_optimizers = 3 if use_multiple_optimizers else 1 + + optim = trainer.fit_loop.epoch_loop.batch_loop.optim_progress + + # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) + total = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches + + # we raised expection on the first optimizer + current = (1 if use_multiple_optimizers else 0) + + if accumulate_grad_batches == 2 and use_multiple_optimizers: + total += 1 + + assert optim.optimizer.step.total == Tracker(ready=total + 1, started=total, processed=None, completed=total) + assert optim.optimizer.step.current == Tracker( + ready=current + 1, started=current, processed=None, completed=current + ) + + if accumulate_grad_batches == 2: + # that's weird ! todo (tchaton) investigate this + total = (9 if use_multiple_optimizers else 3) + current = 0 # same there. + + assert optim.optimizer.zero_grad.total == Tracker(ready=total, started=total, processed=None, completed=total) + assert optim.optimizer.zero_grad.current == Tracker( + ready=current, started=current, processed=None, completed=current + ) + + # for multiple optimizers: 4 batches + 1 on epoch + total = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches + + if accumulate_grad_batches == 2: + total += 1 + + assert optim.scheduler.total == Tracker(ready=total, started=None, processed=None, completed=total) + # assert optim.scheduler.current == Tracker(ready=0, started=None, processed=None, completed=0) + + assert optim.optimizer_idx == (1 if use_multiple_optimizers else 0) + + checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) + assert checkpoint["epoch"] == 1 + assert checkpoint["global_step"] == 4 // accumulate_grad_batches + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=3, + limit_val_batches=0, + resume_from_checkpoint=chk.last_model_path, + accumulate_grad_batches=accumulate_grad_batches + ) + + model.should_fail = False + trainer.fit(model) + + pr = trainer.fit_loop.epoch_loop.progress + + assert pr.total == Tracker(ready=3, started=3, processed=3, completed=3) + assert pr.current == Tracker(ready=2, started=2, processed=2, completed=2) + + assert pr.batch.total == Tracker(ready=9, started=9, processed=9, completed=9) + assert pr.batch.current == Tracker(ready=3, started=3, processed=3, completed=3) + + optim = trainer.fit_loop.epoch_loop.progress.optim + + if accumulate_grad_batches == 2: + total = 2 * 3 * (3 if use_multiple_optimizers else 1) + else: + total = (3 * 3 * (3 if use_multiple_optimizers else 1)) + current = (3 if use_multiple_optimizers else 1) + + assert optim.optimizer.step.total == Tracker(ready=total, started=total, processed=None, completed=total) + assert optim.optimizer.step.current == Tracker(ready=current, started=current, processed=None, completed=current) + + assert optim.optimizer.zero_grad.total == Tracker(ready=total, started=total, processed=None, completed=total) + assert optim.optimizer.zero_grad.current == Tracker( + ready=current, started=current, processed=None, completed=current + ) + + # for multiple optimizers: 4 batches + 1 on epoch + if accumulate_grad_batches == 2: + total = (2 * 3 + 3 if use_multiple_optimizers else 3) + else: + total = (3 * 3 + 3 if use_multiple_optimizers else 3) + current = (2 if use_multiple_optimizers else 1) + + assert optim.scheduler.total == Tracker(ready=total, started=None, processed=None, completed=total) + # assert optim.scheduler.current == Tracker(ready=current, started=None, processed=None, completed=current) + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_progress_tracking_validation_multiple_datasets(tmpdir): + + class ValidationModel(BoringModel): + + def __init__(self): + super().__init__() + + def validation_step(self, batch, batch_idx, dataloader_idx): + if self.trainer.fit_loop.epoch_loop.batch_idx == 3 and batch_idx == 1 and dataloader_idx == 1: + raise CustomException + return super().validation_step(batch, batch_idx) + + def val_dataloader(self): + return [super().val_dataloader(), super().val_dataloader(), super().val_dataloader()] + + model = ValidationModel() + model.validation_epoch_end = None + + chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) + chk.last_model_path = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=3, + callbacks=chk, + resume_from_checkpoint=None, + val_check_interval=2, + num_sanity_val_steps=0, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + pr = trainer.fit_loop.epoch_loop.val_loop.progress + + assert isinstance(pr, EpochLoopProgress) + assert isinstance(pr.epoch, EpochProgress) + assert isinstance(pr.epoch.batch, BatchProgress) + assert pr.epoch.total == Tracker(ready=2, started=2, processed=1, completed=1) + assert pr.epoch.current == Tracker(ready=1, started=1, processed=0, completed=0) + + # 3 dataloaders with 3 samples for batch_idx == 1 + first dataloader on batch_idx == 1 + failure on batch_idx = 1 + current = 2 + total = 3 * 3 + 3 + current + assert pr.epoch.batch.total == Tracker(ready=total, started=total, processed=total - 1, completed=total - 1) + assert pr.epoch.batch.current == Tracker( + ready=current, started=current, processed=current - 1, completed=current - 1 + ) + + assert pr.epoch.dataloader_idx == 1 + + print() + print("RESTARTING") + print() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=3, + callbacks=chk, + resume_from_checkpoint=chk.last_model_path, + val_check_interval=2, + num_sanity_val_steps=0, # TODO (tchaton) This fails when increasing to 1 + ) + + trainer.fit(model) + + pr = trainer.fit_loop.epoch_loop.progress + + assert pr.total == Tracker(ready=1, started=1, processed=1, completed=1) + assert pr.current == Tracker(ready=1, started=1, processed=1, completed=1) + + pr = trainer.fit_loop.epoch_loop.val_loop.progress + + assert pr.epoch.total == Tracker(ready=2, started=2, processed=2, completed=2) + assert pr.epoch.current == Tracker(ready=1, started=1, processed=1, completed=1) + + # total = 3 (num validation samples) * 3 (num dataloaders) * 2 (num validation) + assert pr.epoch.batch.total == Tracker(ready=18, started=18, processed=18, completed=18) + assert pr.epoch.batch.current == Tracker(ready=3, started=3, processed=3, completed=3) + assert pr.epoch.dataloader_idx == 2 diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py new file mode 100644 index 0000000000000..87f5687aec592 --- /dev/null +++ b/tests/utilities/test_auto_restart.py @@ -0,0 +1,649 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +import random +from collections.abc import Iterable +from typing import Optional + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, SequentialSampler +from torch.utils.data._utils.worker import get_worker_info +from torch.utils.data.dataloader import DataLoader, default_collate +from torch.utils.data.dataset import Dataset, IterableDataset + +import tests.helpers.utils as tutils +from pytorch_lightning import seed_everything +from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, FastForwardSampler +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.runif import RunIf + + +# Credit to PyTorch Team. +# Taken from: +# https://github.com/pytorch/pytorch/blob/3b977a0d2834d300c0301a0c6af98c8e939019ce/torch/utils/data/_utils/worker.py#L151 +# Not available in PyTorch 1.4. +def _generate_state(base_seed, worker_id): + INIT_A = 0x43b0d7e5 + MULT_A = 0x931e8875 + INIT_B = 0x8b51f9dd + MULT_B = 0x58f38ded + MIX_MULT_L = 0xca01f9dd + MIX_MULT_R = 0x4973f715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def test_fast_forward_getattr(): + dataset = range(15) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, 3, False) + index_batch_sampler = FastForwardSampler(batch_sampler) + + assert index_batch_sampler.batch_size == 3 + assert index_batch_sampler.sampler == sampler + + +def test_fast_forward_on_batch_sampler(): + """ + This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived + the right next batch on restart. + """ + dataset = range(15) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, 3, False) + index_batch_sampler = FastForwardSampler(batch_sampler) + + assert isinstance(index_batch_sampler, Iterable) + + index_batch_sampler_iter = iter(index_batch_sampler) + + assert next(index_batch_sampler_iter) == [0, 1, 2] + assert next(index_batch_sampler_iter) == [3, 4, 5] + + state_dict = index_batch_sampler.state_dict(2) + + index_batch_sampler = FastForwardSampler(batch_sampler) + index_batch_sampler.load_state_dict(state_dict) + + index_batch_sampler_iter = iter(index_batch_sampler) + assert next(index_batch_sampler_iter) == [6, 7, 8] + + +def test_fast_forward_on_sequential_sampler(): + """ + This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived + the right next batch on restart. + """ + dataset = range(15) + sequential_sampler = SequentialSampler(dataset) + sampler = FastForwardSampler(sequential_sampler) + sampler.setup(3) + batch_sampler = BatchSampler(sampler, 3, False) + + batch_sampler_iter = iter(batch_sampler) + + assert next(batch_sampler_iter) == [0, 1, 2] + assert next(batch_sampler_iter) == [3, 4, 5] + + state_dict = sampler.state_dict(2) + assert state_dict[0]["current_iteration"] == 6 + + sampler.load_state_dict(state_dict) + + batch_sampler_iter = iter(batch_sampler) + assert next(batch_sampler_iter) == [6, 7, 8] + + +@RunIf(min_torch="1.6.0") +@pytest.mark.skipif(torch.cuda.is_available(), reason="todo (tchaton) Need more investigation") +def test_fast_forward_on_random_sampler(): + """ + This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived + the right next batch on restart. + """ + seed_everything(42) + + dataset = range(15) + random_sampler = RandomSampler(dataset) + sampler = FastForwardSampler(random_sampler) + sampler.setup(3) + batch_sampler = BatchSampler(sampler, 3, False) + + batch_sampler_iter = iter(batch_sampler) + + assert next(batch_sampler_iter) == [14, 9, 1] + assert next(batch_sampler_iter) == [7, 11, 3] + assert next(batch_sampler_iter) == [12, 8, 2] + + state_dict = sampler.state_dict(3) + assert state_dict[0]["current_iteration"] == 9 + state_dict[0]["current_iteration"] = 6 + + seed_everything(42) + sampler = FastForwardSampler(random_sampler) + sampler.setup(3) + batch_sampler = BatchSampler(sampler, 3, False) + sampler.load_state_dict(state_dict) + + batch_sampler_iter = iter(batch_sampler) + assert next(batch_sampler_iter) == [12, 8, 2] + has_raised = False + try: + for _ in range(5): + next(batch_sampler_iter) + except StopIteration: + has_raised = True + assert sampler._current_iteration == 0 + sampler.load_state_dict(sampler.state_dict(0)) + assert has_raised + + +class RangeIterableDataset(IterableDataset): + + def __init__(self, data, num_workers: int, batch_size: int, is_in_workers: bool, state_dict=None): + self.data = list(data) + self.batch_size = batch_size + self.num_workers = num_workers + self.is_in_workers = is_in_workers + self.state_dict = state_dict + + def __iter__(self): + worker_info = get_worker_info() + if worker_info and self.num_workers == 2: + id = worker_info.id + num_samples = len(self.data) + if id == 0: + self.data = list(self.data)[:num_samples // 2] + else: + self.data = list(self.data)[num_samples // 2:] + self.user_sampler = RandomSampler(self.data) + else: + self.user_sampler = RandomSampler(self.data) + + self.iter_sampler = iter(self.user_sampler) + return self + + def __next__(self): + return self.data[next(self.iter_sampler)] + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +@RunIf(min_torch="1.6.0") +def test_fast_forward_sampler_over_iterative_dataset(num_workers): + """ + This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being + used to capture workers states. + """ + batch_size = 3 + initial_seed = seed_everything(42) + generator = torch.Generator() + generator.manual_seed(initial_seed) + dataset = RangeIterableDataset(range(20), num_workers, batch_size, True) + dataset = CaptureIterableDataset(dataset, num_workers) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) + iter_dataloader = iter(dataloader) + batches = [] + for _ in range(5): + batches.append(next(iter_dataloader)) + + if num_workers == 0: + batch_0_expected = torch.tensor([4, 3, 12]) + batch_1_expected = torch.tensor([18, 0, 7]) + batch_2_expected = torch.tensor([8, 14, 11]) + batch_3_expected = torch.tensor([5, 10, 13]) + batch_4_expected = torch.tensor([17, 19, 15]) + elif num_workers == 1: + batch_0_expected = torch.tensor([3, 18, 17]) + batch_1_expected = torch.tensor([13, 2, 19]) + batch_2_expected = torch.tensor([6, 4, 7]) + batch_3_expected = torch.tensor([1, 14, 5]) + batch_4_expected = torch.tensor([12, 8, 16]) + else: + batch_0_expected = torch.tensor([3, 4, 5]) + batch_1_expected = torch.tensor([10, 12, 14]) + batch_2_expected = torch.tensor([7, 0, 9]) + batch_3_expected = torch.tensor([16, 18, 17]) + batch_4_expected = torch.tensor([8, 1, 2]) + + assert torch.equal(batches[0]["data"], batch_0_expected) + assert torch.equal(batches[1]["data"], batch_1_expected) + assert torch.equal(batches[2]["data"], batch_2_expected) + assert torch.equal(batches[3]["data"], batch_3_expected) + assert torch.equal(batches[4]["data"], batch_4_expected) + + # restarting on batch_1 and getting 3 extra batches + + state_dict = {'iter_sampler': {}} + for batch in batches[:2]: + _state_dict = CaptureIterableDataset.convert_batch_into_state_dict(batch) + for k, v in _state_dict.items(): + state_dict[k].update(v) + + assert len(state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1) + + initial_seed = seed_everything(42) + generator.manual_seed(initial_seed) + dataset = RangeIterableDataset(range(20), num_workers, batch_size, True, state_dict=state_dict) + dataset = CaptureIterableDataset(dataset) + dataset.load_state_dict(state_dict) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) + iter_dataloader = iter(dataloader) + batches = [] + for _ in range(3): + batches.append(next(iter_dataloader)) + + assert torch.equal(batches[0]["data"], batch_2_expected) + assert torch.equal(batches[1]["data"], batch_3_expected) + assert torch.equal(batches[2]["data"], batch_4_expected) + + +def _setup_ddp(rank, worldsize): + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize): + _setup_ddp(rank, worldsize) + + initial_seed = seed_everything(42) + + generator = torch.Generator() + generator.manual_seed(initial_seed) + + num_workers = 2 + batch_size = 4 + + dataset = range(30) + sampler = FastForwardSampler( + DistributedSampler(dataset, num_replicas=worldsize, rank=rank, drop_last=True, seed=initial_seed) + ) + sampler.setup(batch_size) + dataloader = DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler + ) + + iter_dataloader = iter(dataloader) + + num_yielded = 0 + batches = [] + while True: + try: + batches.append(next(iter_dataloader)) + num_yielded += 1 + except StopIteration: + break + + expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3]) + assert torch.equal(batches[-1], expected) + + assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 + + reload_state_dict = sampler.state_dict(num_yielded - 1) + assert reload_state_dict[0]["current_iteration"] == 12 + + sampler = FastForwardSampler( + DistributedSampler(dataset, num_replicas=worldsize, rank=rank, drop_last=True, seed=initial_seed) + ) + sampler.setup(batch_size) + sampler.load_state_dict(reload_state_dict) + dataloader = DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler + ) + + iter_dataloader = iter(dataloader) + + batches = [] + while True: + try: + batches.append(next(iter_dataloader)) + except StopIteration: + break + + assert torch.equal(batches[-1], expected) + assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 25 sec and should be skipped in Azure CI") +@RunIf(skip_windows=True, min_torch="1.6.0") +def test_fast_forward_sampler_with_distributed_sampler(): + """Make sure result logging works with DDP""" + tutils.set_random_master_port() + worldsize = 2 + mp.spawn(_test_fast_forward_sampler_with_distributed_sampler, args=(worldsize, ), nprocs=worldsize) + + +class MetaLearningDataset(IterableDataset): + + def __init__( + self, + dataset: Dataset, + batch_size: int, + drop_last: bool, + task_num_classes: int = 5, + num_workers: Optional[int] = None, + global_rank: Optional[int] = None, + world_size: Optional[int] = None, + initial_seed: Optional[torch.Generator] = None, + shuffle: bool = True, + debugging: bool = False, + ): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_workers = num_workers or 1 + self.global_rank = global_rank + self.world_size = world_size + self.task_num_classes = task_num_classes + self.labels = labels = getattr(dataset, "labels") + self.initial_seed = initial_seed + self.generator: Optional[torch.Generator] = None + self.current_task_iteration = 0 + self.shuffle = shuffle + self.debugging = debugging + + if labels is None: + raise MisconfigurationException(f"Provided {self.dataset} should have an attribute labels.") + + if len(labels) != len(dataset): + raise MisconfigurationException("Found provided ``labels`` don't match the dataset length.") + + if ((isinstance(global_rank, int) and world_size is None) + or (isinstance(world_size, int) and global_rank is None)): # noqa E129 + raise MisconfigurationException("Both ``world_size`` and ``global_rank`` should be provided !") + + self.unique_labels = np.unique(self.labels) + + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + @property + def is_distributed(self) -> bool: + return self.world_size is not None and self.world_size > 1 + + def set_seed(self, shared: bool = False): + initial_seed = self.initial_seed + self.current_task_iteration + if shared: + seed = initial_seed + np_seed = _generate_state(initial_seed, 0) + else: + seed = initial_seed + self.worker_id + self.global_rank + self.current_task_iteration + np_seed = _generate_state(initial_seed, self.worker_id + self.global_rank) + + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(np_seed) + + def sample_task_indices(self): + self.set_seed(shared=True) + self.selected_indexes = np.random.choice(self.unique_labels, self.task_num_classes, replace=False) + self.selected_indexes.sort() + + # subset of indices from the entire dataset where the labels are actually among the + # task_num_classes selected_indexes + + self.task_indices = np.arange(len(self.dataset))[np.in1d(self.labels, self.selected_indexes)] + self.task_length = len(self.task_indices) + self.set_seed(shared=False) + + @property + def worker_rank(self) -> int: + worker_id = self.worker_id + is_global_zero = self.global_rank == 0 + return self.global_rank + worker_id + int(not is_global_zero) + + def create_sampler(self): + data = range(self.task_length) + if self.world_size == 1 and self.num_workers in (0, 1): + if self.shuffle: + self.sampler = RandomSampler(data, generator=self.generator) + else: + self.sampler = SequentialSampler(data) + else: + num_workers = 1 if self.num_workers in (None, 0) else self.num_workers + num_replicas = num_workers * self.world_size + current_seed = self.initial_seed + self.current_task_iteration + self.sampler = DistributedSampler( + data, num_replicas=num_replicas, rank=self.worker_rank, shuffle=self.shuffle, seed=current_seed + ) + + def __iter__(self): + if self.generator is None: + self.generator = torch.Generator().manual_seed(self.initial_seed) + self.sample_task_indices() + self.create_sampler() + self.batch_sampler = BatchSampler(self.sampler, batch_size=self.batch_size, drop_last=self.drop_last) + self.iter_sampler = iter(self.batch_sampler) + self.is_first_batch = True + self.current_task_iteration += 1 + return self + + def increment_iteration(self): + self.current_task_iteration += 1 + + def __next__(self): + # this is optional, but useful to accumulate gradient over the entire task. + is_first_batch = self.is_first_batch if self.debugging else (self.is_first_batch and self.worker_id == 0) + if is_first_batch: + self.is_first_batch = False + return {"task_length": len(self.batch_sampler), "selected_indexes": self.selected_indexes} + + random_indices = next(self.iter_sampler) + task_indices = [self.task_indices[idx] for idx in random_indices] + return default_collate([self.dataset[idx] for idx in task_indices]) + + +class ClassificationDataset(Dataset): + + def __init__(self, inputs, labels): + self.inputs = inputs + self.labels = labels + assert len(self.inputs) == len(self.labels) + + def __getitem__(self, index): + return (self.inputs[index], self.labels[index]) + + def __len__(self): + return len(self.inputs) + + +def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize): + if worldsize > 1: + _setup_ddp(rank, worldsize) + + def all_gather(tensor, world_size): + tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, tensor) + return tensor_list + + initial_seed = seed_everything(42) + + generator = torch.Generator() + generator.manual_seed(initial_seed) + + num_workers = 2 + batch_size = 4 + dataset_length = 60 + num_classes = 10 + + labels = np.random.randint(0, num_classes, dataset_length) + + dataset = ClassificationDataset(range(dataset_length), labels) + dataset = MetaLearningDataset( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + global_rank=rank, + world_size=worldsize, + initial_seed=initial_seed, + debugging=True, + shuffle=True, + ) + dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed) + dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) + + epoch_results = [] + for _ in range(2): + iter_dataloader = iter(dataloader) + batches = [] + while True: + try: + batches.append(next(iter_dataloader)) + except StopIteration: + break + epoch_results.append(batches) + dataloader.dataset.dataset.current_task_iteration += 1 + + assert len(epoch_results) == 2 + + assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2 + + if worldsize == 1: + assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"] + assert torch.equal( + epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"] + ) + assert epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["id"] == 0 + assert epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["id"] == 1 + assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0]) + else: + first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize) + second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize) + assert torch.equal(first_task_metadata[0], first_task_metadata[1]) + assert torch.equal(second_task_metadata[0], second_task_metadata[1]) + assert torch.equal(first_task_metadata[0], second_task_metadata[1]) + + first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize) + assert not torch.equal(first_batch_list[0], first_batch_list[1]) + second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize) + assert not torch.equal(second_batch_list[0], second_batch_list[1]) + + # restarting on epoch 0 / real batch 2 + state_dict = {'iter_sampler': {}} + for batch in epoch_results[0][2:4]: + metadata = CaptureIterableDataset.convert_batch_into_state_dict(batch) + for k, v in metadata.items(): + state_dict[k].update(v) + + dataset = ClassificationDataset(range(dataset_length), labels) + dataset = MetaLearningDataset( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + global_rank=rank, + world_size=worldsize, + initial_seed=initial_seed, + debugging=True, + shuffle=True, + ) + + dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed) + dataset.load_state_dict(state_dict) + dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) + + epoch_results_restart = [] + for _ in range(2): + iter_dataloader = iter(dataloader) + batches = [] + while True: + try: + batches.append(next(iter_dataloader)) + except StopIteration: + break + epoch_results_restart.append(batches) + dataloader.dataset.dataset.increment_iteration() + dataloader.dataset.reset_on_epoch() + + assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0]) + epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]] + epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]] + + for t, tr in zip(epoch_tensors, epoch_tensors_restart): + assert torch.equal(t, tr) + + epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]] + epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]] + + for t, tr in zip(epoch_tensors, epoch_tensors_restart): + assert torch.equal(t, tr) + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI") +@RunIf(min_torch="1.6.0") +def test_fast_forward_sampler_iterative_dataset(): + _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1) + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI") +@RunIf(skip_windows=True, min_torch="1.6.0") +def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): + """Make sure result logging works with DDP""" + tutils.set_random_master_port() + worldsize = 2 + mp.spawn( + _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset, args=(worldsize, ), nprocs=worldsize + )