Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the deprecated get_progress_bar_dict #12839

Merged
merged 20 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_metrics(self, trainer, model):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
standard_metrics = pl_module.get_progress_bar_dict()
standard_metrics = get_standard_metrics(trainer, pl_module)
pbar_metrics = trainer.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
Expand Down
29 changes: 0 additions & 29 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,35 +1731,6 @@ def unfreeze(self) -> None:

self.train()

def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.

Implement this to override the default items displayed in the progress bar.
By default it includes the average loss value, split index of BPTT (if used)
and the version of the experiment when using a logger.

.. code-block::

Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]

Here is an example how to override the defaults:

.. code-block:: python

def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

Return:
Dictionary with the items to be displayed in the progress bar.
"""
return progress_base.get_standard_metrics(self.trainer, self)

def _verify_is_manual_optimization(self, fn_name):
if self.automatic_optimization:
raise MisconfigurationException(
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
def _check_progress_bar(model: "pl.LightningModule") -> None:
r"""
Checks if get_progress_bar_dict is overridden and sends a deprecation warning.

Args:
model: The model to check the get_progress_bar_dict method.
"""
Expand Down
22 changes: 0 additions & 22 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,6 @@ def test_v1_7_0_datamodule_transform_properties(tmpdir):
_ = LightningDataModule(val_transforms="b")


def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
class TestModel(BoringModel):
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
test_model = TestModel()
with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"):
trainer.fit(test_model)
standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix
assert "loss" in standard_metrics_postfix
assert "v_num" not in standard_metrics_postfix

with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_ = trainer.progress_bar_dict
rschireman marked this conversation as resolved.
Show resolved Hide resolved


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def __init__(self, not_supported):
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
"get_progress_bar_dict",
}
)
# remove `nn.Module` hooks
Expand Down