Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reshape and transpose operators from attention module #16342

Closed

Conversation

yihuaxu
Copy link
Contributor

@yihuaxu yihuaxu commented Mar 21, 2019

According to the performance status of Bert/Transformer model, fused matmul/reshape/transpose operators to reduce memory's copy.

Platform: Intel(R) Xeon(R) CPU E5-2699 v3 @ 2.30GHz
Model Path: third_party/inference_demo/bert_emb128/model
Batch Size: 1
Command: ./paddle/fluid/inference/tests/api/test_analyzer_bert --infer_model=third_party/inference_demo/bert_emb128/model/ --infer_data=third_party/inference_demo/bert_emb128/data.txt --gtest_filter=Analyzer_bert.profile --paddle_num_threads=1 --repeat=10 --batch_size=1 --test_all_data
Data Source: third_party/inference_demo/bert_emb128/data.txt.

The following is the comparison with the different scenarios.

image

Model Comparison:
(a).Before Optimization:
image

(b).After Optimization:
image

Reference:
Can we avoid head split_merge in Transformer.pdf
image

@yihuaxu yihuaxu force-pushed the develop_7fbf52daa_matmul_fuse_pass branch from 89136c8 to ef3cd38 Compare March 21, 2019 06:36
@luotao1 luotao1 added the Intel label Mar 21, 2019
@yihuaxu yihuaxu force-pushed the develop_7fbf52daa_matmul_fuse_pass branch from ef3cd38 to 0eea941 Compare March 21, 2019 11:15
@yihuaxu yihuaxu force-pushed the develop_7fbf52daa_matmul_fuse_pass branch 2 times, most recently from 62a9e02 to dc40dfb Compare March 21, 2019 13:20
@yihuaxu yihuaxu force-pushed the develop_7fbf52daa_matmul_fuse_pass branch from dc40dfb to b4359f9 Compare March 21, 2019 22:48
@yihuaxu
Copy link
Contributor Author

yihuaxu commented Mar 24, 2019

start a review

@jianhang-liu
Copy link
Contributor

@tensor-tang Please help to review this PR. This is one critical patch for BERT (and apply to Transformer) also. Thanks!

@yihuaxu yihuaxu changed the title Fuse matmul/reshape/transpose operators to reduce memory's copy Remove reshape and transpose operators from attention module Apr 16, 2019
@yihuaxu
Copy link
Contributor Author

yihuaxu commented Apr 24, 2019

@tensor-tang Please help us review this PR and give some suggestion. Thanks a lot!

@luotao1
Copy link
Contributor

luotao1 commented Apr 25, 2019

The following is the comparison with the different scenarios.

Do you have the model level comparison before and after this PR?

@yihuaxu
Copy link
Contributor Author

yihuaxu commented Apr 27, 2019

The following is the comparison with the different scenarios.

Do you have the model level comparison before and after this PR?

Just updated the description included the comparison of this model

@@ -137,7 +137,8 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// following two passes should be located in the last, since
// they will work on all fused ops.
"expected_kernel_cache_pass", //
"runtime_context_cache_pass"});
"runtime_context_cache_pass", //
"fuse_reshape_transpose_scale_matmul_pass"});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see line137, put fuse_reshape_transpose_scale_matmul_pass before expected_kernel_cache_pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

this->template GEMM<T>(transA == CblasTrans, transB == CblasTrans, M, N, K,
alpha, Ak, lda, Bk, ldb, beta, Ck, ldc);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • What's the difference between old BatchedGEMM and new BatchedGEMM in your PR?
  • I see the difference is the input format from const T to std::vector<const T *> *a_array.
  • Could you reuse the old one or unify them? Or why do you create a new one?

Same for the Matmul.
There is already

void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
                                 const MatDescriptor &dim_a,
                                 const framework::Tensor &mat_b,
                                 const MatDescriptor &dim_b, T alpha,
                                 framework::Tensor *mat_out, T beta) 

Copy link
Contributor Author

@yihuaxu yihuaxu May 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • What's the difference between old BatchedGEMM and new BatchedGEMM in your PR?
    Transfer the arrays of input and output into BatchedGEMM directly.
  • I see the difference is the input format from const T to std::vector<const T *> *a_array.
    According to the transpose dimensions's difference and the stride's requirement, the array calculation of MKL BatchedGEMM need be get though the special calculation. So it tends to move the calculation into the internal of matmul operator.
  • Could you reuse the old one or unify them? Or why do you create a new one?
    The initial idea is that it can avoid the completion into the common blas's implementation. If we need implement this or others array's calculation into blas_impl.h, it can not keep the code's clean.

Same for the Matmul.
There is already

void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
                                 const MatDescriptor &dim_a,
                                 const framework::Tensor &mat_b,
                                 const MatDescriptor &dim_b, T alpha,
                                 framework::Tensor *mat_out, T beta) 

.SetDefault(std::vector<int>{-1, -1, -1});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why matmul need is_test attribute?
Why add last_dim attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why matmul need is_test attribute?
Add the "is_test" attribute for inference mode and don't influence other requirement.
Why add last_dim attribute?
To decrease the count of matmul operator's attributes, but it will result in that it is only for the special dimensions of reshape and transpose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the "is_test" attribute for inference mode and don't influence other requirement.

Matmul is a common and base op, and it should not have the difference between train and inference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luotao1 Score (inference) and Forward part of Training sometimes have difference. "is_test" attribute is used to distinguish between them and it's widely used in many OPs. For example:

  • Batch Norm: Use fixed mean/variance instead of computing on batch
  • Softmax: skip epson for performance improvement
  • sequence_pool: don't create index buffer for performance improvement

The optimization here for attention (i.e. remove transpose/reshape via enhanced MatMul) only need apply to Inference only at this time. So we add "is_test" to contain our code.

In case this optimization need be applied to training also, we can add backward part and remove this "is_test" in fwd.

@jianhang-liu jianhang-liu added this to the v1.5 for Intel milestone May 13, 2019
@bingyanghuang
Copy link
Contributor

@jianhang-liu This PR should be moved to Release 1.6

template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const;

template <typename T>
void MatMul(std::vector<const T*>* a_array, const MatDescriptor& dim_a,
const int ld_a, std::vector<const T*>* b_array,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why matmul input should be a vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now Bingyang are ready to re-implement the pass in future. This PR will be aborted.

@@ -176,11 +177,24 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;

template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, std::vector<const T*>* a_array,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std:vector<const T*>* ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now Bingyang are ready to re-implement the pass in future. This PR will be aborted.

paddle/fluid/operators/matmul_op.cc Show resolved Hide resolved
@tensor-tang
Copy link
Contributor

Such a good point and thanks to the foresight GEMM!

@@ -136,7 +136,8 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"is_test_pass", //
// following two passes should be located in the last, since
// they will work on all fused ops.
"expected_kernel_cache_pass", //
"expected_kernel_cache_pass", //
"fuse_reshape_transpose_scale_matmul_pass", //
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add at last?

this is a very big fuse, maybe should be earlier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now Bingyang are ready to re-implement the pass in future. This PR will be aborted.

@GaoWei8
Copy link
Contributor

GaoWei8 commented Aug 26, 2019

This PR is tested on Ernie in CPU and the num of threads is set as 20.
The original time after fuse is 72.5432ms (without this PR) and the time decreases to 69.9925ms with merged this PR.
Is this situation correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants