Skip to content

Commit db9f11d

Browse files
truncate long version number in progress bar (Lightning-AI#2594)
* truncate version number * add docs and example * extend docs * docs * docs * changelog * show last * Update pytorch_lightning/core/lightning.py * Update pytorch_lightning/core/lightning.py Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent c047676 commit db9f11d

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
### Changed
1818

19+
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
1920

2021
### Deprecated
2122

docs/source/experiment_reporting.rst

+14-6
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ Control log writing frequency
3030
Writing to a logger can be expensive. In Lightning you can set the interval at which you
3131
want to log using this trainer flag.
3232

33-
.. seealso::
34-
:class:`~pytorch_lightning.trainer.trainer.Trainer`
35-
3633
.. testcode::
3734

3835
k = 100
3936
trainer = Trainer(log_save_interval=k)
4037

38+
.. seealso::
39+
:class:`~pytorch_lightning.trainer.trainer.Trainer`
40+
4141
----------
4242

4343
Log metrics
@@ -94,10 +94,14 @@ For instance, here we log images using tensorboard.
9494
Modify progress bar
9595
^^^^^^^^^^^^^^^^^^^
9696

97-
Each return dict from the training_end, validation_end, testing_end and training_step also has
98-
a key called "progress_bar".
97+
Each return dict from the
98+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
99+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`,
100+
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` and
101+
:meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`
102+
can also contain a key called `progress_bar`.
99103

100-
Here we show the validation loss in the progress bar
104+
Here we show the validation loss in the progress bar:
101105

102106
.. testcode::
103107

@@ -109,6 +113,10 @@ Here we show the validation loss in the progress bar
109113
results = {'progress_bar': logs}
110114
return results
111115

116+
The progress bar by default already includes the training loss and version number of the experiment
117+
if you are using a logger. These defaults can be customized by overriding the
118+
:meth:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
119+
112120

113121
----------
114122

pytorch_lightning/core/lightning.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
15441544
Example:
15451545
.. code-block:: python
15461546
1547-
15481547
def on_save_checkpoint(self, checkpoint):
15491548
# 99% of use cases you don't need to implement this method
15501549
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
@@ -1558,7 +1557,23 @@ def on_save_checkpoint(self, checkpoint):
15581557

15591558
def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
15601559
r"""
1561-
Additional items to be displayed in the progress bar.
1560+
Implement this to override the default items displayed in the progress bar.
1561+
By default it includes the average loss value, split index of BPTT (if used)
1562+
and the version of the experiment when using a logger.
1563+
1564+
.. code-block::
1565+
1566+
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
1567+
1568+
Here is an example how to override the defaults:
1569+
1570+
.. code-block:: python
1571+
1572+
def get_progress_bar_dict(self):
1573+
# don't show the version number
1574+
items = super().get_progress_bar_dict()
1575+
items.pop("v_num", None)
1576+
return items
15621577
15631578
Return:
15641579
Dictionary with the items to be displayed in the progress bar.
@@ -1572,7 +1587,10 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
15721587
tqdm_dict['split_idx'] = self.trainer.split_idx
15731588

15741589
if self.trainer.logger is not None and self.trainer.logger.version is not None:
1575-
tqdm_dict['v_num'] = self.trainer.logger.version
1590+
version = self.trainer.logger.version
1591+
# show last 4 places of long version strings
1592+
version = version[-4:] if isinstance(version, str) else version
1593+
tqdm_dict['v_num'] = version
15761594

15771595
return tqdm_dict
15781596

0 commit comments

Comments
 (0)