Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fc padding to improve mkl GEMM's performance when N and K are multiple of 128. #20972

Merged
merged 9 commits into from Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions paddle/fluid/framework/ir/fc_fuse_pass.cc
Expand Up @@ -89,6 +89,35 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
std::string activation_type = with_relu ? "relu" : "";
desc.SetAttr("activation_type", activation_type);

// This is to add padding for dimension 128 on concern of MKL performance
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto place = weight->place();
bool use_gpu = Get<bool>("use_gpu");
auto weight_data = weight->data<float>();
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
auto weight_dims = weight->dims();
int weight_num = product(weight_dims);
int w_h = weight_dims[0];
int w_w = weight_dims[1];
if (!use_gpu) {
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
if (w_h % 128 == 0 && w_w % 128 == 0) {
float* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) {
memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w,
w_w * sizeof(float));
}
weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
auto weight_data_new =
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
weight->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < w_h; i++) {
memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w,
w_w * sizeof(float));
}
delete[] weight_data_tmp;
desc.SetAttr("padding_weights", true);
}
}

// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
Expand Up @@ -21,6 +21,24 @@ namespace paddle {
namespace framework {
namespace ir {

void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}

Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "conv2d_filters_0", {});
AddVarToScope(param_scope, "conv2d_bias_0", {});
AddVarToScope(param_scope, "weights_0", {});
AddVarToScope(param_scope, "weights_1", {});
AddVarToScope(param_scope, "bias_1", {});
AddVarToScope(param_scope, "bias_2", {});
return param_scope;
}

TEST(FCFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
Expand Down Expand Up @@ -50,6 +68,8 @@ TEST(FCFusePass, basic) {

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
pass->Set("use_gpu", new bool(true));
graph->Set("__param_scope__", CreateParamScope());
int num_nodes_before = graph->Nodes().size();
int num_mul_nodes_before = GetNumOpNodes(graph, "mul");
VLOG(3) << DebugString(graph);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Expand Up @@ -147,6 +147,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("auto_config_layout",
new bool(argument->anakin_auto_config_layout()));
}
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
}

pre_pass = pass_name;

Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
Expand Up @@ -153,7 +153,6 @@ void profile(bool use_mkldnn = false, bool use_ngraph = false) {

if (use_mkldnn) {
config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
}

if (use_ngraph) {
Expand Down Expand Up @@ -193,7 +192,6 @@ void compare(bool use_mkldnn = false, bool use_ngraph = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}

if (use_ngraph) {
Expand Down
34 changes: 28 additions & 6 deletions paddle/fluid/operators/fc_op.cc
Expand Up @@ -32,17 +32,33 @@ class FCOp : public framework::OperatorWithKernel {

auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
bool padding_weights = ctx->Attrs().Get<bool>("padding_weights");

if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1,
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The height of Bias should be 1."
"But received height of Bias is %d.",
bias_dims[0]));
PADDLE_ENFORCE_EQ(
bias_dims[1], w_dims1,
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The width of Bias should be equal to width of Weight."
"But received width of Bias is %d and width of Weight is %d.",
bias_dims[1], w_dims1));
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(
bias_dims[0], w_dims1,
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The height of Bias should be equal to the width of weight."
"But received height of Bias is %d and width of Weight is %d.",
bias_dims[0], w_dims1));
}
}

Expand All @@ -65,7 +81,8 @@ class FCOp : public framework::OperatorWithKernel {
"in_num_col_dims.");

std::vector<int64_t> output_dims;
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims,
padding_weights);

ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
Expand Down Expand Up @@ -107,6 +124,11 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>(
"padding_weights",
"(bool, default false) When padding weights in the fc fuse pass, "
"the 'padding_weights' attribute is set as true.")
.SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
Expand Down
26 changes: 17 additions & 9 deletions paddle/fluid/operators/fc_op.h
Expand Up @@ -27,17 +27,21 @@ using Tensor = framework::Tensor;
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
int in_num_col_dims) {
int in_num_col_dims, bool padding_weights) {
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
PADDLE_ENFORCE_EQ(
in_mat_dims[1], w_dims[0],
"Fully Connected input and weigth size do not match. %s, %s");
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
PADDLE_ENFORCE_EQ(in_mat_dims[1], w_dims0,
platform::errors::InvalidArgument(
"Fully Connected input and weigth size do not match. "
"input width: %d,weight height: %d",
in_mat_dims[1], w_dims0));

out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
out_dims.push_back(in_dims[i]);
}
out_dims.push_back(w_dims[1]);
out_dims.push_back(w_dims1);
}

