Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/lightning/pytorch/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class BatchSizeFinder(Callback):
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)

margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.

Example::

# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
Expand Down Expand Up @@ -118,6 +122,8 @@ def __init__(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
margin: float = 0.05,
max_val: Optional[int] = None,
) -> None:
mode = mode.lower()
if mode not in self.SUPPORTED_MODES:
Expand All @@ -129,6 +135,8 @@ def __init__(
self._init_val = init_val
self._max_trials = max_trials
self._batch_arg_name = batch_arg_name
self._margin = margin
self._max_val = max_val
self._early_exit = False

@override
Expand Down Expand Up @@ -180,6 +188,8 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
self._init_val,
self._max_trials,
self._batch_arg_name,
self._margin,
self._max_val,
)

self.optimal_batch_size = new_size
Expand Down
44 changes: 38 additions & 6 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def _scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
margin: float = 0.05,
max_val: Optional[int] = None,
) -> Optional[int]:
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
error.
Expand All @@ -58,6 +60,10 @@ def _scale_batch_size(
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)

margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.

"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
Expand All @@ -80,9 +86,9 @@ def _scale_batch_size(
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)

if mode == "power":
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val)
elif mode == "binsearch":
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val)

garbage_collection_cuda()

Expand Down Expand Up @@ -173,6 +179,7 @@ def _run_power_scaling(
batch_arg_name: str,
max_trials: int,
params: dict[str, Any],
max_val: Optional[int],
) -> int:
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
Expand All @@ -186,7 +193,9 @@ def _run_power_scaling(

try:
_try_loop_run(trainer, params)
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
new_size, changed = _adjust_batch_size(
trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val
)

if not changed:
break
Expand All @@ -209,12 +218,14 @@ def _run_power_scaling(
return new_size


def _run_binary_scaling(
def _run_binsearch_scaling(
trainer: "pl.Trainer",
new_size: int,
batch_arg_name: str,
max_trials: int,
params: dict[str, Any],
margin: float,
max_val: Optional[int],
) -> int:
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.

Expand Down Expand Up @@ -242,9 +253,13 @@ def _run_binary_scaling(
if high - low <= 1:
break
midval = (high + low) // 2
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded")
new_size, changed = _adjust_batch_size(
trainer, batch_arg_name, value=midval, desc="succeeded", max_val=max_val
)
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
new_size, changed = _adjust_batch_size(
trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val
)

if not changed:
break
Expand All @@ -270,6 +285,15 @@ def _run_binary_scaling(
else:
raise # some other error not memory related

# Apply margin reduction for binsearch mode
if margin > 0:
margin_reduced_size = max(1, int(new_size * (1 - margin)))
if margin_reduced_size != new_size:
rank_zero_info(
f"Applying margin of {margin:.1%}, reducing batch size from {new_size} to {margin_reduced_size}"
)
new_size = margin_reduced_size

return new_size


Expand All @@ -279,6 +303,7 @@ def _adjust_batch_size(
factor: float = 1.0,
value: Optional[int] = None,
desc: Optional[str] = None,
max_val: Optional[int] = None,
) -> tuple[int, bool]:
"""Helper function for adjusting the batch size.

Expand All @@ -289,6 +314,7 @@ def _adjust_batch_size(
value: if a value is given, will override the batch size with this value.
Note that the value of `factor` will not have an effect in this case
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value.

Returns:
The new batch size for the next trial and a bool that signals whether the
Expand All @@ -314,6 +340,12 @@ def _adjust_batch_size(
pass

new_size = value if value is not None else int(batch_size * factor)

# Apply max_val limit if provided
if max_val is not None and new_size > max_val:
if desc:
rank_zero_info(f"Batch size {new_size} exceeds max_val limit {max_val}, capping at {max_val}")
new_size = max_val
if desc:
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
changed = new_size != batch_size
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
margin: float = 0.05,
max_val: Optional[int] = None,
) -> Optional[int]:
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
error.
Expand Down Expand Up @@ -75,6 +77,10 @@ def scale_batch_size(
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)

margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.

"""
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
_check_scale_batch_size_configuration(self._trainer)
Expand All @@ -88,6 +94,8 @@ def scale_batch_size(
init_val=init_val,
max_trials=max_trials,
batch_arg_name=batch_arg_name,
margin=margin,
max_val=max_val,
)
# do not continue with the loop in case Tuner is used
batch_size_finder._early_exit = True
Expand Down
45 changes: 45 additions & 0 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,51 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
assert trainer.num_val_batches[0] != steps_per_trial


@pytest.mark.parametrize("margin", [0.0, 0.1, 0.2])
def test_scale_batch_size_margin_and_max_val(tmp_path, margin):
"""Test margin feature for batch size scaling by comparing results with and without margin."""
# First, find the batch size without margin
model1 = BatchSizeModel(batch_size=2)
trainer1 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
tuner1 = Tuner(trainer1)

result_without_margin = tuner1.scale_batch_size(
model1, mode="binsearch", max_trials=2, steps_per_trial=1, margin=0.0
)

model2 = BatchSizeModel(batch_size=2)
trainer2 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
tuner2 = Tuner(trainer2)

result_with_margin = tuner2.scale_batch_size(
model2, mode="binsearch", max_trials=2, steps_per_trial=1, margin=margin
)

assert result_without_margin is not None
assert result_with_margin is not None

if margin == 0.0:
assert result_with_margin == result_without_margin
else:
expected_with_margin = max(1, int(result_without_margin * (1 - margin)))
assert result_with_margin == expected_with_margin
assert result_with_margin <= result_without_margin


@pytest.mark.parametrize("mode", ["power", "binsearch"])
def test_scale_batch_size_max_val_limit(tmp_path, mode):
"""Test that max_val limits the batch size for both power and binsearch modes."""
model = BatchSizeModel(batch_size=2)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
tuner = Tuner(trainer)

max_val = 8 # Set a low max value
result = tuner.scale_batch_size(model, mode=mode, max_trials=5, steps_per_trial=1, max_val=max_val)

assert result is not None
assert result <= max_val


def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""

Expand Down
Loading