Skip to content

Commit

Permalink
Only access loss tensor every logging_steps (huggingface#6802)
Browse files Browse the repository at this point in the history
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
12 people authored and Zigur committed Oct 26, 2020
1 parent 1ab2fd1 commit e7e2c65
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.global_step = 0
logger.info(" Starting fine-tuning.")

tr_loss = 0.0
logging_loss = 0.0
tr_loss = torch.tensor(0.0).to(self.args.device)
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
Expand Down Expand Up @@ -720,14 +720,15 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
logging_loss = tr_loss
logging_loss_scalar = tr_loss_scalar

self.log(logs)

Expand Down Expand Up @@ -773,8 +774,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
break
epoch_pbar.close()
train_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
Expand All @@ -784,6 +783,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break

train_pbar.close()
if self.tb_writer:
Expand All @@ -793,7 +794,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
delattr(self, "_past")

logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)

def hyperparameter_search(
self,
Expand Down Expand Up @@ -973,7 +974,7 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Expand All @@ -989,7 +990,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`float`: The training loss on this batch.
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
if hasattr(self, "_training_step"):
warnings.warn(
Expand Down Expand Up @@ -1027,7 +1028,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
else:
loss.backward()

return loss.item()
return loss

def is_local_master(self) -> bool:
"""
Expand Down Expand Up @@ -1276,6 +1277,10 @@ def prediction_loop(
preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())

# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
Expand Down

0 comments on commit e7e2c65

Please sign in to comment.