Skip to content

Commit

Permalink
deal with conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanRisheng committed Apr 7, 2022
2 parents dfc61f7 + c31386e commit 420d9f1
Show file tree
Hide file tree
Showing 94 changed files with 3,292 additions and 943 deletions.
7 changes: 0 additions & 7 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ IF(WITH_XPU)
DSTS ${dst_dir} ${dst_dir})
ENDIF()

IF(WITH_IPU)
set(dst_dir "${PADDLE_INFERENCE_INSTALL_DIR}/third_party/install/ipu")
copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/platform/device/ipu/libpaddle_ipu.so
DSTS ${dst_dir})
ENDIF()

# CMakeCache Info
copy(inference_lib_dist
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ class {} : public egr::GradNodeBase {{
// Get Input AutoGradMeta
{}
// Set Device Id
auto place = egr::Controller::Instance().GetExpectedPlace();
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
// Forward API Call
{}
// Get Outputs
Expand Down Expand Up @@ -284,6 +294,7 @@ class {} : public egr::GradNodeBase {{
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
{}
{}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ if(WITH_IPU)
pass_library(ipu_runtime_replacer_pass base DIR ipu)
pass_library(inference_process_pass base DIR ipu)
pass_library(inference_postprocess_pass base DIR ipu)
pass_library(popart_canonicalization_pass base DIR ipu DEPS paddle_ipu)
pass_library(popart_canonicalization_pass base DIR ipu)
pass_library(ipu_inplace_pass base DIR ipu)
pass_library(infer_shape_pass base DIR ipu)
pass_library(delete_scale_op_pass base DIR ipu)
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ endif()
#TODO(wilber, T8T9): Do we still need to support windows gpu static library?
if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules})
elseif(WITH_IPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules} paddle_ipu)
else()
create_static_lib(paddle_inference ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules})
endif()
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif

#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/device/ipu/paddle_ipu_handler.h"
#endif

namespace paddle {

using inference::Singleton;
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
const auto x_shape = x_var_desc->GetShape();
const auto y_shape = y_var_desc->GetShape();
if (op_type == "elementwise_add" && y_var_desc->Persistable()) {
if (y_shape.size() != 1) {
return false;
}
if (y_shape[0] != x_shape[1]) {
return false;
}
}
if (x_shape.size() == 1 && y_shape.size() == 1) {
VLOG(3) << "Now trt may not support two 1d tensor elementwise op.";
return false;
Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename T>
struct Mul {
__device__ T operator()(const T &a, const T &b) const { return a * b; }
};

template <typename T>
struct Div {
__device__ T operator()(const T &a, const T &b) const { return a / b; }
};
} // namespace details

template <typename T, typename Operator>
Expand Down Expand Up @@ -130,6 +135,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
details::Mul<float>());
} else if (type_ == "div") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
details::Div<float>());
} else {
PADDLE_THROW(platform::errors::Fatal(
"The %s type elementwise is not implemented in trt plugin.", type_));
Expand Down Expand Up @@ -242,11 +251,15 @@ int ElementwisePluginDynamic::enqueue(
} else if (type_ == "mul") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size, midd_size, post_size, details::Mul<float>());
} else if (type_ == "div") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size, midd_size, post_size, details::Div<float>());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Paddle-TRT only support elementwise operation: {add, mul} currently, "
"but got %s.",
type_));
PADDLE_THROW(
platform::errors::Unimplemented("Paddle-TRT only support elementwise "
"operation: {add, mul, div} currently, "
"but got %s.",
type_));
}

return cudaGetLastError() != cudaSuccess;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/inplace_abn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,12 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {

namespace ops = paddle::operators;

DECLARE_INPLACE_OP_INFERER(InplaceAbnOpInplaceInferer, {"X", "Y"});
REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker,
ops::BatchNormOpInferVarType,
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
ops::InplaceABNOpGradMaker<paddle::imperative::OpBase>)
ops::InplaceABNOpGradMaker<paddle::imperative::OpBase>,
InplaceAbnOpInplaceInferer)
REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp)

