Skip to content

Commit

Permalink
Move max_batches definition to the Loops (#16820)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 22, 2023
1 parent f969411 commit 565d611
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 94 deletions.
11 changes: 5 additions & 6 deletions src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
assert isinstance(trainer.num_training_batches, int)
trainer.num_training_batches += 1
# Therefore, we will virtually increase the number of training batches by 1 and skip backward.
trainer.fit_loop.max_batches += 1
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches

trainer.accumulate_grad_batches = trainer.num_training_batches
assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported"
trainer.accumulate_grad_batches = trainer.fit_loop.max_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
trainer.fit_loop._skip_backward = False
Expand All @@ -266,7 +265,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.fit_loop.max_batches -= 1
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
Expand Down
38 changes: 15 additions & 23 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
self.verbose = verbose
self.inference_mode = inference_mode
self.batch_progress = BatchProgress() # across dataloaders
self._max_batches: List[Union[int, float]] = []

self._results = _ResultCollection(training=False)
self._logged_outputs: List[_OUT_DICT] = []
Expand All @@ -55,6 +56,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
self._last_val_dl_reload_epoch = float("-inf")

@property
def num_dataloaders(self) -> int:
Expand All @@ -66,19 +68,22 @@ def num_dataloaders(self) -> int:
@property
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches this loop will run for each dataloader."""
if self.trainer.testing:
return self.trainer.num_test_batches
elif self.trainer.sanity_checking:
return self.trainer.num_sanity_val_batches
elif self.trainer.validating:
return self.trainer.num_val_batches
raise RuntimeError(f"Unexpected stage: {self.trainer.state.stage}")
max_batches = self._max_batches
if self.trainer.sanity_checking:
return [min(self.trainer.num_sanity_val_steps, batches) for batches in max_batches]
return max_batches

@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
return sum(self.max_batches) == 0

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return bool(n_epochs and self.trainer.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)

@_no_grad_context
def run(self) -> List[_OUT_DICT]:
self.setup_data()
Expand Down Expand Up @@ -110,11 +115,7 @@ def run(self) -> List[_OUT_DICT]:
def setup_data(self) -> None:
trainer = self.trainer

if (
self._combined_loader is not None
and trainer.state.fn == "fit"
and not trainer._data_connector._should_reload_val_dl
):
if self._combined_loader is not None and trainer.state.fn == "fit" and not self._should_reload_val_dl:
return

source = self._data_source
Expand All @@ -130,20 +131,11 @@ def setup_data(self) -> None:
(trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch())
or not trainer.sanity_checking
):
trainer._last_val_dl_reload_epoch = trainer.current_epoch
self._last_val_dl_reload_epoch = trainer.current_epoch

stage = trainer.state.stage
assert stage is not None
num_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)
if trainer.testing:
trainer.num_test_batches = num_batches
elif trainer.sanity_checking:
trainer.num_val_batches = num_batches
trainer.num_sanity_val_batches = [
min(trainer.num_sanity_val_steps, val_batches) for val_batches in num_batches
]
else:
trainer.num_val_batches = num_batches
self._max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)

if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop
for dl in combined_loader.flattened:
Expand Down
43 changes: 25 additions & 18 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import Optional, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _auto_add_worker_init_fn
Expand Down Expand Up @@ -79,10 +79,12 @@ def __init__(
self.min_epochs = min_epochs
self.epoch_loop = _TrainingEpochLoop(trainer)
self.epoch_progress = Progress()
self.max_batches: Union[int, float] = float("inf")

self._data_source = _DataLoaderSource(None, "train_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -136,10 +138,16 @@ def _can_stop_early(self) -> bool:
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
return met_min_epochs and met_min_steps

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and self.trainer.current_epoch - self._last_train_dl_reload_epoch >= n_epochs

@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.trainer.num_training_batches == 0:
if self.max_batches == 0:
rank_zero_info("`Trainer.fit` stopped: No training batches.")
return True

Expand Down Expand Up @@ -168,8 +176,8 @@ def done(self) -> bool:
@property
def skip(self) -> bool:
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
# until `on_run_start`, we use `limit_train_batches` instead
# if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`)
# so we cannot use it solely
return self.done or self.trainer.limit_train_batches == 0

def run(self) -> None:
Expand All @@ -190,11 +198,10 @@ def run(self) -> None:
self.on_run_end()

def setup_data(self, shuffle: bool = True) -> None:
trainer = self.trainer

if self._combined_loader is not None and not trainer._data_connector._should_reload_train_dl:
if self._combined_loader is not None and not self._should_reload_train_dl:
return

trainer = self.trainer
source = self._data_source
pl_module = trainer.lightning_module
if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
Expand Down Expand Up @@ -227,7 +234,7 @@ def setup_data(self, shuffle: bool = True) -> None:
self._combined_loader = combined_loader

module = pl_module or trainer.datamodule
orig_train_batches = trainer.num_training_batches = (
orig_train_batches = self.max_batches = (
len(self._combined_loader)
if has_len_all_ranks(self._combined_loader, trainer.strategy, module)
else float("inf")
Expand All @@ -236,12 +243,12 @@ def setup_data(self, shuffle: bool = True) -> None:
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
trainer._last_train_dl_reload_epoch = trainer.current_epoch
self._last_train_dl_reload_epoch = trainer.current_epoch

if isinstance(trainer.limit_train_batches, int):
trainer.num_training_batches = min(orig_train_batches, trainer.limit_train_batches)
elif trainer.num_training_batches != float("inf"):
trainer.num_training_batches = int(orig_train_batches * trainer.limit_train_batches)
self.max_batches = min(orig_train_batches, trainer.limit_train_batches)
elif self.max_batches != float("inf"):
self.max_batches = int(orig_train_batches * trainer.limit_train_batches)
elif trainer.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
Expand All @@ -250,10 +257,10 @@ def setup_data(self, shuffle: bool = True) -> None:

if isinstance(trainer.val_check_interval, int):
trainer.val_check_batch = trainer.val_check_interval
if trainer.val_check_batch > trainer.num_training_batches and trainer.check_val_every_n_epoch is not None:
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
raise ValueError(
f" `val_check_interval` ({trainer.val_check_interval}) must be less than or equal"
f" to the number of the training batches ({trainer.num_training_batches})."
f" to the number of the training batches ({self.max_batches})."
" If you want to disable validation set `limit_val_batches` to 0.0 instead."
" If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
Expand All @@ -268,19 +275,19 @@ def setup_data(self, shuffle: bool = True) -> None:
" checking validation every k training batches."
)
else:
trainer.val_check_batch = int(trainer.num_training_batches * trainer.val_check_interval)
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
trainer.val_check_batch = max(1, trainer.val_check_batch)

if trainer.loggers and trainer.num_training_batches < trainer.log_every_n_steps:
if trainer.loggers and self.max_batches < trainer.log_every_n_steps:
rank_zero_warn(
f"The number of training batches ({trainer.num_training_batches}) is smaller than the logging interval"
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,
)

if (
trainer.num_training_batches == 0
self.max_batches == 0
and trainer.limit_train_batches > 0.0
and isinstance(trainer.limit_train_batches, float)
and orig_train_batches != float("inf")
Expand Down
8 changes: 2 additions & 6 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None:
self.epoch_batch_indices: List[List[List[int]]] = []
self.current_batch_indices: List[int] = [] # used by PredictionWriter
self.batch_progress = Progress() # across dataloaders
self.max_batches: List[Union[int, float]] = []

self._warning_cache = WarningCache()
self._data_source = _DataLoaderSource(None, "predict_dataloader")
Expand Down Expand Up @@ -71,11 +72,6 @@ def num_dataloaders(self) -> int:
assert combined_loader is not None
return len(combined_loader.flattened)

@property
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches this loop will run for each dataloader."""
return self.trainer.num_predict_batches

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0
Expand Down Expand Up @@ -109,7 +105,7 @@ def setup_data(self) -> None:
if not source.is_defined() or trainer.limit_predict_batches == 0:
return

trainer.num_predict_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
for dl in combined_loader.flattened:
Expand Down
14 changes: 0 additions & 14 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer
self._datahook_selector: Optional[_DataHookSelector] = None

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return bool(n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs)

def on_trainer_init(
self,
val_check_interval: Optional[Union[int, float]],
Expand Down Expand Up @@ -83,7 +71,6 @@ def on_trainer_init(
)

self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def prepare_data(self) -> None:
trainer = self.trainer
Expand All @@ -107,7 +94,6 @@ def prepare_data(self) -> None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
call._call_lightning_module_hook(trainer, "prepare_data")
trainer._is_data_prepared = True

def attach_data(
self,
Expand Down
47 changes: 30 additions & 17 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def __init__(
)

self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()

setup._log_device_info(self)

self.should_stop = False
self.state = TrainerState()

# configure profiler
setup._init_profiler(self, profiler)
Expand Down Expand Up @@ -378,22 +382,6 @@ def __init__(
num_sanity_val_steps,
)

def _setup_on_init(self) -> None:
setup._log_device_info(self)

self.should_stop = False
self.state = TrainerState()

# TODO(carmocca): move these to the loops
self.num_training_batches = float("inf")
self.num_sanity_val_batches: List[Union[int, float]] = []
self.num_test_batches: List[Union[int, float]] = []
self.num_val_batches: List[Union[int, float]] = []
self.num_predict_batches: List[Union[int, float]] = []

self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

def fit(
self,
model: "pl.LightningModule",
Expand Down Expand Up @@ -1305,6 +1293,31 @@ def predict_dataloaders(self) -> EVAL_DATALOADERS:
if (combined_loader := self.predict_loop._combined_loader) is not None:
return combined_loader.iterables

@property
def num_training_batches(self) -> Union[int, float]:
return self.fit_loop.max_batches

@property
def num_sanity_val_batches(self) -> List[Union[int, float]]:
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]

@property
def num_val_batches(self) -> List[Union[int, float]]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop.max_batches
# if no trainer.fn is set, assume fit's validation
# use the protected access, because it shouldn't return the sanity_val batches
return self.fit_loop.epoch_loop.val_loop._max_batches

@property
def num_test_batches(self) -> List[Union[int, float]]:
return self.test_loop.max_batches

@property
def num_predict_batches(self) -> List[Union[int, float]]:
return self.predict_loop.max_batches

@property
def _evaluation_loop(self) -> _EvaluationLoop:
if self.state.fn == TrainerFn.FITTING:
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,9 +1366,9 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):

def test_train_epoch_end_ckpt_with_no_validation():
trainer = Trainer(val_check_interval=0.5)
trainer.num_val_batches = [0]
trainer.fit_loop.epoch_loop.val_loop._max_batches = [0]
assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
trainer.num_val_batches = [1]
trainer.fit_loop.epoch_loop.val_loop._max_batches = [1]
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
trainer.val_check_interval = 0.8
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
"""Test that checks that info message is logged when users sets `should_stop` but min conditions are not
met."""
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.fit_loop.max_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
Expand Down
Loading

0 comments on commit 565d611

Please sign in to comment.