Skip to content

Commit

Permalink
Merge pull request #72 from DLHub-Argonne/keras_backend
Browse files Browse the repository at this point in the history
Autodetect Keras backend
  • Loading branch information
WardLT committed Jul 18, 2019
2 parents 40676b6 + 0ce9fa5 commit 7998042
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
19 changes: 19 additions & 0 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
from keras import __version__ as keras_version
from keras.models import load_model, model_from_json, model_from_yaml
from keras.layers import Layer
from keras import backend

from dlhub_sdk.models.servables.python import BasePythonServableModel
from dlhub_sdk.utils.types import compose_argument_block

_keras_version_tuple = tuple(int(i) for i in keras_version.split("."))


def _detect_backend(output):
"""Add the backend
Args:
output (KerasModel): Current description of Keras model, will be modified
"""

# Determine the name of the object
my_backend = backend.backend().lower()

# Add it as a requirement
output.add_requirement(my_backend, 'detect')


class KerasModel(BasePythonServableModel):
"""Servable based on a Keras Model object.
Expand Down Expand Up @@ -80,6 +95,10 @@ def capture_summary(x):
# Add keras as a dependency
output.add_requirement('keras', keras_version)
output.add_requirement('h5py', 'detect')

# Detect backend and get its version
_detect_backend(output)

return output

def format_layer_spec(self, layers):
Expand Down
4 changes: 3 additions & 1 deletion dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import os

from tensorflow import __version__ as tf_version
from keras import __version__ as keras_version
from keras.models import Sequential, Model
from keras.layers import Dense, Input
Expand Down Expand Up @@ -62,7 +63,8 @@ def test_keras_single_input(self):
"name": "mlp", "files": {"model": model_path},
"dependencies": {"python": {
'keras': keras_version,
'h5py': h5py_version
'h5py': h5py_version,
'tensorflow': tf_version
}}},
"servable": {"methods": {"run": {
"input": {"type": "ndarray", "description": "Tensor", "shape": [None, 1]},
Expand Down

0 comments on commit 7998042

Please sign in to comment.