REGISTER_OP_CPU_KERNEL(
Expand Down
31 changes: 22 additions & 9 deletions paddle/fluid/platform/device/ipu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
IF(WITH_IPU)
FILE(GLOB POPART_CANONICALIZATION_SRC ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/device/ipu/popart_canonicalization/*.cc)
list(APPEND PADDLE_IPU_SRC ${POPART_CANONICALIZATION_SRC})
if(WITH_IPU)
set(paddle_ipu_handler ${CMAKE_CURRENT_BINARY_DIR}/paddle_ipu_handler.h.tmp)
set(paddle_ipu_handler_final ${CMAKE_CURRENT_BINARY_DIR}/paddle_ipu_handler.h)
file(WRITE ${paddle_ipu_handler} "// Auto generated from CMake. DO NOT EDIT!\n\n")
file(APPEND ${paddle_ipu_handler} "\#pragma once\n")
file(APPEND ${paddle_ipu_handler} "\#include \"paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h\"\n\n")
file(GLOB POPART_CANONICALIZATION_SRC ${CMAKE_CURRENT_SOURCE_DIR}/popart_canonicalization/*.cc)
copy_if_different(${paddle_ipu_handler} ${paddle_ipu_handler_final})

foreach(file_path ${POPART_CANONICALIZATION_SRC})
file(READ ${file_path} file_content)
string(REGEX MATCHALL "(REGISTER_HANDLER)(\\()([A-Za-z0-9_]+)(,)" op_handlers ${file_content})
string(REPLACE "REGISTER_HANDLER(" "" op_handlers "${op_handlers}")
string(REPLACE "," "" op_handlers "${op_handlers}")
foreach(op_handler ${op_handlers})
file(APPEND ${paddle_ipu_handler} "USE_HANDLER(${op_handler});\n")
endforeach()
endforeach()

set(IPU_BACKEND_SRC
"ipu_strategy.cc"
"ipu_executor.cc"
Expand All @@ -13,10 +29,7 @@ IF(WITH_IPU)
"ipu_device.cc"
)

cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper popdist)
cc_library(popart_canonicalization SRCS ${POPART_CANONICALIZATION_SRC} DEPS graph)
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper popdist popart_canonicalization)
cc_library(ipu_info SRCS ${IPU_INFO_SRC} DEPS popart-only enforce)
add_library(paddle_ipu SHARED ${PADDLE_IPU_SRC})
add_dependencies(paddle_ipu ipu_backend)
set(PADDLE_IPU_LIB "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_ipu.so" CACHE STRING "")
set(PADDLE_IPU_LIB_DIR "${CMAKE_CURRENT_BINARY_DIR}" CACHE STRING "")
ENDIF()
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ Node *log_softmax_handler(Graph *graph, Node *node) {
node->outputs);
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(relu, relu_handler);
REGISTER_HANDLER(tanh, tanh_handler);
REGISTER_HANDLER(log, log_handler);
REGISTER_HANDLER(sigmoid, sigmoid_handler);
REGISTER_HANDLER(sqrt, sqrt_handler);
REGISTER_HANDLER(gelu, gelu_handler);
REGISTER_HANDLER(log_softmax, log_softmax_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,36 @@ namespace paddle {
namespace platform {
namespace ipu {

#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::platform::ipu::RegisterHandler(#name, func)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

#define REGISTER_HANDLER(op_type, handler) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_ipu_op_handler__##op_type, \
"REGISTER_HANDLER must be called in global namespace"); \
struct __PaddleRegisterIpuOpHandler_##op_type { \
__PaddleRegisterIpuOpHandler_##op_type() { \
::paddle::platform::ipu::RegisterHandler( \
#op_type, paddle::platform::ipu::handler); \
} \
int Touch() const { return 0; } \
}; \
static __PaddleRegisterIpuOpHandler_##op_type \
__PaddleRegisterIpuOpHandler_instance##op_type; \
int TouchPaddleIpuOpHandlerRegister_##op_type() { \
return __PaddleRegisterIpuOpHandler_instance##op_type.Touch(); \
}

#define USE_HANDLER(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_ipu_op_handler__##op_type, \
"USE_HANDLER must be called in global namespace"); \
extern int TouchPaddleIpuOpHandlerRegister_##op_type(); \
UNUSED static int use_handler__itself_##op_type##_ = \
TouchPaddleIpuOpHandlerRegister_##op_type()

using SymbolHandler = std::function<Node *(Graph *, Node *)>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ Node *elementwise_mod_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_mod");
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(elementwise_sub, elementwise_sub_handler);
REGISTER_HANDLER(elementwise_div, elementwise_div_handler);
Expand All @@ -101,8 +106,3 @@ REGISTER_HANDLER(elementwise_min, elementwise_min_handler);
REGISTER_HANDLER(elementwise_max, elementwise_max_handler);
REGISTER_HANDLER(elementwise_pow, elementwise_pow_handler);
REGISTER_HANDLER(elementwise_mod, elementwise_mod_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ Node *less_than_handler(Graph *graph, Node *node) {
{GetOutputVarNode("Out", node)}, {});
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(equal, equal_handler);
REGISTER_HANDLER(logical_not, logical_not_handler);
REGISTER_HANDLER(logical_or, logical_or_handler);
REGISTER_HANDLER(logical_and, logical_and_handler);
REGISTER_HANDLER(greater_than, greater_than_handler);
REGISTER_HANDLER(less_than, less_than_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ Node *arg_max_handler(Graph *graph, Node *node) {
{{"axis", axis}, {"keepdims", int64_t{0}}});
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(mean, mean_handler);
REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler);
Expand All @@ -377,8 +382,3 @@ REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler);
REGISTER_HANDLER(arg_max, arg_max_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ Node *dropout_handler(Graph *graph, Node *node) {
}
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(pool2d, pool2d_handler);
REGISTER_HANDLER(batch_norm, batch_norm_handler);
REGISTER_HANDLER(group_norm, group_norm_handler);
REGISTER_HANDLER(instance_norm, instance_norm_handler);
REGISTER_HANDLER(layer_norm, layer_norm_handler);
REGISTER_HANDLER(conv2d, conv2d_handler);
REGISTER_HANDLER(dropout, dropout_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ Node *detach_handler(Graph *graph, Node *node) {
node->outputs);
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(custom_op, custom_op_handler);
REGISTER_HANDLER(print, print_handler);
REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);
REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler);
REGISTER_HANDLER(identity, identity_handler);
REGISTER_HANDLER(detach, detach_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ Node *reduce_prod_handler(Graph *graph, Node *node) {
return reduce_op_handler(graph, node, "popart_reduceprod");
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
REGISTER_HANDLER(reduce_min, reduce_min_handler);
REGISTER_HANDLER(reduce_sum, reduce_sum_handler);
REGISTER_HANDLER(reduce_max, reduce_max_handler);
REGISTER_HANDLER(reduce_prod, reduce_prod_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ Node *topk_handler(Graph *graph, Node *node) {
static_cast<int>(framework::proto::VarType::INT32));
}

REGISTER_HANDLER(top_k, topk_handler);
REGISTER_HANDLER(top_k_v2, topk_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(top_k, topk_handler);
REGISTER_HANDLER(top_k_v2, topk_handler);
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ Node *split_handler(Graph *graph, Node *node) {
{"split", std::vector<int64_t>{sections.begin(), sections.end()}}});
}

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle

REGISTER_HANDLER(fill_constant, fill_constant_handler);
REGISTER_HANDLER(gaussian_random, gaussian_random_handler);
REGISTER_HANDLER(uniform_random, uniform_random_handler);
Expand All @@ -593,8 +598,3 @@ REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler);
REGISTER_HANDLER(split, split_handler);
REGISTER_HANDLER(one_hot, one_hot_handler);
REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler);

} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Loading

1 comment on commit 420d9f1

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.