diff --git a/src/lightning/pytorch/callbacks/batch_size_finder.py b/src/lightning/pytorch/callbacks/batch_size_finder.py index e348f4946e51b..a03251ec995b8 100644 --- a/src/lightning/pytorch/callbacks/batch_size_finder.py +++ b/src/lightning/pytorch/callbacks/batch_size_finder.py @@ -63,6 +63,10 @@ class BatchSizeFinder(Callback): - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) + margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using + 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction). + max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value. + Example:: # 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is @@ -118,6 +122,8 @@ def __init__( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", + margin: float = 0.05, + max_val: Optional[int] = None, ) -> None: mode = mode.lower() if mode not in self.SUPPORTED_MODES: @@ -129,6 +135,8 @@ def __init__( self._init_val = init_val self._max_trials = max_trials self._batch_arg_name = batch_arg_name + self._margin = margin + self._max_val = max_val self._early_exit = False @override @@ -180,6 +188,8 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule self._init_val, self._max_trials, self._batch_arg_name, + self._margin, + self._max_val, ) self.optimal_batch_size = new_size diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 78d2aa52f5725..510e07e674ec5 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -32,6 +32,8 @@ def _scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", + margin: float = 0.05, + max_val: Optional[int] = None, ) -> Optional[int]: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -58,6 +60,10 @@ def _scale_batch_size( - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) + margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using + 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction). + max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value. + """ if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.") @@ -80,9 +86,9 @@ def _scale_batch_size( new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) if mode == "power": - new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) + new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val) elif mode == "binsearch": - new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) + new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val) garbage_collection_cuda() @@ -173,6 +179,7 @@ def _run_power_scaling( batch_arg_name: str, max_trials: int, params: dict[str, Any], + max_val: Optional[int], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -186,7 +193,9 @@ def _run_power_scaling( try: _try_loop_run(trainer, params) - new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") + new_size, changed = _adjust_batch_size( + trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val + ) if not changed: break @@ -209,12 +218,14 @@ def _run_power_scaling( return new_size -def _run_binary_scaling( +def _run_binsearch_scaling( trainer: "pl.Trainer", new_size: int, batch_arg_name: str, max_trials: int, params: dict[str, Any], + margin: float, + max_val: Optional[int], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -242,9 +253,13 @@ def _run_binary_scaling( if high - low <= 1: break midval = (high + low) // 2 - new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded") + new_size, changed = _adjust_batch_size( + trainer, batch_arg_name, value=midval, desc="succeeded", max_val=max_val + ) else: - new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") + new_size, changed = _adjust_batch_size( + trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val + ) if not changed: break @@ -270,6 +285,15 @@ def _run_binary_scaling( else: raise # some other error not memory related + # Apply margin reduction for binsearch mode + if margin > 0: + margin_reduced_size = max(1, int(new_size * (1 - margin))) + if margin_reduced_size != new_size: + rank_zero_info( + f"Applying margin of {margin:.1%}, reducing batch size from {new_size} to {margin_reduced_size}" + ) + new_size = margin_reduced_size + return new_size @@ -279,6 +303,7 @@ def _adjust_batch_size( factor: float = 1.0, value: Optional[int] = None, desc: Optional[str] = None, + max_val: Optional[int] = None, ) -> tuple[int, bool]: """Helper function for adjusting the batch size. @@ -289,6 +314,7 @@ def _adjust_batch_size( value: if a value is given, will override the batch size with this value. Note that the value of `factor` will not have an effect in this case desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging + max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value. Returns: The new batch size for the next trial and a bool that signals whether the @@ -314,6 +340,12 @@ def _adjust_batch_size( pass new_size = value if value is not None else int(batch_size * factor) + + # Apply max_val limit if provided + if max_val is not None and new_size > max_val: + if desc: + rank_zero_info(f"Batch size {new_size} exceeds max_val limit {max_val}, capping at {max_val}") + new_size = max_val if desc: rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") changed = new_size != batch_size diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 8b9b423619bd2..1f0abfbe3a0ef 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -41,6 +41,8 @@ def scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", + margin: float = 0.05, + max_val: Optional[int] = None, ) -> Optional[int]: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -75,6 +77,10 @@ def scale_batch_size( - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) + margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using + 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction). + max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value. + """ _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method) _check_scale_batch_size_configuration(self._trainer) @@ -88,6 +94,8 @@ def scale_batch_size( init_val=init_val, max_trials=max_trials, batch_arg_name=batch_arg_name, + margin=margin, + max_val=max_val, ) # do not continue with the loop in case Tuner is used batch_size_finder._early_exit = True diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index f0e5fbe6a3c49..a6ea856839c7e 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -489,6 +489,51 @@ def test_batch_size_finder_callback_val_batches(tmp_path): assert trainer.num_val_batches[0] != steps_per_trial +@pytest.mark.parametrize("margin", [0.0, 0.1, 0.2]) +def test_scale_batch_size_margin_and_max_val(tmp_path, margin): + """Test margin feature for batch size scaling by comparing results with and without margin.""" + # First, find the batch size without margin + model1 = BatchSizeModel(batch_size=2) + trainer1 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False) + tuner1 = Tuner(trainer1) + + result_without_margin = tuner1.scale_batch_size( + model1, mode="binsearch", max_trials=2, steps_per_trial=1, margin=0.0 + ) + + model2 = BatchSizeModel(batch_size=2) + trainer2 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False) + tuner2 = Tuner(trainer2) + + result_with_margin = tuner2.scale_batch_size( + model2, mode="binsearch", max_trials=2, steps_per_trial=1, margin=margin + ) + + assert result_without_margin is not None + assert result_with_margin is not None + + if margin == 0.0: + assert result_with_margin == result_without_margin + else: + expected_with_margin = max(1, int(result_without_margin * (1 - margin))) + assert result_with_margin == expected_with_margin + assert result_with_margin <= result_without_margin + + +@pytest.mark.parametrize("mode", ["power", "binsearch"]) +def test_scale_batch_size_max_val_limit(tmp_path, mode): + """Test that max_val limits the batch size for both power and binsearch modes.""" + model = BatchSizeModel(batch_size=2) + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + tuner = Tuner(trainer) + + max_val = 8 # Set a low max value + result = tuner.scale_batch_size(model, mode=mode, max_trials=5, steps_per_trial=1, max_val=max_val) + + assert result is not None + assert result <= max_val + + def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path): """Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""