Skip to content

Commit 49a4a36

Browse files
carmoccaawaelchli
andauthored
Have the outputs match the loops format (#12182)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 821ca7e commit 49a4a36

File tree

10 files changed

+350
-153
lines changed

10 files changed

+350
-153
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
* Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/pull/10638))
3131

3232

33-
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/pull/10639))
33+
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/pull/10680))
3434

3535

3636
- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/pull/10465))
@@ -410,6 +410,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
410410
- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
411411

412412

413+
- Deprecated the `on_train_batch_end(outputs)` format when multiple optimizers are used and TBPTT is enabled ([#12182](https://github.com/PyTorchLightning/pytorch-lightning/pull/12182))
414+
415+
416+
- Deprecated the `training_epoch_end(outputs)` format when multiple optimizers are used and TBPTT is enabled ([#12182](https://github.com/PyTorchLightning/pytorch-lightning/pull/12182))
417+
418+
413419
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
414420

415421

pytorch_lightning/core/lightning.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,9 @@ def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
705705
training_epoch_end(train_outs)
706706
707707
Args:
708-
outputs: List of outputs you defined in :meth:`training_step`.
709-
If there are multiple optimizers, it is a list containing a list of outputs for each optimizer.
710-
If using ``truncated_bptt_steps > 1``, each element is a list of outputs corresponding to the outputs
711-
of each processed split batch.
708+
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers or when
709+
using ``truncated_bptt_steps > 0``, the lists have the dimensions
710+
(n_batches, tbptt_steps, n_optimizers). Dimensions of length 1 are squeezed.
712711
713712
Return:
714713
None

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
import numpy as np
1919
import torch
2020

21+
import pytorch_lightning as pl
2122
from pytorch_lightning import loops # import as loops to avoid circular imports
2223
from pytorch_lightning.loops.batch import TrainingBatchLoop
2324
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
24-
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
25+
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _v1_8_output_format
2526
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2627
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
2728
from pytorch_lightning.trainer.supporters import CombinedLoader
@@ -216,7 +217,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
216217

217218
batch_end_outputs = self._prepare_outputs_training_batch_end(
218219
batch_output,
219-
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
220+
lightning_module=self.trainer.lightning_module,
220221
num_optimizers=len(self.trainer.optimizers),
221222
)
222223

@@ -337,26 +338,38 @@ def _should_accumulate(self) -> bool:
337338
@staticmethod
338339
def _prepare_outputs_training_batch_end(
339340
batch_output: _BATCH_OUTPUTS_TYPE,
340-
automatic: bool,
341+
lightning_module: "pl.LightningModule",
341342
num_optimizers: int,
342343
) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
343-
"""Processes the outputs from the batch loop into the format passed to the ``training_batch_end`` hook.
344-
345-
``(tbptt_steps, n_opt) -> (n_opt, tbptt_steps)``. The optimizer dimension might have been squeezed.
346-
"""
344+
"""Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook."""
347345
if not batch_output:
348346
return []
349347

350348
# convert optimizer dicts to list
351-
if automatic:
349+
if lightning_module.automatic_optimization:
352350
batch_output = apply_to_collection(
353351
batch_output, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers
354352
)
355-
array = np.array(batch_output, dtype=object)
356-
if array.ndim == 1:
357-
array = np.expand_dims(array, 1)
358353

359-
array = array.transpose((1, 0))
354+
array = np.array(batch_output, dtype=object)
355+
# TODO: remove in v1.8
356+
if (
357+
num_optimizers > 1
358+
and lightning_module.truncated_bptt_steps > 0
359+
and not _v1_8_output_format(lightning_module.on_train_batch_end)
360+
):
361+
rank_zero_deprecation(
362+
"You are training with multiple optimizers AND truncated backpropagation through time enabled."
363+
" The current format of the `on_train_batch_end(outputs, ...)` is a 2d list with sizes"
364+
" (n_optimizers, tbptt_steps), however, this has been deprecated and will change in version v1.8 to"
365+
" (tbptt_steps, n_optimizers). You can update your code by adding the following parameter to your"
366+
" hook signature: `on_train_batch_end(outputs, ..., new_format=True)`."
367+
)
368+
# (tbptt_steps, n_opt) -> (n_opt, tbptt_steps)
369+
if array.ndim == 1:
370+
array = np.expand_dims(array, 1)
371+
array = array.transpose((1, 0))
372+
# squeeze all single-element dimensions
360373
array = array.squeeze()
361374
array = array.tolist()
362375
array = _recursive_unpad(array)
@@ -365,35 +378,42 @@ def _prepare_outputs_training_batch_end(
365378
@staticmethod
366379
def _prepare_outputs_training_epoch_end(
367380
batch_outputs: _OUTPUTS_TYPE,
368-
automatic: bool,
381+
lightning_module: "pl.LightningModule",
369382
num_optimizers: int,
370383
) -> Union[List[List[List[Dict[str, Any]]]], List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
371-
"""Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook.
372-
373-
``(n_batches, tbptt_steps, n_opt) -> (n_opt, n_batches, tbptt_steps)``.
374-
All single-element dimensions might have been squeezed.
375-
376-
This processing is necessary because the format of the inputs to the ``training_epoch_end`` hook does not
377-
match the loop structure and because empty dimensions are squeezed. This could break with loop customization.
378-
"""
384+
"""Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook."""
379385
# `batch_outputs` (plural) is the same as `epoch_end_output` (singular)
380386
if not batch_outputs:
381387
return []
382388

383389
# convert optimizer dicts to list
384-
if automatic:
390+
if lightning_module.automatic_optimization:
385391
batch_outputs = apply_to_collection(
386392
batch_outputs, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers
387393
)
388394

389395
array = _recursive_pad(batch_outputs)
390-
if array.ndim == 2:
391-
array = np.expand_dims(array, 2)
392-
array = array.transpose((2, 0, 1))
396+
# TODO: remove in v1.8
397+
if (
398+
num_optimizers > 1
399+
and lightning_module.truncated_bptt_steps > 0
400+
and not _v1_8_output_format(lightning_module.on_train_epoch_end)
401+
):
402+
rank_zero_deprecation(
403+
"You are training with multiple optimizers AND truncated backpropagation through time enabled."
404+
" The current format of the `training_epoch_end(outputs)` is a 3d list with sizes"
405+
" (n_optimizers, n_batches, tbptt_steps), however, this has been deprecated and will change in version"
406+
" v1.8 to (n_batches, tbptt_steps, n_optimizers). You can update your code by adding the following"
407+
" parameter to your hook signature: `training_epoch_end(outputs, new_format=True)`."
408+
)
409+
# (n_batches, tbptt_steps, n_opt) -> (n_opt, n_batches, tbptt_steps)
410+
if array.ndim == 2:
411+
array = np.expand_dims(array, 2)
412+
array = array.transpose((2, 0, 1))
413+
# squeeze all single-element dimensions
393414
array = array.squeeze()
394415
array = array.tolist()
395416
array = _recursive_unpad(array)
396-
397417
# in case we squeezed from 1-element array to a 0-dim array
398418
array = array if isinstance(array, list) else [array]
399419
# remove residual empty lists
@@ -519,7 +539,7 @@ def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> No
519539
self._dataloader_state_dict = None
520540

521541

522-
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Dict[str, Any]]:
542+
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
523543
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.
524544
525545
Example::

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def on_advance_end(self) -> None:
276276
if is_overridden("training_epoch_end", model) and self._outputs:
277277
epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end(
278278
self._outputs,
279-
automatic=model.automatic_optimization,
279+
lightning_module=model,
280280
num_optimizers=len(self.trainer.optimizers),
281281
)
282282
# run lightning module hook training_epoch_end

pytorch_lightning/loops/utilities.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
from collections import OrderedDict
1516
from contextlib import contextmanager
1617
from datetime import timedelta
1718
from functools import lru_cache
18-
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union
19+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union
1920

2021
import numpy as np
2122
import torch
@@ -221,3 +222,9 @@ def _reset_progress(loop: Loop) -> None:
221222
v.reset()
222223
elif isinstance(v, Loop):
223224
_reset_progress(v)
225+
226+
227+
# TODO: remove in v1.8
228+
def _v1_8_output_format(fx: Callable) -> bool:
229+
parameters = inspect.signature(fx).parameters
230+
return "new_format" in parameters and parameters["new_format"].default is True

tests/deprecated_api/test_remove_1-8.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
4343
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
4444
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
45+
from tests.deprecated_api import no_deprecated_call
4546
from tests.helpers.boring_model import BoringDataModule, BoringModel
4647
from tests.helpers.runif import RunIf
4748
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
@@ -652,6 +653,55 @@ def test_v1_8_0_weights_save_path(tmpdir):
652653
_ = trainer.weights_save_path
653654

654655

656+
def test_deprecated_epoch_outputs_format(tmpdir):
657+
class DeprecationModel(BoringModel):
658+
def __init__(self):
659+
super().__init__()
660+
self.truncated_bptt_steps = 1
661+
662+
def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
663+
output = super().training_step(batch, batch_idx)
664+
output["hiddens"] = hiddens
665+
return output
666+
667+
def tbptt_split_batch(self, batch, split_size):
668+
return [batch, batch]
669+
670+
def training_epoch_end(self, outputs):
671+
...
672+
673+
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
674+
...
675+
676+
def configure_optimizers(self):
677+
return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())]
678+
679+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
680+
model = DeprecationModel()
681+
batch_match = r"on_train_batch_end.*will change in version v1.8 to \(tbptt_steps, n_optimizers\)"
682+
with pytest.deprecated_call(match=batch_match):
683+
trainer.fit(model)
684+
685+
class DeprecationModel2(DeprecationModel):
686+
def on_train_batch_end(self, *args, new_format=True):
687+
...
688+
689+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
690+
model = DeprecationModel()
691+
epoch_match = r"training_epoch_end.*will change in version v1.8 to \(n_batches, tbptt_steps, n_optimizers\)"
692+
with pytest.deprecated_call(match=epoch_match):
693+
trainer.fit(model)
694+
695+
class NoDeprecationModel(DeprecationModel2):
696+
def training_epoch_end(self, outputs, new_format=True):
697+
...
698+
699+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
700+
model = NoDeprecationModel()
701+
with no_deprecated_call(match="will change in version v1.8.*new_format=True"):
702+
trainer.fit(model)
703+
704+
655705
@pytest.mark.flaky(reruns=3)
656706
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
657707
def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):

0 commit comments

Comments
 (0)