Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… feature/insert_reduce_to_parallel_exe
  • Loading branch information
chengduoZH committed May 2, 2018
2 parents ed052f1 + ff99d94 commit c0a3746
Show file tree
Hide file tree
Showing 34 changed files with 498 additions and 118 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ RUN apt-get update && \
automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools libtool && \
net-tools libtool ccache && \
apt-get clean -y

# Install Go and glide
Expand Down
24 changes: 18 additions & 6 deletions paddle/cuda/include/hl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef HL_BASE_H_
#define HL_BASE_H_
#pragma once

#include <cstddef>

Expand Down Expand Up @@ -207,8 +206,8 @@ typedef struct {

#ifdef __NVCC__

#include "cuda_runtime.h"
#include "hl_cuda.h"
#include <cuda_runtime.h>
#include "paddle/cuda/include/hl_cuda.h"
#include "paddle/utils/Logging.h"

extern __thread bool g_sync_flag;
Expand All @@ -228,6 +227,19 @@ extern __thread cudaStream_t default_stream;
<< "CUDA error: " << hl_get_device_error_string((size_t)err); \
}

#endif /* __NVCC__ */
// __shfl has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T
__shfl_sync(unsigned, T val, int src_line, int width) {
return __shfl(val, src_line, width);
}

#endif /* HL_BASE_H_ */
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif

#endif // __NVCC__
14 changes: 9 additions & 5 deletions paddle/cuda/src/hl_cuda_lstm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue,
}

__device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
int addr = idx % 32;
const int warp_size = 32;
int addr = idx % warp_size;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, addr < warp_size);
#pragma unroll
for (int k = 1; k < 32; k++) {
// rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
addr = __shfl_sync(addr, (idx + 1) % 32, 32);
a[k] = __shfl_sync(a[k], addr, 32);
addr = __shfl_sync(mask, addr, (idx + 1) % 32, 32);
a[k] = __shfl_sync(mask, a[k], addr, 32);
}

#pragma unroll
Expand All @@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
}

addr = (32 - idx) % 32;
CREATE_SHFL_MASK(mask, idx % 32 < warp_size);
#pragma unroll
for (int k = 0; k < 32; k++) {
a[k] = __shfl_sync(a[k], addr, 32);
addr = __shfl_sync(addr, (idx + 31) % 32, 32);
a[k] = __shfl_sync(mask, a[k], addr, 32);
addr = __shfl_sync(mask, addr, (idx + 31) % 32, 32);
}
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/cuda/src/hl_top_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break;
__syncthreads();

unsigned mask = 0u;
// CREATE_SHFL_MASK(mask, tid < len);

if (tid == maxId[0]) {
if (beam < maxLength) {
shTopK[tid] = topK[beam];
}
}
if (maxId[0] / 32 == warp) {
if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break;
if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break;
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool skip_scale_loss,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
bool use_nccl_allreduce)
: loss_var_name_(loss_var_name),
places_(places),
Expand All @@ -50,7 +50,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
bool use_nccl_allreduce)
: loss_var_name_(loss_var_name),
places_(places),
Expand All @@ -60,7 +60,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
skip_scale_loss_ = skip_scale_loss;
use_default_grad_scale_ = use_default_grad_scale;
}

void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
Expand Down Expand Up @@ -141,8 +141,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if skip_scale_loss_
if (!skip_scale_loss_) {
// user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
bool skip_scale_loss, bool use_nccl_allreduce);
bool use_default_grad_scale, bool use_nccl_allreduce);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
bool skip_scale_loss, bool use_nccl_allreduce);
bool use_default_grad_scale, bool use_nccl_allreduce);
#endif

std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
Expand All @@ -61,7 +61,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool use_nccl_allreduce_;
bool skip_scale_loss_;
bool use_default_grad_scale_;

bool IsScaleLossOp(const OpDesc &op) const;

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/lod_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool customize_scale_loss, bool use_nccl_allreduce)
bool use_default_grad_scale, bool use_nccl_allreduce)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;

Expand Down Expand Up @@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), customize_scale_loss, use_nccl_allreduce);
member_->nccl_ctxs_.get(), use_default_grad_scale, use_nccl_allreduce);
#else
details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_,
customize_scale_loss, use_nccl_allreduce);
use_default_grad_scale, use_nccl_allreduce);
#endif
auto graph = builder.Build(main_program);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/parallel_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ParallelExecutor {
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool customize_scale_loss,
bool allow_op_delay, bool use_default_grad_scale,
bool use_nccl_allreduce);

~ParallelExecutor();
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/selected_rows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}

std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
framework::Tensor* value) const {
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<int64_t> non_keys;
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
Expand All @@ -133,15 +133,15 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys.push_back(keys[i]);
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
return non_keys;
return non_keys_pair;
}

bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/selected_rows.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/lod_tensor.h"
Expand Down Expand Up @@ -78,10 +79,11 @@ class SelectedRows {
/*
* @brief Get value by the key list, if the
*
* @return a list of keys which does not exists in table
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std::vector<int64_t> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const;
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
framework::Tensor* value) const;

/*
* @brief Set a key-value pair into the table.
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/selected_rows_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
}

TEST_F(SelectedRowsTester, Table) {
TEST_F(SelectedRowsTester, SparseTable) {
platform::CPUPlace cpu;
SelectedRows table;
// initialize a sparse table
Expand Down Expand Up @@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) {
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key});
auto non_keys = table.Get(keys, &get_value);
auto non_key_pairs = table.Get(keys, &get_value);

ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
ASSERT_EQ(non_keys[0], non_key);
ASSERT_EQ(non_key_pairs.size(), static_cast<size_t>(1));
ASSERT_EQ(non_key_pairs[0].first, non_key);
}

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_));
infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the executation
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/inference/tensorrt/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}

// A logger for create TensorRT infer builder.
Expand Down Expand Up @@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
return *x;
}

virtual ~NaiveLogger() override {}
~NaiveLogger() override {}
};

} // namespace tensorrt
Expand Down

0 comments on commit c0a3746

Please sign in to comment.