diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index ec2375b6a8e2e..2b19f0448955d 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_DISTRIBUTE) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - grpc_server.cc tensor_parser.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(serde_test SRCS test_serde.cc tensor_parser.cc DEPS grpc++_unsecure grpc_unsecure gpr + cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) endif() diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b88883d7296bb..9691d1e86b111 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -57,7 +57,7 @@ class RequestSend final : public RequestBase { framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new TensorResponse(scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -76,7 +76,7 @@ class RequestSend final : public RequestBase { } protected: - std::shared_ptr request_; + std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; }; diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index f79ea26752a8c..9c21a07432031 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -35,7 +35,8 @@ namespace paddle { namespace operators { namespace detail { -typedef std::pair> ReceivedMessage; +typedef std::pair> + ReceivedMessage; typedef SimpleBlockQueue ReceivedQueue; typedef std::pair MessageWithName; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index 27bafd4183f31..ae6f9db3bd31a 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -23,7 +23,7 @@ #include #include #include -#include "paddle/fluid/operators/detail/tensor_parser.h" +#include "paddle/fluid/operators/detail/variable_response.h" // NOTE: This method was originally created by tensorflow // (https://github.com/tensorflow/tensorflow/) we borrow this @@ -37,18 +37,19 @@ class RpcService; class ServerCompletionQueue; class ServerContext; -// Support parsing/unparsing of tensorflow::TensorResponse. -// Wire-format is identical to RecvTensorResponse. +// Support parsing/unparsing of tensorflow::VariableResponse. +// Wire-format is identical to RecvVariableResponse. template <> -class SerializationTraits { +class SerializationTraits { public: - static Status Serialize(const paddle::operators::detail::TensorResponse& msg, - grpc_byte_buffer** bp, bool* own_buffer) { + static Status Serialize( + const paddle::operators::detail::VariableResponse& msg, + grpc_byte_buffer** bp, bool* own_buffer) { PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); return Status(); } static Status Deserialize(grpc_byte_buffer* buffer, - paddle::operators::detail::TensorResponse* msg, + paddle::operators::detail::VariableResponse* msg, int max_message_size = INT_MAX) { if (buffer == nullptr) { return Status(StatusCode::INTERNAL, "No payload"); @@ -59,7 +60,7 @@ class SerializationTraits { paddle::operators::detail::GrpcByteSource source(buffer); int ret = msg->Parse(&source); if (ret != 0) { - result = Status(StatusCode::INTERNAL, "TensorResponse parse error"); + result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); } } g_core_codegen_interface->grpc_byte_buffer_destroy(buffer); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 52de5d1465f53..598aaa4c51a6c 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -34,7 +34,7 @@ enum VarType { // NOTICE(gongwb):don't modify this proto if you are not // not familar with how we serialize in sendrecvop_utils.h -// and deserilize it in tensor_parser.h. +// and deserilize it in variable_response.h. message VariableMessage { enum Type { // Pod Types diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 35c9cc1d27934..d7bbf79c50651 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h" -#include "paddle/fluid/operators/detail/tensor_parser.h" +#include "paddle/fluid/operators/detail/variable_response.h" namespace paddle { namespace operators { @@ -181,7 +181,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable*& var) { - operators::detail::TensorResponse resp(scope, &ctx); + operators::detail::VariableResponse resp(scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); var = resp.GetVar(); } diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc index ce3ada14fa038..4be5963794717 100644 --- a/paddle/fluid/operators/detail/test_serde.cc +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" -#include "paddle/fluid/operators/detail/tensor_parser.h" +#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" diff --git a/paddle/fluid/operators/detail/tensor_parser.cc b/paddle/fluid/operators/detail/variable_response.cc similarity index 96% rename from paddle/fluid/operators/detail/tensor_parser.cc rename to paddle/fluid/operators/detail/variable_response.cc index 65655cc46f7bd..12e8eb0b4da22 100644 --- a/paddle/fluid/operators/detail/tensor_parser.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensor_parser.h" +#include "paddle/fluid/operators/detail/variable_response.h" #include #include "paddle/fluid/operators/detail/send_recv.pb.h" -#include "sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" namespace paddle { namespace operators { @@ -94,7 +94,7 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, return true; } -bool TensorResponse::CopyLodTensorData( +bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, framework::DDim& dims, int length) { auto var = scope_->FindVar(meta_.varname()); @@ -130,7 +130,7 @@ inline framework::DDim GetDims( return framework::make_ddim(vecdims); } -bool TensorResponse::CopySelectRowsTensorData( +bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, framework::DDim& dims, int length) { auto var = scope_->FindVar(meta_.varname()); @@ -148,7 +148,7 @@ bool TensorResponse::CopySelectRowsTensorData( return true; } -bool TensorResponse::CopySelectRowsData( +bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { auto var = scope_->FindVar(meta_.varname()); @@ -211,7 +211,7 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input, return true; } -int TensorResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { +int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { GrpcByteBufferSource source; source.Init(byte_buffer); GrpcByteBufferSourceWrapper r(&source); @@ -219,7 +219,7 @@ int TensorResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { return Parse(&r); } -int TensorResponse::Parse(Source* source) { +int VariableResponse::Parse(Source* source) { ::google::protobuf::io::ZeroCopyInputStream* input_stream = source->contents(); ::google::protobuf::io::CodedInputStream input(input_stream); diff --git a/paddle/fluid/operators/detail/tensor_parser.h b/paddle/fluid/operators/detail/variable_response.h similarity index 93% rename from paddle/fluid/operators/detail/tensor_parser.h rename to paddle/fluid/operators/detail/variable_response.h index 7ae3a3edb0470..c7bc7a46e7bc8 100644 --- a/paddle/fluid/operators/detail/tensor_parser.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -32,13 +32,13 @@ namespace paddle { namespace operators { namespace detail { -class TensorResponse { +class VariableResponse { public: - TensorResponse(const framework::Scope* scope, - const platform::DeviceContext* dev_ctx) + VariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx) : scope_(scope), dev_ctx_(dev_ctx){}; - virtual ~TensorResponse(){}; + virtual ~VariableResponse(){}; // return: // 0:ok.