Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float16 GEMM math function on GPU #8695

Merged
merged 17 commits into from
Mar 9, 2018
Merged
39 changes: 39 additions & 0 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,23 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
namespace math {

using float16 = paddle::platform::float16;

template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}

template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
Expand All @@ -46,6 +58,15 @@ void gemm<platform::CPUDeviceContext, double>(
beta, C, ldc);
}

template <>
void gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
PADDLE_THROW("float16 GEMM not supported on CPU");
}

template <>
void gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const bool transA,
Expand All @@ -68,6 +89,15 @@ void gemm<platform::CPUDeviceContext, double>(
lda, B, ldb, beta, C, ldc);
}

template <>
void matmul<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
PADDLE_THROW("float16 matmul not supported on CPU");
}

template <>
void matmul<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context,
Expand Down Expand Up @@ -126,6 +156,15 @@ void matmul<platform::CPUDeviceContext, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>());
}

template <>
void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) {
PADDLE_THROW("float16 batched_gemm not supported on CPU");
}

#ifdef PADDLE_WITH_MKLML
// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize.
template <>
Expand Down
108 changes: 108 additions & 0 deletions paddle/fluid/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,40 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
namespace math {

using float16 = paddle::platform::float16;

template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);

PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N));
}

template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
Expand Down Expand Up @@ -60,6 +89,28 @@ void gemm<platform::CUDADeviceContext, double>(
lda, &beta, C, N));
}

template <>
void gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const int lda, const float16* B,
const int ldb, const float16 beta, float16* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;

const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);

PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, ldc));
}

template <>
void gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const bool transA,
Expand Down Expand Up @@ -90,6 +141,35 @@ void gemm<platform::CUDADeviceContext, double>(
lda, &beta, C, ldc));
}

template <>
void matmul<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");

PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");

int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];

CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;

gemm<platform::CUDADeviceContext, float16>(
context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}

template <>
void matmul<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context,
Expand Down Expand Up @@ -148,6 +228,34 @@ void matmul<platform::CUDADeviceContext, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>());
}

template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;

const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);

PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
}

template <>
void batched_gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
Expand Down
Loading