This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
metric.py
52 lines (42 loc) · 1.69 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Iterable, Optional
import torch
from allennlp.common.registrable import Registrable
class Metric(Registrable):
"""
A very general abstract class representing a metric which can be
accumulated.
"""
supports_distributed = False
def __call__(
self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor]
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions.
gold_labels : `torch.Tensor`, required.
A tensor corresponding to some gold label to evaluate against.
mask : `torch.BoolTensor`, optional (default = `None`).
A mask can be passed, in order to deal with metrics which are
computed over potentially padded elements, such as sequence labels.
"""
raise NotImplementedError
def get_metric(self, reset: bool):
"""
Compute and return the metric. Optionally also call `self.reset`.
"""
raise NotImplementedError
def reset(self) -> None:
"""
Reset any accumulators or internal state.
"""
raise NotImplementedError
@staticmethod
def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]:
"""
If you actually passed gradient-tracking Tensors to a Metric, there will be
a huge memory leak, because it will prevent garbage collection for the computation
graph. This method ensures the tensors are detached.
"""
# Check if it's actually a tensor in case something else was passed.
return (x.detach() if isinstance(x, torch.Tensor) else x for x in tensors)