Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speed Compiling]: Reduce NVCC compiling files. #5573

Merged
merged 12 commits into from
Nov 16, 2017
19 changes: 15 additions & 4 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ function(op_library TARGET)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
set(cc_srcs)
set(cu_srcs)
set(cu_cc_srcs)
set(op_common_deps operator op_registry math_function)
set(options "")
set(oneValueArgs "")
Expand All @@ -22,13 +23,18 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND cc_srcs ${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
else()
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
Expand All @@ -43,7 +49,7 @@ function(op_library TARGET)
endif()

if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
Expand Down Expand Up @@ -140,7 +146,9 @@ function(op_library TARGET)

# pybind USE_CPU_ONLY_OP
list(LENGTH cu_srcs cu_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)

if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
Expand All @@ -160,11 +168,12 @@ set(DEPS_OPS
recurrent_op
dynamic_recurrent_op
softmax_with_cross_entropy_op
softmax_op
sequence_softmax_op
sum_op
pool_op
pool_with_index_op
conv_op
lstm_op
conv_transpose_op
nccl_op
sequence_conv_op
Expand All @@ -179,6 +188,8 @@ set(DEPS_OPS
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
op_library(conv_op DEPS vol2col)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
Expand Down Expand Up @@ -219,6 +230,6 @@ cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc
DEPS dynamic_recurrent_op)
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
math::set_constant(ctx.device_context(), input_grad, 0);

PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
Expand All @@ -214,9 +212,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
math::set_constant(ctx.device_context(), filter_grad, 0);

// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_batch_size_like_op.h"
#include "paddle/framework/op_registry.h"

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/operators/gru_op.h"

namespace ops = paddle::operators;
Expand Down
54 changes: 24 additions & 30 deletions paddle/operators/gru_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
Expand All @@ -57,19 +53,15 @@ class GRUKernel : public framework::OpKernel<T> {

bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
auto& dev_ctx = context.device_context();
to_batch(dev_ctx, *input, *batch_gate, true, is_reverse);

int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
auto g = EigenMatrix<T>::From(*batch_gate);
auto place = context.GetEigenDevice<Place>();
if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = g +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
math::RowwiseAdd<Place, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
}

int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
Expand All @@ -89,15 +81,15 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateValue = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute(
context.device_context(), gru_value, frame_size, cur_batch_size,
dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue;
}

math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(context.device_context(), *batch_hidden, *hidden);
to_seq(dev_ctx, *batch_hidden, *hidden);
}

void Compute(const framework::ExecutionContext& context) const override {
Expand Down Expand Up @@ -138,15 +130,14 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
context.GetPlace());
math::SetConstant<Place, T> zero;
zero(context.device_context(), &batch_hidden_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_gate_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
auto& dev_ctx = context.device_context();
zero(dev_ctx, &batch_hidden_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));

bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse);
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);

math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
Expand All @@ -157,7 +148,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
if (weight_grad) {
gru_grad.gateWeightGrad =
weight_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), weight_grad, static_cast<T>(0.0));
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad =
weight_grad->data<T>() + 2 * frame_size * frame_size;
} else {
Expand Down Expand Up @@ -188,7 +179,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
Expand All @@ -202,23 +193,26 @@ class GRUGradKernel : public framework::OpKernel<T> {
}

math::GRUUnitGradFunctor<Place, T>::compute(
context.device_context(), gru_value, gru_grad, frame_size,
cur_batch_size,
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_gate_grad.set_lod(batch_gate->lod());
to_seq(context.device_context(), batch_gate_grad, *input_grad);
to_seq(dev_ctx, batch_gate_grad, *input_grad);
}
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_g = EigenMatrix<T>::From(batch_gate_grad);
auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
int m = static_cast<int>(batch_gate_grad.dims()[0]);
int n = static_cast<int>(batch_gate_grad.dims()[1]);
Tensor ones;
ones.mutable_data<T>({m}, context.GetPlace());
math::SetConstant<Place, T> set;
set(dev_ctx, &ones, static_cast<T>(1));
math::gemv<Place, T>(dev_ctx, true, m, n, 1., batch_gate_grad.data<T>(),
ones.data<T>(), 0., bias_grad->data<T>());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is used at many places, should we make a similar function as RowwiseAdd to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add a ColwiseSum functor in paddle/operators/math/math_function.h to compute the gradient of bias.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/operators/lstm_op.h"

namespace ops = paddle::operators;
Expand Down
19 changes: 5 additions & 14 deletions paddle/operators/lstm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
Expand Down Expand Up @@ -65,16 +61,11 @@ class LSTMKernel : public framework::OpKernel<T> {
framework::DDim dims({in_dims[0], frame_size});

if (bias) {
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto b = EigenMatrix<T>::From(*bias);
auto gate = EigenMatrix<T>::From(*batch_gate);
gate.device(ctx.GetEigenDevice<Place>()) =
gate +
b.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
.broadcast(
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
Tensor b = *bias;
b.Resize({bias->numel(), 1});
Tensor gate_bias = b.Slice(0, 4 * frame_size);
math::RowwiseAdd<Place, T> add_bias;
add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
}

math::LstmMetaValue<T> lstm_value;
Expand Down
16 changes: 8 additions & 8 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
add_subdirectory(detail)

if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(softmax SRCS softmax.cc DEPS device_context)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
Expand Down
Loading