Skip to content

Commit

Permalink
Use generators tqdm progressbars (huggingface#6696)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and Zigur committed Oct 26, 2020
1 parent 3a00ee5 commit 01384f7
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,28 +641,30 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
logging_loss = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in train_iterator:
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)

if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
epoch_iterator = parallel_loader
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
epoch_iterator = train_dataloader

# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None

epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
for step, inputs in enumerate(epoch_iterator):

# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
epoch_pbar.update(1)
continue

tr_loss += self.training_step(model, inputs)
Expand Down Expand Up @@ -745,11 +747,12 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
epoch_iterator.close()
break
epoch_pbar.close()
train_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
Expand All @@ -761,6 +764,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
"configured. Check your training configuration if this is unexpected."
)

train_pbar.close()
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
Expand Down

0 comments on commit 01384f7

Please sign in to comment.