Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "【Hackathon No.52】为 Paddle dist 算子实现 float16 数据类型支持" #53527

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions paddle/phi/kernels/dist_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(dist_grad,
GPU,
ALL_LAYOUT,
phi::DistGradKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif
27 changes: 17 additions & 10 deletions paddle/phi/kernels/funcs/math_cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ limitations under the License. */

#include <algorithm>

#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/data_type.h"

namespace phi {
namespace funcs {

Expand Down Expand Up @@ -173,7 +170,11 @@ struct KeyValuePair<half> {
template <typename T>
__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val += phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask);
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
return val;
}

Expand Down Expand Up @@ -242,8 +243,11 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
template <typename T>
__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val = std::max(
val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask));
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}

Expand All @@ -261,8 +265,11 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) {
template <typename T>
__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val = std::min(
val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask));
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = min(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}

Expand Down Expand Up @@ -303,7 +310,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {

// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : std::numeric_limits<T>::min();
val = (lane < block_span) ? shared[lane] : -1e10f;
val = WarpReduceMax(val, mask);

return val;
Expand Down Expand Up @@ -351,7 +358,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) {

// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : std::numeric_limits<T>::max();
val = (lane < block_span) ? shared[lane] : 1e10f;
val = WarpReduceMin(val, mask);

return val;
Expand Down
84 changes: 33 additions & 51 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
Expand All @@ -27,56 +24,47 @@ namespace phi {

#define FULL_MASK 0xffffffff

template <typename Tx, typename Ty = Tx>
template <typename T>
struct ZeroOrderFunctor {
HOSTDEVICE explicit inline ZeroOrderFunctor() {}

HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(x != y);
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
}
};

template <typename Tx, typename Ty = Tx>
template <typename T>
struct OtherOrderFunctor {
HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& _p_order)
: p_order(_p_order) {}

HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(
pow(abs(static_cast<Ty>(x) - static_cast<Ty>(y)), p_order));
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
}

private:
Ty p_order;
T p_order_;
};

template <typename Tx, typename Ty = Tx>
template <typename T>
struct PowFunctor {
HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order)
: p_order(_p_order) {}

HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tx>(pow(static_cast<Ty>(x), p_order));
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
}

private:
Ty p_order;
T p_order_;
};

template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT sum_val(0.0);
T sum_val = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += static_cast<MT>(func(x[i], y[i]));
sum_val += func(x[i], y[i]);
}

__syncthreads();
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = static_cast<T>(sum_val);
out[blockIdx.x] = sum_val;
}
}

Expand All @@ -85,10 +73,10 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = std::numeric_limits<T>::min();
T max_val = -1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = std::max(max_val, abs(x[i] - y[i]));
max_val = max(max_val, abs(x[i] - y[i]));
}

__syncthreads();
Expand All @@ -103,10 +91,10 @@ __global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = std::numeric_limits<T>::max();
T min_val = 1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = std::min(min_val, abs(x[i] - y[i]));
min_val = min(min_val, abs(x[i] - y[i]));
}

__syncthreads();
Expand All @@ -122,7 +110,6 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
Expand All @@ -144,8 +131,9 @@ void DistKernel(const Context& dev_ctx,
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
Expand All @@ -162,19 +150,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else {
MT p_order = static_cast<MT>(p);
T p_order = static_cast<T>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
T p_order_ = static_cast<T>(1. / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
}

} else {
Expand All @@ -185,10 +173,4 @@ void DistKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(dist,
GPU,
ALL_LAYOUT,
phi::DistKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
40 changes: 0 additions & 40 deletions python/paddle/fluid/tests/unittests/test_dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,46 +158,6 @@ def init_case(self):
self.p = 1.5


class TestDistFP16Op(OpTest):
def init_data_type(self):
self.data_type = 'float16'


class TestDistFP16OpCase1(TestDistFP16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0


class TestDistFP16OpCase2(TestDistFP16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0


class TestDistFP16OpCase3(TestDistFP16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")


class TestDistFP16OpCase4(TestDistFP16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")


class TestDistFP16OpCase5(TestDistFP16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5


class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
Expand Down
12 changes: 4 additions & 8 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}

Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -706,12 +706,8 @@ def dist(x, y, p=2, name=None):
if in_dygraph_mode():
return _C_ops.dist(x, y, p)

check_variable_and_dtype(
x, 'dtype', ['float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(
y, 'dtype', ['float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down