Skip to content

Commit

Permalink
Remove the deprecated get_progress_bar_dict (#12839)
Browse files Browse the repository at this point in the history
Co-authored-by: Raymond G Schireman <raymond.schireman@uvm.edu>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people authored Apr 22, 2022
1 parent c4bb078 commit f931e27
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 84 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated `dataloader_idx` argument from `on_train_batch_start/end` hooks `Callback` and `LightningModule` ([#12769](https://github.com/PyTorchLightning/pytorch-lightning/pull/12769))


- Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839))

### Fixed


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_metrics(self, trainer, model):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
standard_metrics = pl_module.get_progress_bar_dict()
standard_metrics = get_standard_metrics(trainer, pl_module)
pbar_metrics = trainer.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
Expand Down
30 changes: 0 additions & 30 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -1731,35 +1730,6 @@ def unfreeze(self) -> None:

self.train()

def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.
Implement this to override the default items displayed in the progress bar.
By default it includes the average loss value, split index of BPTT (if used)
and the version of the experiment when using a logger.
.. code-block::
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
Here is an example how to override the defaults:
.. code-block:: python
def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
Return:
Dictionary with the items to be displayed in the progress bar.
"""
return progress_base.get_standard_metrics(self.trainer, self)

def _verify_is_manual_optimization(self, fn_name):
if self.automatic_optimization:
raise MisconfigurationException(
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:

__verify_dp_batch_transfer_support(trainer, model)
_check_add_get_queue(model)
# TODO: Delete _check_progress_bar in v1.7
_check_progress_bar(model)
# TODO: Delete _check_on_post_move_to_device in v1.7
_check_on_post_move_to_device(model)
_check_deprecated_callback_hooks(trainer)
Expand Down Expand Up @@ -143,20 +141,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
)


def _check_progress_bar(model: "pl.LightningModule") -> None:
r"""
Checks if get_progress_bar_dict is overridden and sends a deprecation warning.
Args:
model: The model to check the get_progress_bar_dict method.
"""
if is_overridden("get_progress_bar_dict", model):
rank_zero_deprecation(
"The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7."
" Please use the `ProgressBarBase.get_metrics` instead."
)


def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
r"""
Checks if `on_post_move_to_device` method is overridden and sends a deprecation warning.
Expand Down
15 changes: 1 addition & 14 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy

import torch
Expand Down Expand Up @@ -2191,19 +2191,6 @@ def distributed_sampler_kwargs(self) -> Optional[dict]:
def data_parallel(self) -> bool:
return isinstance(self.strategy, ParallelStrategy)

@property
def progress_bar_dict(self) -> dict:
"""Read-only for progress bar metrics."""
rank_zero_deprecation(
"`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."
" Use `ProgressBarBase.get_metrics` instead."
)
ref_model = self.lightning_module
ref_model = cast(pl.LightningModule, ref_model)
if self.progress_bar_callback:
return self.progress_bar_callback.get_metrics(self, ref_model)
return self.progress_bar_metrics

@property
def enable_validation(self) -> bool:
"""Check if we should run validation during training."""
Expand Down
22 changes: 0 additions & 22 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,6 @@
from tests.plugins.environments.test_lsf_environment import _make_rankfile


def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
class TestModel(BoringModel):
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
test_model = TestModel()
with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"):
trainer.fit(test_model)
standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix
assert "loss" in standard_metrics_postfix
assert "v_num" not in standard_metrics_postfix

with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"):
_ = trainer.progress_bar_dict


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def __init__(self, not_supported):
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
"get_progress_bar_dict",
}
)
# remove `nn.Module` hooks
Expand Down

0 comments on commit f931e27

Please sign in to comment.