Skip to content

Commit 55a1146

Browse files
SeeForTwotensorflower-gardener
authored andcommitted
rollforward of cl/441571702: Make RemoteCall decide if its outputs are host memory types using fulltype.
NEW: Removed DT_TO_FT and full_type_from_spec from structure.py and related tests from structure_test.py. Added fulltype_list_to_product to type_utils.py. multi_device_iterator_ops.py now uses full_types_for_flat_tensors and fulltype_list_to_product from python/framework/type_utils.py. PiperOrigin-RevId: 463976466
1 parent f0f17a1 commit 55a1146

File tree

7 files changed

+142
-7
lines changed

7 files changed

+142
-7
lines changed

tensorflow/core/kernels/function_ops.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ limitations under the License.
2525
#include "tensorflow/core/common_runtime/graph_constructor.h"
2626
#include "tensorflow/core/common_runtime/memory_types.h"
2727
#include "tensorflow/core/framework/cancellation.h"
28+
#include "tensorflow/core/framework/full_type.pb.h"
29+
#include "tensorflow/core/framework/full_type_util.h"
2830
#include "tensorflow/core/framework/op.h"
2931
#include "tensorflow/core/framework/register_types.h"
3032
#include "tensorflow/core/graph/algorithm.h"
@@ -273,7 +275,8 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
273275
REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_DEFAULT),
274276
SymbolicGradientOp);
275277

