Skip to content

Commit

Permalink
Provide an option to force the use of tf.keras
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Jul 14, 2021
1 parent 91b9483 commit c261b6a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
42 changes: 31 additions & 11 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import logging

# Attempt to load both tf.keras and plain-old keras
try:
# Attempt to use regular keras first, as we figure it's installed for a reason
import keras
import keras as keras_keras
except ImportError:
from tensorflow import keras

keras_keras = None
try:
from tensorflow import keras as tf_keras
except ImportError:
tf_keras = None

from dlhub_sdk.models.servables.python import BasePythonServableModel
from dlhub_sdk.utils.types import compose_argument_block


keras_version = keras.__version__
_keras_version_tuple = tuple(int(i) for i in keras_version.rstrip("-tf").split("."))
logger = logging.getLogger(__name__)
_summary_limit = 10000


def _detect_backend(output):
def _detect_backend(keras, output):
"""Add the backend
Args:
Expand All @@ -35,7 +38,7 @@ class KerasModel(BasePythonServableModel):

@classmethod
def create_model(cls, model_path, output_names=None, arch_path=None,
custom_objects=None):
custom_objects=None, force_tf_keras: bool = False) -> 'KerasModel':
"""Initialize a Keras model.
Args:
Expand All @@ -48,8 +51,23 @@ def create_model(cls, model_path, output_names=None, arch_path=None,
`Keras Documentation
<https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model>`_
for more details.
force_tf_keras (bool): Force the use of TF.Keras even if keras is installed
"""
output = super(KerasModel, cls).create_model('predict')
output: KerasModel = super(KerasModel, cls).create_model('predict')
if force_tf_keras:
if tf_keras is None:
raise ValueError('You forced tf_keras but do not have tensorflow.keras')
keras = tf_keras
elif keras_keras is not None:
# Use old keras by default, as users may have gone out of their way to install it
keras = keras_keras
if tf_keras is not None:
logging.warning('Model publication will use standalone keras, yet you have tf.keras installed. '
'If you don\'t want this, use ``force_tf_keras=True``.')
elif tf_keras is not None:
keras = tf_keras
else:
raise ValueError('You do not have any version of keras installed.')

# Add model as a file to be sent
output.add_file(model_path, 'model')
Expand Down Expand Up @@ -102,12 +120,14 @@ def capture_summary(x):
output['servable']['model_type'] = 'Deep NN'

# Add keras as a dependency
keras_version = keras.__version__
_keras_version_tuple = tuple(int(i) for i in keras_version.rstrip("-tf").split("."))
if not keras_version.endswith("-tf"):
output.add_requirement('keras', keras_version)
output.add_requirement('h5py', 'detect')

# Detect backend and get its version
_detect_backend(output)
_detect_backend(keras, output)

return output

Expand Down
5 changes: 0 additions & 5 deletions dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ def test_custom_layers(self):
finally:
shutil.rmtree(tmpdir)

# Test the errors
with self.assertRaises(ValueError) as exc:
metadata.add_custom_object('BadLayer', float)
self.assertIn('subclass', str(exc.exception))

def test_multi_file(self):
"""Test adding the architecture in a different file """

Expand Down

0 comments on commit c261b6a

Please sign in to comment.