Skip to content

Commit

Permalink
Do not list keras as a requirement when using tf.keras
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Jul 19, 2021
1 parent c261b6a commit 5c3b562
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 5 additions & 2 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ def create_model(cls, model_path, output_names=None, arch_path=None,
if tf_keras is None:
raise ValueError('You forced tf_keras but do not have tensorflow.keras')
keras = tf_keras
use_tf_keras = True
elif keras_keras is not None:
# Use old keras by default, as users may have gone out of their way to install it
keras = keras_keras
use_tf_keras = False
if tf_keras is not None:
logging.warning('Model publication will use standalone keras, yet you have tf.keras installed. '
'If you don\'t want this, use ``force_tf_keras=True``.')
'If you want your model to use tf.keras, use ``force_tf_keras=True``.')
elif tf_keras is not None:
keras = tf_keras
use_tf_keras = True
else:
raise ValueError('You do not have any version of keras installed.')

Expand Down Expand Up @@ -122,7 +125,7 @@ def capture_summary(x):
# Add keras as a dependency
keras_version = keras.__version__
_keras_version_tuple = tuple(int(i) for i in keras_version.rstrip("-tf").split("."))
if not keras_version.endswith("-tf"):
if not (keras_version.endswith("-tf") or use_tf_keras):
output.add_requirement('keras', keras_version)
output.add_requirement('h5py', 'detect')

Expand Down
6 changes: 6 additions & 0 deletions dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

try:
import keras
keras_installed = True
except ImportError:
from tensorflow import keras
keras_installed = False
from unittest import TestCase

from dlhub_sdk.models.servables.keras import KerasModel
Expand Down Expand Up @@ -79,6 +81,10 @@ def test_keras_multioutput(self):

output = metadata.to_dict()

# Make sure keras not used if it is not installed
if not keras_installed:
assert 'keras' not in metadata['dlhub']['dependencies']['python']

# Validate against schema
validate_against_dlhub_schema(output, 'servable')
finally:
Expand Down

0 comments on commit 5c3b562

Please sign in to comment.