Skip to content

Commit 20bc267

Browse files
authored
Fix export for subclass models with multiple inputs. (#19720)
The export now supports subclasses of `Model` for which the `call` method takes more than one input argument. Note that it is required for the model class to implement a `build` method with a signature that matches the `call` method.
1 parent 6e40533 commit 20bc267

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

keras/src/export/export_lib.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -621,18 +621,17 @@ def export_model(model, filepath):
621621
input_signature = [input_signature]
622622
export_archive.add_endpoint("serve", model.__call__, input_signature)
623623
else:
624-
save_spec = _get_save_spec(model)
625-
if not save_spec or not model._called:
624+
input_signature = _get_input_signature(model)
625+
if not input_signature or not model._called:
626626
raise ValueError(
627627
"The model provided has never called. "
628628
"It must be called at least once before export."
629629
)
630-
input_signature = [save_spec]
631630
export_archive.add_endpoint("serve", model.__call__, input_signature)
632631
export_archive.write_out(filepath)
633632

634633

635-
def _get_save_spec(model):
634+
def _get_input_signature(model):
636635
shapes_dict = getattr(model, "_build_shapes_dict", None)
637636
if not shapes_dict:
638637
return None
@@ -654,16 +653,7 @@ def make_tensor_spec(structure):
654653
f"Unsupported type {type(structure)} for {structure}"
655654
)
656655

657-
if len(shapes_dict) == 1:
658-
value = list(shapes_dict.values())[0]
659-
return make_tensor_spec(value)
660-
661-
specs = {}
662-
for key, value in shapes_dict.items():
663-
key = key.rstrip("_shape")
664-
specs[key] = make_tensor_spec(value)
665-
666-
return specs
656+
return [make_tensor_spec(value) for value in shapes_dict.values()]
667657

668658

669659
@keras_export("keras.layers.TFSMLayer")

keras/src/export/export_lib_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,31 @@ def call(self, inputs):
145145
)
146146
revived_model.serve(bigger_input)
147147

148+
def test_model_with_multiple_inputs(self):
149+
150+
class TwoInputsModel(models.Model):
151+
def call(self, x, y):
152+
return x + y
153+
154+
def build(self, y_shape, x_shape):
155+
self.built = True
156+
157+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
158+
model = TwoInputsModel()
159+
ref_input_x = tf.random.normal((3, 10))
160+
ref_input_y = tf.random.normal((3, 10))
161+
ref_output = model(ref_input_x, ref_input_y)
162+
163+
export_lib.export_model(model, temp_filepath)
164+
revived_model = tf.saved_model.load(temp_filepath)
165+
self.assertAllClose(
166+
ref_output, revived_model.serve(ref_input_x, ref_input_y)
167+
)
168+
# Test with a different batch size
169+
revived_model.serve(
170+
tf.random.normal((6, 10)), tf.random.normal((6, 10))
171+
)
172+
148173
@parameterized.named_parameters(
149174
named_product(model_type=["sequential", "functional", "subclass"])
150175
)

0 commit comments

Comments
 (0)