Skip to content

Commit

Permalink
Fix TQDMProgressBar reset and update to show correct time estimation (
Browse files Browse the repository at this point in the history
#12889)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and lexierule committed May 3, 2022
1 parent ab7ad37 commit 55f5e2d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))


## [1.6.2] - 2022-04-27
Expand Down
25 changes: 14 additions & 11 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Expand Up @@ -262,13 +262,13 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
self.main_progress_bar.total = convert_inf(total_batches)
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
_update_n(self.main_progress_bar, current, self.refresh_rate)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -288,17 +288,17 @@ def on_validation_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader)
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
_update_n(self.main_progress_bar, current, self.refresh_rate)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -315,12 +315,12 @@ def on_test_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader)
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx)
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
Expand All @@ -335,12 +335,12 @@ def on_predict_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader)
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx)
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
Expand Down Expand Up @@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x


def _update_n(bar: _tqdm, value: int) -> None:
def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
if not bar.disable:
bar.n = value
total = bar.total
leftover = current % refresh_rate
advance = leftover if (current == total and leftover != 0) else refresh_rate
bar.update(advance)
bar.refresh()
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -2,7 +2,7 @@

numpy>=1.17.2
torch>=1.8.*
tqdm>=4.41.0
tqdm>=4.57.0
PyYAML>=5.4
fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0
Expand Down
27 changes: 14 additions & 13 deletions tests/callbacks/test_tqdm_progress_bar.py
Expand Up @@ -53,6 +53,7 @@ def n(self):
@n.setter
def n(self, value):
self.__n = value

# track the changes in the `n` value
if not len(self.n_values) or value != self.n_values[-1]:
self.n_values.append(value)
Expand Down Expand Up @@ -158,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert not pbar.val_progress_bar.leave
assert trainer.num_sanity_val_batches == expected_sanity_steps
assert pbar.val_progress_bar.total_values == expected_sanity_steps
assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]

# fit
Expand All @@ -177,7 +178,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):

# check val progress bar total
assert pbar.val_progress_bar.total_values == m
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
assert not pbar.val_progress_bar.leave

Expand All @@ -186,7 +187,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
trainer.validate(model)
assert trainer.num_val_batches == m
assert pbar.val_progress_bar.total_values == m
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]

# test
Expand All @@ -195,7 +196,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert pbar.test_progress_bar.leave
k = trainer.num_test_batches
assert pbar.test_progress_bar.total_values == k
assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
assert pbar.test_progress_bar.leave

Expand All @@ -205,7 +206,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert pbar.predict_progress_bar.leave
k = trainer.num_predict_batches
assert pbar.predict_progress_bar.total_values == k
assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
assert pbar.predict_progress_bar.leave

Expand Down Expand Up @@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
@pytest.mark.parametrize(
"train_batches,val_batches,refresh_rate,train_updates,val_updates",
[
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
[2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]],
[0, 0, 3, None, None],
[1, 0, 3, [1], None],
[1, 1, 3, [2], [1]],
[5, 0, 3, [3, 5], None],
[5, 2, 3, [3, 6, 7], [2]],
[5, 2, 6, [6, 7], [2]],
[1, 0, 3, [0, 1], None],
[1, 1, 3, [0, 2], [0, 1]],
[5, 0, 3, [0, 3, 5], None],
[5, 2, 3, [0, 3, 6, 7], [0, 2]],
[5, 2, 6, [0, 6, 7], [0, 2]],
],
)
def test_main_progress_bar_update_amount(
Expand Down Expand Up @@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount(
assert progress_bar.val_progress_bar.n_values == val_updates


@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]])
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])])
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list):
"""Test that test progress updates with the correct amount."""
model = BoringModel()
Expand Down Expand Up @@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled():

@pytest.mark.parametrize(
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
[(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])],
)
def test_progress_bar_max_val_check_interval(
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
Expand Down

0 comments on commit 55f5e2d

Please sign in to comment.