diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ef1ebcb6dcf..ce6850eb9da 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1237,6 +1237,7 @@ tf_gen_op_libs( "set_ops", "script_ops", "sendrecv_ops", + "slice_sendrecv_ops", "sparse_ops", "spectral_ops", "state_ops", @@ -1497,6 +1498,7 @@ cc_library( ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", ":set_ops_op_lib", + ":slice_sendrecv_ops_op_lib", ":sparse_ops_op_lib", ":star_run_graph_op_op_lib", ":summary_ops_op_lib", diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 255c0326e02..3c2b20379c8 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -80,6 +80,8 @@ class Rendezvous : public core::RefCounted { friend class SendOp; friend class RecvOp; friend class FuseRecvOp; + friend class SliceSendOp; + friend class SliceRecvOp; friend class RefSendOp; friend class RefRecvOp; string buf_; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 8ba5d345837..d9709d39f3f 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -69,11 +69,13 @@ const std::unordered_map& Node::kNodeClassTable = {"_Send", NC_SEND}, {"_HostSend", NC_HOST_SEND}, {"_RefSend", NC_REF_SEND}, + {"_SliceSend", NC_SLICE_SEND}, {"_Recv", NC_RECV}, {"_HostRecv", NC_HOST_RECV}, {"_RefRecv", NC_REF_RECV}, {"_FuseRecv", NC_FUSE_RECV}, {"_HostFuseRecv", NC_HOST_FUSE_RECV}, + {"_SliceRecv", NC_SLICE_RECV}, {"Const", NC_CONSTANT}, {"HostConst", NC_CONSTANT}, {"Variable", NC_VARIABLE}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 0e7e032c9a5..0baf8f257a9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -219,12 +219,16 @@ class Node { bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND || - class_ == NC_REF_SEND; } + class_ == NC_REF_SEND || + class_ == NC_SLICE_SEND; } + bool IsSliceSend() const { return class_ == NC_SLICE_SEND; } bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV || - class_ == NC_REF_RECV; } + class_ == NC_REF_RECV || + class_ == NC_SLICE_RECV; } bool IsFuseRecv() const { return class_ == NC_FUSE_RECV || class_ == NC_HOST_FUSE_RECV; } + bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; } bool IsConstant() const { return class_ == NC_CONSTANT; } bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; } bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; } @@ -334,11 +338,13 @@ class Node { NC_SEND, NC_HOST_SEND, NC_REF_SEND, + NC_SLICE_SEND, NC_RECV, NC_HOST_RECV, NC_REF_RECV, NC_FUSE_RECV, NC_HOST_FUSE_RECV, + NC_SLICE_RECV, NC_CONSTANT, NC_VARIABLE, NC_KV_VAR_HANDLE, @@ -844,7 +850,9 @@ inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } inline bool IsSend(const Node* node) { return node->IsSend(); } +inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); } inline bool IsRecv(const Node* node) { return node->IsRecv(); } +inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); } inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); } inline bool IsHostSend(const Node* node) { return node->IsHostSend(); } inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); } diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index a3a521fa123..1201623ffcd 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -454,7 +454,7 @@ bool IsReciprocalGrad(const NodeDef& node) { } bool IsRecv(const NodeDef& node) { - return node.op() == "_Recv" || node.op() == "_HostRecv"; + return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node); } bool IsFuseRecv(const NodeDef& node) { @@ -502,7 +502,7 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; } bool IsSend(const NodeDef& node) { - return node.op() == "_Send" || node.op() == "_HostSend"; + return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node); } bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } @@ -517,6 +517,10 @@ bool IsSize(const NodeDef& node) { return node.op() == "Size"; } bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } +bool IsSliceRecv(const NodeDef& node) { return node.op() == "_SliceRecv"; } + +bool IsSliceSend(const NodeDef& node) { return node.op() == "_SliceSend"; } + bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; } bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 19699ccb933..737581fd412 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -167,6 +167,8 @@ bool IsShuffle(const NodeDef& node); bool IsSigmoidGrad(const NodeDef& node); bool IsSize(const NodeDef& node); bool IsSlice(const NodeDef& node); +bool IsSliceRecv(const NodeDef& node); +bool IsSliceSend(const NodeDef& node); bool IsSnapshot(const NodeDef& node); bool IsSoftmax(const NodeDef& node); bool IsSoftplusGrad(const NodeDef& node); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0c08c30c30a..36721527cc2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5423,8 +5423,9 @@ cc_library( name = "required", deps = [ ":no_op", - ":sendrecv_ops", ":fuserecv_ops", + ":sendrecv_ops", + ":slice_sendrecv_ops", ], ) @@ -5445,6 +5446,12 @@ tf_kernel_library( deps = REQUIRED_DEPS, ) +tf_kernel_library( + name = "slice_sendrecv_ops", + prefix = "slice_sendrecv_ops", + deps = REQUIRED_DEPS, +) + tf_kernel_library( name = "group_embedding_ops", hdrs = ["group_embedding/group_embedding_lookup_sparse_forward_base_ops.h"], @@ -5509,6 +5516,24 @@ tf_cc_test( ], ) +tf_cc_test( + name = "slice_sendrecv_ops_test", + srcs = ["slice_sendrecv_ops_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":control_flow_ops", + ":cwise_op", + ":logging_ops", + ":ops_testutil", + ":ops_util", + ":slice_sendrecv_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "fuserecv_ops", prefix = "fuserecv_ops", diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.cc b/tensorflow/core/kernels/slice_sendrecv_ops.cc new file mode 100644 index 00000000000..f09f314ae10 --- /dev/null +++ b/tensorflow/core/kernels/slice_sendrecv_ops.cc @@ -0,0 +1,562 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/slice_sendrecv_ops.h" + +namespace tensorflow { + +//------------------------------------------------------------------------------ +// Utils. +static string GetSliceRendezvousKeyPrefix(const string& send_device, + const string& recv_device, + const uint64 send_device_incarnation, + const string& tensor_name) { + return strings::StrCat(send_device, ";", + strings::FpToString(send_device_incarnation), ";", + recv_device, ";", tensor_name); +} + +static void GetSliceRendezvousKey(const string& key_prefix, + const string& tensor_name_suffix, + const FrameAndIter& frame_iter, string* key) { + key->clear(); + strings::StrAppend(key, key_prefix, tensor_name_suffix, ";", + frame_iter.frame_id, ":", frame_iter.iter_id); +} + +static FrameAndIter GetFrameAndIter(OpKernelContext* ctx, + bool hostmem_sendrecv) { + if (hostmem_sendrecv && ctx->call_frame() != nullptr) { + // Host memory send/recv pairs are added by + // common_runtime/memory_types.cc. When the pair of nodes are + // added inside a function, we need to use the function call frame + // to formulate the unique rendezvous key. + return FrameAndIter(reinterpret_cast(ctx->call_frame()), 0); + } else { + return ctx->frame_iter(); + } +} + +//------------------------------------------------------------------------------ +// Functions of SliceSendOp. + +SliceSendOp::SliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = \ + GetSliceRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); + if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { + hostmem_sendrecv_ = false; + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_size", &slice_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); +} + +void SliceSendOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + + const Tensor& input_t = ctx->input(0); + FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); + + // send total_bytes. + OP_REQUIRES_OK(ctx, SendTotalBytes(ctx, frame_iter, input_t)); + // if input is dead, only send total_bytes dead tensor. + if (ctx->is_input_dead()) { + return; + } + + // if total bytes is smaller than slice size, send directly. + if (input_t.TotalBytes() <= slice_size_) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + + Rendezvous::ParsedKey parsed_key; + GetSliceRendezvousKey(key_prefix_, "_transfer_data", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed_key, args, input_t, + ctx->is_input_dead())); + return; + } + + // send shape. + OP_REQUIRES_OK(ctx, SendShape(ctx, frame_iter, input_t)); + + // send data. + if (dtype_ == DT_STRING) { + OP_REQUIRES_OK(ctx, SendString(ctx, frame_iter, input_t)); + } else { + OP_REQUIRES_OK(ctx, SendBasicType(ctx, frame_iter, input_t)); + } +} + +Status SliceSendOp::SendTotalBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const Tensor& input_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + + Rendezvous::ParsedKey parsed_key; + Tensor total_bytes_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}), + &total_bytes_t)); + total_bytes_t.scalar()() = input_t.TotalBytes(); + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_totalbytes", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + return ctx->rendezvous()->Send(parsed_key, args, total_bytes_t, + ctx->is_input_dead()); +} + +Status SliceSendOp::SendShape(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const Tensor& input_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + Rendezvous::ParsedKey parsed_key; + + Tensor shape_t; + TensorShape shape = input_t.shape(); + const int rank = shape.dims(); + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({rank}), + &shape_t)); + auto shape_vec = shape_t.vec(); + for (int i = 0; i < rank; i++) { + shape_vec(i) = shape.dim_size(i); + } + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + return ctx->rendezvous()->Send(parsed_key, args, shape_t, + ctx->is_input_dead()); +} + +Status SliceSendOp::SendString(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const Tensor& input_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + Rendezvous::ParsedKey parsed_key; + + // send elements size. + Tensor elements_size_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, input_t.shape(), + &elements_size_t)); + int64 num_elements = input_t.NumElements(); + auto input_flat = input_t.flat(); + auto elements_size_flat = elements_size_t.flat(); + for (int64 i = 0; i < num_elements; i++) { + elements_size_flat(i) = input_flat(i).size(); + } + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_elements_size", + frame_iter, &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, elements_size_t, + ctx->is_input_dead())); + + // send data. + args.alloc_attrs = ctx->input_alloc_attr(0); + Tensor data_t; + for (int64 i = 0; i < num_elements; i++) { + const std::string& elem = input_flat(i); + if (elem.size() <= slice_size_) { + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), + &data_t)); + data_t.scalar()() = elem; + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, + ctx->is_input_dead())); + } else { + TF_RETURN_IF_ERROR(SendStringSlice(ctx, frame_iter, elem, i)); + } + } + + return Status::OK(); +} + +Status SliceSendOp::SendStringSlice(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const std::string& elem, int64 index) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + Rendezvous::ParsedKey parsed_key; + + int64 slice_num = (elem.size() + slice_size_ - 1) / slice_size_; + Tensor data_t; + for (int64 i = 0; i < slice_num; i++) { + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &data_t)); + size_t start = i * slice_size_; + size_t copy_size = slice_size_; + if (start > elem.size() - slice_size_) { + copy_size = elem.size() - start; + } + data_t.scalar()() = elem.substr(start, copy_size); + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(index), "_", + std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, + ctx->is_input_dead())); + } + + return Status::OK(); +} + +Status SliceSendOp::SendBasicType(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const Tensor& input_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + Rendezvous::ParsedKey parsed_key; + + // send data. + Tensor data_t; + int64 bytes_num = input_t.TotalBytes(); + int64 slice_num = (bytes_num + slice_size_ - 1) / slice_size_; + unsigned char* input_base = reinterpret_cast(input_t.data()); + for (int64 i = 0; i < slice_num; i++) { + int64 start = i * slice_size_; + int64 copy_size = slice_size_; + if (start > bytes_num - slice_size_) { + copy_size = bytes_num - start; + } + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT8, TensorShape({copy_size}), + &data_t)); + auto data_base = data_t.data(); + std::memcpy(data_base, input_base+start, copy_size); + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, + ctx->is_input_dead())); + } + + return Status::OK(); +} + +REGISTER_KERNEL_BUILDER(Name("_SliceSend").Device(DEVICE_CPU), SliceSendOp); +REGISTER_KERNEL_BUILDER(Name("_SliceSend").Device(DEVICE_DEFAULT), SliceSendOp); + +//------------------------------------------------------------------------------ +// Functions of SliceRecvOp. + +SliceRecvOp::SliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = \ + GetSliceRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); + if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { + hostmem_sendrecv_ = false; + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_size", &slice_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_type", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("timeout_ms", &timeout_ms_)); +} + +void SliceRecvOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + + FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); + bool is_dead; + + // recv total_bytes. + int64 total_bytes; + OP_REQUIRES_OK(ctx, RecvTotalBytes(ctx, frame_iter, is_dead, total_bytes)); + if (is_dead) { + return; + } + + // if total bytes is smaller than slice size, recv directly. + if (total_bytes <= slice_size_) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + GetSliceRendezvousKey(key_prefix_, "_transfer_data", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + Tensor data_t; + OP_REQUIRES_OK(ctx, ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + ctx->set_output(0, data_t); + return; + } + + // recv shape. + TensorShape shape; + OP_REQUIRES_OK(ctx, RecvShape(ctx, frame_iter, shape)); + + // recv data + Tensor* output_t = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output_t)); + if (dtype_ == DT_STRING) { + OP_REQUIRES_OK(ctx, RecvString(ctx, frame_iter, shape, output_t)); + } else { + OP_REQUIRES_OK(ctx, RecvBasicType(ctx, frame_iter, total_bytes, output_t)); + } +} + +Status SliceRecvOp::RecvTotalBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + bool& is_dead, int64& total_bytes) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + Tensor total_bytes_t; + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_totalbytes", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &total_bytes_t, + &is_dead, timeout_ms_)); + if (!is_dead) { + total_bytes = total_bytes_t.scalar()(); + } + + return Status::OK(); +} + +Status SliceRecvOp::RecvShape(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + TensorShape& shape) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + + Tensor shape_t; + bool is_dead; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &shape_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + auto shape_vec = shape_t.vec(); + const int64 num_elements = shape_t.NumElements(); + for (int64 i = 0; i < num_elements; i++) { + shape.AddDim(shape_vec(i)); + } + + return Status::OK(); +} + +Status SliceRecvOp::RecvString(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const TensorShape& shape, Tensor*& output_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + Rendezvous::ParsedKey parsed_key; + bool is_dead; + + // recv elements size. + GetSliceRendezvousKey(key_prefix_, "_slice_transfer_elements_size", + frame_iter, &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + Tensor elements_size_t; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &elements_size_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + auto elements_size_flat = elements_size_t.flat(); + int64 num_elements = shape.num_elements(); + args.alloc_attrs = ctx->output_alloc_attr(0); + Tensor data_t; + auto output_flat = output_t->flat(); + for (int64 i = 0; i < num_elements; i++) { + if (elements_size_flat(i) <= slice_size_) { + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + output_flat(i) = data_t.scalar()(); + } else { + TF_RETURN_IF_ERROR(RecvStringSlice(ctx, frame_iter, i, + elements_size_flat(i), output_flat)); + } + } + + return Status::OK(); +} + +Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const int64 index, const int64 element_size, + TTypes::Flat& output_flat) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + Rendezvous::ParsedKey parsed_key; + + int64 slice_num = (element_size + slice_size_ - 1) / slice_size_; + Tensor data_t; + bool is_dead = false; + for (int64 i = 0; i < slice_num; i++) { + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(index), "_", + std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + output_flat(index) += data_t.scalar()(); + } + + return Status::OK(); +} + +Status SliceRecvOp::RecvBasicType(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const int64 total_bytes, + Tensor*& output_t) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + Rendezvous::ParsedKey parsed_key; + + Tensor data_t; + bool is_dead = false; + int64 slice_num = (total_bytes + slice_size_ - 1) / slice_size_; + unsigned char* output_base = \ + reinterpret_cast(output_t->data()); + for (int64 i = 0; i < slice_num; i++) { + int64 start = i * slice_size_; + int64 copy_size = slice_size_; + if (start > total_bytes - slice_size_) { + copy_size = total_bytes - start; + } + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(i)); + GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, + &parsed_key.buf_); + VLOG(2) << "SliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + auto data_base = data_t.data(); + std::memcpy(output_base+start, data_base, copy_size); + } + + return Status::OK(); + +} + +REGISTER_KERNEL_BUILDER(Name("_SliceRecv").Device(DEVICE_CPU), SliceRecvOp); +REGISTER_KERNEL_BUILDER(Name("_SliceRecv").Device(DEVICE_DEFAULT), SliceRecvOp); + +} // End of namespace tensorflow diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.h b/tensorflow/core/kernels/slice_sendrecv_ops.h new file mode 100644 index 00000000000..df55c080aa1 --- /dev/null +++ b/tensorflow/core/kernels/slice_sendrecv_ops.h @@ -0,0 +1,89 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class SliceSendOp : public OpKernel { + public: + explicit SliceSendOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + // Variables. + string key_prefix_; + bool hostmem_sendrecv_; + int32 slice_size_; + DataType dtype_; + + // Functions. + Status SendTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const Tensor& input_t); + + Status SendShape(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const Tensor& input_t); + Status SendString(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const Tensor& input_t); + + Status SendStringSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const std::string& elem, int64 index); + + Status SendBasicType(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const Tensor& input_t); + + TF_DISALLOW_COPY_AND_ASSIGN(SliceSendOp); +}; + +class SliceRecvOp : public OpKernel { + public: + explicit SliceRecvOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + // Variable. + string key_prefix_; + bool hostmem_sendrecv_; + int32 slice_size_; + int64 timeout_ms_; + DataType dtype_; + + // Fucntions. + Status RecvTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + bool& is_dead, int64& total_bytes); + + Status RecvShape(OpKernelContext* ctx, const FrameAndIter& frame_iter, + TensorShape& shape); + + Status RecvString(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const TensorShape& shape, Tensor*& output_t); + + Status RecvStringSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const int64 index, const int64 element_size, + TTypes::Flat& output_flat); + + Status RecvBasicType(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const int64 total_bytes, Tensor*& output_t); + + TF_DISALLOW_COPY_AND_ASSIGN(SliceRecvOp); +}; + +} // End of namespace tensorflow + +#endif // End of TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_OPS_H_ diff --git a/tensorflow/core/kernels/slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc new file mode 100644 index 00000000000..5693ed57918 --- /dev/null +++ b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc @@ -0,0 +1,339 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +namespace { +// Implement a trivial version of the Rendezvous interface, to avoid +// clouding the benchmark results with the time spent in the various +// implementations, and to avoid the duplicate-send or duplicate-recv +// errors that would arise from running either benchmark in a loop. +class DummyRendezvous : public Rendezvous { + // Functions. + Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + const bool is_dead) override { + std::string key_str = { key.FullKey().data(), key.FullKey().size() }; + mutex_lock l(mu_); + // consumer does not reach. + if (kv_.count(key_str) == 0) { + struct Var var; + var.type = send; + var.args = args; + var.data = val; + var.is_dead = is_dead; + + kv_[key_str] = var; + return Status::OK(); + } + + auto var = kv_[key_str]; + CHECK_EQ(var.type, recv); + var.done(Status::OK(), args, var.args, val, is_dead); + kv_.erase(key_str); + return Status::OK(); + } + void RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) override { + std::string key_str = { key.FullKey().data(), key.FullKey().size() }; + + mutex_lock l(mu_); + // producer does not reach. + if (kv_.count(key_str) == 0) { + struct Var var; + var.type = recv; + var.args = args; + var.done = done; + + kv_[key_str] = var; + return; + } + + // auto var = kv_[key_str]; + auto var = kv_[key_str]; + CHECK_EQ(var.type, send); + done(Status::OK(), var.args, args, var.data, var.is_dead); + kv_.erase(key_str); + } + void StartAbort(const Status& status) override {} + + private: + enum RendezvousType { + send, + recv + }; + // Type define. + struct Var { + RendezvousType type; + Args args; + Tensor data; + bool is_dead; + DoneCallback done; + }; + + // Variables. + mutex mu_; + std::unordered_map kv_ GUARDED_BY(mu_); +}; + +Node* SliceSend(Graph* g, Node* input, const string& tensor, + const string& sender, const uint64 sender_incarnation, + const string& receiver, const int32 slice_size) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_SliceSend") + .Input(input, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("slice_size", slice_size) + .Finalize(g, &ret)); + return ret; +} + +Node* SliceRecv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver, const int32 slice_size, + const int64 timeout_ms) { + Node* ret; + DataType dtype; + CHECK(DataTypeFromString(type, &dtype)); + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_SliceRecv") + .Attr("tensor_type", dtype) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("slice_size", slice_size) + .Attr("timeout_ms", timeout_ms) + .Finalize(g, &ret)); + return ret; +} + +Node* Equal(Graph* g, Node* x, Node* y) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Equal") + .Input(x) + .Input(y) + .Finalize(g, &ret)); + return ret; +} + +Node* ReduceAll(Graph* g, Node* input, Node* axes) { + return test::graph::Reduce(g, "All", input, axes); +} + +Node* Assert(Graph* g, Node* condition, + std::vector& data) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assert") + .Input(condition) + .Input(data) + .Finalize(g, &ret)); + return ret; +} + +static Graph* TransferStringTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + std::string str = "The quick brown fox jumps over the lazy dog."; // 44 chars. + + Tensor input_t(DT_STRING, TensorShape({2, 4})); + input_t.flat().setConstant(str); // total bytes: 44*8=352 bytes. + Node* input_n = test::graph::Constant(g, input_t); + SliceSend(g, input_n, "T", "/cpu:0", 1, "/cpu:0", slice_size); + Node* recv_n = \ + SliceRecv(g, "T", "string", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + Node* equal_n = Equal(g, input_n, recv_n); + + Tensor axes_t(DT_INT32, TensorShape({input_t.dims()})); + auto axes_flat = axes_t.flat(); + for (int i = 0; i < input_t.dims(); i++) { + axes_flat(i) = i; + } + Node* reduce_all_n = ReduceAll(g, equal_n, test::graph::Constant(g, axes_t)); + + std::vector data_out; + data_out.emplace_back(input_n, 0); + data_out.emplace_back(recv_n, 0); + Assert(g, reduce_all_n, data_out); + + return g; +} + +static Graph* TransferBasicTypeTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + + Tensor input_t(DT_FLOAT, TensorShape({2, 8})); + input_t.flat().setConstant(2); // total bytes = 4*2*8=64 bytes. + Node* input_n = test::graph::Constant(g, input_t); + SliceSend(g, input_n, "T", "/cpu:0", 1, "/cpu:0", slice_size); + Node* recv_n = \ + SliceRecv(g, "T", "float32", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + Node* equal_n = Equal(g, input_n, recv_n); + + Tensor axes_t(DT_INT32, TensorShape({input_t.dims()})); + auto axes_flat = axes_t.flat(); + for (int i = 0; i < input_t.dims(); i++) { + axes_flat(i) = i; + } + Node* reduce_all_n = ReduceAll(g, equal_n, test::graph::Constant(g, axes_t)); + + std::vector data_out; + data_out.emplace_back(input_n, 0); + data_out.emplace_back(recv_n, 0); + Assert(g, reduce_all_n, data_out); + + return g; +} + +static Graph* TransferBigStringTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 16; + const int64 timeout_ms = 5000; + std::string str = "The quick brown fox jumps over the lazy dog."; // 44 chars. + + Tensor input_t(DT_STRING, TensorShape({2, 4})); + input_t.flat().setConstant(str); + input_t.flat()(0) = "short str"; + Node* input_n = \ + test::graph::Constant(g, input_t); // total bytes: 44*7+9=317 bytes. + SliceSend(g, input_n, "T", "/cpu:0", 1, "/cpu:0", slice_size); + Node* recv_n = \ + SliceRecv(g, "T", "string", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + Node* equal_n = Equal(g, input_n, recv_n); + + Tensor axes_t(DT_INT32, TensorShape({input_t.dims()})); + auto axes_flat = axes_t.flat(); + for (int i = 0; i < input_t.dims(); i++) { + axes_flat(i) = i; + } + Node* reduce_all_n = ReduceAll(g, equal_n, test::graph::Constant(g, axes_t)); + + std::vector data_out; + data_out.emplace_back(input_n, 0); + data_out.emplace_back(recv_n, 0); + Assert(g, reduce_all_n, data_out); + + return g; +} + +static Graph* TransferBigBasicTypeTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 16; + const int64 timeout_ms = 5000; + + Tensor input_t(DT_FLOAT, TensorShape({2, 8})); + input_t.flat().setConstant(2); // total bytes: 4*2*8=64 + Node* input_n = test::graph::Constant(g, input_t); + SliceSend(g, input_n, "T", "/cpu:0", 1, "/cpu:0", slice_size); + Node* recv_n = \ + SliceRecv(g, "T", "float32", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + Node* equal_n = Equal(g, input_n, recv_n); + + Tensor axes_t(DT_INT32, TensorShape({input_t.dims()})); + auto axes_flat = axes_t.flat(); + for (int i = 0; i < input_t.dims(); i++) { + axes_flat(i) = i; + } + Node* reduce_all_n = ReduceAll(g, equal_n, test::graph::Constant(g, axes_t)); + + std::vector data_out; + data_out.emplace_back(input_n, 0); + data_out.emplace_back(recv_n, 0); + Assert(g, reduce_all_n, data_out); + + return g; +} + +static Graph* TransferDeadTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + + // val + Tensor val_t(DT_FLOAT, TensorShape({})); + val_t.scalar()() = 2; + Node* val_n = test::graph::Constant(g, val_t); + + Tensor pred_t(DT_BOOL, TensorShape({})); + pred_t.scalar()() = true; + Node* pred_n = test::graph::Constant(g, pred_t); + + Node* switch_n = test::graph::Switch(g, val_n, pred_n); + SliceSend(g, switch_n, "T", "/cpu:0", 1, "/cpu:0", slice_size); + SliceRecv(g, "T", "float32", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + return g; +} + +static void BM_TransferStringTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferStringTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_TransferBasicTypeTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferBasicTypeTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_TransferBigStringTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferBigStringTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_TransferBigBasicTypeTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferBigBasicTypeTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_TransferDeadTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferDeadTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +BENCHMARK(BM_TransferStringTensor); +BENCHMARK(BM_TransferBasicTypeTensor); +BENCHMARK(BM_TransferBigStringTensor); +BENCHMARK(BM_TransferBigBasicTypeTensor); +BENCHMARK(BM_TransferDeadTensor); + +} // End of anonymous namespace + +} // End of namespace tensorflow diff --git a/tensorflow/core/ops/slice_sendrecv_ops.cc b/tensorflow/core/ops/slice_sendrecv_ops.cc new file mode 100644 index 00000000000..11905712410 --- /dev/null +++ b/tensorflow/core/ops/slice_sendrecv_ops.cc @@ -0,0 +1,78 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +namespace tensorflow { + +REGISTER_OP("_SliceSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Attr("slice_size: int >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor from send_device to recv_device. +Supports sending the tensor of any size. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +slice_size: The maximum number of bytes transferred at one time. +)doc"); + +REGISTER_OP("_SliceRecv") + .Output("tensor: tensor_type") + .Attr("tensor_type: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Attr("slice_size: int >= 1") + .Attr("timeout_ms: int >= 0 = 300000") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Receives the named tensor from send_device on recv_device. +Supports recving the tensor of any size. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +slice_size: The maximum number of bytes transferred at one time. +timeout_ms: The maximum wait time for receiving a tensor. +)doc"); + +} // End of namespace tensorflow