Skip to content

Commit

Permalink
Allow NdarraySpec to be written in saved model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 326121293
Change-Id: I7a4351a9ab3e0381ff5616f67d0e61880f3bb649
  • Loading branch information
Akshay Modi authored and tensorflower-gardener committed Aug 11, 2020
1 parent 6974852 commit b297140
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/protobuf/struct.proto
Expand Up @@ -136,6 +136,7 @@ message TypeSpecProto {
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
VARIABLE_SPEC = 9; // tf.VariableSpec
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
NDARRAY_SPEC = 11; // TF Numpy NDarray spec
}
TypeSpecClass type_spec_class = 1;

Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/saved_model/BUILD
Expand Up @@ -587,6 +587,7 @@ py_strict_library(
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python/distribute:values",
"//tensorflow/python/ops/numpy_ops:numpy",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:row_partition",
"@six_archive//:six",
Expand Down
30 changes: 30 additions & 0 deletions tensorflow/python/saved_model/load_test.py
Expand Up @@ -50,9 +50,11 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numpy_ops as tnp
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import load
Expand Down Expand Up @@ -1810,6 +1812,34 @@ def lookup(key):
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)

def test_saving_ndarray_specs(self, cycles):
class NdarrayModule(module.Module):

@def_function.function
def plain(self, x):
return tnp.add(x, 1)

@def_function.function(input_signature=[
np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))])
def with_signature(self, x):
return tnp.add(x, 1)

m = NdarrayModule()
c = tnp.asarray(3.0, tnp.float32)
output_plain, output_with_signature = m.plain(c), m.with_signature(c)

loaded_m = cycle(m, cycles)

load_output_plain, load_output_with_signature = (
loaded_m.plain(c), loaded_m.with_signature(c))

self.assertIsInstance(output_plain, tnp.ndarray)
self.assertIsInstance(load_output_plain, tnp.ndarray)
self.assertIsInstance(output_with_signature, tnp.ndarray)
self.assertIsInstance(load_output_with_signature, tnp.ndarray)
self.assertAllClose(output_plain, load_output_plain)
self.assertAllClose(output_with_signature, load_output_with_signature)


class SingleCycleTests(test.TestCase, parameterized.TestCase):

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/python/saved_model/nested_structure_coder.py
Expand Up @@ -48,6 +48,7 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import row_partition
from tensorflow.python.util import compat
Expand Down Expand Up @@ -516,6 +517,8 @@ class _TypeSpecCodec(object):
resource_variable_ops.VariableSpec,
struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
row_partition.RowPartitionSpec,
struct_pb2.TypeSpecProto.NDARRAY_SPEC:
np_arrays.NdarraySpec,
}

# Mapping from type (TypeSpec subclass) to enum value.
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/saved_model/nested_structure_coder_test.py
Expand Up @@ -28,6 +28,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
Expand Down Expand Up @@ -331,6 +332,14 @@ def testEncodeDataSetSpec(self):
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)

def testEncodeDecodeNdarraySpec(self):
structure = [np_arrays.NdarraySpec(
tensor_spec.TensorSpec([4, 2], dtypes.float32))]
self.assertTrue(self._coder.can_encode(structure))
encoded = self._coder.encode_structure(structure)
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)

def testNotEncodable(self):

class NotEncodable(object):
Expand Down

0 comments on commit b297140

Please sign in to comment.