Skip to content

Commit

Permalink
Update callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Mar 3, 2018
1 parent 990fe86 commit 8a8ad78
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions anago/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@
"""
Custom callbacks.
"""
import os

import numpy as np
from keras.callbacks import Callback, TensorBoard, EarlyStopping, ModelCheckpoint
from seqeval.metrics import f1_score


def get_callbacks(log_dir=None, valid=(), tensorboard=True, eary_stopping=True):
def get_callbacks(log_dir=None, valid=(), checkpoint_dir=None, eary_stopping=True):
"""Get callbacks.
Args:
log_dir (str): the destination to save logs(for TensorBoard).
valid (tuple): data for validation.
tensorboard (bool): Whether to use tensorboard.
checkpoint_dir (bool): Whether to use checkpoint.
eary_stopping (bool): whether to use early stopping.
Returns:
list: list of callbacks
"""
callbacks = []

if log_dir and tensorboard:
if log_dir:
if not os.path.exists(log_dir):
print('Successfully made a directory: {}'.format(log_dir))
os.mkdir(log_dir)
print('Successfully made a directory: {}'.format(log_dir))
callbacks.append(TensorBoard(log_dir))

if valid:
callbacks.append(F1score(*valid))

if log_dir:
if not os.path.exists(log_dir):
print('Successfully made a directory: {}'.format(log_dir))
os.mkdir(log_dir)
if checkpoint_dir:
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
print('Successfully made a directory: {}'.format(checkpoint_dir))

file_name = '_'.join(['model_weights', '{epoch:02d}', '{f1:2.2f}']) + '.h5'
save_callback = ModelCheckpoint(os.path.join(log_dir, file_name),
monitor='f1',
save_weights_only=True)
save_callback = ModelCheckpoint(os.path.join(checkpoint_dir, file_name),
monitor='f1', save_weights_only=True)
callbacks.append(save_callback)

if eary_stopping:
Expand Down

0 comments on commit 8a8ad78

Please sign in to comment.