Skip to content

Commit

Permalink
optimize flip op, removing duplicated computation when dim size is one (
Browse files Browse the repository at this point in the history
  • Loading branch information
sljlp committed Dec 9, 2021
1 parent 18aca3f commit 890638c
Showing 1 changed file with 0 additions and 41 deletions.
41 changes: 0 additions & 41 deletions paddle/fluid/operators/flip_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,6 @@ namespace operators {
using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;

template <typename T>
__global__ void kernel_pointwise_flip_apply(const int N, const T* in_data,
T* out_data, int dim0, int stride0,
int dim1, int flip_dim) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
idx += gridDim.x * blockDim.x) {
int dst_offset = 0;
if (flip_dim == 0) {
// flip 1st dim
dst_offset = (dim0 - 1 - idx / stride0) * stride0 + idx % stride0;
} else {
// flip last dim
dst_offset = idx / stride0 * stride0 + (dim1 - 1 - idx % stride0);
}
out_data[dst_offset] = in_data[idx];
}
}

template <typename T>
__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data,
int64_t* x_shape, int64_t* x_stride,
Expand Down Expand Up @@ -103,29 +85,6 @@ class FlipKernel<platform::CUDADeviceContext, T>
std::vector<int64_t> x_dims_v = framework::vectorize(x_dims);
std::vector<int64_t> x_stride_v = framework::vectorize(x_stride);

// wrap high-dims to 2-dims
if (flip_dims_size == 1 &&
(flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
int dim0 = 1, dim1 = 1;
int stride0 = 1;
if (flip_dims[0] == 0) {
dim0 = x_dims_v[0];
stride0 = x_stride_v[0];
for (size_t i = 1; i < total_dims; ++i) {
dim1 *= x_dims_v[i];
}
} else {
dim1 = x_dims_v[total_dims - 1];
for (size_t i = 0; i < total_dims - 1; ++i) {
dim0 *= x_dims_v[i];
}
stride0 *= x_dims_v[total_dims - 1];
}
kernel_pointwise_flip_apply<
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
N, in_data, out_data, dim0, stride0, dim1, flip_dims[0]);
}

int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes);
int64_t* x_strides_array_gpu =
Expand Down

0 comments on commit 890638c

Please sign in to comment.