Skip to content

Commit

Permalink
Ensure metric results are JSON-serializable (huggingface#10632)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and Iwontbecreative committed Jul 15, 2021
1 parent 063f08a commit d63d81e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
set_seed,
speed_metrics,
Expand Down Expand Up @@ -1831,6 +1832,9 @@ def prediction_loop(
else:
metrics = {}

# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)

if eval_loss is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

Expand Down
26 changes: 22 additions & 4 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
)


if is_torch_available():
import torch

if is_tf_available():
import tensorflow as tf


def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
Expand All @@ -49,14 +56,10 @@ def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
import torch

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_tf_available():
import tensorflow as tf

tf.random.set_seed(seed)


Expand Down Expand Up @@ -423,6 +426,21 @@ def stop_and_update_metrics(self, metrics=None):
self.update_metrics(stage, metrics)


def denumpify_detensorize(metrics):
"""
Recursively calls `.item()` on the element of the dictionary passed
"""
if isinstance(metrics, (list, tuple)):
return type(metrics)(denumpify_detensorize(m) for m in metrics)
elif isinstance(metrics, dict):
return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
elif isinstance(metrics, np.generic):
return metrics.item()
elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
return metrics.item()
return metrics


class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2"
Expand Down

0 comments on commit d63d81e

Please sign in to comment.