Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import tqdm

from pytorch_lightning.utilities.debugging import MisconfigurationException

Expand Down Expand Up @@ -52,8 +53,11 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
dl_outputs.append(output)

# batch done
if self.show_progress_bar:
self.progress_bar.update(1)
if test:
self.test_progress_bar.update(1)
else:
self.val_progress_bar.update(1)
self.main_progress_bar.update(1)
outputs.append(dl_outputs)

eval_results = {}
Expand Down Expand Up @@ -110,6 +114,15 @@ def run_evaluation(self, test=False):
if self.fast_dev_run:
max_batches = 1

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch')
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
Expand All @@ -130,10 +143,16 @@ def run_evaluation(self, test=False):
# hook
model.on_post_performance_check()

if self.show_progress_bar:
# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
self.progress_bar.set_postfix(**tqdm_metrics)
# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
if not test:
self.main_progress_bar.set_postfix(**tqdm_metrics)

# close progress bar
if test:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
Expand Down
38 changes: 20 additions & 18 deletions pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import tqdm

try:
from apex import amp
Expand Down Expand Up @@ -34,18 +35,21 @@ def train(self):
self.nb_val_batches * val_checks_per_epoch)
self.batch_loss_value = 0 # accumulated grads

# limit the number of batches to 1 in fast_dev_run
if self.fast_dev_run:
self.total_batches = 1

# init progress_bar when requested
if self.show_progress_bar:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
nb_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
nb_iterations = None
else:
nb_iterations = self.total_batches

# for iterable train loader, the progress bar never ends
if self.is_iterable_train_dataloader:
nb_iterations = float('inf')
self.progress_bar.reset(nb_iterations)
# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(nb_iterations)
desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
Expand All @@ -68,8 +72,11 @@ def train(self):
# stop training
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
return

self.main_progress_bar.close()

if self.logger is not None:
self.logger.finalize("success")

Expand Down Expand Up @@ -158,9 +165,6 @@ def run_training_batch(self, batch, batch_nb):
if response == -1:
return -1, grad_norm_dic

if self.show_progress_bar:
self.progress_bar.update(1)

splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
Expand Down Expand Up @@ -241,17 +245,15 @@ def optimizer_closure():
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])

# update progress bar
if self.show_progress_bar:
# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
self.progress_bar.set_postfix(**tqdm_metrics)

# activate batch end hook
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
model.on_batch_end()

# update progress bar
self.main_progress_bar.update(1)
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)

# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}

Expand Down
27 changes: 16 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def training_tqdm_dict(self):
"""
tqdm_dict = {
'loss': '{0:.3f}'.format(self.avg_loss),
'epoch': '{}'.format(self.current_epoch),
'batch_nb': '{}'.format(self.batch_nb),
}

Expand Down Expand Up @@ -432,28 +431,34 @@ def run_pretrain_routine(self, model):
# restore training and model before hpc call
self.restore_weights(model)

# progress bar init
if self.show_progress_bar:
self.progress_bar = tqdm.tqdm(0, position=self.process_position)

# when testing requested only run test and return
if self.testing:
if self.show_progress_bar:
self.progress_bar.reset(self.nb_test_batches)

self.run_evaluation(test=True)
return

# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
# reset progress_bar limit for sanity check
if self.show_progress_bar:
self.progress_bar.reset(self.nb_sanity_val_steps)
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
self.main_progress_bar = pbar
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)

# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
self.main_progress_bar = pbar

# clear cache before training
if self.on_gpu:
torch.cuda.empty_cache()
Expand Down