Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Matmul op #26411

Merged
merged 19 commits into from
Aug 22, 2020
158 changes: 82 additions & 76 deletions paddle/fluid/operators/dot_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,86 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);

if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();

if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}

if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
}

template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -84,83 +164,9 @@ class DotGradKernel : public framework::OpKernel<T> {

if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);

if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}

if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();

if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}

if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));

auto step = dim[dim.size() - 1];

int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
}
};

Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;

template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T** A, const T** B, T beta, T** C,
int batchCount) const;

#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,17 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#endif // CUDA_VERSION >= 9010
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/operators/math/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
Expand Down Expand Up @@ -655,6 +656,26 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}

template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
#ifdef PADDLE_WITH_MKLML
const int lda = std::max((transA == CblasNoTrans) ? K : M, 1);
const int ldb = std::max((transB == CblasNoTrans) ? N : K, 1);
const int ldc = std::max(N, 1);
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A,
&lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */,
&batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
#endif
}

#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <>
template <typename T>
Expand Down
Loading