template <typename DeviceContext, typename T>
Expand All @@ -53,23 +57,27 @@ class FCOpKernel : public framework::OpKernel<T> {
(ctx.Attr<std::string>("activation_type") == "relu") ? true : false;

auto w_dims = w->dims();
bool padding_weights = ctx.Attr<bool>("padding_weights");

std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims,
padding_weights);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());

auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
int M = framework::product(out_dims) / w_dims1;

const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());

auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu);
fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu, padding_weights);
}
};

Expand Down
53 changes: 49 additions & 4 deletions paddle/fluid/operators/math/fc.cc
Expand Up @@ -25,10 +25,53 @@ class FCFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
const T* B = nullptr, bool relu = false,
bool padding_weights = false) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.MatMul(M, N, K, X, W, Y);
framework::Tensor Y1;
T* Y1_data = nullptr;
if (N % 128 == 0 && K % 128 == 0) {
const int NN = N + 4;
const int KK = K + 4;
framework::Tensor X1;
T* X1_data = X1.Resize({M * KK}).mutable_data<T>(platform::CPUPlace());
Y1_data = Y1.Resize({M * (N + 4)}).mutable_data<T>(platform::CPUPlace());
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0]));
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
}
framework::Tensor W1;
T* W1_data = nullptr;
if (!padding_weights) {
W1_data = W1.Resize({(K + 4) * (N + 4)})
.mutable_data<T>(platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < K; i++) {
memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0]));
}
}
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK,
(padding_weights ? W : W1_data), NN, static_cast<T>(0.0),
Y1_data, NN);
} else {
blas.MatMul(M, N, K, X, W, Y);
}
if (B == NULL) {
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
if (N % 128 == 0 && K % 128 == 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0]));
}
}
PADDLE_ENFORCE_EQ(relu, false,
platform::errors::PermissionDenied(
"When bias is NULL, relu can not be true."));
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
if (relu) {
Expand All @@ -37,7 +80,8 @@ class FCFunctor<platform::CPUDeviceContext, T> {
.At(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
compute(B, src, dst, N);
}
} else {
auto compute =
Expand All @@ -48,7 +92,8 @@ class FCFunctor<platform::CPUDeviceContext, T> {
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/operators/math/fc.cu
Expand Up @@ -41,7 +41,12 @@ class FCFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
const T* B = nullptr, bool relu = false,
bool padding_weights = false) {
PADDLE_ENFORCE_EQ(
padding_weights, false,
platform::errors::PermissionDenied(
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
"Weight padding in fc can not be used in GPU scope."));
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
static_cast<T>(0.0), Y, N);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/math/fc.h
Expand Up @@ -26,7 +26,8 @@ class FCFunctor {
public:
void operator()(const DeviceContext& context, const int M, const int N,
const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false);
const T* B = nullptr, bool relu = false,
bool weight_pass = false);
};

} // namespace math
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
Expand Up @@ -207,8 +207,13 @@ class FCPrimitiveFactory {
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, LoDTensor* output) {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool padding_weights = ctx.Attr<bool>("padding_weights");
PADDLE_ENFORCE_EQ(padding_weights, false,
platform::errors::PermissionDenied(
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved
"Weight padding in fc can not be used in MKLDNN."));
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims);
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims,
padding_weights);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
}
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fc_op.py
Expand Up @@ -124,6 +124,13 @@ def config(self):
self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1)


class TestFCOpWithPadding(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 4, 3, 128, 128, 2)
GaoWei8 marked this conversation as resolved.
Show resolved Hide resolved


class TestFCOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
Expand Down