276-
RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
278+
RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx)
279+
: AsyncOpKernel(ctx), return_type_(ctx->def().experimental_type()) {
277280
OP_REQUIRES_OK(ctx,
278281
ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
279282
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
@@ -358,9 +361,30 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
358361
opts.args_alloc_attrs.push_back(arg_alloc_attrs);
359362
}
360363
opts.rets_alloc_attrs.reserve(output_dtypes_.size());
364+
DCHECK(!return_type_.IsInitialized() ||
365+
(return_type_.type_id() == TFT_UNSET) ||
366+
(output_dtypes_.size() == return_type_.args_size()))
367+
<< "RemoteCall op has a full type information for "
368+
<< return_type_.args_size() << " outputs but the number of outputs is "
369+
<< output_dtypes_.size();
361370
for (const auto& dtype : output_dtypes_) {
362371
AllocatorAttributes ret_alloc_attrs;
363-
ret_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
372+
bool on_host = DataTypeAlwaysOnHost(dtype);
373+
if (return_type_.IsInitialized() && (return_type_.type_id() != TFT_UNSET)) {
374+
DCHECK(return_type_.type_id() == TFT_PRODUCT)
375+
<< return_type_.DebugString();
376+
FullTypeDef ftd = full_type::GetArgDefaultUnset(
377+
return_type_, opts.rets_alloc_attrs.size());
378+
if (full_type::IsHostMemoryType(ftd)) {
379+
on_host = true;
380+
}
381+
VLOG(5) << "FulltypeDef for RemoteCall output="
382+
<< opts.rets_alloc_attrs.size()
383+
<< ", IsHostMemoryType=" << full_type::IsHostMemoryType(ftd)
384+
<< ":\n"
385+
<< ftd.DebugString();
386+
}
387+
ret_alloc_attrs.set_on_host(on_host);
364388
opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
365389
}
366390
auto* rets = new std::vector<Tensor>;

tensorflow/core/kernels/function_ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#ifndef TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_
1717
#define TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_
1818

19+
#include "tensorflow/core/framework/full_type_util.h"
1920
#include "tensorflow/core/framework/function.h"
2021
#include "tensorflow/core/framework/op_kernel.h"
2122

@@ -70,6 +71,10 @@ class RemoteCallOp : public AsyncOpKernel {
7071
NameAttrList func_;
7172
DataTypeVector input_dtypes_;
7273
DataTypeVector output_dtypes_;
74+
// Note that in the future if all RemoteCall ops have full type
75+
// information, the kernel will not need access to the "Tout" Attr and
76+
// return_type_ will replace output_dtypes_.
77+
FullTypeDef return_type_;
7378

7479
mutex mu_;
7580
typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget;

tensorflow/python/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,6 +3879,24 @@ cuda_py_test(
38793879
],
38803880
)
38813881

3882+
cuda_py_test(
3883+
name = "factory_ops_test",
3884+
size = "small",
3885+
srcs = ["ops/factory_ops_test.py"],
3886+
tags = [
3887+
"no_gpu", # TODO(b/213596871): a similar test times out (delete
3888+
# the "no_gpu" tag once this is bug is fully resolved)
3889+
],
3890+
deps = [
3891+
":sparse_ops",
3892+
"//tensorflow/python:platform_test",
3893+
"//tensorflow/python/data/ops:dataset_ops",
3894+
"//tensorflow/python/distribute:mirrored_strategy",
3895+
"//tensorflow/python/eager:def_function",
3896+
"@absl_py//absl/testing:parameterized",
3897+
],
3898+
)
3899+
38823900
tf_gen_op_wrapper_private_py(
38833901
name = "decode_proto_ops_gen",
38843902
deps = [

tensorflow/python/data/ops/multi_device_iterator_ops.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.framework import ops
2626
from tensorflow.python.framework import tensor_spec
2727
from tensorflow.python.framework import type_spec
28+
from tensorflow.python.framework import type_utils
2829
from tensorflow.python.ops import array_ops
2930
from tensorflow.python.ops import control_flow_ops
3031
from tensorflow.python.ops import functional_ops
@@ -89,11 +90,22 @@ def _next_func(string_handle):
8990
attributes={"experimental_ints_on_device": True},
9091
autograph=False) # Pure graph code.
9192
def _remote_next_func(string_handle):
92-
return functional_ops.remote_call(
93+
return_values = functional_ops.remote_call(
9394
target=source_device,
9495
args=[string_handle] + next_func_concrete.captured_inputs,
9596
Tout=structure.get_flat_tensor_types(self._element_spec),
9697
f=next_func_concrete)
98+
# Add full type information to the graph so that the RemoteCall op
99+
# can determine for each of its outputs whether or not they are ragged
100+
# tensors (or other types that use variants) that contain strings
101+
# (or other host memory types). Then RemoteCall can
102+
# appropriately set AllocatorAttributes to control copies so
103+
# strings/host memory types stay on CPU.
104+
fulltype_list = type_utils.fulltypes_for_flat_tensors(self._element_spec)
105+
fulltype = type_utils.fulltype_list_to_product(fulltype_list)
106+
for return_value in return_values:
107+
return_value.op.experimental_set_type(fulltype)
108+
return return_values
97109

98110
self._next_func = _remote_next_func.get_concrete_function()
99111
self._next_captured_args = self._next_func.captured_inputs

tensorflow/python/framework/type_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,10 @@ def fulltypes_for_flat_tensors(element_spec):
162162
specs = _specs_for_flat_tensors(element_spec)
163163
full_types_lists = [_translate_to_fulltype_for_flat_tensors(s) for s in specs]
164164
rval = nest.flatten(full_types_lists) # flattens list-of-list to flat list.
165-
assert len(rval) == len(element_spec._flat_tensor_specs) # pylint: disable=protected-access
166165
return rval
166+
167+
168+
def fulltype_list_to_product(fulltype_list):
169+
"""Convert a list of FullType Def into a single FullType Def."""
170+
return full_type_pb2.FullTypeDef(
171+
type_id=full_type_pb2.TFT_PRODUCT, args=fulltype_list)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests that sparse tensors work with GPU, such as placement of int and string.
16+
17+
Test using sparse tensors with distributed dataset. Since GPU does
18+
not support strings, sparse tensors containing string should always be placed
19+
on CPU.
20+
"""
21+
22+
from absl.testing import parameterized
23+
from tensorflow.python.data.ops import dataset_ops
24+
from tensorflow.python.distribute import mirrored_strategy
25+
from tensorflow.python.eager import def_function
26+
from tensorflow.python.framework import constant_op
27+
from tensorflow.python.framework import dtypes
28+
from tensorflow.python.framework import sparse_tensor
29+
from tensorflow.python.framework import test_util
30+
from tensorflow.python.ops import sparse_ops
31+
from tensorflow.python.platform import test
32+
33+
34+
def sparse_int64():
35+
return sparse_tensor.SparseTensor(
36+
indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
37+
values=constant_op.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.int64),
38+
dense_shape=[8, 4])
39+
40+
41+
def sparse_str():
42+
return sparse_tensor.SparseTensor(
43+
indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
44+
values=constant_op.constant(['1', '2', '3', '4', '5', '6', '7', '8']),
45+
dense_shape=[8, 4])
46+
47+
48+
class FactoryOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
49+
50+
@parameterized.parameters(
51+
(sparse_int64,),
52+
(sparse_str,),
53+
)
54+
def testSparseWithDistributedDataset(self, sparse_factory):
55+
56+
@def_function.function
57+
def distributed_dataset_producer(t):
58+
strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1'])
59+
sparse_ds = dataset_ops.Dataset.from_tensor_slices(t).batch(2)
60+
dist_dataset = strategy.experimental_distribute_dataset(sparse_ds)
61+
ds = iter(dist_dataset)
62+
return strategy.experimental_local_results(next(ds))[0]
63+
64+
t = sparse_factory()
65+
66+
result = distributed_dataset_producer(t)
67+
self.assertAllEqual(
68+
self.evaluate(sparse_ops.sparse_tensor_to_dense(t)[0]),
69+
self.evaluate(sparse_ops.sparse_tensor_to_dense(result)[0]))
70+
71+
72+
if __name__ == '__main__':
73+
test.main()

tensorflow/python/ops/ragged/ragged_factory_ops_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,13 @@ def distributed_dataset_producer(t):
128128
return strategy.experimental_local_results(next(ds))[0]
129129

130130
t = ragged_factory()
131-
if t.dtype == dtypes.string:
132-
self.skipTest('b/194439197: fix ragged tensor of string')
133131

134132
result = distributed_dataset_producer(t)
135133
self.assertAllEqual(self.evaluate(t[0]), self.evaluate(result[0]))
136134

137135
@parameterized.parameters(
138136
(dense_str,),
139-
# (ragged_str,), # TODO(b/194439197) fix ragged tensor of string
137+
(ragged_str,),
140138
)
141139
def testIntStringWithDistributedDataset(self, string_factory):
142140

0 commit comments

Comments
 (0)