Skip to content

Commit

Permalink
Merge pull request #46 from DLHub-Argonne/update_keras
Browse files Browse the repository at this point in the history
Made KerasModel more flexible
  • Loading branch information
WardLT committed Feb 25, 2019
2 parents d4f787b + 6df0e98 commit 327303b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 16 deletions.
71 changes: 65 additions & 6 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from keras import __version__ as keras_version
from keras.models import load_model
from keras.models import load_model, model_from_json, model_from_yaml
from keras.layers import Layer

from dlhub_sdk.models.servables.python import BasePythonServableModel
from dlhub_sdk.utils.types import compose_argument_block


_keras_version_tuple = tuple(int(i) for i in keras_version.split("."))


Expand All @@ -14,33 +14,65 @@ class KerasModel(BasePythonServableModel):
Assumes that the model has been saved to an hdf5 file"""

@classmethod
def create_model(cls, model_path, output_names):
def create_model(cls, model_path, output_names, arch_path=None,
custom_objects=None):
"""Initialize a Keras model.
Args:
model_path (string): Path to the hd5 file that describes a model and the weights
model_path (string): Path to the hd5 file that contains the weights and, optionally,
the architecture
output_names ([string] or [[string]]): Names of output classes.
If applicable, one list for each output layer.
arch_path (string): Path to the hd5 model containing the architecture, if not
available in the file at :code:`model_path`.
custom_objects (dict): Map of layer names to custom layers. See
`Keras Documentation
<https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model>`_
for more details.
"""
output = super(KerasModel, cls).create_model('predict')

# Add model as a file to be sent
output.add_file(model_path, 'model')
if arch_path is not None:
output.add_file(arch_path, 'arch')

# Store the list of custom objects
if custom_objects is not None:
for k, v in custom_objects.items():
output.add_custom_object(k, v)

# Get the model details
model = load_model(model_path)
if arch_path is None:
model = load_model(model_path, custom_objects=custom_objects)
else:
if arch_path.endswith('.h5') or arch_path.endswith('.hdf') \
or arch_path.endswith('.hdf5') or arch_path.endswith('.hd5'):
model = load_model(arch_path, custom_objects=custom_objects, compile=False)
elif arch_path.endswith('.json'):
with open(arch_path) as fp:
json_string = fp.read()
model = model_from_json(json_string, custom_objects=custom_objects)
elif arch_path.endswith('.yml') or arch_path.endswith('.yaml'):
with open(arch_path) as fp:
yaml_string = fp.read()
model = model_from_yaml(yaml_string, custom_objects=custom_objects)
else:
raise ValueError('File type for architecture not recognized')
model.load_weights(model_path)

# Get the inputs of the model
output['servable']['methods']['run']['input'] = output.format_layer_spec(model.input_shape)
output['servable']['methods']['run']['output'] = output.format_layer_spec(
model.output_shape)
model.output_shape)
output['servable']['methods']['run']['method_details']['classes'] = output_names

# Get a full description of the model
output.summary = ""

def capture_summary(x):
output.summary += x + "\n"

model.summary(print_fn=capture_summary)
output['servable']['model_summary'] = output.summary
output['servable']['model_type'] = 'Deep NN'
Expand All @@ -64,6 +96,33 @@ def format_layer_spec(self, layers):
return compose_argument_block("tuple", "Tuple of tensors",
element_types=[self.format_layer_spec(i) for i in layers])

def add_custom_object(self, name, custom_layer):
"""Add a custom layer to the model specification
See `Keras FAQs
<https://keras.io/getting-started/faq/#handling-custom-layers-or-other-custom-objects-in-saved-models>`
for details.
Args:
name (string): Name of the layer
custom_layer (class): Class of the custom layer
Return:
self
"""

# Get the class name for the custom layer
layer_name = custom_layer.__name__
if not issubclass(custom_layer, Layer):
raise ValueError("Custom layer ({}) must be a subclass of Layer".format(layer_name))
module = custom_layer.__module__

# Add the layer to the model definition
if 'options' not in self._output['servable']:
self['servable']['options'] = {}
if 'custom_objects' not in self['servable']['options']:
self['servable']['options']['custom_objects'] = {}
self['servable']['options']['custom_objects'][name] = '{}.{}'.format(module, layer_name)

def _get_handler(self):
return "keras.KerasServable"

Expand Down
86 changes: 77 additions & 9 deletions dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datetime import datetime
import os
import shutil
from tempfile import mkdtemp
import shutil
import os

from h5py import __version__ as h5py_version
from keras import __version__ as keras_version
from keras.layers import Dense, Input
from keras.models import Sequential, Model
from keras.layers import Dense, Input
from h5py import __version__ as h5py_version
from unittest import TestCase

from dlhub_sdk.models.servables.keras import KerasModel
Expand All @@ -17,16 +17,21 @@
_year = str(datetime.now().year)


def _make_simple_model():
model = Sequential()
model.add(Dense(16, input_shape=(1,), activation='relu', name='hidden'))
model.add(Dense(1, name='output'))
model.compile(optimizer='rmsprop', loss='mse')
return model


class TestKeras(TestCase):

maxDiff = 4096

def test_keras_single_input(self):
# Make a Keras model
model = Sequential()
model.add(Dense(16, input_shape=(1,), activation='relu', name='hidden'))
model.add(Dense(1, name='output'))
model.compile(optimizer='rmsprop', loss='mse')
model = _make_simple_model()

# Save it to disk
tempdir = mkdtemp()
Expand Down Expand Up @@ -78,7 +83,7 @@ def test_keras_single_input(self):
Trainable params: 49
Non-trainable params: 0
_________________________________________________________________
""", # noqa: W291 (trailing whitespace)
""", # noqa: W291 (trailing whitespace needed for text match)
"dependencies": {"python": {
'keras': keras_version,
'h5py': h5py_version
Expand Down Expand Up @@ -123,3 +128,66 @@ def test_keras_multioutput(self):
validate_against_dlhub_schema(output, 'servable')
finally:
shutil.rmtree(tempdir)

def test_custom_layers(self):
"""Test adding custom layers to the definition"""

# Make a simple model
model = _make_simple_model()

tmpdir = mkdtemp()
try:
# Save it
model_path = os.path.join(tmpdir, 'model.hd5')
model.save(model_path)

# Create the metadata
metadata = KerasModel.create_model(model_path, ['y'], custom_objects={'Dense': Dense})
metadata.set_title('test').set_name('test')

# Make sure it has the custom object definitions
self.assertEqual({'custom_objects': {'Dense': 'keras.layers.core.Dense'}},
metadata['servable']['options'])

# Validate it against DLHub schema
validate_against_dlhub_schema(metadata.to_dict(), 'servable')
finally:
shutil.rmtree(tmpdir)

# Test the errors
with self.assertRaises(ValueError) as exc:
metadata.add_custom_object('BadLayer', float)
self.assertIn('subclass', str(exc.exception))

def test_multi_file(self):
"""Test adding the architecture in a different file """

# Make a simple model
model = _make_simple_model()

tmpdir = mkdtemp()
try:
# Save it
model_path = os.path.join(tmpdir, 'model.hd5')
model.save(model_path, include_optimizer=False)
model_json = os.path.join(tmpdir, 'model.json')
with open(model_json, 'w') as fp:
print(model.to_json(), file=fp)
model_yaml = os.path.join(tmpdir, 'model.yml')
with open(model_yaml, 'w') as fp:
print(model.to_yaml(), file=fp)
weights_path = os.path.join(tmpdir, 'weights.hd5')
model.save_weights(weights_path)

# Create the metadata
metadata = KerasModel.create_model(weights_path, ['y'], arch_path=model_path)

# Make sure both files are included in the files list
self.assertEqual(metadata['dlhub']['files'],
{'arch': model_path, 'model': weights_path})

# Try it with the JSON and YAML versions
KerasModel.create_model(weights_path, ['y'], arch_path=model_json)
KerasModel.create_model(weights_path, ['y'], arch_path=model_yaml)
finally:
shutil.rmtree(tmpdir)
2 changes: 1 addition & 1 deletion dlhub_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# single source of truth for package version,
# see https://packaging.python.org/en/latest/single_source_version/
__version__ = "0.6.0"
__version__ = "0.6.1"

# app name to send as part of SDK requests
app_name = "DLHub SDK v{}".format(__version__)

0 comments on commit 327303b

Please sign in to comment.