Skip to content

Commit

Permalink
[PHI] transpose2_grad op migration (PaddlePaddle#46139)
Browse files Browse the repository at this point in the history
* op migrated, Copy(OneDNNContext, ...) added

* mutable_data & op registration in fluid removed

* refactoring

* OneDNNGetDataType to uppercase

* missing cpu check added, handler moved to .h file

* name changed to transpose_grad

* Copy changed back to TensorCopy

* Resizing corrected, Copy(OneDNNContext) removed
  • Loading branch information
Silv3S committed Oct 11, 2022
2 parents b0f3b30 + 2190da2 commit 60b60e9
Show file tree
Hide file tree
Showing 57 changed files with 2,782 additions and 2,552 deletions.
49 changes: 17 additions & 32 deletions cmake/phi.cmake
Expand Up @@ -78,7 +78,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
if(NOT first_registry STREQUAL "")
Expand All @@ -89,38 +89,23 @@ function(kernel_declare TARGET_LIST)
continue()
endif()
endif()
# parse the first kernel name
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name
"${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${first_registry}")
string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_msg
"${kernel_msg}")
string(REPLACE "," ";" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "[ \\\t\r\n]+" "" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")

list(GET kernel_msg 0 kernel_name)
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)

# append kernel declare into declarations.h
# TODO(chenweihang): default declare ALL_LAYOUT for each kernel
if(${kernel_path} MATCHES "./cpu\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./gpu\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./xpu\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./gpudnn\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./kps\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, KPS, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./onednn\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, OneDNN, ALL_LAYOUT);\n")
else()
# deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
endif()
file(
APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n"
)
endif()
endforeach()
endfunction()
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Expand Up @@ -1244,6 +1244,16 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
PrepareArgument();

#ifdef PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled()) {
inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
predictor_id_;
VLOG(3) << "thread_local var predictor_id in TensorRTEngine is set to: "
<< inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
}
#endif

Analyzer().Run(&argument_);

PADDLE_ENFORCE_EQ(
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/inference/tensorrt/engine.cc
Expand Up @@ -646,9 +646,8 @@ void TensorRTEngine::GetEngineInfo() {
LOG(INFO) << "====== engine info ======";
std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector(
infer_engine_->createEngineInspector());
auto infer_context = infer_ptr<nvinfer1::IExecutionContext>(
infer_engine_->createExecutionContextWithoutDeviceMemory());
infer_inspector->setExecutionContext(infer_context.get());
auto infer_context = context();
infer_inspector->setExecutionContext(infer_context);
LOG(INFO) << infer_inspector->getEngineInformation(
nvinfer1::LayerInformationFormat::kONELINE);
LOG(INFO) << "====== engine info end ======";
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Expand Up @@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}

auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);

// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather op Index input data type must be int32";
return false;
}
#if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
Expand Down
92 changes: 47 additions & 45 deletions paddle/fluid/operators/fake_quantize_op.cu.h
Expand Up @@ -590,33 +590,29 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
const T *scale,
const int bin_cnt,
const int round_type,
const int n,
const int c,
const int wh_size,
const int num,
const int cout,
T *out) {
int tid = threadIdx.x;

int channel_size = n / c;
const T *in_c = in + blockIdx.x * channel_size;
T *out_c = out + blockIdx.x * channel_size;

T s = scale[blockIdx.x];
T inv_s = inverse(s);
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / wh_size) % cout];
T inv_s = inverse(s);
T x = in[i];
if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x);
T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt;
out[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
out[i] = round(v) * s / bin_cnt;
}
}
}
Expand All @@ -627,32 +623,29 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
const T *scale,
const int bin_cnt,
const int round_type,
const int n,
const int cin,
const int wh_size,
const int num,
const int cout,
T *out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);

int wh_size = n / (cin * cout);
const T *in_c = in + blockIdx.x * wh_size;
T *out_c = out + blockIdx.x * wh_size;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / wh_size) % cout];
T inv_s = inverse(s);
T x = in[i];
if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x);
T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt;
out[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
out[i] = round(v) * s / bin_cnt;
}
}
}
Expand Down Expand Up @@ -682,30 +675,39 @@ struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, T> {
const T *scale_data = scale.data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace());

int64_t block_size =
std::min(static_cast<int64_t>(num),
static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));

int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);

if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
const int window_size = num / in_dims[0];
ChannelClipAndQuantDequantKernelQuantAxis0<T>
<<<grid, block, 0, ctx.stream()>>>(in_data,
scale_data,
bin_cnt,
round_type,
num,
in_dims[0],
out_data);
<<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
scale_data,
bin_cnt,
round_type,
window_size,
num,
in_dims[0],
out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
const int window_size = num / (in_dims[0] * in_dims[1]);

ChannelClipAndQuantDequantKernelQuantAxis1<T>
<<<grid, block, 0, ctx.stream()>>>(in_data,
scale_data,
bin_cnt,
round_type,
num,
in_dims[0],
in_dims[1],
out_data);
<<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
scale_data,
bin_cnt,
round_type,
window_size,
num,
in_dims[1],
out_data);
}
}
};
Expand Down
59 changes: 0 additions & 59 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Expand Up @@ -42,62 +42,13 @@ class MKLDNNActivationKernel
}
};

template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
Functor functor;
functor(ctx);
}
};

template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<Tensor>("X");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, dout);

auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();

auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();

dx->set_mem_desc(diff_src_memory_p->get_desc());
}

template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};

template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
custom_softplus_eltwise_forward<T>(ctx);
}
};

template <typename T>
using Relu6MKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;

} // namespace operators
} // namespace paddle

Expand All @@ -111,14 +62,4 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>, \
ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>);

#define REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(act_type, grad_functor) \
REGISTER_OP_KERNEL( \
act_type##_grad, \
MKLDNN, \
::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>, \
ops::MKLDNNActivationGradKernel< \
ops::grad_functor<paddle::platform::bfloat16>>);

REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor);

0 comments on commit 60b60e9

Please sign in to comment.