Skip to content

Commit

Permalink
Merge pull request #108 from DLHub-Argonne/allow_custom_objects
Browse files Browse the repository at this point in the history
Allow custom objects
  • Loading branch information
WardLT committed Jun 10, 2021
2 parents d920615 + de1c8ef commit 0096fe5
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions dlhub_sdk/models/servables/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,32 +125,30 @@ def format_layer_spec(self, layers):
return compose_argument_block("tuple", "Tuple of tensors",
element_types=[self.format_layer_spec(i) for i in layers])

def add_custom_object(self, name, custom_layer):
def add_custom_object(self, name, custom_object):
"""Add a custom layer to the model specification
See `Keras FAQs
<https://keras.io/getting-started/faq/#handling-custom-layers-or-other-custom-objects-in-saved-models>`
for details.
Args:
name (string): Name of the layer
custom_layer (class): Class of the custom layer
name (string): Name of the custom object
custom_object (class): Class of the custom object
Return:
self
"""

# Get the class name for the custom layer
layer_name = custom_layer.__name__
if not issubclass(custom_layer, keras.layers.Layer):
raise ValueError("Custom layer ({}) must be a subclass of Layer".format(layer_name))
module = custom_layer.__module__
# get the class name for the custom object
object_name = custom_object.__name__
module = custom_object.__module__

# Add the layer to the model definition
if 'options' not in self._output['servable']:
self['servable']['options'] = {}
if 'custom_objects' not in self['servable']['options']:
self['servable']['options']['custom_objects'] = {}
self['servable']['options']['custom_objects'][name] = '{}.{}'.format(module, layer_name)
self['servable']['options']['custom_objects'][name] = '{}.{}'.format(module, object_name)

def _get_handler(self):
return "keras.KerasServable"
Expand Down

0 comments on commit 0096fe5

Please sign in to comment.