diff --git a/pytorch_lightning/trainer/evaluation_loop_mixin.py b/pytorch_lightning/trainer/evaluation_loop_mixin.py index ad4b6ad818f0f..c2ba5d16fa97a 100644 --- a/pytorch_lightning/trainer/evaluation_loop_mixin.py +++ b/pytorch_lightning/trainer/evaluation_loop_mixin.py @@ -1,4 +1,5 @@ import torch +import tqdm from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -52,8 +53,11 @@ def evaluate(self, model, dataloaders, max_batches, test=False): dl_outputs.append(output) # batch done - if self.show_progress_bar: - self.progress_bar.update(1) + if test: + self.test_progress_bar.update(1) + else: + self.val_progress_bar.update(1) + self.main_progress_bar.update(1) outputs.append(dl_outputs) eval_results = {} @@ -110,6 +114,15 @@ def run_evaluation(self, test=False): if self.fast_dev_run: max_batches = 1 + # init validation or test progress bar + # main progress bar will already be closed when testing so initial position is free + position = 2 * self.process_position + (not test) + desc = 'Testing' if test else 'Validating' + pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch') + setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) + # run evaluation eval_results = self.evaluate(self.model, dataloaders, @@ -130,10 +143,16 @@ def run_evaluation(self, test=False): # hook model.on_post_performance_check() - if self.show_progress_bar: - # add model specific metrics - tqdm_metrics = self.training_tqdm_dict - self.progress_bar.set_postfix(**tqdm_metrics) + # add model specific metrics + tqdm_metrics = self.training_tqdm_dict + if not test: + self.main_progress_bar.set_postfix(**tqdm_metrics) + + # close progress bar + if test: + self.test_progress_bar.close() + else: + self.val_progress_bar.close() # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index a10b8fc1eb49e..d41fc7c16fee6 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -1,4 +1,5 @@ import numpy as np +import tqdm try: from apex import amp @@ -34,18 +35,21 @@ def train(self): self.nb_val_batches * val_checks_per_epoch) self.batch_loss_value = 0 # accumulated grads - # limit the number of batches to 1 in fast_dev_run if self.fast_dev_run: - self.total_batches = 1 - - # init progress_bar when requested - if self.show_progress_bar: + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + nb_iterations = 2 + elif self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + nb_iterations = None + else: nb_iterations = self.total_batches - # for iterable train loader, the progress bar never ends - if self.is_iterable_train_dataloader: - nb_iterations = float('inf') - self.progress_bar.reset(nb_iterations) + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(nb_iterations) + desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_begin(epoch_nb, self) @@ -68,8 +72,11 @@ def train(self): # stop training stop = should_stop and met_min_epochs if stop: + self.main_progress_bar.close() return + self.main_progress_bar.close() + if self.logger is not None: self.logger.finalize("success") @@ -158,9 +165,6 @@ def run_training_batch(self, batch, batch_nb): if response == -1: return -1, grad_norm_dic - if self.show_progress_bar: - self.progress_bar.update(1) - splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() @@ -241,17 +245,15 @@ def optimizer_closure(): self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:]) - # update progress bar - if self.show_progress_bar: - # add model specific metrics - tqdm_metrics = self.training_tqdm_dict - self.progress_bar.set_postfix(**tqdm_metrics) - # activate batch end hook if self.is_function_implemented('on_batch_end'): model = self.get_model() model.on_batch_end() + # update progress bar + self.main_progress_bar.update(1) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a61f1abb00121..c53aeb7bdc464 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -296,7 +296,6 @@ def training_tqdm_dict(self): """ tqdm_dict = { 'loss': '{0:.3f}'.format(self.avg_loss), - 'epoch': '{}'.format(self.current_epoch), 'batch_nb': '{}'.format(self.batch_nb), } @@ -432,15 +431,8 @@ def run_pretrain_routine(self, model): # restore training and model before hpc call self.restore_weights(model) - # progress bar init - if self.show_progress_bar: - self.progress_bar = tqdm.tqdm(0, position=self.process_position) - # when testing requested only run test and return if self.testing: - if self.show_progress_bar: - self.progress_bar.reset(self.nb_test_batches) - self.run_evaluation(test=True) return @@ -448,12 +440,25 @@ def run_pretrain_routine(self, model): # to make sure program won't crash during val ref_model.on_sanity_check_start() if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0: - # reset progress_bar limit for sanity check - if self.show_progress_bar: - self.progress_bar.reset(self.nb_sanity_val_steps) + # init progress bars for validation sanity check + pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps, + leave=False, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') + self.main_progress_bar = pbar + # dummy validation progress bar + self.val_progress_bar = tqdm.tqdm(disable=True) self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing) + # close progress bars + self.main_progress_bar.close() + self.val_progress_bar.close() + + # init progress bar + pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') + self.main_progress_bar = pbar + # clear cache before training if self.on_gpu: torch.cuda.empty_cache()