Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change trainer.should_stop to not stop in between an epoch and run until min_steps/min_epochs only #13890

Merged
merged 18 commits into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,33 @@ execution within that function, and the status of the Trainer.
trainer.state.status
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
trainer.state.stage

should_stop
***********

If you want to terminate the training during ``.fit``, you can set ``trainer.should_stop=True`` and Lightning will terminate the training
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
immediately. Note that, it will respect the arguments ``min_steps`` and ``min_epochs`` to check whether to stop it or not. If these
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
arguments are set and the ``current_epoch`` or ``global_step`` doesn't meet these minimum conditions, training will continue until
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
both conditions are met. If any of these arguments is not set, it won't be considered for the final decision.


.. code-block:: python

trainer = Trainer()
# setting `trainer.should_stop`` at any point of training will terminate
trainer.should_stop = True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

trainer = Trainer(min_epochs=5, max_epochs=100)
# setting `trainer.should_stop`` at any point before 5th epoch is completed
# will not terminate the training until 5th epoch is completed
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer.should_stop = True

trainer = Trainer(min_steps=5, max_epochs=100)
# setting `trainer.should_stop`` at any point before 5th step is completed
# will not terminate the training until 5th step is completed
trainer.should_stop = True

trainer = Trainer(min_steps=5, min_epochs=5, max_epochs=100, limit_train_batches=10)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# setting `trainer.should_stop`` at any point before/after 5th step is completed
# will not terminate the training until 5th epoch is completed
trainer.should_stop = True
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disallowed using `BatchSampler` when running on multiple IPUs ([#13854](https://github.com/PyTorchLightning/pytorch-lightning/pull/13854))


- Refactored `trainer.should_stop` to not stop in between an epoch and run until `min_steps/min_epochs` only ([#13890](https://github.com/Lightning-AI/lightning/pull/13890))

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

### Deprecated

- Deprecated `pytorch_lightning.accelerators.gpu.GPUAccelerator` in favor of `pytorch_lightning.accelerators.cuda.CUDAAccelerator` ([#13636](https://github.com/Lightning-AI/lightning/pull/13636))
Expand Down
18 changes: 17 additions & 1 deletion src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,23 @@ def _is_validation_done(self) -> bool:
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
if self._is_training_done and self._is_validation_done:
return True

if self.trainer.should_stop:
# early stopping
min_epochs = self.trainer.fit_loop.min_epochs
met_min_epochs = self.trainer.current_epoch >= min_epochs if min_epochs else True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
met_min_conditions = met_min_epochs and met_min_steps
if not met_min_conditions:
self._warning_cache.info(
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
return met_min_conditions

return False

def connect( # type: ignore[override]
self,
Expand Down
8 changes: 1 addition & 7 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,9 @@ def done(self) -> bool:
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
self.trainer.should_stop = True
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
return True
else:
rank_zero_info(
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
self.trainer.should_stop = False

return False

@property
Expand Down
6 changes: 6 additions & 0 deletions src/pytorch_lightning/utilities/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning as NewLightningDeprecationWarning
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation as new_rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn

# enable our warnings
Expand All @@ -39,6 +40,11 @@ def deprecation(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
self.add(message)
new_rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)

def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if message not in self:
self.add(message)
new_rank_zero_info(message, stacklevel=stacklevel, **kwargs)


def rank_zero_warn(*args: Any, **kwargs: Any) -> Any:
new_rank_zero_deprecation(
Expand Down
20 changes: 11 additions & 9 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,25 +265,28 @@ def validation_epoch_end(self, outputs):
assert early_stopping.stopped_epoch == expected_stop_epoch


@pytest.mark.parametrize("limit_train_batches", (3, 5))
@pytest.mark.parametrize(
["min_epochs", "min_steps"],
"limit_train_batches,min_epochs,min_steps,stop_step",
[
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
(0, 10),
# (3, 0, 10, 10),
# (5, 0, 10, 10),
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
# being triggered, THEN the trainer should continue until reaching
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
(2, 0),
(3, 2, 0, 6),
(5, 2, 0, 10),
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
# `min_steps` would be reached
(1, 10),
(3, 10),
(3, 1, 10, 10),
(5, 1, 10, 10),
(3, 3, 10, 10),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
(5, 3, 10, 15),
],
)
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps, stop_step):
if min_steps:
assert limit_train_batches < min_steps

Expand Down Expand Up @@ -317,8 +320,7 @@ def training_step(self, batch, batch_idx):
# epochs continue until min steps are reached
assert trainer.current_epoch == expected_epochs
# steps continue until min steps are reached AND the epoch is exhausted
# stopping mid-epoch is not supported
assert trainer.global_step == limit_train_batches * expected_epochs
assert trainer.global_step == stop_step


def test_early_stopping_mode_options():
Expand Down
36 changes: 36 additions & 0 deletions tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -265,3 +266,38 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
) as advance_mocked:
trainer.fit(model, ckpt_path=ckpt_path)
assert advance_mocked.call_count == 1


@pytest.mark.parametrize(
"min_epochs, min_steps, current_epoch, global_step, epoch_loop_done, raise_info_msg",
[
(None, None, 1, 4, True, False),
(4, None, 1, 4, False, True),
(4, 2, 1, 4, False, True),
(4, None, 1, 10, True, False),
(4, 3, 1, 3, False, True),
(4, 10, 1, 10, True, False),
(None, 4, 1, 4, False, False),
(None, 10, 1, 10, True, False),
],
)
def test_should_stop_early_stopping_conditions(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
caplog, min_epochs, min_steps, current_epoch, global_step, epoch_loop_done, raise_info_msg
):
def get_trainer():
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
global_step
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1
return trainer

trainer = get_trainer()
message = f"min_epochs={min_epochs}` or `min_steps={min_steps}` has not been met. Training will continue"
with caplog.at_level(logging.INFO, logger="pytorch_lightning.loops"):
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done

assert (message in caplog.text) is raise_info_msg
32 changes: 31 additions & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def test_fit_loop_done_log_messages(caplog):

fit_loop.epoch_loop.min_steps = 100
assert not fit_loop.done
assert "was signaled to stop but" in caplog.text


def test_warning_valid_train_step_end(tmpdir):
Expand All @@ -198,3 +197,34 @@ def training_step_end(self, outputs):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

trainer.fit(model)


@pytest.mark.parametrize(
"min_epochs, min_steps, current_epoch, fit_loop_done, raise_debug_msg",
[
(4, None, 100, True, False),
(4, None, 3, False, False),
(4, 10, 3, False, False),
(None, 10, 4, True, True),
(4, None, 4, True, True),
(4, 10, 4, True, True),
],
)
def test_should_stop_early_stopping_conditions(
caplog, min_epochs, min_steps, current_epoch, fit_loop_done, raise_debug_msg
):
def get_trainer():
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = 10
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
trainer.fit_loop.epoch_progress.current.processed = current_epoch
return trainer

trainer = get_trainer()
message = "`Trainer.fit` stopped: `trainer.should_stop` was set."
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
assert trainer.fit_loop.done is fit_loop_done

assert (message in caplog.text) is raise_debug_msg
7 changes: 5 additions & 2 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,10 @@ def training_step(self, batch, batch_idx):
output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping
self.log("loss", output["loss"])
self.training_step_invoked += 1
assert not self.trainer.should_stop
if self.current_epoch < 2:
assert not self.trainer.should_stop
else:
assert self.trainer.should_stop
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return output

model = TestModel()
Expand All @@ -618,7 +621,7 @@ def training_step(self, batch, batch_idx):

message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue"
num_messages = sum(1 for record in caplog.records if message in record.message)
assert num_messages == min_epochs - 2
assert num_messages == 1
assert model.training_step_invoked == min_epochs * 2


Expand Down