/
NEpochLogger.py
28 lines (26 loc) · 1.02 KB
/
NEpochLogger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from tensorflow.keras.callbacks import Callback
import math
class NEpochLogger(Callback):
"""
A Logger that log average performance per `display` steps.
"""
def __init__(self, display):
self.step = 0
self.display = display
self.metric_cache = {}
def on_epoch_end(self, batch, logs={}):
self.step += 1
for k in self.params['metrics']:
if k in logs:
self.metric_cache[k] = self.metric_cache.get(k, 0) + logs[k]
if self.step % self.display == 0:
epoch = (math.ceil(self.step / self.params['epochs'])) * self.params['epochs']
metrics_log = ''
for (k, v) in self.metric_cache.items():
val = v / self.display
if abs(val) > 1e-3:
metrics_log += ' - %s: %.4f' % (k, val)
else:
metrics_log += ' - %s: %.4e' % (k, val)
print('step: {}/{} ... {}'.format(self.step, epoch, metrics_log))
self.metric_cache.clear()