Skip to content

Commit

Permalink
Merge pull request #115 from DLHub-Argonne/fix_tests
Browse files Browse the repository at this point in the history
Various fixes to the tests
  • Loading branch information
WardLT committed Sep 16, 2021
2 parents 4be25d1 + 774175b commit cfaa62e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jobs:
pip install keras==${{matrix.cfg.keras-version}} "tensorflow<2" "h5py<3"
else
# Otherwise, use TF2
pip uninstall keras
pip install "tensorflow>2"
fi
Expand Down
2 changes: 1 addition & 1 deletion dlhub_sdk/models/servables/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _create_v1(self, export_directory: str):
'output_nodes': output_nodes})

# Check if there is a run method
if 'run' not in self['servable']['methods']:
if 'run' not in self.servable.methods:
raise ValueError('There is no default servable for this model.\n'
' Make sure to use '
'tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY '
Expand Down
40 changes: 19 additions & 21 deletions dlhub_sdk/models/servables/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,23 @@
import pytest
import os



try:
import keras
keras_installed = True


keras_installed = True
except ImportError:
keras_installed = False

try:
from tensorflow import keras
keras_installed = True

keras_installed = True
except ImportError:
keras_installed = False

no_keras = pytest.mark.skipif(keras_installed == False, reason='keras not installed')


keras_installed = False

from dlhub_sdk.models.servables.keras import KerasModel
from dlhub_sdk.utils.schemas import validate_against_dlhub_schema

print(keras_installed)

no_keras = pytest.mark.skipif(not keras_installed, reason='keras not installed')
_year = str(datetime.now().year)


Expand Down Expand Up @@ -57,6 +48,7 @@ def test_keras_single_input(tmpdir):
output = metadata.to_dict()
validate_against_dlhub_schema(output, 'servable')


@no_keras
def test_keras_multioutput(tmpdir):
# Make a Keras model
Expand Down Expand Up @@ -89,6 +81,7 @@ def test_keras_multioutput(tmpdir):
# Validate against schema
validate_against_dlhub_schema(output, 'servable')


@no_keras
def test_custom_layers(tmpdir):
"""Test adding custom layers to the definition"""
Expand All @@ -111,6 +104,7 @@ def test_custom_layers(tmpdir):
# Validate it against DLHub schema
validate_against_dlhub_schema(metadata.to_dict(), 'servable')


@no_keras
def test_multi_file(tmpdir):
"""Test adding the architecture in a different file """
Expand All @@ -121,12 +115,6 @@ def test_multi_file(tmpdir):
# Save it
model_path = os.path.join(tmpdir, 'model.hd5')
model.save(model_path, include_optimizer=False)
model_json = os.path.join(tmpdir, 'model.json')
with open(model_json, 'w') as fp:
print(model.to_json(), file=fp)
model_yaml = os.path.join(tmpdir, 'model.yml')
with open(model_yaml, 'w') as fp:
print(model.to_yaml(), file=fp)
weights_path = os.path.join(tmpdir, 'weights.hd5')
model.save_weights(weights_path)

Expand All @@ -136,6 +124,16 @@ def test_multi_file(tmpdir):
# Make sure both files are included in the files list
assert metadata.dlhub.files == {'arch': model_path, 'model': weights_path}

# Try it with the JSON and YAML versions
# Try it with the JSON
model_json = os.path.join(tmpdir, 'model.json')
with open(model_json, 'w') as fp:
print(model.to_json(), file=fp)
KerasModel.create_model(weights_path, ['y'], arch_path=model_json)
KerasModel.create_model(weights_path, ['y'], arch_path=model_yaml)

# Try it with YAML in earlier versions
keras_major_version = tuple(int(x) for x in keras.__version__.split(".")[:2])
if keras_major_version < (2, 6):
model_yaml = os.path.join(tmpdir, 'model.yml')
with open(model_yaml, 'w') as fp:
print(model.to_yaml(), file=fp)
KerasModel.create_model(weights_path, ['y'], arch_path=model_yaml)
2 changes: 1 addition & 1 deletion dlhub_sdk/models/servables/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_torch_single_input(tmpdir):
output = metadata.to_dict()
assert output["dlhub"] == {
"version": __version__, "domains": [],
"visible_to": [],
"visible_to": ['public'],
'type': 'servable',
"name": "mlp", "files": {"model": model_path},
"dependencies": {"python": {
Expand Down

0 comments on commit cfaa62e

Please sign in to comment.