Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move max_batches definition to the Loops #16820

Merged
merged 9 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
assert isinstance(trainer.num_training_batches, int)
trainer.num_training_batches += 1
# Therefore, we will virtually increase the number of training batches by 1 and skip backward.
trainer.fit_loop.max_batches += 1
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches

trainer.accumulate_grad_batches = trainer.num_training_batches
assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported"
trainer.accumulate_grad_batches = trainer.fit_loop.max_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
trainer.fit_loop._skip_backward = False
Expand All @@ -266,7 +265,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.fit_loop.max_batches -= 1
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
Expand Down
38 changes: 11 additions & 27 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
self.verbose = verbose
self.inference_mode = inference_mode
self.batch_progress = BatchProgress() # across dataloaders
self.max_batches: List[Union[int, float]] = []

self._results = _ResultCollection(training=False)
self._logged_outputs: List[_OUT_DICT] = []
Expand All @@ -55,6 +56,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
self._last_val_dl_reload_epoch = float("-inf")

@property
def num_dataloaders(self) -> int:
Expand All @@ -63,22 +65,17 @@ def num_dataloaders(self) -> int:
assert combined_loader is not None
return len(combined_loader.flattened)

@property
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches this loop will run for each dataloader."""
if self.trainer.testing:
return self.trainer.num_test_batches
elif self.trainer.sanity_checking:
return self.trainer.num_sanity_val_batches
elif self.trainer.validating:
return self.trainer.num_val_batches
raise RuntimeError(f"Unexpected stage: {self.trainer.state.stage}")

@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
return sum(self.max_batches) == 0

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return bool(n_epochs and self.trainer.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)

@_no_grad_context
def run(self) -> List[_OUT_DICT]:
self.setup_data()
Expand Down Expand Up @@ -110,11 +107,7 @@ def run(self) -> List[_OUT_DICT]:
def setup_data(self) -> None:
trainer = self.trainer

if (
self._combined_loader is not None
and trainer.state.fn == "fit"
and not trainer._data_connector._should_reload_val_dl
):
if self._combined_loader is not None and trainer.state.fn == "fit" and not self._should_reload_val_dl:
return

source = self._data_source
Expand All @@ -130,20 +123,11 @@ def setup_data(self) -> None:
(trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch())
or not trainer.sanity_checking
):
trainer._last_val_dl_reload_epoch = trainer.current_epoch
self._last_val_dl_reload_epoch = trainer.current_epoch

stage = trainer.state.stage
assert stage is not None
num_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)
if trainer.testing:
trainer.num_test_batches = num_batches
elif trainer.sanity_checking:
trainer.num_val_batches = num_batches
trainer.num_sanity_val_batches = [
min(trainer.num_sanity_val_steps, val_batches) for val_batches in num_batches
]
else:
trainer.num_val_batches = num_batches
self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)

if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop
for dl in combined_loader.flattened:
Expand Down
43 changes: 25 additions & 18 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import Optional, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _auto_add_worker_init_fn
Expand Down Expand Up @@ -79,10 +79,12 @@ def __init__(
self.min_epochs = min_epochs
self.epoch_loop = _TrainingEpochLoop(trainer)
self.epoch_progress = Progress()
self.max_batches: Union[int, float] = float("inf")

self._data_source = _DataLoaderSource(None, "train_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -136,10 +138,16 @@ def _can_stop_early(self) -> bool:
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
return met_min_epochs and met_min_steps

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and self.trainer.current_epoch - self._last_train_dl_reload_epoch >= n_epochs

@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.trainer.num_training_batches == 0:
if self.max_batches == 0:
rank_zero_info("`Trainer.fit` stopped: No training batches.")
return True

Expand Down Expand Up @@ -168,8 +176,8 @@ def done(self) -> bool:
@property
def skip(self) -> bool:
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
# until `on_run_start`, we use `limit_train_batches` instead
# if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`)
# so we cannot use it solely
return self.done or self.trainer.limit_train_batches == 0

def run(self) -> None:
Expand All @@ -190,11 +198,10 @@ def run(self) -> None:
self.on_run_end()

def setup_data(self, shuffle: bool = True) -> None:
trainer = self.trainer

if self._combined_loader is not None and not trainer._data_connector._should_reload_train_dl:
if self._combined_loader is not None and not self._should_reload_train_dl:
return

trainer = self.trainer
source = self._data_source
pl_module = trainer.lightning_module
if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
Expand Down Expand Up @@ -227,7 +234,7 @@ def setup_data(self, shuffle: bool = True) -> None:
self._combined_loader = combined_loader

module = pl_module or trainer.datamodule
orig_train_batches = trainer.num_training_batches = (
orig_train_batches = self.max_batches = (
len(self._combined_loader)
if has_len_all_ranks(self._combined_loader, trainer.strategy, module)
else float("inf")
Expand All @@ -236,12 +243,12 @@ def setup_data(self, shuffle: bool = True) -> None:
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
trainer._last_train_dl_reload_epoch = trainer.current_epoch
self._last_train_dl_reload_epoch = trainer.current_epoch

if isinstance(trainer.limit_train_batches, int):
trainer.num_training_batches = min(orig_train_batches, trainer.limit_train_batches)
elif trainer.num_training_batches != float("inf"):
trainer.num_training_batches = int(orig_train_batches * trainer.limit_train_batches)
self.max_batches = min(orig_train_batches, trainer.limit_train_batches)
elif self.max_batches != float("inf"):
self.max_batches = int(orig_train_batches * trainer.limit_train_batches)
elif trainer.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
Expand All @@ -250,10 +257,10 @@ def setup_data(self, shuffle: bool = True) -> None:

if isinstance(trainer.val_check_interval, int):
trainer.val_check_batch = trainer.val_check_interval
if trainer.val_check_batch > trainer.num_training_batches and trainer.check_val_every_n_epoch is not None:
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
raise ValueError(
f" `val_check_interval` ({trainer.val_check_interval}) must be less than or equal"
f" to the number of the training batches ({trainer.num_training_batches})."
f" to the number of the training batches ({self.max_batches})."
" If you want to disable validation set `limit_val_batches` to 0.0 instead."
" If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
Expand All @@ -268,19 +275,19 @@ def setup_data(self, shuffle: bool = True) -> None:
" checking validation every k training batches."
)
else:
trainer.val_check_batch = int(trainer.num_training_batches * trainer.val_check_interval)
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
trainer.val_check_batch = max(1, trainer.val_check_batch)

if trainer.loggers and trainer.num_training_batches < trainer.log_every_n_steps:
if trainer.loggers and self.max_batches < trainer.log_every_n_steps:
rank_zero_warn(
f"The number of training batches ({trainer.num_training_batches}) is smaller than the logging interval"
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,
)

if (
trainer.num_training_batches == 0
self.max_batches == 0
and trainer.limit_train_batches > 0.0
and isinstance(trainer.limit_train_batches, float)
and orig_train_batches != float("inf")
Expand Down
8 changes: 2 additions & 6 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None:
self.epoch_batch_indices: List[List[List[int]]] = []
self.current_batch_indices: List[int] = [] # used by PredictionWriter
self.batch_progress = Progress() # across dataloaders
self.max_batches: List[Union[int, float]] = []

self._warning_cache = WarningCache()
self._data_source = _DataLoaderSource(None, "predict_dataloader")
Expand Down Expand Up @@ -71,11 +72,6 @@ def num_dataloaders(self) -> int:
assert combined_loader is not None
return len(combined_loader.flattened)

@property
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches this loop will run for each dataloader."""
return self.trainer.num_predict_batches

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0
Expand Down Expand Up @@ -109,7 +105,7 @@ def setup_data(self) -> None:
if not source.is_defined() or trainer.limit_predict_batches == 0:
return

