diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto index ee0f089f2a3729..c99eab5dd88e4a 100644 --- a/tensorflow/core/protobuf/struct.proto +++ b/tensorflow/core/protobuf/struct.proto @@ -136,6 +136,7 @@ message TypeSpecProto { PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py VARIABLE_SPEC = 9; // tf.VariableSpec ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py + NDARRAY_SPEC = 11; // TF Numpy NDarray spec } TypeSpecClass type_spec_class = 1; diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 45ee73de51c50d..4507118c17cb6a 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -587,6 +587,7 @@ py_strict_library( "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:optional_ops", "//tensorflow/python/distribute:values", + "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/ops/ragged:row_partition", "@six_archive//:six", diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 320182385f895d..54124df6eba22b 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -50,9 +50,11 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numpy_ops as tnp from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.saved_model import load @@ -1810,6 +1812,34 @@ def lookup(key): self.assertEqual(self.evaluate(imported.lookup("foo")), 15) self.assertEqual(self.evaluate(imported.lookup("idk")), -1) + def test_saving_ndarray_specs(self, cycles): + class NdarrayModule(module.Module): + + @def_function.function + def plain(self, x): + return tnp.add(x, 1) + + @def_function.function(input_signature=[ + np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))]) + def with_signature(self, x): + return tnp.add(x, 1) + + m = NdarrayModule() + c = tnp.asarray(3.0, tnp.float32) + output_plain, output_with_signature = m.plain(c), m.with_signature(c) + + loaded_m = cycle(m, cycles) + + load_output_plain, load_output_with_signature = ( + loaded_m.plain(c), loaded_m.with_signature(c)) + + self.assertIsInstance(output_plain, tnp.ndarray) + self.assertIsInstance(load_output_plain, tnp.ndarray) + self.assertIsInstance(output_with_signature, tnp.ndarray) + self.assertIsInstance(load_output_with_signature, tnp.ndarray) + self.assertAllClose(output_plain, load_output_plain) + self.assertAllClose(output_with_signature, load_output_with_signature) + class SingleCycleTests(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index 9c71b8536752f1..a7e5548ee06e72 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -48,6 +48,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition from tensorflow.python.util import compat @@ -516,6 +517,8 @@ class _TypeSpecCodec(object): resource_variable_ops.VariableSpec, struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC: row_partition.RowPartitionSpec, + struct_pb2.TypeSpecProto.NDARRAY_SPEC: + np_arrays.NdarraySpec, } # Mapping from type (TypeSpec subclass) to enum value. diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index 9951ea64a4979f..fb074f76eb0145 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util +from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test from tensorflow.python.saved_model import nested_structure_coder @@ -331,6 +332,14 @@ def testEncodeDataSetSpec(self): decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded) + def testEncodeDecodeNdarraySpec(self): + structure = [np_arrays.NdarraySpec( + tensor_spec.TensorSpec([4, 2], dtypes.float32))] + self.assertTrue(self._coder.can_encode(structure)) + encoded = self._coder.encode_structure(structure) + decoded = self._coder.decode_proto(encoded) + self.assertEqual(structure, decoded) + def testNotEncodable(self): class NotEncodable(object):