From 242678eaae48ee321eedb95db8cbbf5d65053023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 9 Mar 2023 03:37:02 +0100 Subject: [PATCH] Mark internal components as protected --- .../pytorch/loops/evaluation_loop.py | 4 +- src/lightning/pytorch/loops/fit_loop.py | 4 +- src/lightning/pytorch/loops/loop.py | 6 +- .../pytorch/loops/optimization/automatic.py | 4 +- .../pytorch/loops/optimization/manual.py | 6 +- .../pytorch/loops/prediction_loop.py | 4 +- src/lightning/pytorch/loops/progress.py | 55 ++++++--------- .../pytorch/loops/training_epoch_loop.py | 6 +- src/lightning/pytorch/loops/utilities.py | 4 +- src/lightning/pytorch/strategies/ddp.py | 4 +- .../trainer/configuration_validator.py | 2 +- src/lightning/pytorch/trainer/trainer.py | 4 +- .../pytorch/utilities/distributed.py | 10 +-- tests/tests_pytorch/loops/test_loops.py | 4 +- tests/tests_pytorch/loops/test_progress.py | 68 +++++++++---------- 15 files changed, 85 insertions(+), 100 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index faf56b0648e339..4f0ac6e012fcce 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -25,7 +25,7 @@ from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.loop import _Loop -from lightning.pytorch.loops.progress import BatchProgress +from lightning.pytorch.loops.progress import _BatchProgress from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import ( @@ -54,7 +54,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: super().__init__(trainer) self.verbose = verbose self.inference_mode = inference_mode - self.batch_progress = BatchProgress() # across dataloaders + self.batch_progress = _BatchProgress() # across dataloaders self._max_batches: List[Union[int, float]] = [] self._results = _ResultCollection(training=False) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index a42a6e1b28edba..c5778a0d9e8180 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.fetchers import _DataFetcher -from lightning.pytorch.loops.progress import Progress +from lightning.pytorch.loops.progress import _Progress from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher from lightning.pytorch.trainer import call @@ -84,7 +84,7 @@ def __init__( self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop = _TrainingEpochLoop(trainer) - self.epoch_progress = Progress() + self.epoch_progress = _Progress() self.max_batches: Union[int, float] = float("inf") self._data_source = _DataLoaderSource(None, "train_dataloader") diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 92e0facc0babc7..2a3bf1dfc4a9b6 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -14,7 +14,7 @@ from typing import Dict, Optional import lightning.pytorch as pl -from lightning.pytorch.loops.progress import BaseProgress +from lightning.pytorch.loops.progress import _BaseProgress class _Loop: @@ -63,7 +63,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di for k, v in self.__dict__.items(): key = prefix + k - if isinstance(v, BaseProgress): + if isinstance(v, _BaseProgress): destination[key] = v.state_dict() elif isinstance(v, _Loop): v.state_dict(destination, key + ".") @@ -87,7 +87,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: if key not in state_dict: # compatibility with old checkpoints continue - if isinstance(v, BaseProgress): + if isinstance(v, _BaseProgress): v.load_state_dict(state_dict[key]) if prefix + "state_dict" in state_dict: # compatibility with old checkpoints self.on_load_checkpoint(state_dict[prefix + "state_dict"]) diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 40a80c17c89ae2..d2b5fc923b662b 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -22,7 +22,7 @@ import lightning.pytorch as pl from lightning.pytorch.loops.loop import _Loop from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult -from lightning.pytorch.loops.progress import OptimizationProgress +from lightning.pytorch.loops.progress import _OptimizationProgress from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -151,7 +151,7 @@ class _AutomaticOptimization(_Loop): def __init__(self, trainer: "pl.Trainer") -> None: super().__init__(trainer) - self.optim_progress: OptimizationProgress = OptimizationProgress() + self.optim_progress: _OptimizationProgress = _OptimizationProgress() self._skip_backward: bool = False def run(self, optimizer: Optimizer, kwargs: OrderedDict) -> _OUTPUTS_TYPE: diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index d0dd93aa0adf8b..d10cde4a83ab71 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -22,7 +22,7 @@ from lightning.pytorch.core.optimizer import do_nothing_closure from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.optimization.closure import OutputResult -from lightning.pytorch.loops.progress import Progress, ReadyCompletedTracker +from lightning.pytorch.loops.progress import _Progress, _ReadyCompletedTracker from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -79,8 +79,8 @@ class _ManualOptimization(_Loop): def __init__(self, trainer: "pl.Trainer") -> None: super().__init__(trainer) # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than - # `OptimizationProgress` - self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker) + # `_OptimizationProgress` + self.optim_step_progress = _Progress.from_defaults(_ReadyCompletedTracker) self._output: _OUTPUTS_TYPE = {} diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 4057a853c528db..cd9817ada6721d 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -22,7 +22,7 @@ from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.loop import _Loop -from lightning.pytorch.loops.progress import Progress +from lightning.pytorch.loops.progress import _Progress from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher @@ -49,7 +49,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: # dataloaders x batches x samples. used by PredictionWriter self.epoch_batch_indices: List[List[List[int]]] = [] self.current_batch_indices: List[int] = [] # used by PredictionWriter - self.batch_progress = Progress() # across dataloaders + self.batch_progress = _Progress() # across dataloaders self.max_batches: List[Union[int, float]] = [] self._warning_cache = WarningCache() diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 5406318d9d1728..d2e52f44d7ba7f 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -16,7 +16,7 @@ @dataclass -class BaseProgress: +class _BaseProgress: """Mixin that implements state-loading utilities for dataclasses.""" def state_dict(self) -> dict: @@ -26,7 +26,7 @@ def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) @classmethod - def from_state_dict(cls, state_dict: dict) -> "BaseProgress": + def from_state_dict(cls, state_dict: dict) -> "_BaseProgress": obj = cls() obj.load_state_dict(state_dict) return obj @@ -37,7 +37,7 @@ def reset(self) -> None: @dataclass -class ReadyCompletedTracker(BaseProgress): +class _ReadyCompletedTracker(_BaseProgress): """Track an event's progress. Args: @@ -65,7 +65,7 @@ def reset_on_restart(self) -> None: @dataclass -class StartedTracker(ReadyCompletedTracker): +class _StartedTracker(_ReadyCompletedTracker): """Track an event's progress. Args: @@ -88,7 +88,7 @@ def reset_on_restart(self) -> None: @dataclass -class ProcessedTracker(StartedTracker): +class _ProcessedTracker(_StartedTracker): """Track an event's progress. Args: @@ -112,7 +112,7 @@ def reset_on_restart(self) -> None: @dataclass -class Progress(BaseProgress): +class _Progress(_BaseProgress): """Track aggregated and current progress. Args: @@ -120,8 +120,8 @@ class Progress(BaseProgress): current: Intended to track the current progress of an event. """ - total: ReadyCompletedTracker = field(default_factory=ProcessedTracker) - current: ReadyCompletedTracker = field(default_factory=ProcessedTracker) + total: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker) + current: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker) def __post_init__(self) -> None: if type(self.total) is not type(self.current): # noqa: E721 @@ -132,13 +132,13 @@ def increment_ready(self) -> None: self.current.ready += 1 def increment_started(self) -> None: - if not isinstance(self.total, StartedTracker): + if not isinstance(self.total, _StartedTracker): raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute") self.total.started += 1 self.current.started += 1 def increment_processed(self) -> None: - if not isinstance(self.total, ProcessedTracker): + if not isinstance(self.total, _ProcessedTracker): raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute") self.total.processed += 1 self.current.processed += 1 @@ -148,7 +148,7 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress": + def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) @@ -168,22 +168,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class DataLoaderProgress(Progress): - """Tracks dataloader progress. - - These counters are local to a trainer rank. By default, they are not globally synced across all ranks. - - Args: - total: Tracks the total dataloader progress. - current: Tracks the current dataloader progress. - """ - - total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) - current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) - - -@dataclass -class BatchProgress(Progress): +class _BatchProgress(_Progress): """Tracks batch progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. @@ -210,7 +195,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class SchedulerProgress(Progress): +class _SchedulerProgress(_Progress): """Tracks scheduler progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. @@ -220,12 +205,12 @@ class SchedulerProgress(Progress): current: Tracks the current scheduler progress. """ - total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) - current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) + total: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker) + current: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker) @dataclass -class OptimizerProgress(BaseProgress): +class _OptimizerProgress(_BaseProgress): """Track optimizer progress. Args: @@ -233,8 +218,8 @@ class OptimizerProgress(BaseProgress): zero_grad: Tracks ``optimizer.zero_grad`` calls. """ - step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker)) - zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) + step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker)) + zero_grad: _Progress = field(default_factory=lambda: _Progress.from_defaults(_StartedTracker)) def reset(self) -> None: self.step.reset() @@ -254,14 +239,14 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class OptimizationProgress(BaseProgress): +class _OptimizationProgress(_BaseProgress): """Track optimization progress. Args: optimizer: Tracks optimizer progress. """ - optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) + optimizer: _OptimizerProgress = field(default_factory=_OptimizerProgress) @property def optimizer_steps(self) -> int: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index ab424e9f5080e9..fe95beb636ef43 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -21,7 +21,7 @@ from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE -from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress +from lightning.pytorch.loops.progress import _BatchProgress, _SchedulerProgress from lightning.pytorch.loops.utilities import _is_max_limit_reached from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection @@ -63,8 +63,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self.min_steps = min_steps self.max_steps = max_steps - self.batch_progress = BatchProgress() - self.scheduler_progress = SchedulerProgress() + self.batch_progress = _BatchProgress() + self.scheduler_progress = _SchedulerProgress() self.automatic_optimization = _AutomaticOptimization(trainer) self.manual_optimization = _ManualOptimization(trainer) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index a5d81a6f7f71b7..59429f841bdb62 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -25,7 +25,7 @@ from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher -from lightning.pytorch.loops.progress import BaseProgress +from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import Strategy from lightning.pytorch.trainer.states import RunningStage @@ -119,7 +119,7 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: def _reset_progress(loop: _Loop) -> None: for v in vars(loop).values(): - if isinstance(v, BaseProgress): + if isinstance(v, _BaseProgress): v.reset() elif isinstance(v, _Loop): _reset_progress(v) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index e7c98ece6d1ce5..a6899b1c13307f 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -44,7 +44,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.distributed import register_ddp_comm_hook +from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep @@ -204,7 +204,7 @@ def _register_ddp_hooks(self) -> None: # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda": assert isinstance(self.model, DistributedDataParallel) - register_ddp_comm_hook( + _register_ddp_comm_hook( model=self.model, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py index bb486f3d408cad..060c57aecff483 100644 --- a/src/lightning/pytorch/trainer/configuration_validator.py +++ b/src/lightning/pytorch/trainer/configuration_validator.py @@ -22,7 +22,7 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -def verify_loop_configurations(trainer: "pl.Trainer") -> None: +def _verify_loop_configurations(trainer: "pl.Trainer") -> None: r""" Checks that the model is configured correctly before the run is started. diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index e167dc380ee315..003fe9adc223a8 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -49,7 +49,7 @@ from lightning.pytorch.profilers import Profiler from lightning.pytorch.strategies import ParallelStrategy, Strategy from lightning.pytorch.trainer import call, setup -from lightning.pytorch.trainer.configuration_validator import verify_loop_configurations +from lightning.pytorch.trainer.configuration_validator import _verify_loop_configurations from lightning.pytorch.trainer.connectors.accelerator_connector import ( _AcceleratorConnector, _LITERAL_WARN, @@ -861,7 +861,7 @@ def _run( self._callback_connector._attach_model_callbacks() self._callback_connector._attach_model_logging_functions() - verify_loop_configurations(self) + _verify_loop_configurations(self) # hook log.debug(f"{self.__class__.__name__}: preparing data") diff --git a/src/lightning/pytorch/utilities/distributed.py b/src/lightning/pytorch/utilities/distributed.py index c21f0c8c4c7670..2e7afffb7b2b8f 100644 --- a/src/lightning/pytorch/utilities/distributed.py +++ b/src/lightning/pytorch/utilities/distributed.py @@ -21,7 +21,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info -def register_ddp_comm_hook( +def _register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, @@ -64,14 +64,14 @@ def register_ddp_comm_hook( >>> >>> # fp16_compress_hook for compress gradients >>> ddp_model = ... - >>> register_ddp_comm_hook( # doctest: +SKIP + >>> _register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_hook=default.fp16_compress_hook, ... ) >>> >>> # powerSGD_hook >>> ddp_model = ... - >>> register_ddp_comm_hook( # doctest: +SKIP + >>> _register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, @@ -84,7 +84,7 @@ def register_ddp_comm_hook( >>> # post_localSGD_hook >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP >>> ddp_model = ... - >>> register_ddp_comm_hook( # doctest: +SKIP + >>> _register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... state=post_localSGD.PostLocalSGDState( ... process_group=None, @@ -96,7 +96,7 @@ def register_ddp_comm_hook( >>> >>> # fp16_compress_wrapper combined with other communication hook >>> ddp_model = ... - >>> register_ddp_comm_hook( # doctest: +SKIP + >>> _register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ab32150d368ccd..e0a154255b61f8 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -25,7 +25,7 @@ from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops import _Loop -from lightning.pytorch.loops.progress import BaseProgress +from lightning.pytorch.loops.progress import _BaseProgress def test_restarting_loops_recursive(): @@ -112,7 +112,7 @@ def load_state_dict(self, state_dict: Dict) -> None: def test_loop_hierarchy(): @dataclass - class SimpleProgress(BaseProgress): + class SimpleProgress(_BaseProgress): increment: int = 0 class Simple(_Loop): diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py index 6c273bd828fa55..4eb715e657c3da 100644 --- a/tests/tests_pytorch/loops/test_progress.py +++ b/tests/tests_pytorch/loops/test_progress.py @@ -16,76 +16,76 @@ import pytest from lightning.pytorch.loops.progress import ( - BaseProgress, - OptimizerProgress, - ProcessedTracker, - Progress, - ReadyCompletedTracker, - StartedTracker, + _BaseProgress, + _OptimizerProgress, + _ProcessedTracker, + _Progress, + _ReadyCompletedTracker, + _StartedTracker, ) def test_tracker_reset(): - p = StartedTracker(ready=1, started=2) + p = _StartedTracker(ready=1, started=2) p.reset() - assert p == StartedTracker() + assert p == _StartedTracker() def test_tracker_reset_on_restart(): - t = StartedTracker(ready=3, started=3, completed=2) + t = _StartedTracker(ready=3, started=3, completed=2) t.reset_on_restart() - assert t == StartedTracker(ready=2, started=2, completed=2) + assert t == _StartedTracker(ready=2, started=2, completed=2) - t = ProcessedTracker(ready=4, started=4, processed=3, completed=2) + t = _ProcessedTracker(ready=4, started=4, processed=3, completed=2) t.reset_on_restart() - assert t == ProcessedTracker(ready=2, started=2, processed=2, completed=2) + assert t == _ProcessedTracker(ready=2, started=2, processed=2, completed=2) @pytest.mark.parametrize("attr", ("ready", "started", "processed", "completed")) def test_progress_increment(attr): - p = Progress() + p = _Progress() fn = getattr(p, f"increment_{attr}") fn() - expected = ProcessedTracker(**{attr: 1}) + expected = _ProcessedTracker(**{attr: 1}) assert p.total == expected assert p.current == expected def test_progress_from_defaults(): - actual = Progress.from_defaults(StartedTracker, completed=5) - expected = Progress(total=StartedTracker(completed=5), current=StartedTracker(completed=5)) + actual = _Progress.from_defaults(_StartedTracker, completed=5) + expected = _Progress(total=_StartedTracker(completed=5), current=_StartedTracker(completed=5)) assert actual == expected def test_progress_increment_sequence(): """Test sequence for incrementing.""" - batch = Progress() + batch = _Progress() batch.increment_ready() - assert batch.total == ProcessedTracker(ready=1) - assert batch.current == ProcessedTracker(ready=1) + assert batch.total == _ProcessedTracker(ready=1) + assert batch.current == _ProcessedTracker(ready=1) batch.increment_started() - assert batch.total == ProcessedTracker(ready=1, started=1) - assert batch.current == ProcessedTracker(ready=1, started=1) + assert batch.total == _ProcessedTracker(ready=1, started=1) + assert batch.current == _ProcessedTracker(ready=1, started=1) batch.increment_processed() - assert batch.total == ProcessedTracker(ready=1, started=1, processed=1) - assert batch.current == ProcessedTracker(ready=1, started=1, processed=1) + assert batch.total == _ProcessedTracker(ready=1, started=1, processed=1) + assert batch.current == _ProcessedTracker(ready=1, started=1, processed=1) batch.increment_completed() - assert batch.total == ProcessedTracker(ready=1, started=1, processed=1, completed=1) - assert batch.current == ProcessedTracker(ready=1, started=1, processed=1, completed=1) + assert batch.total == _ProcessedTracker(ready=1, started=1, processed=1, completed=1) + assert batch.current == _ProcessedTracker(ready=1, started=1, processed=1, completed=1) def test_progress_raises(): with pytest.raises(ValueError, match="instances should be of the same class"): - Progress(ReadyCompletedTracker(), ProcessedTracker()) + _Progress(_ReadyCompletedTracker(), _ProcessedTracker()) - p = Progress(ReadyCompletedTracker(), ReadyCompletedTracker()) - with pytest.raises(TypeError, match="ReadyCompletedTracker` doesn't have a `started` attribute"): + p = _Progress(_ReadyCompletedTracker(), _ReadyCompletedTracker()) + with pytest.raises(TypeError, match="_ReadyCompletedTracker` doesn't have a `started` attribute"): p.increment_started() - with pytest.raises(TypeError, match="ReadyCompletedTracker` doesn't have a `processed` attribute"): + with pytest.raises(TypeError, match="_ReadyCompletedTracker` doesn't have a `processed` attribute"): p.increment_processed() @@ -94,8 +94,8 @@ def test_optimizer_progress_default_factory(): If `default_factory` was not used, the default would be shared between instances. """ - p1 = OptimizerProgress() - p2 = OptimizerProgress() + p1 = _OptimizerProgress() + p2 = _OptimizerProgress() p1.step.increment_completed() assert p1.step.total.completed == p1.step.current.completed assert p1.step.total.completed == 1 @@ -103,6 +103,6 @@ def test_optimizer_progress_default_factory(): def test_deepcopy(): - _ = deepcopy(BaseProgress()) - _ = deepcopy(Progress()) - _ = deepcopy(ProcessedTracker()) + _ = deepcopy(_BaseProgress()) + _ = deepcopy(_Progress()) + _ = deepcopy(_ProcessedTracker())