Skip to content

Commit

Permalink
Merge pull request #8669 from chengduoZH/feature/concat_op
Browse files Browse the repository at this point in the history
Refine concat_op
  • Loading branch information
chengduo committed Mar 7, 2018
2 parents 8c71ada + c3864ea commit 84aea8a
Show file tree
Hide file tree
Showing 12 changed files with 883 additions and 50 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_functor)

list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad, false)
REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>)
84 changes: 38 additions & 46 deletions paddle/fluid/operators/concat_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/strided_memcpy.h"

namespace paddle {
Expand All @@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);

auto out_stride = framework::stride_numel(out->dims());

size_t output_offset = 0;

// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.

if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();

for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());

StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
}
}
};
Expand All @@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());

for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());

for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
}
} else {
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
}
}
};
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ if(WITH_GPU)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
Expand All @@ -37,10 +38,12 @@ else()
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
cc_library(concat_functor SRCS concat.cc DEPS device_context tensor)
endif()

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor)
119 changes: 119 additions & 0 deletions paddle/fluid/operators/math/concat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/concat.h"

namespace paddle {
namespace operators {
namespace math {

/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();

int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;

std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());

// computation
for (int k = 0; k < out_rows; ++k) {
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};

/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();

int input_rows = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;

std::vector<int64_t> output_cols(outputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());

// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = output_cols[j];
T* dst_ptr = outputs[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};

template class ConcatFunctor<platform::CPUDeviceContext, int>;
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;

template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
Loading

0 comments on commit 84aea8a

Please sign in to comment.