Skip to content

Commit

Permalink
Add brpc serialization support. (#11430)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongweibao committed Dec 14, 2018
1 parent 37c2e24 commit 0b1c7d8
Show file tree
Hide file tree
Showing 28 changed files with 1,422 additions and 153 deletions.
4 changes: 3 additions & 1 deletion benchmark/fluid/fluid_benchmark.py
Expand Up @@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")

config = distribute_transpiler.DistributeTranspilerConfig()
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config)

t.transpile(
trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to
Expand Down
20 changes: 12 additions & 8 deletions cmake/external/brpc.cmake
Expand Up @@ -14,14 +14,16 @@

INCLUDE(ExternalProject)

find_library(SSL_LIBRARY NAMES ssl)
find_package(OpenSSL REQUIRED)

message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})

ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY})
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})

find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY})

SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})

SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
Expand All @@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})

# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib")
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")

# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a"
GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
Expand All @@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
Expand All @@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)

add_definitions(-DBRPC_WITH_GLOG)

LIST(APPEND external_project_dependencies brpc)
10 changes: 7 additions & 3 deletions cmake/external/gtest.cmake
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

IF(WITH_TESTING)
ENABLE_TESTING()
#FIXME:(gongwb) Move brpc's gtest dependency.
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)

INCLUDE(ExternalProject)

SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
Expand Down Expand Up @@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest)

LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING)
ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
4 changes: 2 additions & 2 deletions cmake/external/leveldb.cmake
Expand Up @@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
GIT_REPOSITORY "https://github.com/google/leveldb"
GIT_TAG v1.18
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/framework/CMakeLists.txt
Expand Up @@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)

if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)

set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

else()
if(WITH_NGRAPH)
if(NOT WIN32)
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Expand Up @@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc

cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()

if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc)
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
Expand All @@ -30,7 +37,7 @@ else()
variable_visitor)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc)
ddim selected_rows_functor sendrecvop_rpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/executor.cc
Expand Up @@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(0)
->SendComplete();
auto client =
paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
client->SendComplete();
#endif
}

Expand Down
31 changes: 19 additions & 12 deletions paddle/fluid/operators/distributed/CMakeLists.txt
Expand Up @@ -12,44 +12,51 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")

if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory)

set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)

cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)

cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)

if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL)
endif()

cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)

cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)

set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)

cc_test(brpc_server_test SRCS rpc_server_test.cc
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)

cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)

if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
endif()

0 comments on commit 0b1c7d8

Please sign in to comment.