Skip to content

Commit

Permalink
Flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Jul 25, 2019
1 parent 02dbbd1 commit a6cb035
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
3 changes: 2 additions & 1 deletion dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_keras_version_tuple = tuple(int(i) for i in keras_version.split("."))
_summary_limit = 10000


def _detect_backend(output):
"""Add the backend
Expand Down Expand Up @@ -90,7 +91,7 @@ def capture_summary(x):

model.summary(print_fn=capture_summary)
output.summary = (output.summary[:_summary_limit] + '<<TRUNCATED>>') \
if len(output.summary) > _summary_limit else output.summary
if len(output.summary) > _summary_limit else output.summary

output['servable']['model_summary'] = output.summary
output['servable']['model_type'] = 'Deep NN'
Expand Down
12 changes: 8 additions & 4 deletions dlhub_sdk/models/servables/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class TorchModel(BasePythonServableModel):
Assumes that the model has been saved to a pt or a pth file"""

@classmethod
def create_model(cls, model_path, input_shape, output_shape, input_type='float', output_type='float'):
def create_model(cls, model_path, input_shape, output_shape,
input_type='float', output_type='float'):
"""Initialize a PyTorch model.
Args:
Expand All @@ -33,8 +34,10 @@ def create_model(cls, model_path, input_shape, output_shape, input_type='float',
raise ValueError('File type for architecture not recognized')

# Get the inputs of the model
output['servable']['methods']['run']['input'] = output.format_layer_spec(input_shape, input_type)
output['servable']['methods']['run']['output'] = output.format_layer_spec(output_shape, output_type)
output['servable']['methods']['run']['input'] =\
output.format_layer_spec(input_shape, input_type)
output['servable']['methods']['run']['output'] =\
output.format_layer_spec(output_shape, output_type)

output['servable']['model_summary'] = str(model)
output['servable']['model_type'] = 'Deep NN'
Expand All @@ -54,7 +57,8 @@ def format_layer_spec(self, layers, datatypes):
(dict) Description of the inputs / outputs
"""
if isinstance(layers, tuple):
return compose_argument_block("ndarray", "Tensor", shape=list(layers), item_type=datatypes)
return compose_argument_block("ndarray", "Tensor",
shape=list(layers), item_type=datatypes)
else:
if isinstance(datatypes, str):
datatypes = [datatypes] * len(layers)
Expand Down
13 changes: 8 additions & 5 deletions dlhub_sdk/models/servables/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,19 @@ def test_multinetwork(self):
# Test the output shapes
self.assertEqual(metadata['servable']['methods']['run']['input'],
{'type': 'tuple', 'description': 'Tuple of tensors',
'element_types': [{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 4],
'element_types': [{'type': 'ndarray',
'description': 'Tensor', 'shape': [None, 4],
'item_type': {'type': 'float'}},
{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 4],
{'type': 'ndarray',
'description': 'Tensor', 'shape': [None, 4],
'item_type': {'type': 'float'}}]})
self.assertEqual(metadata['servable']['methods']['run']['output'],
{'type': 'tuple', 'description': 'Tuple of tensors',
'element_types': [{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 1],
'element_types': [{'type': 'ndarray',
'description': 'Tensor', 'shape': [None, 1],
'item_type': {'type': 'float'}},
{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 1],
{'type': 'ndarray',
'description': 'Tensor', 'shape': [None, 1],
'item_type': {'type': 'float'}}]})

validate_against_dlhub_schema(metadata.to_dict(), 'servable')

0 comments on commit a6cb035

Please sign in to comment.