This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
mean_absolute_error.py
74 lines (60 loc) · 2.27 KB
/
mean_absolute_error.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Optional
from overrides import overrides
import torch
import torch.distributed as dist
from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
@Metric.register("mean_absolute_error")
class MeanAbsoluteError(Metric):
"""
This `Metric` calculates the mean absolute error (MAE) between two tensors.
"""
def __init__(self) -> None:
self._absolute_error = 0.0
self._total_count = 0.0
def __call__(
self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None,
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions of shape (batch_size, ...).
gold_labels : `torch.Tensor`, required.
A tensor of the same shape as `predictions`.
mask : `torch.BoolTensor`, optional (default = `None`).
A tensor of the same shape as `predictions`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
device = gold_labels.device
absolute_errors = torch.abs(predictions - gold_labels)
if mask is not None:
absolute_errors *= mask
_total_count = torch.sum(mask)
else:
_total_count = gold_labels.numel()
_absolute_error = torch.sum(absolute_errors)
if is_distributed():
absolute_error = torch.tensor(_absolute_error, device=device)
total_count = torch.tensor(_total_count, device=device)
dist.all_reduce(absolute_error, op=dist.ReduceOp.SUM)
dist.all_reduce(total_count, op=dist.ReduceOp.SUM)
_absolute_error = absolute_error.item()
_total_count = total_count.item()
self._absolute_error += _absolute_error
self._total_count += _total_count
def get_metric(self, reset: bool = False):
"""
# Returns
The accumulated mean absolute error.
"""
mean_absolute_error = self._absolute_error / self._total_count
if reset:
self.reset()
return {"mae": mean_absolute_error}
@overrides
def reset(self):
self._absolute_error = 0.0
self._total_count = 0.0