Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MatMul operator #4856

Merged
merged 5 commits into from
Oct 18, 2017
Merged

Conversation

mkliegl
Copy link
Contributor

@mkliegl mkliegl commented Oct 17, 2017

This closes issues #4683 and #4696 .

The MatMul operator is used to perform (batched) matrix multiplication
over the last two dimensions of the input tensors X and Y.

If a transpose flag is specified, the last two dimensions of the
tensor are transposed. If the tensor is rank-1 of shape [D], then
for X it is treated as [1, D] in nontransposed form and as [D, 1]
in transposed form, whereas for Y it is the opposite: It is treated
as [D, 1] in nontransposed form and as [1, D] in transposed form.

Examples without transpose:

  • X: [K], Y: [K] => Out: [1]
  • X: [K], Y: [K, N] => Out: [N]
  • X: [B, M, K], Y: [K] => Out: [B, M]
  • X: [M, K], Y: [B, K, N] => Out: [B, M, N]
  • X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]

The behavior is designed to be similar to the numpy.matmul function.
The differences are:

  • Currently only rank 1 to rank 3 input tensors are supported.
  • We add transpose_X and transpose_Y flags, similar to BLAS routines.

If there is interest, I could add support for rank 4 and higher tensors in a future PR. Essentially this should just involve adding some code to reshape to rank 3 and then undoing the reshape.

Markus Kliegl and others added 4 commits October 16, 2017 17:59
Similar to np.matmul, but also has transpose_X and transpose_Y flags,
and only supports tensors from rank 1 to 3 inclusive.

For GPU, uses cublas?gemmStridedBatched. For CPU, uses
cblas_?gemm_batch if available via MKL; otherwise a simple serial
implementation that loops over the batch dimension is employed for now.
@wangkuiyi wangkuiyi requested review from lcy-seso and removed request for wangkuiyi October 17, 2017 02:14
Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR!

@@ -130,6 +130,87 @@ void matmul<platform::CPUPlace, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>());
}

#ifdef PADDLE_USE_MKLML
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a good idea to move batched_gemm into gemm.{cc,cu} or batched_gemm.{cc,cu} so to prevent math_function.cc from explode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking into this just now, but I think the batched_gemm code is very closely related to the gemm code. They need the same includes, etc., and if one needs to be modified, very likely also the other. I think it makes sense to keep them together. But please let me know if you still prefer that I split them.

// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported
// yet.
template <typename Place, typename T>
class MatMulFunctor {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this class template into a new header file matmal.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good - I'll do that.

@zhouxiao-coder
Copy link
Contributor

It looks all good to me :)

On the necessity of rank >= 4 matmul. The only time I find it very useful is when I'm trying to implement the "multihead_attention" from the paper Attention is All You Need code. They have extra dimensions like heads or blocks, so 3-d is not enough.

That being said, I agree we can merge this PR first.

@mkliegl
Copy link
Contributor Author

mkliegl commented Oct 17, 2017

@zhouxiao-coder Thank you for the pointer! I'll try to work on rank >= 4 soon then.
@wangkuiyi Thank you for the suggestions! I moved MatMulFunctor into a separate matmul.h file, but left gemm and batched_gemm together for now (see comment above).

@mkliegl mkliegl merged commit 1648982 into PaddlePaddle:develop Oct 18, 2017
@mkliegl mkliegl deleted the generalize_matmul_squash_rebase branch October 18, 2017 02:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants