Skip to content

Commit

Permalink
Merge pull request #34 from alexwwang/develop
Browse files Browse the repository at this point in the history
Develop - Fix shape mismatch bug in KMaxPooling, improve save load robust and config flexibility
  • Loading branch information
BrikerMan committed Feb 27, 2019
2 parents 17eb3ca + 6f4c39d commit cb19aa9
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 62 deletions.
10 changes: 6 additions & 4 deletions kashgari/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def __init__(self, k=1, sorted=True, data_format='channels_last', **kwargs):
self.sorted = sorted
self.data_format = K.normalize_data_format(data_format)

def build(self, input_shape):
assert len(input_shape) == 3
super(KMaxPooling, self).build(input_shape)
# def build(self, input_shape):
# assert len(input_shape) == 3
# super(KMaxPooling, self).build(input_shape)

def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
Expand All @@ -149,7 +149,9 @@ def call(self, inputs):
return tf.transpose(top_k, [0, 2, 1])

def get_config(self):
config = {'data_format': self.data_format}
config = {'k': self.k,
'sorted': self.sorted,
'data_format': self.data_format}
base_config = super(KMaxPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
84 changes: 78 additions & 6 deletions kashgari/tasks/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
"""
import os
import json
import pickle
import pathlib
import traceback
import logging
logger = logging.getLogger(__name__)
import numpy as np
from typing import Dict

import keras
from keras.models import Model
from keras import backend as K

from kashgari.utils import helper
from kashgari.embeddings import CustomEmbedding, BaseEmbedding
from kashgari.utils.crf import CRF, crf_loss, crf_accuracy
Expand Down Expand Up @@ -73,8 +78,24 @@ def save(self, model_path: str):
with open(os.path.join(model_path, 'model.json'), 'w', encoding='utf-8') as f:
f.write(json.dumps(model_info, indent=2, ensure_ascii=False))

with open(os.path.join(model_path, 'struct.json'), 'w', encoding='utf-8') as f:
f.write(self.model.to_json())

#self.model.save_weights(os.path.join(model_path, 'weights.h5'))
optimizer_weight_values = None
try:
symbolic_weights = getattr(self.model.optimizer, 'weights')
optimizer_weight_values = K.batch_get_value(symbolic_weights)
except Exception as e:
logger.warn('error occur: {}'.format(e))
traceback.print_tb(e.__traceback__)
logger.warn('No optimizer weights found.')
if optimizer_weight_values is not None:
with open(os.path.join(model_path, 'optimizer.pkl'), 'wb') as f:
pickle.dump(optimizer_weight_values, f)

self.model.save(os.path.join(model_path, 'model.model'))
logging.info('model saved to {}'.format(os.path.abspath(model_path)))
logger.info('model saved to {}'.format(os.path.abspath(model_path)))

@staticmethod
def create_custom_objects(model_info):
Expand Down Expand Up @@ -113,15 +134,66 @@ def load_model(cls, model_path: str):
custom_objects = cls.create_custom_objects(model_info)

if custom_objects:
logging.debug('prepared custom objects: {}'.format(custom_objects))

agent.model = keras.models.load_model(os.path.join(model_path, 'model.model'),
custom_objects=custom_objects)
logger.debug('prepared custom objects: {}'.format(custom_objects))

try:
agent.model = keras.models.load_model(os.path.join(model_path, 'model.model'),
custom_objects=custom_objects)
except Exception as e:
logger.warn('Error `{}` occured trying directly model loading. Try to rebuild.'.format(e))
logger.debug('Load model structure from json.')
with open(os.path.join(model_path, 'struct.json'), 'r', encoding='utf-8') as f:
model_struct = f.read()
agent.model = keras.models.model_from_json(model_struct,
custom_objects=custom_objects)
logger.debug('Build optimizer with model info.')
optimizer_conf = model_info['hyper_parameters'].get('optimizer', None)
optimizer = 'adam' #default
if optimizer_conf is not None and isinstance(optimizer_conf, dict):
module_str = optimizer_conf.get('module', 'None')
name_str = optimizer_conf.get('name', 'None')
params = optimizer_conf.get('params', None)
invalid_set = [None, 'None', '', {}]
if not any([module_str.strip() in invalid_set,
name_str.strip() in invalid_set,
params in invalid_set]):
try:
optimizer = getattr(eval(module_str), name_str)(**params)
except:
logger.warn('Invalid optimizer configuration in model info. Use `adam` as default.')
else:
logger.warn('No optimizer configuration found in model info. Use `adam` as default.')

default_compile_params = {'loss': 'categorical_crossentropy', 'metrics':['accuracy']}
compile_params = model_info['hyper_parameters'].get('compile_params', default_compile_params)
logger.debug('Compile model from scratch.')
try:
agent.model.compile(optimizer=optimizer, **compile_params)
except:
logger.warn('Failed to compile model. Compile params seems incorrect.')
logger.warn('Use default options `{}` to compile.'.format(default_compile_params))
agent.model.compile(optimizer=optimizer, **default_compile_params)
logger.debug('Load model weights.')
agent.model.summary()
agent.model.load_weights(os.path.join(model_path, 'model.model'))
agent.model._make_train_function()
optimizer_weight_values = None
logger.debug('Load optimizer weights.')
try:
with open(os.path.join(model_path, 'optimizer.pkl'), 'rb') as f:
optimizer_weight_values = pickle.load(f)
except Exception as e:
logger.warn('Try to load optimizer weights but no optimizer weights file found.')
if optimizer_weight_values is not None:
agent.model.optimizer.set_weights(optimizer_weight_values)
else:
logger.warn('Rebuild model but optimizer weights missed. Retrain needed.')
logger.info('Model rebuild finished.')
agent.embedding.update(model_info.get('embedding', {}))
agent.model.summary()
agent.label2idx = label2idx
agent.embedding.token2idx = token2idx
logging.info('loaded model from {}'.format(os.path.abspath(model_path)))
logger.info('loaded model from {}'.format(os.path.abspath(model_path)))
return agent


Expand Down
Loading

0 comments on commit cb19aa9

Please sign in to comment.