Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
58 lines (56 sloc) 2.96 KB
from losswise import Session
from keras.callbacks import Callback
class LosswiseKerasCallback(Callback):
def __init__(self, tag=None, params={}, track_git=True, display_interval=None, max_iter=None):
# model hyper parameters, json serializable Python object
self.tag = tag
if not isinstance(params, dict):
raise TypeError("\"params\" argument must be a valid python dictionary")
if tag is not None and not isinstance(tag, str):
raise TypeError("\"tag\" argument must be a valid python string")
self.params_data = params
self.track_git = track_git
self.graph_map = {}
self.display_interval = display_interval
self.max_iter = max_iter
super(LosswiseKerasCallback, self).__init__()
def on_train_begin(self, logs={}):
if self.max_iter is None:
if 'epochs' in self.params and 'samples' in self.params and self.params['samples'] is not None and 'batch_size' in self.params and self.params['batch_size'] is not None:
self.max_iter = int(self.params['epochs'] * self.params['samples'] / self.params['batch_size'])
elif 'steps_per_epoch' in self.params and 'epochs' in self.params:
self.max_iter = self.params['steps_per_epoch'] * self.params['epochs']
elif 'samples_per_epoch' in self.params and 'epochs' in self.params:
self.max_iter = self.params['samples_per_epoch'] * self.params['epochs']
elif 'steps' in self.params and 'epochs' in self.params:
self.max_iter = self.params['steps'] * self.params['epochs']
print("Warning: Please specify max_iter!")
print("You have not set max_iter, for example do LosswiseKerasCallback(..., max_iter=10000)")
self.session = Session(tag=self.tag, max_iter=self.max_iter, params=self.params_data,
self.metric_list = []
for metric in self.params['metrics']:
if not metric.startswith('val_'):
if metric not in self.metric_list:
for metric in self.metric_list:
if 'acc' in metric:
kind = 'max'
kind = 'min'
self.graph_map[metric] = self.session.graph(metric, kind=kind, display_interval=self.display_interval)
self.x = 0
def on_epoch_end(self, epoch, logs={}):
for metric in self.metric_list:
metric_val = "val_" + metric
if metric_val in logs:
data = {metric_val: logs[metric_val]}
self.graph_map[metric].append(self.x, data)
def on_batch_end(self, batch, logs={}):
for metric in self.metric_list:
data = {metric: logs.get(metric)}
self.graph_map[metric].append(self.x, data)
self.x += 1
def on_train_end(self, logs={}):