Skip to content

Commit

Permalink
Fix resuming the tqdm progress bar (#13962)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 2, 2022
1 parent ff2e329 commit f576ed3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
5 changes: 4 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) ([#13973](https://github.com/Lightning-AI/lightning/pull/13973))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation (2/2) ([#13962](https://github.com/Lightning-AI/lightning/pull/13962))



## [1.6.5] - 2022-07-13

Expand Down Expand Up @@ -463,7 +466,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/Lightning-AI/lightning/pull/12891))
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/Lightning-AI/lightning/pull/12653))
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/Lightning-AI/lightning/pull/12965))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation (1/2) ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/Lightning-AI/lightning/pull/12821))


Expand Down
22 changes: 12 additions & 10 deletions src/pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,13 @@ def on_train_start(self, *_: Any) -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_batches = self.total_batches_current_epoch
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.initial = 0
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current, self.refresh_rate)
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -280,16 +281,17 @@ def on_validation_batch_start(
return

self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
self.val_progress_bar.initial = 0
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
_update_n(self.val_progress_bar, self.val_batch_idx)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current, self.refresh_rate)
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -307,11 +309,12 @@ def on_test_batch_start(
return

self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.initial = 0
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
Expand All @@ -327,11 +330,12 @@ def on_predict_batch_start(
return

self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.initial = 0
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
Expand Down Expand Up @@ -375,9 +379,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x


def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
def _update_n(bar: _tqdm, value: int) -> None:
if not bar.disable:
total = bar.total
leftover = current % refresh_rate
advance = leftover if (current == total and leftover != 0) else refresh_rate
bar.update(advance)
bar.n = value
bar.refresh()

0 comments on commit f576ed3

Please sign in to comment.