-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MatMul operator #4856
MatMul operator #4856
Conversation
Similar to np.matmul, but also has transpose_X and transpose_Y flags, and only supports tensors from rank 1 to 3 inclusive. For GPU, uses cublas?gemmStridedBatched. For CPU, uses cblas_?gemm_batch if available via MKL; otherwise a simple serial implementation that loops over the batch dimension is employed for now.
…uesqiao/Paddle into generalize_matmul_squash_rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR!
@@ -130,6 +130,87 @@ void matmul<platform::CPUPlace, double>( | |||
matrix_b.data<double>(), beta, matrix_out->data<double>()); | |||
} | |||
|
|||
#ifdef PADDLE_USE_MKLML |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a good idea to move batched_gemm
into gemm.{cc,cu}
or batched_gemm.{cc,cu}
so to prevent math_function.cc from explode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking into this just now, but I think the batched_gemm
code is very closely related to the gemm
code. They need the same includes, etc., and if one needs to be modified, very likely also the other. I think it makes sense to keep them together. But please let me know if you still prefer that I split them.
// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported | ||
// yet. | ||
template <typename Place, typename T> | ||
class MatMulFunctor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this class template into a new header file matmal.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good - I'll do that.
It looks all good to me :) On the necessity of rank >= 4 That being said, I agree we can merge this PR first. |
@zhouxiao-coder Thank you for the pointer! I'll try to work on rank >= 4 soon then. |
This closes issues #4683 and #4696 .
The MatMul operator is used to perform (batched) matrix multiplication
over the last two dimensions of the input tensors
X
andY
.If a transpose flag is specified, the last two dimensions of the
tensor are transposed. If the tensor is rank-1 of shape [D], then
for
X
it is treated as [1, D] in nontransposed form and as [D, 1]in transposed form, whereas for
Y
it is the opposite: It is treatedas [D, 1] in nontransposed form and as [1, D] in transposed form.
Examples without transpose:
The behavior is designed to be similar to the
numpy.matmul
function.The differences are:
transpose_X
andtranspose_Y
flags, similar to BLAS routines.If there is interest, I could add support for rank 4 and higher tensors in a future PR. Essentially this should just involve adding some code to reshape to rank 3 and then undoing the reshape.