trainer.num_predict_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
for dl in combined_loader.flattened:
Expand Down
17 changes: 3 additions & 14 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer
self._datahook_selector: Optional[_DataHookSelector] = None

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return bool(n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs)

def on_trainer_init(
self,
val_check_interval: Optional[Union[int, float]],
Expand Down Expand Up @@ -83,7 +71,6 @@ def on_trainer_init(
)

self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def prepare_data(self) -> None:
trainer = self.trainer
Expand All @@ -107,7 +94,6 @@ def prepare_data(self) -> None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
call._call_lightning_module_hook(trainer, "prepare_data")
trainer._is_data_prepared = True

def attach_data(
self,
Expand Down Expand Up @@ -388,6 +374,9 @@ def _reset_eval_dataloader(
f" `limit_{mode.dataloader_prefix}_batches={min_percentage}`"
)

if mode == RunningStage.SANITY_CHECKING:
num_batches = min(trainer.num_sanity_val_steps, num_batches)

loader_num_batches.append(num_batches)
combined_loader.flattened = dataloaders

Expand Down
49 changes: 32 additions & 17 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def __init__(
)

self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()

setup._log_device_info(self)

self.should_stop = False
self.state = TrainerState()

# configure profiler
setup._init_profiler(self, profiler)
Expand Down Expand Up @@ -378,22 +382,6 @@ def __init__(
num_sanity_val_steps,
)

def _setup_on_init(self) -> None:
setup._log_device_info(self)

self.should_stop = False
self.state = TrainerState()

# TODO(carmocca): move these to the loops
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.num_training_batches = float("inf")
self.num_sanity_val_batches: List[Union[int, float]] = []
self.num_test_batches: List[Union[int, float]] = []
self.num_val_batches: List[Union[int, float]] = []
self.num_predict_batches: List[Union[int, float]] = []

self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

def fit(
self,
model: "pl.LightningModule",
Expand Down Expand Up @@ -939,6 +927,9 @@ def _run_sanity_check(self) -> None:
# because sanity check only runs when we are not restarting
_reset_progress(val_loop)

# reset the loaded data
val_loop._combined_loader = None

# restore the previous stage when the sanity check if finished
self.state.stage = stage

Expand Down Expand Up @@ -1305,6 +1296,30 @@ def predict_dataloaders(self) -> EVAL_DATALOADERS:
if (combined_loader := self.predict_loop._combined_loader) is not None:
return combined_loader.iterables

@property
def num_training_batches(self) -> Union[int, float]:
return self.fit_loop.max_batches

@property
def num_sanity_val_batches(self) -> List[Union[int, float]]:
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]

@property
def num_val_batches(self) -> List[Union[int, float]]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop.max_batches
# if no trainer.fn is set, assume fit's validation
return self.fit_loop.epoch_loop.val_loop.max_batches

@property
def num_test_batches(self) -> List[Union[int, float]]:
return self.test_loop.max_batches

@property
def num_predict_batches(self) -> List[Union[int, float]]:
return self.predict_loop.max_batches

@property
def _evaluation_loop(self) -> _EvaluationLoop:
if self.state.fn == TrainerFn.FITTING:
Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,9 +1114,7 @@ def val_dataloader(self):
wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step,
) as mocked:
trainer.fit(model)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)
assert mocked.call_count == sum(trainer.num_sanity_val_batches)


@pytest.mark.parametrize("limit_val_batches", [0.0, 1, 1.0, 0.3])
Expand Down