From 2a1f7d56446523cdbd18314c40e44acfed8aef95 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 5 Apr 2022 16:37:21 +0530 Subject: [PATCH 1/3] Fix rich main progress bar update --- pytorch_lightning/callbacks/progress/base.py | 5 ++++ .../callbacks/progress/rich_progress.py | 23 ++++++++++--------- .../callbacks/progress/tqdm_progress.py | 5 ---- tests/callbacks/test_rich_progress_bar.py | 2 +- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 234d62a68c308..42ad957235f2c 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -77,6 +77,11 @@ def test_description(self) -> str: def predict_description(self) -> str: return "Predicting" + @property + def _val_processed(self) -> int: + # use total in case validation runs more than once per training epoch + return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed + @property def train_batch_idx(self) -> int: """The number of batches processed during training. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 4d7c4b7864055..0b0a19c6d7acd 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -368,8 +368,12 @@ def _add_task(self, total_batches: int, description: str, visible: bool = True) f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: - if self.progress is not None and self._should_update(current, total): + def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None: + if self.progress is not None and self.is_enabled: + total = self.progress.tasks[progress_bar_id].total + if not self._should_update(current, total): + return + leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) @@ -419,7 +423,7 @@ def on_predict_batch_start( self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) + self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update_metrics(trainer, pl_module) self.refresh() @@ -428,23 +432,20 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - # TODO: Use total val_processed here just like TQDM in a follow-up - self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) - self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) + self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) + self._update(self.val_progress_bar_id, self.val_batch_idx) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader) + self._update(self.test_progress_bar_id, self.test_batch_idx) self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update( - self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader - ) + self._update(self.predict_progress_bar_id, self.predict_batch_idx) self.refresh() def _get_train_description(self, current_epoch: int) -> str: diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 5d5b17bfb9922..dcca487620584 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -168,11 +168,6 @@ def is_enabled(self) -> bool: def is_disabled(self) -> bool: return not self.is_enabled - @property - def _val_processed(self) -> int: - # use total in case validation runs more than once per training epoch - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed - def disable(self) -> None: self._enabled = False diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 29ef3aa98f89b..b0dddc385539d 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -205,7 +205,7 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): @RunIf(rich=True) -@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)])) +@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 4 + 3), (4, 3 + 3), (7, 2 + 2)])) def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count): model = BoringModel() trainer = Trainer( From 55ca9bf31b5312b0fa56145765e62862291360aa Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 5 Apr 2022 18:03:03 +0530 Subject: [PATCH 2/3] pick tests cases from tqdm tests --- tests/callbacks/test_rich_progress_bar.py | 40 ++++++++++++++++------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index b0dddc385539d..8fdcb6c99e331 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): @RunIf(rich=True) -@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 4 + 3), (4, 3 + 3), (7, 2 + 2)])) -def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count): +@pytest.mark.parametrize( + "refresh_rate,train_batches,val_batches,expected_call_count", + [ + (3, 6, 6, 4 + 3), + (4, 6, 6, 3 + 3), + (7, 6, 6, 2 + 2), + (1, 2, 3, 5 + 4), + (1, 0, 0, 0 + 0), + (3, 1, 0, 1 + 0), + (3, 1, 1, 1 + 2), + (3, 5, 0, 2 + 0), + (3, 5, 2, 3 + 2), + (6, 5, 2, 2 + 2), + ], +) +def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count): model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, - limit_train_batches=6, - limit_val_batches=6, + limit_train_batches=train_batches, + limit_val_batches=val_batches, max_epochs=1, callbacks=RichProgressBar(refresh_rate=refresh_rate), ) @@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call trainer.fit(model) assert progress_update.call_count == expected_call_count - fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] - fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] - assert fit_main_bar.completed == 12 - assert fit_main_bar.total == 12 - assert fit_main_bar.visible - assert fit_val_bar.completed == 6 - assert fit_val_bar.total == 6 - assert not fit_val_bar.visible + if train_batches > 0: + fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] + assert fit_main_bar.completed == train_batches + val_batches + assert fit_main_bar.total == train_batches + val_batches + assert fit_main_bar.visible + if val_batches > 0: + fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] + assert fit_val_bar.completed == val_batches + assert fit_val_bar.total == val_batches + assert not fit_val_bar.visible @RunIf(rich=True) From 0a9129cf4581d00ed0d99ed275f3b2f66122f337 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 5 Apr 2022 22:34:24 +0530 Subject: [PATCH 3/3] remove check --- pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 0b0a19c6d7acd..cc475634ff6ea 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -380,7 +380,7 @@ def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> N self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + return current % self.refresh_rate == 0 or current == total def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit":