forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_metrics.py
35 lines (28 loc) · 1.31 KB
/
loss_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from ..torch_core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ['LossMetrics']
class LossMetrics(LearnerCallback):
"Add `loss_func.metrics` to metrics named by `loss_func.metric_names`"
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
"Add the metrics names to the `Recorder`."
self.names = ifnone(self.learn.loss_func.metric_names, [])
if not self.names: warn('LossMetrics requested by no loss_func.metric_names provided')
self.learn.recorder.add_metric_names(self.names)
def on_epoch_begin(self, **kwargs):
"Initialize the metrics for this epoch."
self.metrics = {name:0. for name in self.names}
self.nums = 0
def on_batch_end(self, last_target, train, **kwargs):
"Update the metrics if not `train`"
if train: return
bs = last_target.size(0)
for name in self.names:
self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()
self.nums += bs
def on_epoch_end(self, **kwargs):
"Finish the computation and sends the result to the Recorder."
if not self.nums: return
metrics = [self.metrics[name]/self.nums for name in self.names]
self.learn.recorder.add_metrics(metrics)