From 79c2ee0a8fa723f2363c438a6338f2bf289f472e Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sat, 15 Aug 2020 21:43:35 +0800 Subject: [PATCH 01/18] add matmul_v2 --- paddle/fluid/operators/math/blas.h | 5 + paddle/fluid/operators/math/blas_impl.cu.h | 11 + paddle/fluid/operators/math/blas_impl.h | 21 ++ paddle/fluid/operators/matmul_v2_op.cc | 124 ++++++ paddle/fluid/operators/matmul_v2_op.cu | 22 ++ paddle/fluid/operators/matmul_v2_op.h | 304 +++++++++++++++ .../tests/unittests/test_matmul_v2_op.py | 355 ++++++++++++++++++ 7 files changed, 842 insertions(+) create mode 100644 paddle/fluid/operators/matmul_v2_op.cc create mode 100644 paddle/fluid/operators/matmul_v2_op.cu create mode 100644 paddle/fluid/operators/matmul_v2_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_matmul_v2_op.py diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index f8c971954fc4c..42a60e9220cf8 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -198,6 +198,11 @@ class Blas { int K, T alpha, const T* A, const T* B, T beta, T* C, int batchCount, int64_t strideA, int64_t strideB) const; + template + void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, + int K, T alpha, const T** A, const T** B, T beta, T** C, + int batchCount) const; + #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) template void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 64b35cfeaecd1..d0c5f74d4efb8 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -458,6 +458,17 @@ void Blas::BatchedGEMM( #endif // CUDA_VERSION >= 9010 } +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, + C[k]); + } +} + template <> template void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index cdaf53fea3008..892bf15738141 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include #include @@ -655,6 +656,26 @@ void Blas::BatchedGEMM( #endif } +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { +#ifdef PADDLE_WITH_MKLML + const int lda = std::max((transA == CblasNoTrans) ? K : M, 1); + const int ldb = std::max((transB == CblasNoTrans) ? N : K, 1); + const int ldc = std::max(N, 1); + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, + &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, + C[k]); + } +#endif +} + #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) template <> template diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc new file mode 100644 index 0000000000000..f9b72b5db6e9f --- /dev/null +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/matmul_v2_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class MatMulV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2"); + bool trans_x = ctx->Attrs().Get("trans_x"); + bool trans_y = ctx->Attrs().Get("trans_y"); + + std::vector dims_x = + paddle::framework::vectorize(ctx->GetInputDim("X")); + std::vector dims_y = + paddle::framework::vectorize(ctx->GetInputDim("Y")); + auto ndims_x = dims_x.size(); + auto ndims_y = dims_y.size(); + + bool x_broadcasted = false, y_broadcasted = false; + if (ndims_x == 1) { + dims_x.insert(dims_x.begin(), 1); + ndims_x = 2; + x_broadcasted = true; + } + + if (ndims_y == 1) { + dims_y.push_back(1); + ndims_y = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = dims_x[ndims_x - 1]; + } else { + M = dims_x[ndims_x - 2]; + } + if (trans_y) { + N = dims_y[ndims_y - 2]; + } else { + N = dims_y[ndims_y - 1]; + } + + std::vector new_dims; + if (ndims_x >= ndims_y) { + new_dims.assign(dims_x.begin(), dims_x.end() - 2); + } else { + new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + if (x_broadcasted && y_broadcasted) { + new_dims.push_back(1); + } + + auto out_dims = framework::make_ddim(new_dims); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /* --> */ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "tensor of shape (dim0, dim1 ... M, K)"); + AddInput("Y", "tensor of shape (dim0, dim1 ... K, N)"); + AddOutput("Out", "tensor of shape (dim0, dim1 ... M, N)"); + AddAttr("trans_x", + "Set true to transpose the last two dimensions of X before " + "doing multiplication") + .SetDefault(false); + AddAttr("trans_y", + "Set true to transpose the last two dimensions of Y before " + "doing multiplication") + .SetDefault(false); + AddComment(R"DOC( + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + matmul_v2, ops::MatMulV2Kernel, + ops::MatMulV2Kernel); diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu new file mode 100644 index 0000000000000..bce4f0de85504 --- /dev/null +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/matmul_v2_op.h" + +namespace ops = paddle::operators; +namespace plf = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(matmul_v2, + ops::MatMulV2Kernel, + ops::MatMulV2Kernel); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h new file mode 100644 index 0000000000000..6b98b4d46f7dd --- /dev/null +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -0,0 +1,304 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +void ComputeBroadcastBinaryOpDims(const int A_ndim, const std::int64_t* A_dims, + const int B_ndim, const std::int64_t* B_dims, + std::int64_t* A_broadcast_dims, + std::int64_t* B_broadcast_dims, + std::int64_t* C_broadcast_dims) { + const int ndim = std::max(A_ndim, B_ndim); + std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); + std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); + std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); + std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); + for (int i = 0; i < ndim; ++i) { + PADDLE_ENFORCE_EQ(A_broadcast_dims[i] == B_broadcast_dims[i] || + A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1, + true, platform::errors::InvalidArgument( + "Input(X) and Input(Y) has error dim.")); + if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { + C_broadcast_dims[i] = 0; + } else { + C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); + } + } +} + +int64_t GetIndexFromDims(const int n, const int64_t* dims, + const int64_t* index) { + int64_t sum = 0; + for (int i = 0; i < n; ++i) { + if (dims[i] > 1) { + sum = sum * dims[i] + index[i]; + } + } + return sum; +} + +void IncreaseIndexInDims(const int ndim, const int64_t* dims, int64_t* index) { + for (int i = ndim - 1; i >= 0; --i) { + ++index[i]; + if (index[i] >= dims[i]) { + index[i] -= dims[i]; + } else { + break; + } + } +} + +template +void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, + bool trans_y, + const paddle::framework::ExecutionContext& ctx) { + // get dims + const std::vector x_dims = vectorize(X->dims()); + const std::vector y_dims = vectorize(Y->dims()); + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + // get data ptr + const T* x_data = X->data(); + const T* y_data = Y->data(); + + if (x_ndim == 1 && y_ndim == 1) { + VLOG(0) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(ctx.GetPlace()); + auto out_eigen = framework::EigenScalar::From(*Out); + auto x_eigen = framework::EigenVector::Flatten(*X); + auto y_eigen = framework::EigenVector::Flatten(*Y); + + auto& dev = *ctx.template device_context().eigen_device(); + out_eigen.device(dev) = (x_eigen * y_eigen).sum(); + return; + } + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + if (x_ndim == 1) { + const int N = X->numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], N, + platform::errors::InvalidArgument("Input(Y) has error dim.")); + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], N, + platform::errors::InvalidArgument("Input(Y) has error dim.")); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->Resize(framework::make_ddim(out_dims)); + Out->mutable_data(ctx.GetPlace()); + if (trans_y) { + VLOG(0) << "haha"; + } + if (trans_y) { + const int M = Y->numel() / N; + VLOG(0) << "MatMul's case 2"; + blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data()); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y->numel() / (M * N); + if (batch_size == 1) { + VLOG(0) << "MatMul's case 3"; + blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data()); + } else { + VLOG(0) << "MatMul's case 4"; + blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data, + x_data, 0, Out->data(), batch_size, M * N, 0); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y->numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], N, + platform::errors::InvalidArgument("Input(X) has error dim.")); + } else { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 1], N, + platform::errors::InvalidArgument("Input(X) has error dim.")); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->Resize(framework::make_ddim(out_dims)); + Out->mutable_data(ctx.GetPlace()); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X->numel() / (M * N); + if (batch_size == 1) { + VLOG(0) << "MatMul's case 5"; + blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data()); + } else { + VLOG(0) << "MatMul's case 6"; + blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data, + y_data, 0, Out->data(), batch_size, M * N, 0); + } + } else { + const int M = X->numel() / N; + VLOG(0) << "MatMul's case 7"; + blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data()); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument( + "Input(X) has error dim.")); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument( + "Input(X) has error dim.")); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = std::max(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + + ComputeBroadcastBinaryOpDims(x_ndim - 2, x_dims.data(), y_ndim - 2, + y_dims.data(), x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->Resize(framework::make_ddim(out_broadcast_dims)); + Out->mutable_data(ctx.GetPlace()); + + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = !std::equal( + x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); + + const std::int64_t x_batch_size = std::accumulate( + x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL, + std::multiplies()); + const std::int64_t y_batch_size = std::accumulate( + y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL, + std::multiplies()); + const std::int64_t out_batch_size = std::accumulate( + out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL, + std::multiplies()); + if (out_batch_size == 0) { + return; + } + + if (x_batch_size == 1 && y_batch_size == 1) { + VLOG(0) << "MatMul's case 8"; + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, + y_data, 0.0f, Out->data()); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + VLOG(0) << "MatMul's case 9"; + blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f, + Out->data()); + } else { + VLOG(0) << "MatMul's case 10"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, + x_data, y_data, 0, Out->data(), out_batch_size, 0, + K * N); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + VLOG(0) << "MatMul's case 11"; + blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, + x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f, + Out->data()); + } else { + VLOG(0) << "MatMul's case 12"; + blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, + 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, + M * K, 0); + } + } else if (!is_broadcast_dims) { + VLOG(0) << "MatMul's case 13"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, + y_data, 0, Out->data(), out_batch_size, M * K, K * N); + } else { + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + const std::int64_t x_index = + GetIndexFromDims(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexFromDims(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = Out->data() + i * M * N; + IncreaseIndexInDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + VLOG(0) << "MatMul's case 14"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, + x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(), + out_batch_size); + } +} + +template +class MatMulV2Kernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Output("Out"); + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + + MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py new file mode 100644 index 0000000000000..74093d2622598 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -0,0 +1,355 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, )) + elif X.ndim == 2: + X = X.T + else: + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((Y.size, )) + else: + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float64") + return Out + + +class TestMatMulV2Op(OpTest): + """ + case 1 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + def setUp(self): + self.config() + self.op_type = "matmul_v2" + x = np.random.random(self.x_shape).astype(self.dtype) + y = np.random.random(self.y_shape).astype(self.dtype) + result = reference_matmul(x, y, self.trans_x, self.trans_y) + + self.inputs = { + 'X': x, + 'Y': y, + } + self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + +class TestMatMuklOp2(TestMatMulV2Op): + """ + case 2 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 3, 2, 100) + self.trans_x = False + self.trans_y = True + self.dtype = "float64" + + +class TestMatMuklOp3(TestMatMulV2Op): + """ + case 3 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp4(TestMatMulV2Op): + """ + case 4 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 2, 100, 2) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp5(TestMatMulV2Op): + """ + case 5 + """ + + def config(self): + self.x_shape = (1, 1, 100, 2) + self.y_shape = (100, ) + self.trans_x = True + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp6(TestMatMulV2Op): + """ + case 6 + """ + + def config(self): + self.x_shape = (1, 2, 100, 2) + self.y_shape = (100, ) + self.trans_x = True + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp7(TestMatMulV2Op): + """ + case 7 + """ + + def config(self): + self.x_shape = (1, 2, 2, 100) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp8(TestMatMulV2Op): + """ + case 8 + """ + + def config(self): + self.x_shape = (1, 1, 2, 100) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp9(TestMatMulV2Op): + """ + case 9 + """ + + def config(self): + self.x_shape = (1, 1, 1, 100) + self.y_shape = (2, 1, 2, 100) + self.trans_x = False + self.trans_y = True + self.dtype = "float64" + + +class TestMatMuklOp10(TestMatMulV2Op): + """ + case 10 + """ + + def config(self): + self.x_shape = (1, 1, 2, 100) + self.y_shape = (1, 2, 100, 2) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp11(TestMatMulV2Op): + """ + case 11 + """ + + def config(self): + self.x_shape = (2, 1, 2, 100) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp12(TestMatMulV2Op): + """ + case 12 + """ + + def config(self): + self.x_shape = (2, 1, 100, 2) + self.y_shape = (1, 1, 100, 2) + self.trans_x = True + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp13(TestMatMulV2Op): + """ + case 13 + """ + + def config(self): + self.x_shape = (2, 2, 100, 2) + self.y_shape = (2, 2, 100, 2) + self.trans_x = True + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp14(TestMatMulV2Op): + """ + case 14_1 + """ + + def config(self): + self.x_shape = (3, 1, 1, 100, 2) + self.y_shape = (1, 2, 2, 100, 2) + self.trans_x = True + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp15(TestMatMulV2Op): + """ + case 14_2 + """ + + def config(self): + self.x_shape = (3, 1, 1, 2, 100) + self.y_shape = (1, 2, 2, 100, 1) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +# class TestMatMuklOp2(TestMatMulV2Op): +# """ +# """ +# def config(self): +# self.x_shape = (10,) +# self.y_shape = (1, 10, 5) +# self.trans_x = False +# self.trans_y = False +# self.dtype = "float64" + +# class TestMatMuklOp3(TestMatMulV2Op): +# """ +# """ +# def config(self): +# self.x_shape = (10,) +# self.y_shape = (10, 10, 5) +# self.trans_x = False +# self.trans_y = False +# self.dtype = "float64" + +# class Generator(object): +# def setUp(self): +# self.op_type = "matmul_v2" +# X = np.random.random(self.shape_X).astype("float64") +# Y = np.random.random(self.shape_Y).astype("float64") +# Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y) +# #print(X.shape,Y.shape,Out.shape,self.transpose_X,self.transpose_X) +# #print(Out) +# self.inputs = {'X': X, 'Y': Y} +# self.attrs = { +# 'trans_x': self.transpose_X, +# 'trans_y': self.transpose_Y +# } +# self.outputs = {'Out': Out} + +# def test_check_output(self): +# self.check_output() + +# def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y, batchsize): +# global shape_x, shape_y +# if dim_X == 1 and dim_Y == 1: +# return [100], [100] + +# if dim_X == 1: +# shape_x = [100] +# if transpose_Y: +# shape_y = [2, 100] +# else: +# if batchsize == -1: +# shape_y = [100, 2] +# else: +# shape_y = [batchsize, 100, 2] +# return shape_x, shape_y + +# if dim_Y == 1: +# shape_y = [100] +# if transpose_X: +# shape_x = [100, 2] +# else: +# if batchsize == -1: +# shape_x = [2, 100] +# else: +# shape_x = [batchsize, 2, 100] +# return shape_x, shape_y + +# # Generate operators cases for all possibilities +# def inject_test(dim_x, dim_y, trans_x, trans_y, batchsize): +# test_name = ('TestMatMulV2Op_dimX_{}_dim_Y_{}_transX_{}_transY_{}_Batchsize{}'.format( +# dim_x, dim_y, trans_x, trans_y, batchsize)) +# shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, +# trans_y, batchsize) +# print(shape_x, shape_y, trans_x, trans_y) +# globals()[test_name] = type(test_name, (Generator, OpTest), { +# 'shape_X': shape_x, +# 'shape_Y': shape_y, +# 'transpose_X': trans_x, +# 'transpose_Y': trans_y, +# }) + +# for dim_X in [1]: +# for dim_Y in [1, -1]: +# for batchsize in [-1, 1, 2]: +# for transose_x in [False, True]: +# for transose_y in [False, True]: +# inject_test(dim_X, dim_Y, transose_x, transose_y, batchsize) +if __name__ == "__main__": + unittest.main() From 42287279227f3fb625c33f7119b30ea0fc24ad92 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sat, 15 Aug 2020 21:45:47 +0800 Subject: [PATCH 02/18] fix op --- paddle/fluid/operators/matmul_v2_op.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 6b98b4d46f7dd..8bc59c08cc993 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -121,9 +121,6 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, } Out->Resize(framework::make_ddim(out_dims)); Out->mutable_data(ctx.GetPlace()); - if (trans_y) { - VLOG(0) << "haha"; - } if (trans_y) { const int M = Y->numel() / N; VLOG(0) << "MatMul's case 2"; From bf694dea5528042d0fac415116128f76fffeae19 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sat, 15 Aug 2020 23:50:52 +0800 Subject: [PATCH 03/18] fix op --- paddle/fluid/operators/dot_op.h | 158 +++++++++++++------------ paddle/fluid/operators/matmul_v2_op.cc | 57 ++++++++- paddle/fluid/operators/matmul_v2_op.cu | 4 + 3 files changed, 139 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index 2580b00d7c2bd..1b155ca8eacfe 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -26,6 +26,86 @@ template using EigenMatrix = framework::EigenMatrix; +template +void DotFunction(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { +#ifdef __NVCC__ + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + auto& dev = *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + auto& dev = *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_y[i] * data_dout[s]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_x[i] * data_dout[s]; + } + } +#endif +} + template class DotKernel : public framework::OpKernel { public: @@ -84,83 +164,9 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); -#ifdef __NVCC__ - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); - } - } else { - auto dout = EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = EigenMatrix::From(*tensor_y); - auto dx = EigenMatrix::From(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = EigenMatrix::From(*tensor_x); - auto dy = EigenMatrix::From(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_y[i] * data_dout[s]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_x[i] * data_dout[s]; - } - } -#endif + DotFunction(tensor_x, tensor_y, tensor_dout, tensor_dx, + tensor_dy, ctx); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index f9b72b5db6e9f..aeef6108b5ab6 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -111,14 +111,63 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { } }; +class MatMulV2OpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2"); + OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "matmul_v2"); + auto x_dims = context->GetInputDim("X"); + auto y_dims = context->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (context->HasOutput(x_grad_name)) { + context->SetOutputDim(x_grad_name, x_dims); + } + if (context->HasOutput(y_grad_name)) { + context->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +template +class MatMulV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("matmul_v2_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, + ops::MatMulV2GradOpMaker, + ops::MatMulV2GradOpMaker); + +REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad); + REGISTER_OP_CPU_KERNEL( matmul_v2, ops::MatMulV2Kernel, ops::MatMulV2Kernel); + +REGISTER_OP_CPU_KERNEL( + matmul_v2_grad, + ops::MatMulV2GradKernel, + ops::MatMulV2GradKernel); diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index bce4f0de85504..31cfcff222266 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -20,3 +20,7 @@ namespace plf = paddle::platform; REGISTER_OP_CUDA_KERNEL(matmul_v2, ops::MatMulV2Kernel, ops::MatMulV2Kernel); + +REGISTER_OP_CUDA_KERNEL( + matmul_v2_grad, ops::GatherNdGradOpKernel, + ops::GatherNdGradOpKernel); From b5c32372cc2550c33964d45da480ad8bbddf842a Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sat, 15 Aug 2020 23:54:47 +0800 Subject: [PATCH 04/18] fix op --- paddle/fluid/operators/matmul_v2_op.h | 64 +++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 8bc59c08cc993..aec52703cfc99 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -86,6 +87,10 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, const T* y_data = Y->data(); if (x_ndim == 1 && y_ndim == 1) { + PADDLE_ENFORCE_EQ(X->numel(), Y->numel(), + platform::errors::InvalidArgument( + "X's numbers is not equal to Y's numbers," + "when X/Y dims =1")); VLOG(0) << "MatMul's case 1"; Out->Resize({1}); Out->mutable_data(ctx.GetPlace()); @@ -297,5 +302,64 @@ class MatMulV2Kernel : public framework::OpKernel { } }; +template +class MatMulV2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + // get dims + const std::vector x_dims = vectorize(X->dims()); + const std::vector y_dims = vectorize(Y->dims()); + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + if (dx) dx->mutable_data(ctx.GetPlace()); + if (dy) dy->mutable_data(ctx.GetPlace()); + + // x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + if (dOut->numel() == 1) { + DotFunction(X, Y, dOut, dx, dy, ctx); + return; + } + } + + if (trans_x) { + if (trans_y) { + // X'Y' + // dA = Y'G', dB = G'X' + if (dx) MatMulFunction(Y, dOut, dx, true, true, ctx); + if (dy) MatMulFunction(dOut, X, dy, true, true, ctx); + } else { + // X'Y: + // dX = YG', dY = XG + if (dx) MatMulFunction(Y, dOut, dx, false, true, ctx); + if (dy) + MatMulFunction(X, dOut, dy, false, false, ctx); + } + } else { + if (trans_y) { + // XY': + // dX = GY, dY = G'X + if (dx) + MatMulFunction(dOut, Y, dx, false, false, ctx); + if (dy) MatMulFunction(dOut, X, dy, true, false, ctx); + } else { + // XY: + // dX = GY', dY = X'G + if (dx) MatMulFunction(dOut, Y, dx, false, true, ctx); + if (dy) MatMulFunction(X, dOut, dy, true, false, ctx); + } + } + // reduce sum to get grad ReduceKernelFunctor + } +}; + } // namespace operators } // namespace paddle From abafccdf645486968f7096a3cb4a081d8463cc18 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sun, 16 Aug 2020 12:43:00 +0800 Subject: [PATCH 05/18] add grad --- paddle/fluid/operators/dot_op.h | 12 ++-- paddle/fluid/operators/matmul_v2_op.h | 66 ++++++++++++++++--- .../tests/unittests/test_matmul_v2_op.py | 3 + 3 files changed, 65 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index 1b155ca8eacfe..cec706300d77b 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -27,10 +27,10 @@ template ; template -void DotFunction(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { +void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { #ifdef __NVCC__ if (1 == tensor_dout->dims().size()) { auto dout = framework::EigenVector::Flatten(*tensor_dout); @@ -165,8 +165,8 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotFunction(tensor_x, tensor_y, tensor_dout, tensor_dx, - tensor_dy, ctx); + DotGradFunction(tensor_x, tensor_y, tensor_dout, + tensor_dx, tensor_dy, ctx); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index aec52703cfc99..0e3c4b5648c04 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" namespace paddle { namespace operators { @@ -325,39 +326,84 @@ class MatMulV2GradKernel : public framework::OpKernel { // x's or y's dim = 1 if (x_ndim == 1 && y_ndim == 1) { if (dOut->numel() == 1) { - DotFunction(X, Y, dOut, dx, dy, ctx); + DotGradFunction(X, Y, dOut, dx, dy, ctx); return; } } + // FIXME + assert(x_ndim >= 2 && y_ndim >= 2); + // the normal case + Tensor dx_help, dy_help; if (trans_x) { if (trans_y) { // X'Y' // dA = Y'G', dB = G'X' - if (dx) MatMulFunction(Y, dOut, dx, true, true, ctx); - if (dy) MatMulFunction(dOut, X, dy, true, true, ctx); + if (dx) + MatMulFunction(Y, dOut, &dx_help, true, true, ctx); + if (dy) + MatMulFunction(dOut, X, &dy_help, true, true, ctx); } else { // X'Y: // dX = YG', dY = XG - if (dx) MatMulFunction(Y, dOut, dx, false, true, ctx); + if (dx) + MatMulFunction(Y, dOut, &dx_help, false, true, ctx); if (dy) - MatMulFunction(X, dOut, dy, false, false, ctx); + MatMulFunction(X, dOut, &dy_help, false, false, + ctx); } } else { if (trans_y) { // XY': // dX = GY, dY = G'X if (dx) - MatMulFunction(dOut, Y, dx, false, false, ctx); - if (dy) MatMulFunction(dOut, X, dy, true, false, ctx); + MatMulFunction(dOut, Y, &dx_help, false, false, + ctx); + if (dy) + MatMulFunction(dOut, X, &dy_help, true, false, ctx); } else { // XY: // dX = GY', dY = X'G - if (dx) MatMulFunction(dOut, Y, dx, false, true, ctx); - if (dy) MatMulFunction(X, dOut, dy, true, false, ctx); + if (dx) + MatMulFunction(dOut, Y, &dx_help, false, true, ctx); + if (dy) + MatMulFunction(X, dOut, &dy_help, true, false, ctx); } } - // reduce sum to get grad ReduceKernelFunctor + + // cal the broadcast message + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + const int dx_help_ndim = x_help_dims.size(); + const int dy_help_ndim = y_help_dims.size(); + + assert(dx_help_ndim == dy_help_ndim); + // get reduce dim + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (std::int64_t idx = dx_help_ndim - 1 - 2; idx >= 0; idx--) { + if (x_help_dims) } + // ComputeReudceSumBinaryOpDims(); + + const int ndim = std::max(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + ComputeBroadcastBinaryOpDims(x_ndim - 2, x_dims.data(), y_ndim - 2, + y_dims.data(), x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + + // reduce sum to get grad by ReduceKernelFunctor + // if (dx) { + // } + // ReduceKernelFunctor( + // dx, output, dims, keep_dim, reduce_all, context) + + // y->mutable_data(ctx.GetPlace()); + // auto y_e = framework::EigenVector::Flatten(*y); + // auto y_arr = y_e.reshape(shape); } }; diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 74093d2622598..409756eba96c7 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -83,6 +83,9 @@ def setUp(self): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out') + class TestMatMuklOp2(TestMatMulV2Op): """ From fde400c00c29b06e13ad9048a815376800d70469 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Mon, 17 Aug 2020 19:12:56 +0800 Subject: [PATCH 06/18] fix op --- paddle/fluid/operators/matmul_v2_op.cu | 4 +- paddle/fluid/operators/matmul_v2_op.h | 181 +++++++++++++++++-------- 2 files changed, 129 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index 31cfcff222266..64ec65a234197 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -22,5 +22,5 @@ REGISTER_OP_CUDA_KERNEL(matmul_v2, ops::MatMulV2Kernel); REGISTER_OP_CUDA_KERNEL( - matmul_v2_grad, ops::GatherNdGradOpKernel, - ops::GatherNdGradOpKernel); + matmul_v2_grad, ops::MatMulV2GradKernel, + ops::MatMulV2GradKernel); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 0e3c4b5648c04..f4113e040b18f 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -22,17 +22,49 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +// #include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" + +#ifdef __NVCC__ +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#endif namespace paddle { namespace operators { using framework::Tensor; -void ComputeBroadcastBinaryOpDims(const int A_ndim, const std::int64_t* A_dims, - const int B_ndim, const std::int64_t* B_dims, - std::int64_t* A_broadcast_dims, - std::int64_t* B_broadcast_dims, - std::int64_t* C_broadcast_dims) { + +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +template +void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, + const std::vector& reduce_dims, + const paddle::framework::ExecutionContext& ctx) { + if (reduce_dims.empty()) { + framework::TensorCopySync(*input, ctx.GetPlace(), output); + return; + } +#ifdef __NVCC__ + auto stream = ctx.cuda_device_context().stream(); + TensorReduce>( + *input, output, reduce_dims, static_cast(0), cub::Sum(), + IdentityFunctor(), stream); +#else + ReduceKernelFunctor( + input, output, reduce_dims, true, false, ctx) + .template apply(); +#endif +} + +static void ComputeBroadcastBinaryOpDims( + const int A_ndim, const std::int64_t* A_dims, const int B_ndim, + const std::int64_t* B_dims, std::int64_t* A_broadcast_dims, + std::int64_t* B_broadcast_dims, std::int64_t* C_broadcast_dims) { const int ndim = std::max(A_ndim, B_ndim); std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); @@ -51,8 +83,8 @@ void ComputeBroadcastBinaryOpDims(const int A_ndim, const std::int64_t* A_dims, } } -int64_t GetIndexFromDims(const int n, const int64_t* dims, - const int64_t* index) { +static int64_t GetIndexMessage(const int n, const int64_t* dims, + const int64_t* index) { int64_t sum = 0; for (int i = 0; i < n; ++i) { if (dims[i] > 1) { @@ -62,7 +94,8 @@ int64_t GetIndexFromDims(const int n, const int64_t* dims, return sum; } -void IncreaseIndexInDims(const int ndim, const int64_t* dims, int64_t* index) { +static void IncreaseIndexInDims(const int ndim, const int64_t* dims, + int64_t* index) { for (int i = ndim - 1; i >= 0; --i) { ++index[i]; if (index[i] >= dims[i]) { @@ -92,7 +125,7 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, platform::errors::InvalidArgument( "X's numbers is not equal to Y's numbers," "when X/Y dims =1")); - VLOG(0) << "MatMul's case 1"; + VLOG(3) << "MatMul's case 1"; Out->Resize({1}); Out->mutable_data(ctx.GetPlace()); auto out_eigen = framework::EigenScalar::From(*Out); @@ -129,16 +162,16 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, Out->mutable_data(ctx.GetPlace()); if (trans_y) { const int M = Y->numel() / N; - VLOG(0) << "MatMul's case 2"; + VLOG(3) << "MatMul's case 2"; blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data()); } else { const int M = y_dims[y_ndim - 1]; const int batch_size = Y->numel() / (M * N); if (batch_size == 1) { - VLOG(0) << "MatMul's case 3"; + VLOG(3) << "MatMul's case 3"; blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data()); } else { - VLOG(0) << "MatMul's case 4"; + VLOG(3) << "MatMul's case 4"; blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data, x_data, 0, Out->data(), batch_size, M * N, 0); } @@ -171,16 +204,16 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, const int M = x_dims[x_ndim - 1]; const int batch_size = X->numel() / (M * N); if (batch_size == 1) { - VLOG(0) << "MatMul's case 5"; + VLOG(3) << "MatMul's case 5"; blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data()); } else { - VLOG(0) << "MatMul's case 6"; + VLOG(3) << "MatMul's case 6"; blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data, y_data, 0, Out->data(), batch_size, M * N, 0); } } else { const int M = X->numel() / N; - VLOG(0) << "MatMul's case 7"; + VLOG(3) << "MatMul's case 7"; blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data()); } return; @@ -232,17 +265,17 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, } if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(0) << "MatMul's case 8"; + VLOG(3) << "MatMul's case 8"; blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, y_data, 0.0f, Out->data()); } else if (x_batch_size == 1) { if (M == 1 && trans_y) { - VLOG(0) << "MatMul's case 9"; + VLOG(3) << "MatMul's case 9"; blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f, Out->data()); } else { - VLOG(0) << "MatMul's case 10"; + VLOG(3) << "MatMul's case 10"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, 0, @@ -250,18 +283,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, } } else if (y_batch_size == 1) { if (!trans_x) { - VLOG(0) << "MatMul's case 11"; + VLOG(3) << "MatMul's case 11"; blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f, Out->data()); } else { - VLOG(0) << "MatMul's case 12"; + VLOG(3) << "MatMul's case 12"; blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, M * K, 0); } } else if (!is_broadcast_dims) { - VLOG(0) << "MatMul's case 13"; + VLOG(3) << "MatMul's case 13"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, M * K, K * N); @@ -272,16 +305,16 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, std::vector index(batch_dim); for (std::int64_t i = 0; i < out_batch_size; ++i) { const std::int64_t x_index = - GetIndexFromDims(batch_dim, x_broadcast_dims.data(), index.data()); + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); const std::int64_t y_index = - GetIndexFromDims(batch_dim, y_broadcast_dims.data(), index.data()); + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); x_ptr[i] = x_data + x_index * M * K; y_ptr[i] = y_data + y_index * K * N; out_ptr[i] = Out->data() + i * M * N; IncreaseIndexInDims(batch_dim, out_broadcast_dims.data(), index.data()); } - VLOG(0) << "MatMul's case 14"; + VLOG(3) << "MatMul's case 14"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(), @@ -298,7 +331,6 @@ class MatMulV2Kernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); } }; @@ -312,19 +344,25 @@ class MatMulV2GradKernel : public framework::OpKernel { auto* dOut = ctx.Input(framework::GradVarName("Out")); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); + // get dims const std::vector x_dims = vectorize(X->dims()); const std::vector y_dims = vectorize(Y->dims()); + const std::vector dout_dims = vectorize(dOut->dims()); const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); + const int ndim = dout_dims.size(); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - if (dx) dx->mutable_data(ctx.GetPlace()); - if (dy) dy->mutable_data(ctx.GetPlace()); + + // if (dx) dx->mutable_data(ctx.GetPlace()); + // if (dy) dy->mutable_data(ctx.GetPlace()); // x's or y's dim = 1 if (x_ndim == 1 && y_ndim == 1) { + if (dx) dx->mutable_data(ctx.GetPlace()); + if (dy) dy->mutable_data(ctx.GetPlace()); if (dOut->numel() == 1) { DotGradFunction(X, Y, dOut, dx, dy, ctx); return; @@ -335,6 +373,7 @@ class MatMulV2GradKernel : public framework::OpKernel { assert(x_ndim >= 2 && y_ndim >= 2); // the normal case Tensor dx_help, dy_help; + if (trans_x) { if (trans_y) { // X'Y' @@ -371,39 +410,73 @@ class MatMulV2GradKernel : public framework::OpKernel { } } - // cal the broadcast message // get help dims + assert(dx_help_ndim == dy_help_ndim); const std::vector dx_help_dims = vectorize(dx_help.dims()); const std::vector dy_help_dims = vectorize(dy_help.dims()); - const int dx_help_ndim = x_help_dims.size(); - const int dy_help_ndim = y_help_dims.size(); - assert(dx_help_ndim == dy_help_ndim); + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill(dx_broadcast_dims.data(), + dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill(dy_broadcast_dims.data(), + dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + // dx_dims [1, 1, 3] + // dx_help_dims [4, 2 ,1, 3] + // dx_broadcast_dims [1, 1, 1, 3] + // dx_reduce_dims [0, 1] // get reduce dim - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (std::int64_t idx = dx_help_ndim - 1 - 2; idx >= 0; idx--) { - if (x_help_dims) } - // ComputeReudceSumBinaryOpDims(); - - const int ndim = std::max(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - ComputeBroadcastBinaryOpDims(x_ndim - 2, x_dims.data(), y_ndim - 2, - y_dims.data(), x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - - // reduce sum to get grad by ReduceKernelFunctor - // if (dx) { + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + + // T* vlog_x = dx_help.data(); + // T* vlog_y = dy_help.data(); + // VLOG(0) << "x data:"; + // for (int i =0; i < dx_help.numel(); ++i){ + // VLOG(0) << vlog_x[i]; + // } + // VLOG(0) << "y data:"; + + // for (int i =0; i < dy_help.numel(); ++i){ + // VLOG(0) << vlog_y[i]; + // } + + // VLOG(0) << "dx_reduce_dims"; + // for(auto ele : dx_reduce_dims){ + // VLOG(0) << ele; // } - // ReduceKernelFunctor( - // dx, output, dims, keep_dim, reduce_all, context) - // y->mutable_data(ctx.GetPlace()); - // auto y_e = framework::EigenVector::Flatten(*y); - // auto y_arr = y_e.reshape(shape); + // VLOG(0) << "dy_reduce_dims"; + // for(auto ele : dy_reduce_dims){ + // VLOG(0) << ele; + // } + + // reduce sum to get grad by ReduceSum + if (dx) { + ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, + ctx); + dx->Resize(X->dims()); + } + if (dy) { + ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, + ctx); + dy->Resize(Y->dims()); + } } }; From 795a68cf59445e66a2c55e80f65b8b7f16ff1241 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 17:40:18 +0800 Subject: [PATCH 07/18] fix api --- paddle/fluid/operators/matmul_v2_op.cc | 7 +- paddle/fluid/operators/matmul_v2_op.h | 126 +++++++++++++++++-------- 2 files changed, 91 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index aeef6108b5ab6..c0673d4f23a31 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -105,8 +105,11 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { "Set true to transpose the last two dimensions of Y before " "doing multiplication") .SetDefault(false); - AddComment(R"DOC( - + AddComment( + R"DOC(Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K), +B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges +from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being +two dimensional, it behaves like normal matrix multiplication. )DOC"); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index f4113e040b18f..6ca3815531a16 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -107,12 +107,14 @@ static void IncreaseIndexInDims(const int ndim, const int64_t* dims, } template -void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, - bool trans_y, +void MatMulFunction(const Tensor* X, const Tensor* Y, + const std::vector& x_dims, + const std::vector& y_dims, Tensor* Out, + bool trans_x, bool trans_y, const paddle::framework::ExecutionContext& ctx) { // get dims - const std::vector x_dims = vectorize(X->dims()); - const std::vector y_dims = vectorize(Y->dims()); + // const std::vector x_dims = vectorize(X->dims()); + // const std::vector y_dims = vectorize(Y->dims()); const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); @@ -322,6 +324,16 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, } } +template +void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, + bool trans_y, + const paddle::framework::ExecutionContext& ctx) { + const std::vector x_dims = vectorize(X->dims()); + const std::vector y_dims = vectorize(Y->dims()); + MatMulFunction(X, Y, x_dims, y_dims, Out, trans_x, trans_y, + ctx); +} + template class MatMulV2Kernel : public framework::OpKernel { public: @@ -346,12 +358,13 @@ class MatMulV2GradKernel : public framework::OpKernel { bool trans_y = ctx.Attr("trans_y"); // get dims - const std::vector x_dims = vectorize(X->dims()); - const std::vector y_dims = vectorize(Y->dims()); - const std::vector dout_dims = vectorize(dOut->dims()); - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const int ndim = dout_dims.size(); + std::vector x_dims = vectorize(X->dims()); + std::vector y_dims = vectorize(Y->dims()); + std::vector dout_dims = vectorize(dOut->dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); @@ -369,44 +382,77 @@ class MatMulV2GradKernel : public framework::OpKernel { } } - // FIXME + // It is very tricky. + if (x_ndim == 1) { + x_dims.insert(x_dims.begin() + 0, 1); + x_ndim += 1; + if (trans_x) + dout_dims.push_back(1); + else + dout_dims.insert(dout_dims.begin() + ndim - 1, 1); + ndim += 1; + } + + VLOG(0) << "x_dims"; + for (auto ele : x_dims) VLOG(0) << ele; + + if (y_ndim == 1) { + y_dims.push_back(1); + y_ndim += 1; + if (trans_y) + dout_dims.insert(dout_dims.begin() + ndim - 1, 1); + else + dout_dims.push_back(1); + ndim += 1; + } + VLOG(0) << "y_dims"; + for (auto ele : y_dims) VLOG(0) << ele; + + VLOG(0) << "dout_dims"; + for (auto ele : dout_dims) VLOG(0) << ele; + assert(x_ndim >= 2 && y_ndim >= 2); // the normal case Tensor dx_help, dy_help; - if (trans_x) { if (trans_y) { // X'Y' // dA = Y'G', dB = G'X' if (dx) - MatMulFunction(Y, dOut, &dx_help, true, true, ctx); + MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, + true, true, ctx); if (dy) - MatMulFunction(dOut, X, &dy_help, true, true, ctx); + MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, + true, true, ctx); } else { // X'Y: // dX = YG', dY = XG if (dx) - MatMulFunction(Y, dOut, &dx_help, false, true, ctx); + MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, + false, true, ctx); if (dy) - MatMulFunction(X, dOut, &dy_help, false, false, - ctx); + MatMulFunction(X, dOut, x_dims, dout_dims, &dy_help, + false, false, ctx); } } else { if (trans_y) { // XY': // dX = GY, dY = G'X if (dx) - MatMulFunction(dOut, Y, &dx_help, false, false, - ctx); + MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, + false, false, ctx); if (dy) - MatMulFunction(dOut, X, &dy_help, true, false, ctx); + MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, + true, false, ctx); } else { // XY: // dX = GY', dY = X'G if (dx) - MatMulFunction(dOut, Y, &dx_help, false, true, ctx); + MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, + false, true, ctx); if (dy) - MatMulFunction(X, dOut, &dy_help, true, false, ctx); + MatMulFunction(X, dOut, x_dims, dout_dims, &dy_help, + true, false, ctx); } } @@ -444,35 +490,35 @@ class MatMulV2GradKernel : public framework::OpKernel { } } - // T* vlog_x = dx_help.data(); - // T* vlog_y = dy_help.data(); - // VLOG(0) << "x data:"; - // for (int i =0; i < dx_help.numel(); ++i){ - // VLOG(0) << vlog_x[i]; - // } - // VLOG(0) << "y data:"; + VLOG(0) << "dx_reduce_dims"; + for (auto ele : dx_reduce_dims) { + VLOG(0) << ele; + } - // for (int i =0; i < dy_help.numel(); ++i){ - // VLOG(0) << vlog_y[i]; - // } + VLOG(0) << "dy_reduce_dims"; + for (auto ele : dy_reduce_dims) { + VLOG(0) << ele; + } - // VLOG(0) << "dx_reduce_dims"; - // for(auto ele : dx_reduce_dims){ - // VLOG(0) << ele; - // } + VLOG(0) << "dx_help_dims"; + for (auto ele : dx_help_dims) { + VLOG(0) << ele; + } - // VLOG(0) << "dy_reduce_dims"; - // for(auto ele : dy_reduce_dims){ - // VLOG(0) << ele; - // } + VLOG(0) << "dy_help_dims"; + for (auto ele : dy_help_dims) { + VLOG(0) << ele; + } // reduce sum to get grad by ReduceSum if (dx) { + dx->Resize(dx_help.dims()); ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, ctx); dx->Resize(X->dims()); } if (dy) { + dy->Resize(dy_help.dims()); ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, ctx); dy->Resize(Y->dims()); From e887e14625932bbe444e0a6334f1d059ff2252fb Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 20:39:01 +0800 Subject: [PATCH 08/18] add api --- paddle/fluid/operators/matmul_v2_op.cc | 14 +- paddle/fluid/operators/matmul_v2_op.h | 121 ++++--------- .../tests/unittests/test_matmul_v2_op.py | 164 ++++++++---------- python/paddle/tensor/linalg.py | 136 +++++++++++++++ 4 files changed, 258 insertions(+), 177 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index c0673d4f23a31..c79fffc211004 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -94,9 +94,9 @@ class MatMulV2Op : public framework::OperatorWithKernel { class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "tensor of shape (dim0, dim1 ... M, K)"); - AddInput("Y", "tensor of shape (dim0, dim1 ... K, N)"); - AddOutput("Out", "tensor of shape (dim0, dim1 ... M, N)"); + AddInput("X", "tensor of shape (d0, d1 ... M, K)"); + AddInput("Y", "tensor of shape (d0, d1 ... K, N)"); + AddOutput("Out", "tensor of shape (d0, d1 ... M, N)"); AddAttr("trans_x", "Set true to transpose the last two dimensions of X before " "doing multiplication") @@ -106,10 +106,10 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { "doing multiplication") .SetDefault(false); AddComment( - R"DOC(Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K), -B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges -from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being -two dimensional, it behaves like normal matrix multiplication. + R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), + B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). + In addition, it is also satisfied the broadcast rule which is similar as + numpy.matmul. )DOC"); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 6ca3815531a16..b13ed4065cc70 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -22,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" -// #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #ifdef __NVCC__ @@ -46,6 +46,7 @@ void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { if (reduce_dims.empty()) { + // FIXME maybe reduce this copy operation framework::TensorCopySync(*input, ctx.GetPlace(), output); return; } @@ -61,24 +62,26 @@ void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, #endif } -static void ComputeBroadcastBinaryOpDims( - const int A_ndim, const std::int64_t* A_dims, const int B_ndim, - const std::int64_t* B_dims, std::int64_t* A_broadcast_dims, - std::int64_t* B_broadcast_dims, std::int64_t* C_broadcast_dims) { - const int ndim = std::max(A_ndim, B_ndim); - std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); - std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); - std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); - std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); +static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims, + const int y_ndim, const std::int64_t* y_dims, + std::int64_t* x_bd_dims, + std::int64_t* y_bd_dims, + std::int64_t* out_bd_dims) { + const int ndim = std::max(x_ndim, y_ndim); + std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); + std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); + std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); + std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); + for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ(A_broadcast_dims[i] == B_broadcast_dims[i] || - A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1, - true, platform::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim.")); - if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { - C_broadcast_dims[i] = 0; + PADDLE_ENFORCE_EQ( + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, + true, platform::errors::InvalidArgument( + "Input(X) and Input(Y) has error dim.")); + if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { + out_bd_dims[i] = 0; } else { - C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); + out_bd_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]); } } } @@ -94,8 +97,8 @@ static int64_t GetIndexMessage(const int n, const int64_t* dims, return sum; } -static void IncreaseIndexInDims(const int ndim, const int64_t* dims, - int64_t* index) { +static void IndexIncreaseFromDims(const int ndim, const int64_t* dims, + int64_t* index) { for (int i = ndim - 1; i >= 0; --i) { ++index[i]; if (index[i] >= dims[i]) { @@ -112,9 +115,6 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, const std::vector& y_dims, Tensor* Out, bool trans_x, bool trans_y, const paddle::framework::ExecutionContext& ctx) { - // get dims - // const std::vector x_dims = vectorize(X->dims()); - // const std::vector y_dims = vectorize(Y->dims()); const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); @@ -126,7 +126,7 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, PADDLE_ENFORCE_EQ(X->numel(), Y->numel(), platform::errors::InvalidArgument( "X's numbers is not equal to Y's numbers," - "when X/Y dims =1")); + "when X/Y's dims =1")); VLOG(3) << "MatMul's case 1"; Out->Resize({1}); Out->mutable_data(ctx.GetPlace()); @@ -236,10 +236,9 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, std::vector y_broadcast_dims(ndim); std::vector out_broadcast_dims(ndim); - ComputeBroadcastBinaryOpDims(x_ndim - 2, x_dims.data(), y_ndim - 2, - y_dims.data(), x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); + GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(), + x_broadcast_dims.data(), y_broadcast_dims.data(), + out_broadcast_dims.data()); out_broadcast_dims[ndim - 2] = M; out_broadcast_dims[ndim - 1] = N; @@ -262,10 +261,7 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, const std::int64_t out_batch_size = std::accumulate( out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL, std::multiplies()); - if (out_batch_size == 0) { - return; - } - + if (out_batch_size == 0) return; if (x_batch_size == 1 && y_batch_size == 1) { VLOG(3) << "MatMul's case 8"; blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, @@ -301,11 +297,13 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, M * K, K * N); } else { + // in the case, can't use stridedgemm std::vector x_ptr(out_batch_size); std::vector y_ptr(out_batch_size); std::vector out_ptr(out_batch_size); std::vector index(batch_dim); for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset const std::int64_t x_index = GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); const std::int64_t y_index = @@ -314,7 +312,7 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, x_ptr[i] = x_data + x_index * M * K; y_ptr[i] = y_data + y_index * K * N; out_ptr[i] = Out->data() + i * M * N; - IncreaseIndexInDims(batch_dim, out_broadcast_dims.data(), index.data()); + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); } VLOG(3) << "MatMul's case 14"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, @@ -369,9 +367,6 @@ class MatMulV2GradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // if (dx) dx->mutable_data(ctx.GetPlace()); - // if (dy) dy->mutable_data(ctx.GetPlace()); - // x's or y's dim = 1 if (x_ndim == 1 && y_ndim == 1) { if (dx) dx->mutable_data(ctx.GetPlace()); @@ -381,8 +376,8 @@ class MatMulV2GradKernel : public framework::OpKernel { return; } } - - // It is very tricky. + // It is very tricky. For this broadcast, currently using the reduce sum to + // get gradient. if (x_ndim == 1) { x_dims.insert(x_dims.begin() + 0, 1); x_ndim += 1; @@ -393,9 +388,6 @@ class MatMulV2GradKernel : public framework::OpKernel { ndim += 1; } - VLOG(0) << "x_dims"; - for (auto ele : x_dims) VLOG(0) << ele; - if (y_ndim == 1) { y_dims.push_back(1); y_ndim += 1; @@ -405,19 +397,14 @@ class MatMulV2GradKernel : public framework::OpKernel { dout_dims.push_back(1); ndim += 1; } - VLOG(0) << "y_dims"; - for (auto ele : y_dims) VLOG(0) << ele; - VLOG(0) << "dout_dims"; - for (auto ele : dout_dims) VLOG(0) << ele; - - assert(x_ndim >= 2 && y_ndim >= 2); + ASSERT_GE(x_ndim, 2); + ASSERT_GE(y_ndim, 2); // the normal case Tensor dx_help, dy_help; if (trans_x) { if (trans_y) { - // X'Y' - // dA = Y'G', dB = G'X' + // X'Y': dA = Y'G', dB = G'X' if (dx) MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, true, true, ctx); @@ -425,8 +412,7 @@ class MatMulV2GradKernel : public framework::OpKernel { MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, true, true, ctx); } else { - // X'Y: - // dX = YG', dY = XG + // X'Y: dX = YG', dY = XG if (dx) MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, false, true, ctx); @@ -436,8 +422,7 @@ class MatMulV2GradKernel : public framework::OpKernel { } } else { if (trans_y) { - // XY': - // dX = GY, dY = G'X + // XY': dX = GY, dY = G'X if (dx) MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, false, false, ctx); @@ -445,8 +430,7 @@ class MatMulV2GradKernel : public framework::OpKernel { MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, true, false, ctx); } else { - // XY: - // dX = GY', dY = X'G + // XY: dX = GY', dY = X'G if (dx) MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, false, true, ctx); @@ -455,9 +439,7 @@ class MatMulV2GradKernel : public framework::OpKernel { true, false, ctx); } } - // get help dims - assert(dx_help_ndim == dy_help_ndim); const std::vector dx_help_dims = vectorize(dx_help.dims()); const std::vector dy_help_dims = vectorize(dy_help.dims()); @@ -473,12 +455,6 @@ class MatMulV2GradKernel : public framework::OpKernel { std::copy(y_dims.data(), y_dims.data() + y_ndim, dy_broadcast_dims.data() + ndim - y_ndim); - // dx_dims [1, 1, 3] - // dx_help_dims [4, 2 ,1, 3] - // dx_broadcast_dims [1, 1, 1, 3] - // dx_reduce_dims [0, 1] - // get reduce dim - std::vector dx_reduce_dims; std::vector dy_reduce_dims; for (int idx = 0; idx <= ndim - 3; idx++) { @@ -489,27 +465,6 @@ class MatMulV2GradKernel : public framework::OpKernel { dy_reduce_dims.push_back(idx); } } - - VLOG(0) << "dx_reduce_dims"; - for (auto ele : dx_reduce_dims) { - VLOG(0) << ele; - } - - VLOG(0) << "dy_reduce_dims"; - for (auto ele : dy_reduce_dims) { - VLOG(0) << ele; - } - - VLOG(0) << "dx_help_dims"; - for (auto ele : dx_help_dims) { - VLOG(0) << ele; - } - - VLOG(0) << "dy_help_dims"; - for (auto ele : dy_help_dims) { - VLOG(0) << ele; - } - // reduce sum to get grad by ReduceSum if (dx) { dx->Resize(dx_help.dims()); diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 409756eba96c7..5ae675d8dcdb0 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -145,7 +145,7 @@ class TestMatMuklOp6(TestMatMulV2Op): """ def config(self): - self.x_shape = (1, 2, 100, 2) + self.x_shape = (1, 2, 100, 1) self.y_shape = (100, ) self.trans_x = True self.trans_y = False @@ -158,7 +158,7 @@ class TestMatMuklOp7(TestMatMulV2Op): """ def config(self): - self.x_shape = (1, 2, 2, 100) + self.x_shape = (1, 2, 1, 100) self.y_shape = (100, ) self.trans_x = False self.trans_y = False @@ -269,90 +269,80 @@ def config(self): self.dtype = "float64" -# class TestMatMuklOp2(TestMatMulV2Op): -# """ -# """ -# def config(self): -# self.x_shape = (10,) -# self.y_shape = (1, 10, 5) -# self.trans_x = False -# self.trans_y = False -# self.dtype = "float64" - -# class TestMatMuklOp3(TestMatMulV2Op): -# """ -# """ -# def config(self): -# self.x_shape = (10,) -# self.y_shape = (10, 10, 5) -# self.trans_x = False -# self.trans_y = False -# self.dtype = "float64" - -# class Generator(object): -# def setUp(self): -# self.op_type = "matmul_v2" -# X = np.random.random(self.shape_X).astype("float64") -# Y = np.random.random(self.shape_Y).astype("float64") -# Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y) -# #print(X.shape,Y.shape,Out.shape,self.transpose_X,self.transpose_X) -# #print(Out) -# self.inputs = {'X': X, 'Y': Y} -# self.attrs = { -# 'trans_x': self.transpose_X, -# 'trans_y': self.transpose_Y -# } -# self.outputs = {'Out': Out} - -# def test_check_output(self): -# self.check_output() - -# def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y, batchsize): -# global shape_x, shape_y -# if dim_X == 1 and dim_Y == 1: -# return [100], [100] - -# if dim_X == 1: -# shape_x = [100] -# if transpose_Y: -# shape_y = [2, 100] -# else: -# if batchsize == -1: -# shape_y = [100, 2] -# else: -# shape_y = [batchsize, 100, 2] -# return shape_x, shape_y - -# if dim_Y == 1: -# shape_y = [100] -# if transpose_X: -# shape_x = [100, 2] -# else: -# if batchsize == -1: -# shape_x = [2, 100] -# else: -# shape_x = [batchsize, 2, 100] -# return shape_x, shape_y - -# # Generate operators cases for all possibilities -# def inject_test(dim_x, dim_y, trans_x, trans_y, batchsize): -# test_name = ('TestMatMulV2Op_dimX_{}_dim_Y_{}_transX_{}_transY_{}_Batchsize{}'.format( -# dim_x, dim_y, trans_x, trans_y, batchsize)) -# shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, -# trans_y, batchsize) -# print(shape_x, shape_y, trans_x, trans_y) -# globals()[test_name] = type(test_name, (Generator, OpTest), { -# 'shape_X': shape_x, -# 'shape_Y': shape_y, -# 'transpose_X': trans_x, -# 'transpose_Y': trans_y, -# }) - -# for dim_X in [1]: -# for dim_Y in [1, -1]: -# for batchsize in [-1, 1, 2]: -# for transose_x in [False, True]: -# for transose_y in [False, True]: -# inject_test(dim_X, dim_Y, transose_x, transose_y, batchsize) +class TestMatMuklOp16(TestMatMulV2Op): + """ + case 16 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (100) + self.y_shape = (1, 2, 2, 100, 1) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp17(TestMatMulV2Op): + """ + case 17 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 1, 100) + self.y_shape = (100) + self.trans_x = False + self.trans_y = False + self.dtype = "float64" + + +class TestMatMuklOp18(TestMatMulV2Op): + """ + case 17 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 100, 1) + self.y_shape = (100) + self.trans_x = False + self.trans_y = True + self.dtype = "float64" + + +class TestMatMulV2API(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float64") + input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float64") + + result = paddle.matmul_v2(input_x, input_y) + + x_np = np.random([4, 3]).astype("float64") + y_np = np.random([3, 4]).astype("float64") + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": x_np, + "input_y": y_np}, + fetch_list=[result]) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_x = np.random([4, 3]).astype("float64") + input_y = np.random([3, 4]).astype("float64") + x = paddle.to_tensor(input_x) + y = paddle.to_tensor(input_y) + result = paddle.matmul_v2(x, y) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 972c9fbce4d2a..adfe2c3455c13 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -21,6 +21,7 @@ __all__ = [ 'matmul', + 'matmul_v2', 'dot', # 'einsum', 'norm', @@ -171,6 +172,141 @@ def __check_input(x, y): return out +def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): + """ + Applies matrix multiplication to two tensors. It is different from the + `matmul` in the broadcast rule. `matmul_v2` follows the complete broadcast + rules, and its behavior is consistent with `np.matmul`. + + Currently, the input tensors' rank can be any, `matmul_v2` can be used to + achieve the `dot`, `matmul` and `batchmatmul`. + + The actual behavior depends on the shapes of :math:`x`, :math:`y` and the + flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: + + - If a transpose flag is specified, the last two dimensions of the tensor + are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for + :math:`x` it is treated as :math:`[1, D]` in nontransposed form and as + :math:`[D, 1]` in transposed form, whereas for :math:`y` it is the + opposite: It is treated as :math:`[D, 1]` in nontransposed form and as + :math:`[1, D]` in transposed form. + + The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: + + - If both tensors are 1-dimensional, the dot product result is obtained. + - If both tensors are 2-dimensional, the matrix-matrix product is obtained. + - If the `x` is 1-dimensional and the `y` is 2-dimensional, + a `1` is prepended to its dimension in order to conduct the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + - If the `x` is 2-dimensional and `y` is 1-dimensional, + the matrix-vector product is obtained. + - If both arguments are at least 1-dimensional and at least one argument + is N-dimensional (where N > 2), then a batched matrix multiply is obtained. + If the first argument is 1-dimensional, a 1 is prepended to its dimension + in order to conduct the batched matrix multiply and removed after. + If the second argument is 1-dimensional, a 1 is appended to its + dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (exclude the last two dimensions) dimensions are + broadcasted according the broadcast rule. + For example, if input is a :math:`j \times 1 \times n \times m` (j×1×n×m) + tensor and the other is a :math:`1 \times m \times p` (k×m×p) tensor, + out will be an :math:`j \times k \times n \times p` (j×k×n×p) tensor. + + Args: + x (Tensor): The input variable which is a Tensor or LoDTensor. + y (Tensor): The input variable which is a Tensor or LoDTensor. + transpose_x (bool): Whether to transpose :math:`x` before multiplication. + transpose_y (bool): Whether to transpose :math:`y` before multiplication. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Tensor: The output Tensor. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + # vector x vector + x_data = np.random.uniform([10]).astype(np.float32) + y_data = np.random.uniform([10]).astype(np.float32) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + z = paddle.matmul_v2(x, y) + print(z.numpy().shape) + # [1] + + # matrix x vector + x_data = np.random.uniform([10, 5]).astype(np.float32) + y_data = np.random.uniform([5]).astype(np.float32) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + z = paddle.matmul_v2(x, y) + print(z.numpy().shape) + # [10] + + # batched matrix x broadcasted vector + x_data = np.random.uniform([10, 5, 2]).astype(np.float32) + y_data = np.random.uniform([2]).astype(np.float32) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + z = paddle.matmul_v2(x, y) + print(z.numpy().shape) + # [10, 5] + + # batched matrix x batched matrix + x_data = np.random.uniform([10, 5, 2]).astype(np.float32) + y_data = np.random.uniform([10, 2, 5]).astype(np.float32) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + z = paddle.matmul_v2(x, y) + print(z.numpy().shape) + # [10, 5, 5] + + # batched matrix x broadcasted matrix + x_data = np.random.uniform([10, 1, 5, 2]).astype(np.float32) + y_data = np.random.uniform([1, 3, 2, 5]).astype(np.float32) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + z = paddle.matmul_v2(x, y) + print(z.numpy().shape) + # [10, 3, 5, 5] + + """ + if in_dygraph_mode(): + out = _varbase_creator(dtype=x.dtype) + core.ops.matmul(x, y, out, 'trans_x', transpose_x, 'trans_y', + transpose_y) + return out + + attrs = { + 'trans_x': transpose_x, + 'trans_y': transpose_y, + } + + def __check_input(x, y): + var_names = {'x': x, 'y': y} + for name, val in var_names.items(): + check_variable_and_dtype(val, name, ['float32', 'float64'], + 'matmul_v2') + + __check_input(x, y) + + helper = LayerHelper('matmul_v2', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='matmul_v2', + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs=attrs) + return out + + def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): """ :alias_main: paddle.norm From 3603425ad4cbe2c2d847badb68587fe90547fe83 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 21:36:01 +0800 Subject: [PATCH 09/18] fix api --- paddle/fluid/operators/matmul_v2_op.cc | 2 +- .../tests/unittests/test_matmul_v2_op.py | 30 +-- python/paddle/tensor/linalg.py | 177 ++---------------- 3 files changed, 27 insertions(+), 182 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index c79fffc211004..21e8ca1c43d22 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -108,7 +108,7 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { AddComment( R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). - In addition, it is also satisfied the broadcast rule which is similar as + In addition, it also satisfied the broadcast rule which is similar as numpy.matmul. )DOC"); } diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 5ae675d8dcdb0..884139a23d51c 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core import paddle import paddle.fluid as fluid @@ -295,19 +296,6 @@ def config(self): self.dtype = "float64" -class TestMatMuklOp18(TestMatMulV2Op): - """ - case 17 : to check the gradient for special case - """ - - def config(self): - self.x_shape = (2, 100, 1) - self.y_shape = (100) - self.trans_x = False - self.trans_y = True - self.dtype = "float64" - - class TestMatMulV2API(unittest.TestCase): def setUp(self): self.places = [fluid.CPUPlace()] @@ -316,13 +304,13 @@ def setUp(self): def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): - input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float64") - input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float64") + input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32") + input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32") - result = paddle.matmul_v2(input_x, input_y) + result = paddle.matmul(input_x, input_y) - x_np = np.random([4, 3]).astype("float64") - y_np = np.random([3, 4]).astype("float64") + x_np = np.random.random([4, 3]).astype("float32") + y_np = np.random.random([3, 4]).astype("float32") exe = fluid.Executor(place) fetches = exe.run(fluid.default_main_program(), @@ -337,11 +325,11 @@ def test_static(self): def test_dygraph(self): for place in self.places: with fluid.dygraph.guard(place): - input_x = np.random([4, 3]).astype("float64") - input_y = np.random([3, 4]).astype("float64") + input_x = np.random.random([4, 3]).astype("float64") + input_y = np.random.random([3, 4]).astype("float64") x = paddle.to_tensor(input_x) y = paddle.to_tensor(input_y) - result = paddle.matmul_v2(x, y) + result = paddle.matmul(x, y) if __name__ == "__main__": diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index adfe2c3455c13..3c6fd4e54e67a 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -21,7 +21,6 @@ __all__ = [ 'matmul', - 'matmul_v2', 'dot', # 'einsum', 'norm', @@ -36,160 +35,20 @@ ] -def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): +def matmul(x, y, transpose_x=False, transpose_y=False, name=None): """ - :alias_main: paddle.matmul - :alias: paddle.matmul,paddle.tensor.matmul,paddle.tensor.linalg.matmul + Applies matrix multiplication to two tensors. `matmul` follows + the complete broadcast rules, + and its behavior is consistent with `np.matmul`. - Applies matrix multiplication to two tensors. - - Currently, the input tensors' rank can be any, but when the rank of any - inputs is bigger than 3, this two inputs' rank should be equal. - - The actual behavior depends on the shapes of :math:`x`, :math:`y` and the - flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: - - - If a transpose flag is specified, the last two dimensions of the tensor - are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for - :math:`x` it is treated as :math:`[1, D]` in nontransposed form and as - :math:`[D, 1]` in transposed form, whereas for :math:`y` it is the - opposite: It is treated as :math:`[D, 1]` in nontransposed form and as - :math:`[1, D]` in transposed form. - - - After transpose, the two tensors are 2-D or n-D and matrix multiplication - performs in the following way. - - - If both are 2-D, they are multiplied like conventional matrices. - - If either is n-D, it is treated as a stack of matrices residing in the - last two dimensions and a batched matrix multiply supporting broadcast - applies on the two tensors. - - Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and - nontransposed, the prepended or appended dimension :math:`1` will be - removed after matrix multiplication. - - Args: - x (Variable): The input variable which is a Tensor or LoDTensor. - y (Variable): The input variable which is a Tensor or LoDTensor. - transpose_x (bool): Whether to transpose :math:`x` before multiplication. - transpose_y (bool): Whether to transpose :math:`y` before multiplication. - alpha (float): The scale of output. Default 1.0. - name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. - - Returns: - Variable: The product Tensor (or LoDTensor) variable. - - Examples: - .. code-block:: python - - # Examples to clarify shapes of the inputs and output - # x: [B, ..., M, K], y: [B, ..., K, N] - # paddle.matmul(x, y) # out: [B, ..., M, N] - - # x: [B, M, K], y: [B, K, N] - # paddle.matmul(x, y) # out: [B, M, N] - - # x: [B, M, K], y: [K, N] - # paddle.matmul(x, y) # out: [B, M, N] - - # x: [M, K], y: [K, N] - # paddle.matmul(x, y) # out: [M, N] - - # x: [B, M, K], y: [K] - # paddle.matmul(x, y) # out: [B, M] - - # x: [K], y: [K] - # paddle.matmul(x, y) # out: [1] - - # x: [M], y: [N] - # paddle.matmul(x, y, True, True) # out: [M, N] - - import paddle - import paddle.fluid as fluid - x = fluid.data(name='x', shape=[2, 3], dtype='float32') - y = fluid.data(name='y', shape=[3, 2], dtype='float32') - out = paddle.matmul(x, y, True, True) - """ - attrs = { - 'transpose_X': transpose_x, - 'transpose_Y': transpose_y, - 'alpha': float(alpha), - } - - if in_dygraph_mode(): - out = _varbase_creator(dtype=x.dtype) - core.ops.matmul(x, y, out, 'transpose_X', transpose_x, 'transpose_Y', - transpose_y, 'alpha', float(alpha)) - return out - - def __check_input(x, y): - var_names = {'x': x, 'y': y} - for name, val in var_names.items(): - check_variable_and_dtype( - val, name, ['float16', 'float32', 'float64'], 'matmul') - x_shape = list(x.shape) - y_shape = list(y.shape) - if len(x_shape) == 1: - x_shape = [1] + x_shape - if len(y_shape) == 1: - y_shape = y_shape + [1] - - # check the inner 2 dimensions - if transpose_x: - x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2] - if transpose_y: - y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2] - if x_shape[-1] != y_shape[-2]: - assert (x_shape[-1] == -1) or (y_shape[-2] == -1), \ - "After performing an optional transpose, Input X's width should be " \ - "equal to Y's width for multiplication " \ - "prerequisites. But received X's shape: %s, Y's shape: %s\n" % \ - (x_shape, y_shape) - - if len(y_shape) > 2 and len(x_shape) > 2: - for i, dim_x in enumerate(x_shape[:-2]): - # don't check neg shape - if dim_x < 0 or y_shape[i] < 0: - continue - if dim_x != y_shape[i]: - raise ValueError( - "When the matrix is larger than 2 dimensions, the higher " - "dimensional values of the two matrices need to be equal. " - "But received x_shape[%d] != y_shape[%d]. X's shape: %s, " - "Y's shape: %s.\n" % (i, i, x_shape, y_shape)) - - __check_input(x, y) - - helper = LayerHelper('matmul', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='matmul', - inputs={'X': x, - 'Y': y}, - outputs={'Out': out}, - attrs=attrs) - return out - - -def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): - """ - Applies matrix multiplication to two tensors. It is different from the - `matmul` in the broadcast rule. `matmul_v2` follows the complete broadcast - rules, and its behavior is consistent with `np.matmul`. - - Currently, the input tensors' rank can be any, `matmul_v2` can be used to + Currently, the input tensors' rank can be any, `matmul` can be used to achieve the `dot`, `matmul` and `batchmatmul`. The actual behavior depends on the shapes of :math:`x`, :math:`y` and the flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: - If a transpose flag is specified, the last two dimensions of the tensor - are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for - :math:`x` it is treated as :math:`[1, D]` in nontransposed form and as - :math:`[D, 1]` in transposed form, whereas for :math:`y` it is the - opposite: It is treated as :math:`[D, 1]` in nontransposed form and as - :math:`[1, D]` in transposed form. + are transposed. If the tensor is rank-1 of shape, the transpose is invalid. The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: @@ -208,9 +67,8 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (exclude the last two dimensions) dimensions are broadcasted according the broadcast rule. - For example, if input is a :math:`j \times 1 \times n \times m` (j×1×n×m) - tensor and the other is a :math:`1 \times m \times p` (k×m×p) tensor, - out will be an :math:`j \times k \times n \times p` (j×k×n×p) tensor. + For example, if input is a (j×1×n×m) tensor and the other is a (k×m×p) tensor, + out will be a (j×k×n×p) tensor. Args: x (Tensor): The input variable which is a Tensor or LoDTensor. @@ -236,7 +94,7 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): y_data = np.random.uniform([10]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) - z = paddle.matmul_v2(x, y) + z = paddle.matmul(x, y) print(z.numpy().shape) # [1] @@ -245,7 +103,7 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): y_data = np.random.uniform([5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) - z = paddle.matmul_v2(x, y) + z = paddle.matmul(x, y) print(z.numpy().shape) # [10] @@ -254,7 +112,7 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): y_data = np.random.uniform([2]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) - z = paddle.matmul_v2(x, y) + z = paddle.matmul(x, y) print(z.numpy().shape) # [10, 5] @@ -263,7 +121,7 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): y_data = np.random.uniform([10, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) - z = paddle.matmul_v2(x, y) + z = paddle.matmul(x, y) print(z.numpy().shape) # [10, 5, 5] @@ -272,16 +130,15 @@ def matmul_v2(x, y, transpose_x=False, transpose_y=False, name=None): y_data = np.random.uniform([1, 3, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) - z = paddle.matmul_v2(x, y) + z = paddle.matmul(x, y) print(z.numpy().shape) # [10, 3, 5, 5] """ + op_type = 'matmul_v2' if in_dygraph_mode(): - out = _varbase_creator(dtype=x.dtype) - core.ops.matmul(x, y, out, 'trans_x', transpose_x, 'trans_y', - transpose_y) - return out + op = getattr(core.ops, op_type) + return op(x, y, 'trans_x', transpose_x, 'trans_y', transpose_y) attrs = { 'trans_x': transpose_x, @@ -292,7 +149,7 @@ def __check_input(x, y): var_names = {'x': x, 'y': y} for name, val in var_names.items(): check_variable_and_dtype(val, name, ['float32', 'float64'], - 'matmul_v2') + 'matmul') __check_input(x, y) From 534d204ac2f3902e870f5c9303e7aa99e5a7a33d Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 21:44:08 +0800 Subject: [PATCH 10/18] fix the example --- python/paddle/tensor/linalg.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 3c6fd4e54e67a..69c15d5524344 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -90,8 +90,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): paddle.disable_static() # vector x vector - x_data = np.random.uniform([10]).astype(np.float32) - y_data = np.random.uniform([10]).astype(np.float32) + x_data = np.random.random([10]).astype(np.float32) + y_data = np.random.random([10]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) z = paddle.matmul(x, y) @@ -99,8 +99,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): # [1] # matrix x vector - x_data = np.random.uniform([10, 5]).astype(np.float32) - y_data = np.random.uniform([5]).astype(np.float32) + x_data = np.random.random([10, 5]).astype(np.float32) + y_data = np.random.random([5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) z = paddle.matmul(x, y) @@ -108,8 +108,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): # [10] # batched matrix x broadcasted vector - x_data = np.random.uniform([10, 5, 2]).astype(np.float32) - y_data = np.random.uniform([2]).astype(np.float32) + x_data = np.random.random([10, 5, 2]).astype(np.float32) + y_data = np.random.random([2]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) z = paddle.matmul(x, y) @@ -117,8 +117,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): # [10, 5] # batched matrix x batched matrix - x_data = np.random.uniform([10, 5, 2]).astype(np.float32) - y_data = np.random.uniform([10, 2, 5]).astype(np.float32) + x_data = np.random.random([10, 5, 2]).astype(np.float32) + y_data = np.random.random([10, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) z = paddle.matmul(x, y) @@ -126,8 +126,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): # [10, 5, 5] # batched matrix x broadcasted matrix - x_data = np.random.uniform([10, 1, 5, 2]).astype(np.float32) - y_data = np.random.uniform([1, 3, 2, 5]).astype(np.float32) + x_data = np.random.random([10, 1, 5, 2]).astype(np.float32) + y_data = np.random.random([1, 3, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) z = paddle.matmul(x, y) From 313df0fd51fe446d50762c5739ef0952f51d3c8f Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 21:48:49 +0800 Subject: [PATCH 11/18] fix the comment --- python/paddle/tensor/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 69c15d5524344..4c8b0826b7e88 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -71,8 +71,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): out will be a (j×k×n×p) tensor. Args: - x (Tensor): The input variable which is a Tensor or LoDTensor. - y (Tensor): The input variable which is a Tensor or LoDTensor. + x (Tensor): The input tensor which is a Tensor or LoDTensor. + y (Tensor): The input tensor which is a Tensor or LoDTensor. transpose_x (bool): Whether to transpose :math:`x` before multiplication. transpose_y (bool): Whether to transpose :math:`y` before multiplication. name(str|None): A name for this layer(optional). If set None, the layer From 475ae9c9256f1dbb424fcc4f28c4f45edd38701c Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 18 Aug 2020 23:14:13 +0800 Subject: [PATCH 12/18] fix api --- paddle/fluid/operators/matmul_v2_op.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index b13ed4065cc70..3d29844a7f123 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include -#include #include #include #include @@ -301,7 +300,7 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, std::vector x_ptr(out_batch_size); std::vector y_ptr(out_batch_size); std::vector out_ptr(out_batch_size); - std::vector index(batch_dim); + std::vector index(batch_dim, 0); for (std::int64_t i = 0; i < out_batch_size; ++i) { // using the index to get offset const std::int64_t x_index = @@ -398,8 +397,6 @@ class MatMulV2GradKernel : public framework::OpKernel { ndim += 1; } - ASSERT_GE(x_ndim, 2); - ASSERT_GE(y_ndim, 2); // the normal case Tensor dx_help, dy_help; if (trans_x) { From f98108599753c132838f31bb54c1f46669940a8f Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 19 Aug 2020 10:47:53 +0800 Subject: [PATCH 13/18] fix the api --- python/paddle/fluid/layers/nn.py | 60 +++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 446510121e72a..25a1414c89b04 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5095,7 +5095,65 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): y = fluid.layers.data(name='y', shape=[3, 2], dtype='float32') out = fluid.layers.matmul(x, y, True, True) """ - return paddle.matmul(x, y, transpose_x, transpose_y, alpha, name) + attrs = { + 'transpose_X': transpose_x, + 'transpose_Y': transpose_y, + 'alpha': float(alpha), + } + + if in_dygraph_mode(): + out = _varbase_creator(dtype=x.dtype) + core.ops.matmul(x, y, out, 'transpose_X', transpose_x, 'transpose_Y', + transpose_y, 'alpha', float(alpha)) + return out + + def __check_input(x, y): + var_names = {'x': x, 'y': y} + for name, val in var_names.items(): + check_variable_and_dtype( + val, name, ['float16', 'float32', 'float64'], 'matmul') + x_shape = list(x.shape) + y_shape = list(y.shape) + if len(x_shape) == 1: + x_shape = [1] + x_shape + if len(y_shape) == 1: + y_shape = y_shape + [1] + + # check the inner 2 dimensions + if transpose_x: + x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2] + if transpose_y: + y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2] + if x_shape[-1] != y_shape[-2]: + assert (x_shape[-1] == -1) or (y_shape[-2] == -1), \ + "After performing an optional transpose, Input X's width should be " \ + "equal to Y's width for multiplication " \ + "prerequisites. But received X's shape: %s, Y's shape: %s\n" % \ + (x_shape, y_shape) + + if len(y_shape) > 2 and len(x_shape) > 2: + for i, dim_x in enumerate(x_shape[:-2]): + # don't check neg shape + if dim_x < 0 or y_shape[i] < 0: + continue + if dim_x != y_shape[i]: + raise ValueError( + "When the matrix is larger than 2 dimensions, the higher " + "dimensional values of the two matrices need to be equal. " + "But received x_shape[%d] != y_shape[%d]. X's shape: %s, " + "Y's shape: %s.\n" % (i, i, x_shape, y_shape)) + + __check_input(x, y) + + helper = LayerHelper('matmul', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='matmul', + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs=attrs) + return out def topk(input, k, name=None): From 052936926551331ab5b96659ca6568b7e34c0b30 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 19 Aug 2020 11:40:36 +0800 Subject: [PATCH 14/18] fix utest --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 25a1414c89b04..82b0bbad55af4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -26,7 +26,7 @@ import paddle from ..layer_helper import LayerHelper from ..initializer import Normal, Constant, NumpyArrayInitializer -from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program +from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator from .. import dygraph_utils from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ From cd9b27ca9fdd92eb9a5417bd047cd0a2b8654126 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 19 Aug 2020 13:51:23 +0800 Subject: [PATCH 15/18] fix cblas --- paddle/fluid/operators/matmul_v2_op.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 3d29844a7f123..dc83e4d964815 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include #include #include From 440d22e99f5fab8eec5966691081506bf49751d9 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 19 Aug 2020 15:11:28 +0800 Subject: [PATCH 16/18] fix comment --- python/paddle/tensor/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4c8b0826b7e88..deb818cf3075e 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -67,8 +67,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (exclude the last two dimensions) dimensions are broadcasted according the broadcast rule. - For example, if input is a (j×1×n×m) tensor and the other is a (k×m×p) tensor, - out will be a (j×k×n×p) tensor. + For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor, + out will be a (j, k, n, p) tensor. Args: x (Tensor): The input tensor which is a Tensor or LoDTensor. From 6d4030d8770584ffa8fe07b65e0fe7e0372a5538 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 19 Aug 2020 20:49:49 +0800 Subject: [PATCH 17/18] fix comment --- paddle/fluid/operators/matmul_v2_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 21e8ca1c43d22..0254ad0a563d9 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -108,7 +108,7 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { AddComment( R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). - In addition, it also satisfied the broadcast rule which is similar as + In addition, it also follows the broadcast rule which is similar as numpy.matmul. )DOC"); } From cfe2061b0d64097f1a3255be9dbdb1d4cbc323b0 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Fri, 21 Aug 2020 00:35:28 +0800 Subject: [PATCH 18/18] fix comment --- python/paddle/fluid/layers/nn.py | 1 + python/paddle/tensor/linalg.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 82b0bbad55af4..5b523a5892933 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5024,6 +5024,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.matmul") def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): """ Applies matrix multiplication to two tensors. diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index deb818cf3075e..5b7288e8fd8df 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -41,24 +41,30 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): the complete broadcast rules, and its behavior is consistent with `np.matmul`. - Currently, the input tensors' rank can be any, `matmul` can be used to + Currently, the input tensors' number of dimensions can be any, `matmul` can be used to achieve the `dot`, `matmul` and `batchmatmul`. The actual behavior depends on the shapes of :math:`x`, :math:`y` and the flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: - If a transpose flag is specified, the last two dimensions of the tensor - are transposed. If the tensor is rank-1 of shape, the transpose is invalid. + are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor + is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas + for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`. The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: - If both tensors are 1-dimensional, the dot product result is obtained. + - If both tensors are 2-dimensional, the matrix-matrix product is obtained. + - If the `x` is 1-dimensional and the `y` is 2-dimensional, a `1` is prepended to its dimension in order to conduct the matrix multiply. After the matrix multiply, the prepended dimension is removed. + - If the `x` is 2-dimensional and `y` is 1-dimensional, the matrix-vector product is obtained. + - If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is obtained. If the first argument is 1-dimensional, a 1 is prepended to its dimension @@ -71,8 +77,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): out will be a (j, k, n, p) tensor. Args: - x (Tensor): The input tensor which is a Tensor or LoDTensor. - y (Tensor): The input tensor which is a Tensor or LoDTensor. + x (Tensor): The input tensor which is a Tensor. + y (Tensor): The input tensor which is a Tensor. transpose_x (bool): Whether to transpose :math:`x` before multiplication. transpose_y (bool): Whether to transpose :math:`y` before multiplication. name(str|None): A name for this layer(optional). If set None, the layer @@ -89,7 +95,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): import numpy as np paddle.disable_static() - # vector x vector + # vector * vector x_data = np.random.random([10]).astype(np.float32) y_data = np.random.random([10]).astype(np.float32) x = paddle.to_tensor(x_data) @@ -98,7 +104,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): print(z.numpy().shape) # [1] - # matrix x vector + # matrix * vector x_data = np.random.random([10, 5]).astype(np.float32) y_data = np.random.random([5]).astype(np.float32) x = paddle.to_tensor(x_data) @@ -107,7 +113,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): print(z.numpy().shape) # [10] - # batched matrix x broadcasted vector + # batched matrix * broadcasted vector x_data = np.random.random([10, 5, 2]).astype(np.float32) y_data = np.random.random([2]).astype(np.float32) x = paddle.to_tensor(x_data) @@ -116,7 +122,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): print(z.numpy().shape) # [10, 5] - # batched matrix x batched matrix + # batched matrix * batched matrix x_data = np.random.random([10, 5, 2]).astype(np.float32) y_data = np.random.random([10, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data) @@ -125,7 +131,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): print(z.numpy().shape) # [10, 5, 5] - # batched matrix x broadcasted matrix + # batched matrix * broadcasted matrix x_data = np.random.random([10, 1, 5, 2]).astype(np.float32) y_data = np.random.random([1, 3, 2, 5]).astype(np.float32) x = paddle.to_tensor(x_data)