From 81b4e3a7353e3c301eaff1b1deaeb0f2197b7419 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Thu, 31 Oct 2019 12:39:27 +0300 Subject: [PATCH 1/6] Splitted progress bars --- .../trainer/evaluation_loop_mixin.py | 35 ++++++++++++++--- pytorch_lightning/trainer/train_loop_mixin.py | 39 ++++++++++--------- pytorch_lightning/trainer/trainer.py | 27 +++++++------ 3 files changed, 65 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop_mixin.py b/pytorch_lightning/trainer/evaluation_loop_mixin.py index c3cc51efe1a97..30d5edbbedcee 100644 --- a/pytorch_lightning/trainer/evaluation_loop_mixin.py +++ b/pytorch_lightning/trainer/evaluation_loop_mixin.py @@ -1,4 +1,5 @@ import torch +import tqdm from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -52,8 +53,11 @@ def evaluate(self, model, dataloaders, max_batches, test=False): dl_outputs.append(output) # batch done - if self.show_progress_bar: - self.progress_bar.update(1) + if test: + self.test_progress_bar.update(1) + else: + self.val_progress_bar.update(1) + self.main_progress_bar.update(1) outputs.append(dl_outputs) eval_results = {} @@ -110,6 +114,19 @@ def run_evaluation(self, test=False): if self.fast_dev_run: max_batches = 1 + # init validation or test progress bar + if test: + # main progress bar will already be closed when testing so initial position is free + pbar = tqdm.tqdm(desc='Testing', total=max_batches, + leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar) + self.test_progress_bar = pbar + else: + pbar = tqdm.tqdm(desc='Validating', total=max_batches, + leave=False, position=2 * self.process_position + 1, + disable=not self.show_progress_bar) + self.val_progress_bar = pbar + # run evaluation eval_results = self.evaluate(self.model, dataloaders, @@ -129,10 +146,16 @@ def run_evaluation(self, test=False): # hook model.on_post_performance_check() - if self.show_progress_bar: - # add model specific metrics - tqdm_metrics = self.training_tqdm_dict - self.progress_bar.set_postfix(**tqdm_metrics) + # add model specific metrics + tqdm_metrics = self.training_tqdm_dict + if not test: + self.main_progress_bar.set_postfix(**tqdm_metrics) + + # close progress bar + if test: + self.test_progress_bar.close() + else: + self.val_progress_bar.close() # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 2683ae66094eb..3c9a2e30c98d6 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -1,4 +1,5 @@ import numpy as np +import tqdm try: from apex import amp @@ -34,18 +35,20 @@ def train(self): self.nb_val_batches * val_checks_per_epoch) self.batch_loss_value = 0 # accumulated grads - # limit the number of batches to 1 in fast_dev_run - if self.fast_dev_run: - self.total_batches = 1 + nb_iterations = self.total_batches + + # for iterable train loader, the progress bar never ends + if self.is_iterable_train_dataloader: + nb_iterations = float('inf') - # init progress_bar when requested - if self.show_progress_bar: - nb_iterations = self.total_batches + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + if self.fast_dev_run: + nb_iterations = 2 - # for iterable train loader, the progress bar never ends - if self.is_iterable_train_dataloader: - nb_iterations = float('inf') - self.progress_bar.reset(nb_iterations) + # reset progress bar + self.main_progress_bar.reset(nb_iterations) + desc = f'Epoch {epoch_nb}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_begin(epoch_nb, self) @@ -68,8 +71,11 @@ def train(self): # stop training stop = should_stop and met_min_epochs if stop: + self.main_progress_bar.close() return + self.main_progress_bar.close() + if self.logger is not None: self.logger.finalize("success") @@ -158,9 +164,6 @@ def run_training_batch(self, batch, batch_nb): if response == -1: return -1, grad_norm_dic - if self.show_progress_bar: - self.progress_bar.update(1) - # call training_step once per optimizer for opt_idx, optimizer in enumerate(self.optimizers): @@ -226,17 +229,15 @@ def optimizer_closure(): self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:]) - # update progress bar - if self.show_progress_bar: - # add model specific metrics - tqdm_metrics = self.training_tqdm_dict - self.progress_bar.set_postfix(**tqdm_metrics) - # activate batch end hook if self.is_function_implemented('on_batch_end'): model = self.get_model() model.on_batch_end() + # update progress bar + self.main_progress_bar.update(1) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a5beecc4b2b9f..f1bbfc8000cbc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -293,7 +293,6 @@ def training_tqdm_dict(self): """ tqdm_dict = { 'loss': '{0:.3f}'.format(self.avg_loss), - 'epoch': '{}'.format(self.current_epoch), 'batch_nb': '{}'.format(self.batch_nb), } @@ -426,15 +425,8 @@ def run_pretrain_routine(self, model): # restore training and model before hpc call self.restore_weights(model) - # progress bar init - if self.show_progress_bar: - self.progress_bar = tqdm.tqdm(0, position=self.process_position) - # when testing requested only run test and return if self.testing: - if self.show_progress_bar: - self.progress_bar.reset(self.nb_test_batches) - self.run_evaluation(test=True) return @@ -442,12 +434,25 @@ def run_pretrain_routine(self, model): # to make sure program won't crash during val ref_model.on_sanity_check_start() if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0: - # reset progress_bar limit for sanity check - if self.show_progress_bar: - self.progress_bar.reset(self.nb_sanity_val_steps) + # init progress bars for validation sanity check + pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps, + leave=False, position=2 * self.process_position, + disable=not self.show_progress_bar) + self.main_progress_bar = pbar + # dummy validation progress bar + self.val_progress_bar = tqdm.tqdm(disable=True) self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing) + # close progress bars + self.main_progress_bar.close() + self.val_progress_bar.close() + + # init progress bar + pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar) + self.main_progress_bar = pbar + # clear cache before training if self.on_gpu: torch.cuda.empty_cache() From 751b776590d9d42907580dedf4f7061c9af994e5 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Thu, 31 Oct 2019 12:45:56 +0300 Subject: [PATCH 2/6] Iterable dataset total batches fix --- pytorch_lightning/trainer/train_loop_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 3c9a2e30c98d6..90ed1f184f753 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -39,7 +39,7 @@ def train(self): # for iterable train loader, the progress bar never ends if self.is_iterable_train_dataloader: - nb_iterations = float('inf') + nb_iterations = None # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run if self.fast_dev_run: From bacafe6dbec6e84389c245dba990edaa236c1624 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Fri, 1 Nov 2019 13:07:05 +0300 Subject: [PATCH 3/6] Use dynamic ncols and use batch as units --- pytorch_lightning/trainer/evaluation_loop_mixin.py | 6 ++++-- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop_mixin.py b/pytorch_lightning/trainer/evaluation_loop_mixin.py index 30d5edbbedcee..f9bbe5d96cfb0 100644 --- a/pytorch_lightning/trainer/evaluation_loop_mixin.py +++ b/pytorch_lightning/trainer/evaluation_loop_mixin.py @@ -119,12 +119,14 @@ def run_evaluation(self, test=False): # main progress bar will already be closed when testing so initial position is free pbar = tqdm.tqdm(desc='Testing', total=max_batches, leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar) + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch') self.test_progress_bar = pbar else: pbar = tqdm.tqdm(desc='Validating', total=max_batches, leave=False, position=2 * self.process_position + 1, - disable=not self.show_progress_bar) + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch') self.val_progress_bar = pbar # run evaluation diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f1bbfc8000cbc..88f0c01ef8bf4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -437,7 +437,7 @@ def run_pretrain_routine(self, model): # init progress bars for validation sanity check pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps, leave=False, position=2 * self.process_position, - disable=not self.show_progress_bar) + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar # dummy validation progress bar self.val_progress_bar = tqdm.tqdm(disable=True) @@ -450,7 +450,7 @@ def run_pretrain_routine(self, model): # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar) + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar # clear cache before training From 8a6b5835ae747686ba9f85c2995760d0cb916a20 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Fri, 1 Nov 2019 15:09:06 +0300 Subject: [PATCH 4/6] Count epochs from 1 in progress bar --- pytorch_lightning/trainer/train_loop_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 90ed1f184f753..70b69af8426fe 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -47,7 +47,7 @@ def train(self): # reset progress bar self.main_progress_bar.reset(nb_iterations) - desc = f'Epoch {epoch_nb}' if not self.is_iterable_train_dataloader else '' + desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else '' self.main_progress_bar.set_description(desc) # changing gradient according accumulation_scheduler From 12b7ada2c0c606666814b9f8af1d1a42ccfa687b Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Fri, 1 Nov 2019 21:35:41 +0300 Subject: [PATCH 5/6] Fix for disabled progress bar --- pytorch_lightning/trainer/train_loop_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 70b69af8426fe..43025790c37a7 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -46,7 +46,9 @@ def train(self): nb_iterations = 2 # reset progress bar - self.main_progress_bar.reset(nb_iterations) + # .reset() doesn't work on disabled progress bar so we should check + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(nb_iterations) desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else '' self.main_progress_bar.set_description(desc) From 139a03bc548c2da8418c0bc53a7b5f3cd14bdcbe Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Sat, 2 Nov 2019 17:26:33 +0300 Subject: [PATCH 6/6] Code simplifications --- .../trainer/evaluation_loop_mixin.py | 20 +++++++------------ pytorch_lightning/trainer/train_loop_mixin.py | 13 ++++++------ 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop_mixin.py b/pytorch_lightning/trainer/evaluation_loop_mixin.py index 730770a5949cb..c2ba5d16fa97a 100644 --- a/pytorch_lightning/trainer/evaluation_loop_mixin.py +++ b/pytorch_lightning/trainer/evaluation_loop_mixin.py @@ -115,19 +115,13 @@ def run_evaluation(self, test=False): max_batches = 1 # init validation or test progress bar - if test: - # main progress bar will already be closed when testing so initial position is free - pbar = tqdm.tqdm(desc='Testing', total=max_batches, - leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch') - self.test_progress_bar = pbar - else: - pbar = tqdm.tqdm(desc='Validating', total=max_batches, - leave=False, position=2 * self.process_position + 1, - disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch') - self.val_progress_bar = pbar + # main progress bar will already be closed when testing so initial position is free + position = 2 * self.process_position + (not test) + desc = 'Testing' if test else 'Validating' + pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch') + setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) # run evaluation eval_results = self.evaluate(self.model, diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 11eb15f069501..d41fc7c16fee6 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -35,15 +35,14 @@ def train(self): self.nb_val_batches * val_checks_per_epoch) self.batch_loss_value = 0 # accumulated grads - nb_iterations = self.total_batches - - # for iterable train loader, the progress bar never ends - if self.is_iterable_train_dataloader: - nb_iterations = None - - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run if self.fast_dev_run: + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run nb_iterations = 2 + elif self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + nb_iterations = None + else: + nb_iterations = self.total_batches # reset progress bar # .reset() doesn't work on disabled progress bar so we should check