Skip to content

Commit

Permalink
Merge pull request #81 from DLHub-Argonne/tf_keras
Browse files Browse the repository at this point in the history
Added support for tf.keras
  • Loading branch information
WardLT committed Aug 24, 2020
2 parents 6b88cb2 + 3399691 commit 7034b37
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 65 deletions.
30 changes: 18 additions & 12 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from keras import __version__ as keras_version
from keras.models import load_model, model_from_json, model_from_yaml
from keras.layers import Layer
from keras import backend
try:
# Attempt to use Tensorflow Keras first
from tensorflow import keras
except ImportError:
import keras


from dlhub_sdk.models.servables.python import BasePythonServableModel
from dlhub_sdk.utils.types import compose_argument_block

_keras_version_tuple = tuple(int(i) for i in keras_version.split("."))

keras_version = keras.__version__
_keras_version_tuple = tuple(int(i) for i in keras_version.rstrip("-tf").split("."))
_summary_limit = 10000


Expand All @@ -18,7 +22,7 @@ def _detect_backend(output):
"""

# Determine the name of the object
my_backend = backend.backend().lower()
my_backend = keras.backend.backend().lower()

# Add it as a requirement
output.add_requirement(my_backend, 'detect')
Expand Down Expand Up @@ -59,19 +63,20 @@ def create_model(cls, model_path, output_names=None, arch_path=None,

# Get the model details
if arch_path is None:
model = load_model(model_path, custom_objects=custom_objects)
model = keras.models.load_model(model_path, custom_objects=custom_objects)
else:
if arch_path.endswith('.h5') or arch_path.endswith('.hdf') \
or arch_path.endswith('.hdf5') or arch_path.endswith('.hd5'):
model = load_model(arch_path, custom_objects=custom_objects, compile=False)
model = keras.models.load_model(arch_path,
custom_objects=custom_objects, compile=False)
elif arch_path.endswith('.json'):
with open(arch_path) as fp:
json_string = fp.read()
model = model_from_json(json_string, custom_objects=custom_objects)
model = keras.models.model_from_json(json_string, custom_objects=custom_objects)
elif arch_path.endswith('.yml') or arch_path.endswith('.yaml'):
with open(arch_path) as fp:
yaml_string = fp.read()
model = model_from_yaml(yaml_string, custom_objects=custom_objects)
model = keras.models.model_from_yaml(yaml_string, custom_objects=custom_objects)
else:
raise ValueError('File type for architecture not recognized')
model.load_weights(model_path)
Expand All @@ -97,7 +102,8 @@ def capture_summary(x):
output['servable']['model_type'] = 'Deep NN'

# Add keras as a dependency
output.add_requirement('keras', keras_version)
if not keras_version.endswith("-tf"):
output.add_requirement('keras', keras_version)
output.add_requirement('h5py', 'detect')

# Detect backend and get its version
Expand Down Expand Up @@ -135,7 +141,7 @@ def add_custom_object(self, name, custom_layer):

# Get the class name for the custom layer
layer_name = custom_layer.__name__
if not issubclass(custom_layer, Layer):
if not issubclass(custom_layer, keras.layers.Layer):
raise ValueError("Custom layer ({}) must be a subclass of Layer".format(layer_name))
module = custom_layer.__module__

Expand Down
57 changes: 4 additions & 53 deletions dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
import shutil
import os

from tensorflow import __version__ as tf_version
from keras import __version__ as keras_version
from keras.models import Sequential, Model
from keras.layers import Dense, Input
from h5py import __version__ as h5py_version
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input
from unittest import TestCase

from dlhub_sdk.models.servables.keras import KerasModel
from dlhub_sdk.utils.schemas import validate_against_dlhub_schema
from dlhub_sdk.version import __version__


_year = str(datetime.now().year)
Expand Down Expand Up @@ -45,52 +41,8 @@ def test_keras_single_input(self):
metadata.set_title('Keras Test')
metadata.set_name('mlp')

output = metadata.to_dict()
self.assertEqual(output, {
"datacite": {"creators": [], "titles": [{"title": "Keras Test"}],
"publisher": "DLHub", "publicationYear": _year,
"identifier": {"identifier": "10.YET/UNASSIGNED",
"identifierType": "DOI"},
"resourceType": {"resourceTypeGeneral": "InteractiveResource"},
"descriptions": [],
"fundingReferences": [],
"relatedIdentifiers": [],
"alternateIdentifiers": [],
"rightsList": []},
"dlhub": {"version": __version__, "domains": [],
"visible_to": ["public"],
'type': 'servable',
"name": "mlp", "files": {"model": model_path},
"dependencies": {"python": {
'keras': keras_version,
'h5py': h5py_version,
'tensorflow': tf_version
}}},
"servable": {"methods": {"run": {
"input": {"type": "ndarray", "description": "Tensor", "shape": [None, 1]},
"output": {"type": "ndarray", "description": "Tensor",
"shape": [None, 1]}, "parameters": {},
"method_details": {
"method_name": "predict",
"classes": ["y"]
}}},
"type": "Keras Model",
"shim": "keras.KerasServable",
"model_type": "Deep NN",
"model_summary": """_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
hidden (Dense) (None, 16) 32
_________________________________________________________________
output (Dense) (None, 1) 17
=================================================================
Total params: 49
Trainable params: 49
Non-trainable params: 0
_________________________________________________________________
"""}}) # noqa: W291 (trailing whitespace needed for text match)

# Validate against schema
output = metadata.to_dict()
validate_against_dlhub_schema(output, 'servable')
finally:
shutil.rmtree(tempdir)
Expand Down Expand Up @@ -147,8 +99,7 @@ def test_custom_layers(self):
metadata.set_title('test').set_name('test')

# Make sure it has the custom object definitions
self.assertEqual({'custom_objects': {'Dense': 'keras.layers.core.Dense'}},
metadata['servable']['options'])
self.assertIn('Dense', metadata['servable']['options']['custom_objects'])

# Validate it against DLHub schema
validate_against_dlhub_schema(metadata.to_dict(), 'servable')
Expand Down

0 comments on commit 7034b37

Please sign in to comment.