diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4da46e94c5cd9..266303b4cb582 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) +op_library(concat_op DEPS concat_functor) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index bdce8f0a6f5f8..0eedd8ee51ebf 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, ops::ConcatOpGrad, false) -REGISTER_OP_CPU_KERNEL(concat, - ops::ConcatKernel) -REGISTER_OP_CPU_KERNEL(concat_grad, - ops::ConcatGradKernel) +REGISTER_OP_CPU_KERNEL( + concat, ops::ConcatKernel) +REGISTER_OP_CPU_KERNEL( + concat_grad, + ops::ConcatGradKernel) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 208a4481c6afe..92c8ab6d9ff11 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { @@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + framework::Tensor* out = ctx.Output("Out"); int64_t axis = static_cast(ctx.Attr("axis")); auto place = ctx.GetPlace(); out->mutable_data(place); - auto out_stride = framework::stride_numel(out->dims()); - - size_t output_offset = 0; - - // If axis >=1, copy to out immediately need to call many times - // of cuda memcpy. Copy the input to cpu and do the stride copy, - // then copy to gpu output. - - if (platform::is_gpu_place(place) && axis >= 1) { - platform::CPUPlace copy_place; - auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place); - framework::Tensor cpu_out; - cpu_out.Resize(out->dims()); - cpu_out.mutable_data(copy_place); - auto& dev_ctx = ctx.device_context(); - std::vector> cpu_ins; - for (auto* in : ins) { - std::unique_ptr cpu_in(new framework::Tensor); - framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get()); - cpu_ins.emplace_back(std::move(cpu_in)); - } - // TODO(dzhwinter): overlap copy and compute stream - // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/ - dev_ctx.Wait(); - - for (auto& in : cpu_ins) { - auto& cpu_in = *in.get(); - auto in_stride = framework::stride_numel(cpu_in.dims()); - - StridedNumelCopyWithAxis( - cpu_ctx, axis, cpu_out.data() + output_offset, out_stride, - cpu_in.data(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; - } - framework::TensorCopy(cpu_out, place, dev_ctx, out); - } else { + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && ins.size() < 10) { + size_t output_offset = 0; for (auto* in : ins) { auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(out->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data() + output_offset, out_stride, in->data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; } + } else { + std::vector inputs(ins.size()); + for (size_t j = 0; j < ins.size(); ++j) { + inputs[j] = *ins[j]; + } + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatFunctor concat_functor; + concat_functor(dev_ctx, inputs, static_cast(axis), out); } } }; @@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel { auto* in = ctx.Input(framework::GradVarName("Out")); auto outs = ctx.MultiOutput(framework::GradVarName("X")); int64_t axis = static_cast(ctx.Attr("axis")); - size_t input_offset = 0; - auto in_stride = framework::stride_numel(in->dims()); - for (auto& out : outs) { - out->mutable_data(ctx.GetPlace()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), - out_stride, in->data() + input_offset, - in_stride, out_stride[axis]); - input_offset += out_stride[axis]; + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + size_t input_offset = 0; + auto in_stride = framework::stride_numel(in->dims()); + + for (auto& out : outs) { + out->mutable_data(ctx.GetPlace()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride, out_stride[axis]); + input_offset += out_stride[axis]; + } + } else { + std::vector outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data(ctx.GetPlace()); + outputs[j] = *outs[j]; + } + + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatGradFunctor + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast(axis), outputs); } } }; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 768106fadf355..751e69b1c89dd 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_GPU) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) + nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -37,6 +38,7 @@ else() cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) + cc_library(concat_functor SRCS concat.cc DEPS device_context tensor) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) +cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc new file mode 100644 index 0000000000000..b542143419e05 --- /dev/null +++ b/paddle/fluid/operators/math/concat.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + auto& cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < out_rows; ++k) { + T* dst_ptr = output->data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const T* src_prt = input[j].data() + k * col_len; + memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, + sizeof(T) * col_len); + col_idx += col_len; + } + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + + int input_rows = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + int input_cols = 0; + + std::vector output_cols(outputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = outputs[i].numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto& cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = output_cols[j]; + T* dst_ptr = outputs[j].data() + k * col_len; + memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, + sizeof(T) * col_len); + col_idx += col_len; + } + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu new file mode 100644 index 0000000000000..60b266f08fb2d --- /dev/null +++ b/paddle/fluid/operators/math/concat.cu @@ -0,0 +1,280 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__device__ T upper_bound(const T* first, T count, T val) { + const T* orig = first; + const T* it = nullptr; + T step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (!(val < *it)) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first - orig; +} + +template +__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(input_cols, col_size, tid_x) - 1; + + int curr_offset = input_cols[segment]; + int curr_segment = segment; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__global__ void KernelConcat(T** inputs, const int input_col, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* input_ptr = inputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int* output_cols, + int col_size, T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(output_cols, col_size, tid_x) - 1; + int curr_offset = output_cols[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + + framework::Vector inputs_data(num * sizeof(T*) / 2); + framework::Vector inputs_cols(num + 1); + inputs_cols[0] = 0; + T** inputs_ptr = reinterpret_cast(inputs_data.data()); + + bool sameShape = true; + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + inputs_cols[i + 1] = out_cols; + inputs_ptr[i] = const_cast(input[i].data()); + } + + T** ins_gpu = + reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); + const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); + + // computation + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcat<<>>( + ins_gpu, cols, out_rows, out_cols, output->data()); + } else { + KernelConcat<<>>( + ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, + out_cols, output->data()); + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + framework::Vector outputs_data(num * sizeof(T*) / 2); + framework::Vector outputs_cols(num + 1); + outputs_cols[0] = 0; + T** outputs_ptr = reinterpret_cast(outputs_data.data()); + + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols[i + 1] = input_col; + outputs_ptr[i] = outputs[i].data(); + } + + T** outs_gpu = + reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); + const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); + + // computation + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((input_col + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((input_col + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, output_col_0, outs_gpu); + } else { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, outs_col_gpu, + static_cast(outputs_cols.size()), outs_gpu); + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h new file mode 100644 index 0000000000000..22147d79e4b1e --- /dev/null +++ b/paddle/fluid/operators/math/concat.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 + * + * Output = [[1,2], + * [3,4], + * [5,6]] + */ +template +class ConcatFunctor { + public: + void operator()(const DeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output); +}; + +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ +template +class ConcatGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const int axis, std::vector& outputs); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc new file mode 100644 index 0000000000000..1741af8148bb9 --- /dev/null +++ b/paddle/fluid/operators/math/concat_test.cc @@ -0,0 +1,336 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" +#include +#include +#include "paddle/fluid/framework/tensor_util.h" + +using namespace paddle::framework; +using namespace paddle::platform; + +template +void testConcat() { + Tensor input_a_cpu; + Tensor input_b_cpu; + Tensor out_cpu; + Tensor input_a; + Tensor input_b; + Tensor out; + + DeviceContext* context = new DeviceContext(Place()); + // DeviceContext context(Place()); + + /** + * cast1: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [3, 3, 4] + * output: + * out.shape: [5, 3, 4] + */ + auto dim_a = make_ddim({2, 3, 4}); + auto dim_b = make_ddim({3, 3, 4}); + auto dim_out = make_ddim({5, 3, 4}); + + input_a.mutable_data(dim_a, Place()); + input_b.mutable_data(dim_b, Place()); + out.mutable_data(dim_out, Place()); + + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.mutable_data(dim_a, CPUPlace()); + input_b_cpu.mutable_data(dim_b, CPUPlace()); + out_cpu.mutable_data(dim_out, CPUPlace()); + } + + int* a_ptr; + int* b_ptr; + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 3 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + std::vector input; + input.push_back(input_a); + input.push_back(input_b); + + paddle::operators::math::ConcatFunctor concat_functor; + concat_functor(*context, input, 0, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + int* out_ptr; + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + int cols = 2 * 3 * 4; + int idx_a = 0, idx_b = 0; + for (int j = 0; j < 5 * 3 * 4; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]); + ++idx_a; + } + } + // + /** + * cast2: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 4, 4] + * output: + * out.shape: [2, 7, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 4, 4}); + dim_out = make_ddim({2, 7, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 4 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + cols = 3 * 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 28; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } + + /** + * cast3: + * inputs: + * t_a.shape: [2, 3, 5] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 3, 9] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 5}); + dim_out = make_ddim({2, 3, 9}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 5; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 2, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + // check the data + cols = 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 6; ++i) { + for (int j = 0; j < 9; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } + + /** + * cast4: + * inputs: + * axis = 1 + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 6, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 4}); + dim_out = make_ddim({2, 6, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + // check the data + cols = 12; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 24; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } +} + +TEST(math, concat) { + testConcat(); +#ifdef PADDLE_WITH_CUDA + testConcat(); +#endif +} diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7da6e04d0a8b8..583a3e740e10e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -121,6 +121,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); + multi_process = GetCUDAMultiProcessors(place_.device); + max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -154,6 +156,10 @@ void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaGetLastError()); } +int CUDADeviceContext::GetMaxPhysicalThreadCount() const { + return multi_process * max_threads_per_mp; +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a294ba5101528..918243ccfe359 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */ Place GetPlace() const override; + /*! \brief Return the max physical thread count in the device context */ + int GetMaxPhysicalThreadCount() const; + /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; + + int multi_process; + int max_threads_per_mp; }; template <> diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 05e1eae853e20..da4041bad0d82 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -33,6 +33,26 @@ int GetCUDADeviceCount() { return count; } +int GetCUDAMultiProcessors(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMultiProcessors"); + return count; +} + +int GetCUDAMaxThreadsPerMultiProcessor(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE(cudaDeviceGetAttribute( + &count, cudaDevAttrMaxThreadsPerMultiProcessor, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"); + return count; +} + int GetCurrentDeviceId() { int device_id; PADDLE_ENFORCE( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 3d4883d8078da..c38ccf0f2ade1 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse = //! Get the total number of GPU devices in system. int GetCUDADeviceCount(); +//! Get the MultiProcessors of the ith GPU. +int GetCUDAMultiProcessors(int i); + +//! Get the MaxThreads of each MultiProcessor of the ith GPU. +int GetCUDAMaxThreadsPerMultiProcessor(int i); + //! Get the current GPU device id in system. int GetCurrentDeviceId();