diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 4b7c1cf3f3c53..6468e3959b527 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -250,13 +250,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. - assert isinstance(trainer.num_training_batches, int) - trainer.num_training_batches += 1 + # Therefore, we will virtually increase the number of training batches by 1 and skip backward. + trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches - - trainer.accumulate_grad_batches = trainer.num_training_batches + assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported" + trainer.accumulate_grad_batches = trainer.fit_loop.max_batches def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: trainer.fit_loop._skip_backward = False @@ -266,7 +265,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches - trainer.num_training_batches -= 1 + trainer.fit_loop.max_batches -= 1 assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs -= 1 self.reset_momenta() diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 630bff5e26331..06c3283e0ddb0 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -47,6 +47,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = BatchProgress() # across dataloaders + self._max_batches: List[Union[int, float]] = [] self._results = _ResultCollection(training=False) self._logged_outputs: List[_OUT_DICT] = [] @@ -55,6 +56,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) + self._last_val_dl_reload_epoch = float("-inf") @property def num_dataloaders(self) -> int: @@ -66,19 +68,22 @@ def num_dataloaders(self) -> int: @property def max_batches(self) -> List[Union[int, float]]: """The max number of batches this loop will run for each dataloader.""" - if self.trainer.testing: - return self.trainer.num_test_batches - elif self.trainer.sanity_checking: - return self.trainer.num_sanity_val_batches - elif self.trainer.validating: - return self.trainer.num_val_batches - raise RuntimeError(f"Unexpected stage: {self.trainer.state.stage}") + max_batches = self._max_batches + if self.trainer.sanity_checking: + return [min(self.trainer.num_sanity_val_steps, batches) for batches in max_batches] + return max_batches @property def skip(self) -> bool: """Returns whether the evaluation should be skipped.""" return sum(self.max_batches) == 0 + @property + def _should_reload_val_dl(self) -> bool: + """Check if validation dataloader should be reloaded.""" + n_epochs = self.trainer.reload_dataloaders_every_n_epochs + return bool(n_epochs and self.trainer.current_epoch - self._last_val_dl_reload_epoch >= n_epochs) + @_no_grad_context def run(self) -> List[_OUT_DICT]: self.setup_data() @@ -110,11 +115,7 @@ def run(self) -> List[_OUT_DICT]: def setup_data(self) -> None: trainer = self.trainer - if ( - self._combined_loader is not None - and trainer.state.fn == "fit" - and not trainer._data_connector._should_reload_val_dl - ): + if self._combined_loader is not None and trainer.state.fn == "fit" and not self._should_reload_val_dl: return source = self._data_source @@ -130,20 +131,11 @@ def setup_data(self) -> None: (trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch()) or not trainer.sanity_checking ): - trainer._last_val_dl_reload_epoch = trainer.current_epoch + self._last_val_dl_reload_epoch = trainer.current_epoch stage = trainer.state.stage assert stage is not None - num_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module) - if trainer.testing: - trainer.num_test_batches = num_batches - elif trainer.sanity_checking: - trainer.num_val_batches = num_batches - trainer.num_sanity_val_batches = [ - min(trainer.num_sanity_val_steps, val_batches) for val_batches in num_batches - ] - else: - trainer.num_val_batches = num_batches + self._max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module) if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop for dl in combined_loader.flattened: diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 4f930e7df2a79..e0ce408821330 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.data import _auto_add_worker_init_fn @@ -79,10 +79,12 @@ def __init__( self.min_epochs = min_epochs self.epoch_loop = _TrainingEpochLoop(trainer) self.epoch_progress = Progress() + self.max_batches: Union[int, float] = float("inf") self._data_source = _DataLoaderSource(None, "train_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None + self._last_train_dl_reload_epoch = float("-inf") @property def total_batch_idx(self) -> int: @@ -136,10 +138,16 @@ def _can_stop_early(self) -> bool: met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True return met_min_epochs and met_min_steps + @property + def _should_reload_train_dl(self) -> bool: + """Check if train dataloader should be reloaded.""" + n_epochs = self.trainer.reload_dataloaders_every_n_epochs + return n_epochs and self.trainer.current_epoch - self._last_train_dl_reload_epoch >= n_epochs + @property def done(self) -> bool: """Evaluates when to leave the loop.""" - if self.trainer.num_training_batches == 0: + if self.max_batches == 0: rank_zero_info("`Trainer.fit` stopped: No training batches.") return True @@ -168,8 +176,8 @@ def done(self) -> bool: @property def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" - # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called - # until `on_run_start`, we use `limit_train_batches` instead + # if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`) + # so we cannot use it solely return self.done or self.trainer.limit_train_batches == 0 def run(self) -> None: @@ -190,11 +198,10 @@ def run(self) -> None: self.on_run_end() def setup_data(self, shuffle: bool = True) -> None: - trainer = self.trainer - - if self._combined_loader is not None and not trainer._data_connector._should_reload_train_dl: + if self._combined_loader is not None and not self._should_reload_train_dl: return + trainer = self.trainer source = self._data_source pl_module = trainer.lightning_module if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): @@ -227,7 +234,7 @@ def setup_data(self, shuffle: bool = True) -> None: self._combined_loader = combined_loader module = pl_module or trainer.datamodule - orig_train_batches = trainer.num_training_batches = ( + orig_train_batches = self.max_batches = ( len(self._combined_loader) if has_len_all_ranks(self._combined_loader, trainer.strategy, module) else float("inf") @@ -236,12 +243,12 @@ def setup_data(self, shuffle: bool = True) -> None: return # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - trainer._last_train_dl_reload_epoch = trainer.current_epoch + self._last_train_dl_reload_epoch = trainer.current_epoch if isinstance(trainer.limit_train_batches, int): - trainer.num_training_batches = min(orig_train_batches, trainer.limit_train_batches) - elif trainer.num_training_batches != float("inf"): - trainer.num_training_batches = int(orig_train_batches * trainer.limit_train_batches) + self.max_batches = min(orig_train_batches, trainer.limit_train_batches) + elif self.max_batches != float("inf"): + self.max_batches = int(orig_train_batches * trainer.limit_train_batches) elif trainer.limit_train_batches != 1.0: raise MisconfigurationException( "When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int." @@ -250,10 +257,10 @@ def setup_data(self, shuffle: bool = True) -> None: if isinstance(trainer.val_check_interval, int): trainer.val_check_batch = trainer.val_check_interval - if trainer.val_check_batch > trainer.num_training_batches and trainer.check_val_every_n_epoch is not None: + if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: raise ValueError( f" `val_check_interval` ({trainer.val_check_interval}) must be less than or equal" - f" to the number of the training batches ({trainer.num_training_batches})." + f" to the number of the training batches ({self.max_batches})." " If you want to disable validation set `limit_val_batches` to 0.0 instead." " If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`." ) @@ -268,19 +275,19 @@ def setup_data(self, shuffle: bool = True) -> None: " checking validation every k training batches." ) else: - trainer.val_check_batch = int(trainer.num_training_batches * trainer.val_check_interval) + trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval) trainer.val_check_batch = max(1, trainer.val_check_batch) - if trainer.loggers and trainer.num_training_batches < trainer.log_every_n_steps: + if trainer.loggers and self.max_batches < trainer.log_every_n_steps: rank_zero_warn( - f"The number of training batches ({trainer.num_training_batches}) is smaller than the logging interval" + f"The number of training batches ({self.max_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch.", category=PossibleUserWarning, ) if ( - trainer.num_training_batches == 0 + self.max_batches == 0 and trainer.limit_train_batches > 0.0 and isinstance(trainer.limit_train_batches, float) and orig_train_batches != float("inf") diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 4ea3000135b26..2a3c1ef3dc35c 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -31,6 +31,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: self.epoch_batch_indices: List[List[List[int]]] = [] self.current_batch_indices: List[int] = [] # used by PredictionWriter self.batch_progress = Progress() # across dataloaders + self.max_batches: List[Union[int, float]] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") @@ -71,11 +72,6 @@ def num_dataloaders(self) -> int: assert combined_loader is not None return len(combined_loader.flattened) - @property - def max_batches(self) -> List[Union[int, float]]: - """The max number of batches this loop will run for each dataloader.""" - return self.trainer.num_predict_batches - @property def skip(self) -> bool: return sum(self.max_batches) == 0 @@ -109,7 +105,7 @@ def setup_data(self) -> None: if not source.is_defined() or trainer.limit_predict_batches == 0: return - trainer.num_predict_batches, combined_loader = trainer._data_connector._reset_eval_dataloader( + self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader( RunningStage.PREDICTING, model=pl_module ) for dl in combined_loader.flattened: diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 10f6eec96ae3f..80ea81b98a4fa 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -44,18 +44,6 @@ def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer self._datahook_selector: Optional[_DataHookSelector] = None - @property - def _should_reload_train_dl(self) -> bool: - """Check if train dataloader should be reloaded.""" - n_epochs = self.trainer.reload_dataloaders_every_n_epochs - return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs - - @property - def _should_reload_val_dl(self) -> bool: - """Check if validation dataloader should be reloaded.""" - n_epochs = self.trainer.reload_dataloaders_every_n_epochs - return bool(n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs) - def on_trainer_init( self, val_check_interval: Optional[Union[int, float]], @@ -83,7 +71,6 @@ def on_trainer_init( ) self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs - self.trainer._is_data_prepared = False def prepare_data(self) -> None: trainer = self.trainer @@ -107,7 +94,6 @@ def prepare_data(self) -> None: lm_prepare_data_per_node = lightning_module.prepare_data_per_node if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): call._call_lightning_module_hook(trainer, "prepare_data") - trainer._is_data_prepared = True def attach_data( self, diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index eca416bbb8d72..8c2291a8fa705 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -349,7 +349,11 @@ def __init__( ) self._detect_anomaly: bool = detect_anomaly - self._setup_on_init() + + setup._log_device_info(self) + + self.should_stop = False + self.state = TrainerState() # configure profiler setup._init_profiler(self, profiler) @@ -378,22 +382,6 @@ def __init__( num_sanity_val_steps, ) - def _setup_on_init(self) -> None: - setup._log_device_info(self) - - self.should_stop = False - self.state = TrainerState() - - # TODO(carmocca): move these to the loops - self.num_training_batches = float("inf") - self.num_sanity_val_batches: List[Union[int, float]] = [] - self.num_test_batches: List[Union[int, float]] = [] - self.num_val_batches: List[Union[int, float]] = [] - self.num_predict_batches: List[Union[int, float]] = [] - - self._last_train_dl_reload_epoch = float("-inf") - self._last_val_dl_reload_epoch = float("-inf") - def fit( self, model: "pl.LightningModule", @@ -1305,6 +1293,31 @@ def predict_dataloaders(self) -> EVAL_DATALOADERS: if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables + @property + def num_training_batches(self) -> Union[int, float]: + return self.fit_loop.max_batches + + @property + def num_sanity_val_batches(self) -> List[Union[int, float]]: + max_batches = self.fit_loop.epoch_loop.val_loop.max_batches + return [min(self.num_sanity_val_steps, batches) for batches in max_batches] + + @property + def num_val_batches(self) -> List[Union[int, float]]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validate_loop.max_batches + # if no trainer.fn is set, assume fit's validation + # use the protected access, because it shouldn't return the sanity_val batches + return self.fit_loop.epoch_loop.val_loop._max_batches + + @property + def num_test_batches(self) -> List[Union[int, float]]: + return self.test_loop.max_batches + + @property + def num_predict_batches(self) -> List[Union[int, float]]: + return self.predict_loop.max_batches + @property def _evaluation_loop(self) -> _EvaluationLoop: if self.state.fn == TrainerFn.FITTING: diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 9e82120f05d68..225bfaf1d9dcf 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1366,9 +1366,9 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs): def test_train_epoch_end_ckpt_with_no_validation(): trainer = Trainer(val_check_interval=0.5) - trainer.num_val_batches = [0] + trainer.fit_loop.epoch_loop.val_loop._max_batches = [0] assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) - trainer.num_val_batches = [1] + trainer.fit_loop.epoch_loop.val_loop._max_batches = [1] assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) trainer.val_check_interval = 0.8 assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 8de5e45350546..2d7929e11de1c 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met( """Test that checks that info message is logged when users sets `should_stop` but min conditions are not met.""" trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0) - trainer.num_training_batches = 10 + trainer.fit_loop.max_batches = 10 trainer.should_stop = True trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 1d73e9f49faa2..6e484438ed839 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -139,15 +139,15 @@ def test_fit_loop_done_log_messages(caplog): fit_loop = _FitLoop(trainer, max_epochs=1) trainer.should_stop = False - trainer.num_training_batches = 5 + fit_loop.max_batches = 5 assert not fit_loop.done assert not caplog.messages - trainer.num_training_batches = 0 + fit_loop.max_batches = 0 assert fit_loop.done assert "No training batches" in caplog.text caplog.clear() - trainer.num_training_batches = 5 + fit_loop.max_batches = 5 epoch_loop = Mock() epoch_loop.global_step = 10 @@ -191,7 +191,7 @@ def test_should_stop_early_stopping_conditions_met( ): """Test that checks that debug message is logged when users sets `should_stop` and min conditions are met.""" trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100) - trainer.num_training_batches = 10 + trainer.fit_loop.max_batches = 10 trainer.should_stop = True trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = ( current_epoch * trainer.num_training_batches diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index ed138b8b94fbe..553fbd49259c4 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1114,9 +1114,7 @@ def val_dataloader(self): wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, ) as mocked: trainer.fit(model) - assert mocked.call_count == sum( - min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches - ) + assert mocked.call_count == sum(trainer.num_sanity_val_batches) @pytest.mark.parametrize("limit_val_batches", [0.0, 1, 1.0, 0.3])