Skip to content

Commit

Permalink
rename tensorparser
Browse files Browse the repository at this point in the history
  • Loading branch information
gongweibao committed Mar 21, 2018
1 parent 0d36059 commit 011c909
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc tensor_parser.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc tensor_parser.cc DEPS grpc++_unsecure grpc_unsecure gpr
cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
endif()
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RequestSend final : public RequestBase {
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new TensorResponse(scope, dev_ctx_));
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand All @@ -76,7 +76,7 @@ class RequestSend final : public RequestBase {
}

protected:
std::shared_ptr<TensorResponse> request_;
std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace paddle {
namespace operators {
namespace detail {

typedef std::pair<std::string, std::shared_ptr<TensorResponse>> ReceivedMessage;
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue;

typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/operators/detail/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/tensor_parser.h"
#include "paddle/fluid/operators/detail/variable_response.h"

// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
Expand All @@ -37,18 +37,19 @@ class RpcService;
class ServerCompletionQueue;
class ServerContext;

// Support parsing/unparsing of tensorflow::TensorResponse.
// Wire-format is identical to RecvTensorResponse.
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<paddle::operators::detail::TensorResponse> {
class SerializationTraits<paddle::operators::detail::VariableResponse> {
public:
static Status Serialize(const paddle::operators::detail::TensorResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
static Status Serialize(
const paddle::operators::detail::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status();
}
static Status Deserialize(grpc_byte_buffer* buffer,
paddle::operators::detail::TensorResponse* msg,
paddle::operators::detail::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload");
Expand All @@ -59,7 +60,7 @@ class SerializationTraits<paddle::operators::detail::TensorResponse> {
paddle::operators::detail::GrpcByteSource source(buffer);
int ret = msg->Parse(&source);
if (ret != 0) {
result = Status(StatusCode::INTERNAL, "TensorResponse parse error");
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
}
}
g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ enum VarType {

// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in tensor_parser.h.
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/tensor_parser.h"
#include "paddle/fluid/operators/detail/variable_response.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -181,7 +181,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable*& var) {
operators::detail::TensorResponse resp(scope, &ctx);
operators::detail::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
var = resp.GetVar();
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/test_serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/tensor_parser.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensor_parser.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h>
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -94,7 +94,7 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
return true;
}

bool TensorResponse::CopyLodTensorData(
bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
Expand Down Expand Up @@ -130,7 +130,7 @@ inline framework::DDim GetDims(
return framework::make_ddim(vecdims);
}

bool TensorResponse::CopySelectRowsTensorData(
bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
Expand All @@ -148,7 +148,7 @@ bool TensorResponse::CopySelectRowsTensorData(
return true;
}

bool TensorResponse::CopySelectRowsData(
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
Expand Down Expand Up @@ -211,15 +211,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
return true;
}

int TensorResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);

return Parse(&r);
}

int TensorResponse::Parse(Source* source) {
int VariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ namespace paddle {
namespace operators {
namespace detail {

class TensorResponse {
class VariableResponse {
public:
TensorResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
VariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: scope_(scope), dev_ctx_(dev_ctx){};

virtual ~TensorResponse(){};
virtual ~VariableResponse(){};

// return:
// 0:ok.
Expand Down

0 comments on commit 011c909

Please sign in to comment.