Skip to content

Commit

Permalink
Mark internal components as protected
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 9, 2023
1 parent 5520cdf commit 242678e
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 100 deletions.
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/evaluation_loop.py
Expand Up @@ -25,7 +25,7 @@
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import BatchProgress
from lightning.pytorch.loops.progress import _BatchProgress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
super().__init__(trainer)
self.verbose = verbose
self.inference_mode = inference_mode
self.batch_progress = BatchProgress() # across dataloaders
self.batch_progress = _BatchProgress() # across dataloaders
self._max_batches: List[Union[int, float]] = []

self._results = _ResultCollection(training=False)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.progress import _Progress
from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
from lightning.pytorch.trainer import call
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop = _TrainingEpochLoop(trainer)
self.epoch_progress = Progress()
self.epoch_progress = _Progress()
self.max_batches: Union[int, float] = float("inf")

self._data_source = _DataLoaderSource(None, "train_dataloader")
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/loops/loop.py
Expand Up @@ -14,7 +14,7 @@
from typing import Dict, Optional

import lightning.pytorch as pl
from lightning.pytorch.loops.progress import BaseProgress
from lightning.pytorch.loops.progress import _BaseProgress


class _Loop:
Expand Down Expand Up @@ -63,7 +63,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di

for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
if isinstance(v, _BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, _Loop):
v.state_dict(destination, key + ".")
Expand All @@ -87,7 +87,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
if key not in state_dict:
# compatibility with old checkpoints
continue
if isinstance(v, BaseProgress):
if isinstance(v, _BaseProgress):
v.load_state_dict(state_dict[key])
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/optimization/automatic.py
Expand Up @@ -22,7 +22,7 @@
import lightning.pytorch as pl
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult
from lightning.pytorch.loops.progress import OptimizationProgress
from lightning.pytorch.loops.progress import _OptimizationProgress
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -151,7 +151,7 @@ class _AutomaticOptimization(_Loop):

def __init__(self, trainer: "pl.Trainer") -> None:
super().__init__(trainer)
self.optim_progress: OptimizationProgress = OptimizationProgress()
self.optim_progress: _OptimizationProgress = _OptimizationProgress()
self._skip_backward: bool = False

def run(self, optimizer: Optimizer, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/loops/optimization/manual.py
Expand Up @@ -22,7 +22,7 @@
from lightning.pytorch.core.optimizer import do_nothing_closure
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.optimization.closure import OutputResult
from lightning.pytorch.loops.progress import Progress, ReadyCompletedTracker
from lightning.pytorch.loops.progress import _Progress, _ReadyCompletedTracker
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -79,8 +79,8 @@ class _ManualOptimization(_Loop):
def __init__(self, trainer: "pl.Trainer") -> None:
super().__init__(trainer)
# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
# `OptimizationProgress`
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
# `_OptimizationProgress`
self.optim_step_progress = _Progress.from_defaults(_ReadyCompletedTracker)

self._output: _OUTPUTS_TYPE = {}

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/prediction_loop.py
Expand Up @@ -22,7 +22,7 @@
from lightning.pytorch.callbacks import BasePredictionWriter
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.progress import _Progress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
Expand All @@ -49,7 +49,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None:
# dataloaders x batches x samples. used by PredictionWriter
self.epoch_batch_indices: List[List[List[int]]] = []
self.current_batch_indices: List[int] = [] # used by PredictionWriter
self.batch_progress = Progress() # across dataloaders
self.batch_progress = _Progress() # across dataloaders
self.max_batches: List[Union[int, float]] = []

self._warning_cache = WarningCache()
Expand Down
55 changes: 20 additions & 35 deletions src/lightning/pytorch/loops/progress.py
Expand Up @@ -16,7 +16,7 @@


@dataclass
class BaseProgress:
class _BaseProgress:
"""Mixin that implements state-loading utilities for dataclasses."""

def state_dict(self) -> dict:
Expand All @@ -26,7 +26,7 @@ def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)

@classmethod
def from_state_dict(cls, state_dict: dict) -> "BaseProgress":
def from_state_dict(cls, state_dict: dict) -> "_BaseProgress":
obj = cls()
obj.load_state_dict(state_dict)
return obj
Expand All @@ -37,7 +37,7 @@ def reset(self) -> None:


@dataclass
class ReadyCompletedTracker(BaseProgress):
class _ReadyCompletedTracker(_BaseProgress):
"""Track an event's progress.
Args:
Expand Down Expand Up @@ -65,7 +65,7 @@ def reset_on_restart(self) -> None:


@dataclass
class StartedTracker(ReadyCompletedTracker):
class _StartedTracker(_ReadyCompletedTracker):
"""Track an event's progress.
Args:
Expand All @@ -88,7 +88,7 @@ def reset_on_restart(self) -> None:


@dataclass
class ProcessedTracker(StartedTracker):
class _ProcessedTracker(_StartedTracker):
"""Track an event's progress.
Args:
Expand All @@ -112,16 +112,16 @@ def reset_on_restart(self) -> None:


@dataclass
class Progress(BaseProgress):
class _Progress(_BaseProgress):
"""Track aggregated and current progress.
Args:
total: Intended to track the total progress of an event.
current: Intended to track the current progress of an event.
"""

total: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
current: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
total: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker)
current: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker)

def __post_init__(self) -> None:
if type(self.total) is not type(self.current): # noqa: E721
Expand All @@ -132,13 +132,13 @@ def increment_ready(self) -> None:
self.current.ready += 1

def increment_started(self) -> None:
if not isinstance(self.total, StartedTracker):
if not isinstance(self.total, _StartedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute")
self.total.started += 1
self.current.started += 1

def increment_processed(self) -> None:
if not isinstance(self.total, ProcessedTracker):
if not isinstance(self.total, _ProcessedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute")
self.total.processed += 1
self.current.processed += 1
Expand All @@ -148,7 +148,7 @@ def increment_completed(self) -> None:
self.current.completed += 1

@classmethod
def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress":
def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress":
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))

Expand All @@ -168,22 +168,7 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class DataLoaderProgress(Progress):
"""Tracks dataloader progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Args:
total: Tracks the total dataloader progress.
current: Tracks the current dataloader progress.
"""

total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
class BatchProgress(Progress):
class _BatchProgress(_Progress):
"""Tracks batch progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Expand All @@ -210,7 +195,7 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class SchedulerProgress(Progress):
class _SchedulerProgress(_Progress):
"""Tracks scheduler progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Expand All @@ -220,21 +205,21 @@ class SchedulerProgress(Progress):
current: Tracks the current scheduler progress.
"""

total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
total: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker)
current: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker)


@dataclass
class OptimizerProgress(BaseProgress):
class _OptimizerProgress(_BaseProgress):
"""Track optimizer progress.
Args:
step: Tracks ``optimizer.step`` calls.
zero_grad: Tracks ``optimizer.zero_grad`` calls.
"""

step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))
step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker))
zero_grad: _Progress = field(default_factory=lambda: _Progress.from_defaults(_StartedTracker))

def reset(self) -> None:
self.step.reset()
Expand All @@ -254,14 +239,14 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class OptimizationProgress(BaseProgress):
class _OptimizationProgress(_BaseProgress):
"""Track optimization progress.
Args:
optimizer: Tracks optimizer progress.
"""

optimizer: OptimizerProgress = field(default_factory=OptimizerProgress)
optimizer: _OptimizerProgress = field(default_factory=_OptimizerProgress)

@property
def optimizer_steps(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Expand Up @@ -21,7 +21,7 @@
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
from lightning.pytorch.loops.progress import _BatchProgress, _SchedulerProgress
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
Expand Down Expand Up @@ -63,8 +63,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
self.min_steps = min_steps
self.max_steps = max_steps

self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()
self.batch_progress = _BatchProgress()
self.scheduler_progress = _SchedulerProgress()

self.automatic_optimization = _AutomaticOptimization(trainer)
self.manual_optimization = _ManualOptimization(trainer)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/utilities.py
Expand Up @@ -25,7 +25,7 @@
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
from lightning.pytorch.loops.progress import BaseProgress
from lightning.pytorch.loops.progress import _BaseProgress
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.trainer.states import RunningStage
Expand Down Expand Up @@ -119,7 +119,7 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:

def _reset_progress(loop: _Loop) -> None:
for v in vars(loop).values():
if isinstance(v, BaseProgress):
if isinstance(v, _BaseProgress):
v.reset()
elif isinstance(v, _Loop):
_reset_progress(v)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/ddp.py
Expand Up @@ -44,7 +44,7 @@
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.distributed import register_ddp_comm_hook
from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook
from lightning.pytorch.utilities.exceptions import _augment_message
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
Expand Down Expand Up @@ -204,7 +204,7 @@ def _register_ddp_hooks(self) -> None:
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda":
assert isinstance(self.model, DistributedDataParallel)
register_ddp_comm_hook(
_register_ddp_comm_hook(
model=self.model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/configuration_validator.py
Expand Up @@ -22,7 +22,7 @@
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature


def verify_loop_configurations(trainer: "pl.Trainer") -> None:
def _verify_loop_configurations(trainer: "pl.Trainer") -> None:
r"""
Checks that the model is configured correctly before the run is started.
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/trainer.py
Expand Up @@ -49,7 +49,7 @@
from lightning.pytorch.profilers import Profiler
from lightning.pytorch.strategies import ParallelStrategy, Strategy
from lightning.pytorch.trainer import call, setup
from lightning.pytorch.trainer.configuration_validator import verify_loop_configurations
from lightning.pytorch.trainer.configuration_validator import _verify_loop_configurations
from lightning.pytorch.trainer.connectors.accelerator_connector import (
_AcceleratorConnector,
_LITERAL_WARN,
Expand Down Expand Up @@ -861,7 +861,7 @@ def _run(
self._callback_connector._attach_model_callbacks()
self._callback_connector._attach_model_logging_functions()

verify_loop_configurations(self)
_verify_loop_configurations(self)

# hook
log.debug(f"{self.__class__.__name__}: preparing data")
Expand Down

0 comments on commit 242678e

Please sign in to comment.