Skip to content

Commit

Permalink
[Bfloat16]register bfloat16 datatype for squared l2 norm (#50908)
Browse files Browse the repository at this point in the history
* register bfloat16 datatype for squared l2 norm

* register bfloat16 datatype for softmax with upper triangular mask

* register bfloat16 for tril triu cuda kernel
  • Loading branch information
shaojiewang committed Feb 27, 2023
1 parent 5d322ce commit 3c12104
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 8 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst,
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}

__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst,
const plat::bfloat16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}

__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
}
Expand All @@ -75,6 +80,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}

__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}

__device__ __inline__ void load_zero_vector_upper_tri(float* dst) {
*(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
}
Expand Down Expand Up @@ -596,8 +605,11 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext,
plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>);
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(squared_l2_norm_grad,
phi::SquaredL2NormGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ PD_REGISTER_KERNEL(squared_l2_norm,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(triu_grad,
GPU,
Expand All @@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(tril_triu_grad,
GPU,
Expand All @@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril_triu_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/tril_triu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_triu,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(triu,
GPU,
Expand All @@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(tril,
GPU,
Expand All @@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

0 comments on commit 3c12104

Please sign in to comment.