diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e11f86d0ee83..420dafee65f8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,6 +147,7 @@ include(external/cares) include(external/grpc) include(external/snappy) # download snappy include(external/snappystream) +include(external/threadpool) include(cudnn) # set cudnn libraries, must before configure include(cupti) diff --git a/cmake/external/threadpool.cmake b/cmake/external/threadpool.cmake new file mode 100644 index 0000000000000..0159815fed81b --- /dev/null +++ b/cmake/external/threadpool.cmake @@ -0,0 +1,30 @@ +INCLUDE(ExternalProject) + +SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool) +SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool) +INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR}) + +ExternalProject_Add( + extern_threadpool + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/progschj/ThreadPool.git" + GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040 + PREFIX ${THREADPOOL_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) + +if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/threadpool_dummy.c) + file(WRITE ${dummyfile} "const char *dummy_threadpool = \"${dummyfile}\";") + add_library(simple_threadpool STATIC ${dummyfile}) +else() + add_library(simple_threadpool INTERFACE) +endif() + +add_dependencies(simple_threadpool extern_threadpool) + +LIST(APPEND external_project_dependencies simple_threadpool) diff --git a/doc/design/images/parallel_executor_overview.dot b/doc/design/images/parallel_executor_overview.dot new file mode 100644 index 0000000000000..40753cb140540 --- /dev/null +++ b/doc/design/images/parallel_executor_overview.dot @@ -0,0 +1,83 @@ +digraph G { + subgraph cluster_init { + label="Initialization" + startup_program [label="startup", shape=box] + node_w_g0 [label="W\nGPU0"] + startup_program -> node_w_g0 [label="Initialize"] + node_w_g1 [label="W\nGPU1"] + node_w_g0 -> node_w_g1 [label="broadcast"] + } + + subgraph cluster_train { + label="forward_backward" + + subgraph cluster_gpu0 { + label="GPU0" + fc_0 [label="fc\nGPU0", shape=box] + hidden_0 [label="hidden\nGPU0"] + node_w_g0 -> fc_0 + fc_0 -> hidden_0 + loss0 [label="loss\nGPU0"] + hidden_0 -> loss0 [label="many ops omitted"] + scale_loss_0 [label="scale_loss_gradient\nGPU0", shape=box] + loss_g0 [label="loss_grad\nGPU0"] + scale_loss_0->loss_g0 + + fc_g_0 [label="w_grad\nGPU0", shape=box] + loss0 -> fc_g_0 + loss_g0 -> fc_g_0 + hidden_0 -> fc_g_0 + } + + subgraph cluster_gpu1 { + label="GPU1" + fc_1 [label="fc\nGPU1", shape=box] + hidden_1 [label="hidden\nGPU1"] + node_w_g1 -> fc_1 + fc_1 -> hidden_1 + loss1 [label="loss\nGPU1"] + hidden_1 -> loss1 [label="many ops omitted"] + scale_loss_1 [label="scale_loss_gradient\nGPU1", shape=box] + loss_g1 [label="loss_grad\nGPU1"] + scale_loss_1->loss_g1 + + fc_g_1 [label="w_grad\nGPU1", shape=box] + loss1 -> fc_g_1 + loss_g1 -> fc_g_1 + hidden_1 -> fc_g_1 + } + } + + all_reduce_w [label="Merge Gradients(AllReduce)", shape=box] + fc_g_0 -> all_reduce_w + fc_g_1 -> all_reduce_w + + fc_g_0_merged [label="w_grad\nMerged\nGPU0"] + fc_g_1_merged [label="w_grad\nMerged\nGPU1"] + all_reduce_w -> fc_g_0_merged + all_reduce_w -> fc_g_1_merged + + subgraph cluster_optimization { + label="Optimization" + subgraph cluster_opt_gpu0 { + label="GPU0" + sgd_0 [label="SGD Op\nGPU0", shape=box] + + fc_g_0_merged -> sgd_0 + node_w_g0 -> sgd_0 + optimized_w_0 [label="Optimized W\nGPU0"] + sgd_0 -> optimized_w_0 + } + subgraph cluster_opt_gpu1 { + label="GPU1" + sgd_1 [label="SGD Op\nGPU1", shape=box] + + fc_g_1_merged -> sgd_1 + node_w_g1 -> sgd_1 + optimized_w_1 [label="Optimized W\nGPU0"] + sgd_1 -> optimized_w_1 + } + } + + +} diff --git a/doc/design/images/parallel_executor_overview.png b/doc/design/images/parallel_executor_overview.png new file mode 100644 index 0000000000000..d890c0ffee3b3 Binary files /dev/null and b/doc/design/images/parallel_executor_overview.png differ diff --git a/doc/design/parallel_executor.md b/doc/design/parallel_executor.md new file mode 100644 index 0000000000000..9aed3b059a159 --- /dev/null +++ b/doc/design/parallel_executor.md @@ -0,0 +1,104 @@ +# ParallelExecutor + +## Background + +Neural network models are defined as a `ProgramDesc` in Fluid. The `ProgramDesc` can be executed by an interpreter(i.e. the `executor` concept in Fluid). The instructions or operators in a `Program` will be executed, and the results will be fetched in Python side. + +The executor is a very naive interpreter. It runs operators one by one. We can use `Parallel.Do` to support data parallelism, however, lacking device information in `ProgramDesc`; it is not possible to optimize the performance of `Parallel.Do`. + +We want a `ProgramDesc` can be run on different nodes. It is better not to contain device information in `ProgramDesc`. However, we can write a high-performance interpreter, which can hold an alternative intermediate representation of `ProgramDesc`, to take full usage of Multi-GPUs. + +ParallelExecutor is an interpreter of `ProgramDesc` which will [out-of-order execute](https://en.wikipedia.org/wiki/Out-of-order_execution) `Program` in data parallelism mode and maximise the utility of Multi-GPUs. + + +## Overview of MultiGPUs logic + +The ParallelExecutor takes the startup program and main program as inputs. The parameters will be initialised on `GPU0` by startup program and will broadcast to multi-GPUs. The main program will be duplicated into multi-GPUs. The gradient will be merged during each iteration, and each device will optimize parameters independently. Since the gradients on each device will be merged before parameter optimization, the parameters will be the same on each device and it does not need to be broadcast the parameters. + +![alt](images/parallel_executor_overview.png) + +There are several optimizations for this logic. + +1. We use an alternate representation in ParallelExecutor. It because the device information is critical for performance optimization. +2. The execution is out-of-order, i.e., an operator will be executed whenever the inputs of the operator are ready. + * GPU is a high-performance device; only one CPU thread cannot fulfil one GPU. So there is a thread pool to execute operators. + * Out-of-order also helps transpilers to generate `ProgramDesc`. It is no need to concern about the best order of performance when implementing a transpiler. +3. The streams of computation, merge gradients and fetch data are different. + +The performance of `ResNeXt152` on `TitanX` which `batch_size=12` is shown below. + +| Number of GPUs | 1 | 2 | 3 | 4| +| --- | --- | --- | --- | --- | +| Image/Sec | 17.9906 | 25.771 | 36.911 | 48.8428 | +| Speed Up | N/A | 1.43247029 | 2.05168255 | 2.71490667 | + + +## Static single assignment Graph + +[Static single assignment form](https://en.wikipedia.org/wiki/Static_single_assignment_form)(`SSA` for short) is a common form for compiler optimization. To implement concurrent execution, we uses an `SSA` graph as an intermedia representation of `ProgramDesc`. + +The `Program` is a directed acyclic graph, since a variable can be assigned multiple times. We enforce a variable will be assigned once, by adding version number to varaibles. We parsing the `Program` into a `SSA` graph. Also, ProgramExecutor duplicate `Program` into multi-devices. We also add a device number to varaibles and insert `NCCLAllReduce` into Graph. + +The data structure of `SSA` graph is: + +```c++ +struct VarHandleBase { + OpHandleBase* generated_op_; + vector pending_ops_; + + string name; + Place place; + size_t version; +}; + +struct OpHandleBase { + vector inputs_; + vector outputs_; +}; + +struct SSAGraph { + // vars on each devices. + // * the vars in each map in vector is on different device. + // * the map is mapping a variable name to variable handles + // with different versions + vector>> vars_; + + // All ops + vector ops_; +}; +``` +The variable handles are the wrapper of `Variables`. The operator handles are the wrapper of `OperatorBase`. Some `OpHandle` is not an `OperatorBase`, such as `NCCLAllReduceOpHandle`, because `AllReduceOpHandle` will use new device contexts. + +When the `ProgramDesc` converted into an `SSA` Graph, the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem is also need to be taken care. The dummy variables, which represent the dependency between operators, will be manually inserted into SSA graph to resolve the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem. + +## Execute SSA Graph + +The SSA graph can be out-of-order executed by an approximate [topological sorting](https://en.wikipedia.org/wiki/Topological_sorting) algorithm. The algorithm is + +1. Maintaining a map of an operator and its needed input number. +2. If a variable is not generated by an operator, i.e., `var.generated_op == nullptr`, decrease the needed input number of its pending operators. +3. If there is an operator which needed input number is decreased to zero, just run this operator. +4. After run this operator, just mark the variables are generated and repeat step 2 until all variables are generated. + +Running an operator can be asynchronized. There is a thread pool to execute an `SSA` graph. + +## Synchronize GPU Kernels + +The GPU is a non-blocking device. The different streams need be synchronized when switing streams. In current implementation, the synchronization based on the following algorithm: + +1. `OpHandle` will record `DeviceContext` that it is used. +2. In `OpHandle::Run`, if the `DeviceContext` of current operator is different from `DeviceContext` of any input variable, just wait the generate operator of this input variable. + +The `wait` are implemented by two strategies: + +1. Invoke `DeviceContext->Wait()`, It will wait all operators on this device contexts complete. +2. Uses `cudaStreamWaitEvent` to sending a event to the stream. It is a non-blocking call. The wait operators will be executed in GPU. + +Generally, the `cudaStreamWaitEvent` will have a better perforamnce. However, `DeviceContext->Wait()` strategy is easier to debug. The strategy can be changed in runtime. + +## What's next? + +* Merging gradient of dense parameters has been done. However, the merging of sparse parameters has not been done. +* The CPU version of Parallel Executor has not been implemented. The out-of-order logic will make CPU compuatation faster, too. +* A better strategy to merge gradients can be introduced. We can shrink the gradients from `float32` to `int8` or `int4` while merging. It will significantly speed up multi-GPUs training without much loss of precision. +* Combine multi-Nodes implementation. By the benifit of out-of-order, sending and recving operator can be an blocking operator, and the transpiler does not need to concern about the best position of operator. diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8c8def6bf47f0..a34e22ff8765f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(details) # ddim lib proto_library(framework_proto SRCS framework.proto) @@ -87,6 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) + +cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor) + cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt new file mode 100644 index 0000000000000..bf1a705ef50b6 --- /dev/null +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -0,0 +1,21 @@ +cc_library(var_handle SRCS var_handle.cc DEPS place) +cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context) +cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) +cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) +nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + dynload_cuda) +cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) + +cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) + +if(WITH_GPU) + set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) +else() + set(multi_devices_graph_builder_deps) +endif() +cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle + scale_loss_grad_op_handle ${multi_devices_graph_builder_deps}) +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph) +cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope + simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc new file mode 100644 index 0000000000000..7a1b40c0b60a7 --- /dev/null +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/computation_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + scope_(scope), + place_(place) {} + +void ComputationOpHandle::RunImpl() { + auto *cur_ctx = dev_ctxes_[place_]; + for (auto *in : inputs_) { + bool need_wait = + in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx; + if (need_wait) { + in->generated_op_->Wait(cur_ctx); + } + } + + op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get(), place_); +} + +std::string ComputationOpHandle::Name() const { return op_->Type(); } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h new file mode 100644 index 0000000000000..d6d2d731ca80a --- /dev/null +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { +struct ComputationOpHandle : public OpHandleBase { + std::unique_ptr op_; + Scope *scope_; + platform::Place place_; + + ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place); + + std::string Name() const override; + + protected: + void RunImpl() override; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc new file mode 100644 index 0000000000000..9180903b864d0 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fetch_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, + std::vector *local_scopes) + : data_(data), offset_(offset), local_scopes_(local_scopes) {} + +FetchOpHandle::~FetchOpHandle() { + for (auto *input_var : inputs_) { + input_var->pending_ops_.erase(this); + } +} + +void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { + PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); +} + +void FetchOpHandle::WaitAndMergeCPUTensors() const { + std::vector tensors_ptr; + tensors_ptr.reserve(tensors_.size()); + for (auto &t : tensors_) { + tensors_ptr.emplace_back(&t); + } + data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace()); +} + +void FetchOpHandle::RunImpl() { + auto cpu_ctx = + platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); + for (auto *input : inputs_) { + auto *var = static_cast(input); + var->generated_op_->Wait(cpu_ctx); + } + + tensors_.resize(inputs_.size()); + auto *var = static_cast(inputs_[0]); + auto &var_name = var->name_; + platform::CPUPlace cpu; + auto &scopes = *local_scopes_; + + for (size_t i = 0; i < scopes.size(); ++i) { + auto &scope = scopes[i]; + auto &t = scope->FindVar(var_name)->Get(); + if (platform::is_gpu_place(var->place_)) { +#ifdef PADDLE_WITH_CUDA + TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]); + dev_ctxes_[t.place()]->Wait(); +#endif + } else { + tensors_[i].ShareDataWith(t); + tensors_[i].set_lod(t.lod()); + } + } + + this->WaitAndMergeCPUTensors(); +} + +std::string FetchOpHandle::Name() const { return "Fetch"; } + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h new file mode 100644 index 0000000000000..904b2d669f8b1 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { + +struct FetchOpHandle : public OpHandleBase { + FeedFetchList *data_; + size_t offset_; + std::vector *local_scopes_; + std::vector tensors_; + + FetchOpHandle(FeedFetchList *data, size_t offset, + std::vector *local_scopes); + + ~FetchOpHandle(); + + void Wait(platform::DeviceContext *waited_dev) override; + + void WaitAndMergeCPUTensors() const; + + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc new file mode 100644 index 0000000000000..a1b913a863cc1 --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/scope.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +#ifdef PADDLE_WITH_CUDA +MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs) + : loss_var_name_(loss_var_name), + places_(places), + local_scopes_(local_scopes), + nccl_ctxs_(nccl_ctxs) { +#else +MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes) + : loss_var_name_(loss_var_name), + places_(places), + local_scopes_(local_scopes) { +#endif + for (auto &p : params) { + grad_names_.insert(GradVarName(p)); + } +} + +std::unique_ptr MultiDevSSAGraphBuilder::Build( + const ProgramDesc &program) const { + auto graph = new SSAGraph(); + SSAGraph &result = *graph; + result.vars_.resize(places_.size()); + + bool is_forwarding = true; + for (auto *op : program.Block(0).AllOps()) { + bool change_forward = false; + if (!is_forwarding) { + // FIXME(yy): Do not hard code like this + if (op->OutputArgumentNames().size() == 1 && + op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) { + continue; // Drop fill 1. for backward coeff; + } + } + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + + result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); + auto *op_handle = result.ops_.back().get(); + op_handle->dev_ctxes_[p] = const_cast( + platform::DeviceContextPool::Instance().Get(p)); + + auto var_names = op->InputArgumentNames(); + + for (auto &each_var_name : var_names) { + VarHandle *var = + CreateOrGetLatestVarHandle(&result, each_var_name, p, i); + op_handle->AddInput(var); + } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(&result, op_handle, each_var_name, p, i); + } + + if (is_forwarding) { + if (var_names.size() == 1 && var_names[0] == loss_var_name_) { +// Insert ScaleCost OpHandle +#ifdef PADDLE_WITH_CUDA + auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p); +#else + auto *communication_dev_ctx = + platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); +#endif + + op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, + communication_dev_ctx); + result.ops_.emplace_back(op_handle); + + // FIXME: Currently ScaleLossGradOp only use device_count as scale + // factor. So it does not depend on any other operators. + // VarHandle *loss = GetVarHandle(loss_var_name, place); + // loss->pending_ops_.emplace_back(op_handle); + // op_handle->inputs_.emplace_back(loss); + + CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i); + change_forward = true; + } + } + } + + if (change_forward) { + is_forwarding = false; + } + + if (!is_forwarding) { + auto var_names = op->OutputArgumentNames(); + for (auto &og : var_names) { + if (grad_names_.count(og) != 0) { // is param grad + // Insert NCCL AllReduce Op +#ifdef PADDLE_WITH_CUDA + result.ops_.emplace_back( + new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); + auto *op_handle = result.ops_.back().get(); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto &vars = result.vars_[i][og]; + + if (vars.empty()) { // This device has no data. continue. + continue; + } + auto *prev_grad = &vars[vars.size() - 1]; + op_handle->AddInput(prev_grad); + + auto &var = vars[vars.size()]; + var.place_ = p; + var.name_ = og; + var.version_ = vars.size() - 1; + + op_handle->AddOutput(&var); + } +#else + PADDLE_ENFORCE("Not implemented"); +#endif + } + } + } + } + + /* + Dependency graph has been constructed. However, there are still data + harzaeds need to be handled. + */ + PolishGraphToSupportDataHazards(&result); + + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + PrintGraphviz(*graph, sout); + VLOG(10) << sout.str(); + } + + return std::unique_ptr(graph); +} // namespace details +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h new file mode 100644 index 0000000000000..d3c8e582cf2cd --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace platform { +class NCCLContextMap; +} + +namespace framework { +class Scope; +namespace details { +class MultiDevSSAGraphBuilder : public SSAGraphBuilder { + public: +#ifdef PADDLE_WITH_CUDA + MultiDevSSAGraphBuilder(const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs); +#else + MultiDevSSAGraphBuilder(const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes); +#endif + + std::unique_ptr Build(const ProgramDesc &program) const override; + + private: + std::string loss_var_name_; + const std::vector &places_; + const std::vector &local_scopes_; + std::unordered_set grad_names_; + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nccl_ctxs_; +#endif +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc new file mode 100644 index 0000000000000..5ddf331cfca39 --- /dev/null +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs) + : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p); + } +} + +void NCCLAllReduceOpHandle::RunImpl() { + if (inputs_.size() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + in->generated_op_->Wait(dev_ctxes_[p]); + } + + auto &var_name = static_cast(this->inputs_[0])->name_; + int dtype = -1; + size_t numel = 0; + + std::vector> all_reduce_calls; + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + int dev_id = boost::get(p).device; + + auto &lod_tensor = s->FindVar(var_name)->Get(); + void *buffer = const_cast(lod_tensor.data()); + + if (dtype == -1) { + dtype = platform::ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + auto &nccl_ctx = nccl_ctxs_.at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } + + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + } +} + +std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h new file mode 100644 index 0000000000000..045070bb6a97e --- /dev/null +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +struct NCCLAllReduceOpHandle : public OpHandleBase { + const std::vector &local_scopes_; + const std::vector &places_; + const platform::NCCLContextMap &nccl_ctxs_; + + NCCLAllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs); + + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc new file mode 100644 index 0000000000000..e4194a7442f67 --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/op_handle_base.h" + +namespace paddle { +namespace framework { +namespace details { +std::string OpHandleBase::DebugString() const { + std::stringstream ss; + ss << "("; + for (auto *var : inputs_) { + ss << var->DebugString() << ", "; + } + ss << ") --> ("; + for (auto *var : outputs_) { + ss << var->DebugString() << ", "; + } + ss << ")\n"; + return ss.str(); +} + +OpHandleBase::~OpHandleBase() { +#ifdef PADDLE_WITH_CUDA + for (auto &ev : events_) { + PADDLE_ENFORCE(cudaEventDestroy(ev.second)); + } +#endif +} + +void OpHandleBase::Run(bool use_event) { +#ifdef PADDLE_WITH_CUDA + if (events_.empty() && use_event) { + for (auto &p : dev_ctxes_) { + int dev_id = boost::get(p.first).device; + PADDLE_ENFORCE(cudaSetDevice(dev_id)); + PADDLE_ENFORCE( + cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); + } + } +#else + PADDLE_ENFORCE(!use_event); +#endif + + RunImpl(); + +#ifdef PADDLE_WITH_CUDA + if (use_event) { + for (auto &p : dev_ctxes_) { + int dev_id = boost::get(p.first).device; + auto stream = + static_cast(p.second)->stream(); + PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream)); + } + } +#endif +} + +void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { + for (auto &dev_ctx : dev_ctxes_) { + dev_ctx.second->Wait(); + } + } else { + auto stream = + static_cast(waited_dev)->stream(); + for (auto &ev : events_) { + PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); + } + } +#else + for (auto &dev_ctx : dev_ctxes_) { + dev_ctx.second->Wait(); + } +#endif +} + +void OpHandleBase::AddInput(VarHandleBase *in) { + this->inputs_.emplace_back(in); + in->pending_ops_.insert(this); +} + +void OpHandleBase::AddOutput(VarHandleBase *out) { + outputs_.emplace_back(out); + out->generated_op_ = this; +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h new file mode 100644 index 0000000000000..71672fd24c65e --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -0,0 +1,62 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace framework { +namespace details { + +class OpHandleBase { + private: + DISABLE_COPY_AND_ASSIGN(OpHandleBase); + + public: + std::vector inputs_; + std::vector outputs_; + std::unordered_map + dev_ctxes_; + +#ifdef PADDLE_WITH_CUDA + std::unordered_map events_; +#endif + + OpHandleBase() {} + + std::string DebugString() const; + + virtual std::string Name() const = 0; + + virtual ~OpHandleBase(); + + void Run(bool use_event); + + virtual void Wait(platform::DeviceContext *waited_dev); + + void AddInput(VarHandleBase *in); + + void AddOutput(VarHandleBase *out); + + protected: + virtual void RunImpl() = 0; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc new file mode 100644 index 0000000000000..0a6f6129b812c --- /dev/null +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, + platform::Place place, + platform::DeviceContext *dev_ctx) + : coeff_(static_cast(1.0 / num_dev)), scope_(scope), place_(place) { + dev_ctxes_[place_] = dev_ctx; +} + +ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} + +void ScaleLossGradOpHandle::RunImpl() { + std::string var_name = static_cast(this->outputs_[0])->name_; + + float *tmp = + scope_->FindVar(var_name)->GetMutable()->mutable_data( + make_ddim({1}), place_); + + if (platform::is_cpu_place(place_)) { + *tmp = coeff_; + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = + static_cast(this->dev_ctxes_[place_]) + ->stream(); + memory::Copy(boost::get(place_), tmp, + platform::CPUPlace(), &coeff_, sizeof(float), stream); +#endif + } +} + +std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h new file mode 100644 index 0000000000000..ab7353a4fc56b --- /dev/null +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { + +struct ScaleLossGradOpHandle : public OpHandleBase { + float coeff_; + Scope *scope_; + platform::Place place_; + + ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, + platform::DeviceContext *context); + + ~ScaleLossGradOpHandle() final; + + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc new file mode 100644 index 0000000000000..1b8c889449059 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h new file mode 100644 index 0000000000000..ac3e2d86993ae --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph.h @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +struct SSAGraph { + std::vector>> vars_; + // aux variables to represent dependency. Useful to resolve data hazard. + std::unordered_set> dep_vars_; + std::vector> ops_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc new file mode 100644 index 0000000000000..361ba6d39721e --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + continue; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + auto *write_op = it_new->second.generated_op_; + auto &read_ops = it_old->second.pending_ops_; + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + + auto *dep_var = new DummyVarHandle(); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->dep_vars_.emplace(dep_var); + } + } + } + } +} + +VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( + SSAGraph *graph, const std::string &each_var_name, + const platform::Place &place, size_t place_offset) { + auto &var_holders = graph->vars_[place_offset]; + auto &var_holder = var_holders[each_var_name]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + auto &init_var = var_holder[0]; + init_var.place_ = place; + init_var.name_ = each_var_name; + init_var.generated_op_ = nullptr; + init_var.version_ = 0; + var = &init_var; + } else { + var = &var_holder.rbegin()->second; + } + return var; +} + +void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset) { + auto &vars = graph->vars_[place_offset][each_var_name]; + size_t version = vars.size(); + auto &var = vars[version]; + var.version_ = version; + var.name_ = each_var_name; + var.place_ = place; + op_handle->AddOutput(&var); +} + +template +void IterAllVar(const SSAGraph &graph, Callback callback) { + for (auto &each : graph.vars_) { + for (auto &pair1 : each) { + for (auto &pair2 : pair1.second) { + callback(pair2.second); + } + } + } + + for (auto &var : graph.dep_vars_) { + callback(*var); + } +} + +void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { + size_t var_id = 0; + std::unordered_map vars; + + sout << "digraph G {\n"; + + IterAllVar(graph, [&](const VarHandleBase &var) { + auto *var_ptr = &var; + auto *var_handle_ptr = dynamic_cast(var_ptr); + auto *dummy_ptr = dynamic_cast(var_ptr); + + size_t cur_var_id = var_id++; + vars[var_ptr] = cur_var_id; + + if (var_handle_ptr) { + sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ + << "\\n" + << var_handle_ptr->place_ << "\\n" + << var_handle_ptr->version_ << "\"]" << std::endl; + } else if (dummy_ptr) { + sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; + } + }); + + size_t op_id = 0; + for (auto &op : graph.ops_) { + std::string op_name = "op_" + std::to_string(op_id++); + sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" + << std::endl; + for (auto in : op->inputs_) { + std::string var_name = "var_" + std::to_string(vars[in]); + sout << var_name << " -> " << op_name << std::endl; + } + + for (auto out : op->outputs_) { + std::string var_name = "var_" + std::to_string(vars[out]); + sout << op_name << " -> " << var_name << std::endl; + } + } + + sout << "}\n"; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h new file mode 100644 index 0000000000000..bf20e7164a100 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/place.h" + +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +class SSAGraphBuilder { + public: + SSAGraphBuilder() {} + virtual ~SSAGraphBuilder() {} + virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + + DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); + + protected: + /** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ + static void PolishGraphToSupportDataHazards(SSAGraph *graph); + + static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset); + + static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, size_t place_offset); + + static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc new file mode 100644 index 0000000000000..8da6ca889b899 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +namespace details { + +SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr &&graph) + : graph_(std::move(graph)) {} + +SSAGraphExecutor::~SSAGraphExecutor() {} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h new file mode 100644 index 0000000000000..3b818b1a45b56 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/feed_fetch_type.h" + +namespace paddle { +namespace framework { +namespace details { + +class SSAGraphExecutor { + DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); + + public: + // Steal graph inside + explicit SSAGraphExecutor(std::unique_ptr &&graph); + + virtual ~SSAGraphExecutor(); + + virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; + + protected: + std::unique_ptr graph_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc new file mode 100644 index 0000000000000..3f8655147b688 --- /dev/null +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -0,0 +1,205 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" + +#include "paddle/fluid/framework/details/fetch_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { +ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( + size_t num_threads, bool use_event, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph) + : SSAGraphExecutor(std::move(graph)), + pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), + local_scopes_(local_scopes), + places_(places), + fetch_ctxs_(places), + use_event_(use_event) {} + +FeedFetchList ThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + std::unordered_map pending_ops; + std::unordered_set pending_vars; + + BlockingQueue ready_vars; + + std::unordered_set ready_ops; + + auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { + pending_vars.insert(&var); + if (var.generated_op_ == nullptr) { + ready_vars.Push(&var); + } + }; + + auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { + pending_ops.insert({&op_instance, op_instance.inputs_.size()}); + }; + + // Transform SSAGraph to pending_ops & pending_vars + for (auto &var_map : graph_->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + InsertPendingVar(version_pair.second); + } + } + } + for (auto &var : graph_->dep_vars_) { + InsertPendingVar(*var); + } + + for (auto &op : graph_->ops_) { + if (op->inputs_.empty()) { // Special case, Op has no input. + ready_ops.insert(op.get()); + } else { + InsertPendingOp(*op); + } + } + + // Step 2. Insert FetchOps + std::vector> fetch_ops; + std::vector dummy_vars; + FeedFetchList fetch_data(fetch_tensors.size()); + + std::unordered_map> fetched_vars; + + for (auto &fetch_var_name : fetch_tensors) { + for (auto &var_map : graph_->vars_) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); + } + } + } + + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors[i]; + auto &vars = fetched_vars.at(var_name); + auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); + fetch_ops.emplace_back(op); + + // FIXME: Use new device context + for (auto &p : places_) { + op->dev_ctxes_[p] = fetch_ctxs_.Get(p); + } + + for (auto *var : vars) { + op->AddInput(var); + } + InsertPendingOp(*op); + } + + auto run_all_ready_ops = [&] { + for (auto *op : ready_ops) { + RunOp(ready_vars, op); + } + ready_ops.clear(); + }; + + // Create local scopes. + for (auto &scope : local_scopes_) { + auto &local_scope = scope->NewScope(); + *scope->Var("@TMP_SCOPE@")->GetMutable() = &local_scope; + } + + // Step 3. Execution + while (!pending_vars.empty()) { + // 1. Run All Ready ops + run_all_ready_ops(); + + // 2. Find ready variable + bool timeout; + auto cur_ready_vars = ready_vars.PopAll(1000, &timeout); + + if (timeout) { + if (exception_) { + throw * exception_; + } else { + continue; + } + } + // 3. Remove the dependency of ready_var. + // Find the ready_ops after the ready_var. + for (auto ready_var : cur_ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + // Keep loop until all vars are ready. + } + + ++computation_count_; + + auto sync_computation = [&] { + computation_count_ = 0; + // Wait All computational streams + for (auto p : this->places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } + for (auto &scope : local_scopes_) { + scope->DropKids(); + } + }; + + // Wait FetchOps. + if (!fetch_ops.empty()) { + fetch_ops.clear(); + sync_computation(); + } + + if (computation_count_ == max_async_computation) { + sync_computation(); + } + + // NOTE: the temp scope can be dropped lazily if needed. + // Drop tmp scopes; + for (auto &scope : local_scopes_) { + auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable(); + kid = nullptr; + } + + return fetch_data; +} + +void ThreadedSSAGraphExecutor::RunOp( + BlockingQueue &ready_var_q, details::OpHandleBase *op) { + auto op_run = [&ready_var_q, op, this] { + try { + VLOG(10) << op->Name() << " : " << op->DebugString(); + op->Run(use_event_); + ready_var_q.Extend(op->outputs_); + } catch (platform::EnforceNotMet ex) { + exception_.reset(new platform::EnforceNotMet(ex)); + } catch (...) { + LOG(FATAL) << "Unknown exception catched"; + } + }; + if (pool_) { + pool_->enqueue(op_run); + } else { + op_run(); + } +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h new file mode 100644 index 0000000000000..2ea57ac8f96bc --- /dev/null +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -0,0 +1,99 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "ThreadPool.h" // ThreadPool in thrird party +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace details { + +template +class BlockingQueue { + public: + void Push(const T &item) { + { + std::lock_guard g(mutex_); + q_.emplace_back(item); + } + cv_.notify_one(); + } + + template + void Extend(const U &items) { + { + std::lock_guard g(mutex_); + for (auto &item : items) { + q_.emplace_back(item); + } + } + cv_.notify_all(); + } + + std::deque PopAll(size_t ms, bool *timeout) { + auto time = + std::chrono::system_clock::now() + std::chrono::milliseconds(ms); + std::unique_lock lock(mutex_); + *timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); }); + std::deque ret; + if (!*timeout) { + std::swap(ret, q_); + } + return ret; + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + std::deque q_; +}; + +class ThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph); + + // Run a SSAGraph by a thread pool + // Use topological sort algorithm + FeedFetchList Run(const std::vector &fetch_tensors) override; + + ~ThreadedSSAGraphExecutor() {} + + private: + void RunOp(BlockingQueue &ready_var_q, + details::OpHandleBase *op); + + private: + std::unique_ptr<::ThreadPool> pool_; + std::vector local_scopes_; + std::vector places_; + platform::DeviceContextPool fetch_ctxs_; + const bool use_event_; + std::unique_ptr exception_; + + size_t computation_count_{0}; + size_t max_async_computation{100}; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/var_handle.cc b/paddle/fluid/framework/details/var_handle.cc new file mode 100644 index 0000000000000..6f00abd9473a8 --- /dev/null +++ b/paddle/fluid/framework/details/var_handle.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/var_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +VarHandleBase::~VarHandleBase() {} + +std::string VarHandle::DebugString() const { + std::stringstream ss; + ss << name_ << ":" << place_; + return ss.str(); +} + +std::string DummyVarHandle::DebugString() const { return "dummy"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h new file mode 100644 index 0000000000000..893cc15f6c8b3 --- /dev/null +++ b/paddle/fluid/framework/details/var_handle.h @@ -0,0 +1,64 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace details { +struct OpHandleBase; + +// VarHandleBase is the var node in the dependency graph. +// A variable can only be generated by a single operator. i.e. +// This is a single assignment graph. +struct VarHandleBase { + virtual ~VarHandleBase(); + virtual std::string DebugString() const = 0; + + // The operator who generate this variable. nullptr if the variable + // is a root node. + OpHandleBase *generated_op_; + + // Operators which depend on this variable ready. + std::unordered_set pending_ops_; +}; + +// VarHandle is actually a single version of Runtime Variable. +// Variable in Runtime mapped to many VarHandles in Graph. +// Each assignment will generate a new var handle with newer version. +// +// NOTE: runtime variables have place. +struct VarHandle : public VarHandleBase { + std::string DebugString() const override; + + // version field currently is not used, however, just store the version to + // debug easily. + size_t version_; + std::string name_; + platform::Place place_; +}; + +// Dummy Variable. It is used to represent dependencies between operators +struct DummyVarHandle : public VarHandleBase { + std::string DebugString() const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0b171e1dcfa90..64c06687b6b90 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -46,7 +46,7 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} -static void CreateTensor(Variable* var, proto::VarType::Type var_type) { +void InitializeVariable(Variable* var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); } else if (var_type == proto::VarType::SELECTED_ROWS) { @@ -294,12 +294,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, if (var->Persistable()) { auto* ptr = scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); + InitializeVariable(ptr, var->GetType()); VLOG(3) << "Create Variable " << var->Name() << " global, which pointer is " << ptr; } else { auto* ptr = local_scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); + InitializeVariable(ptr, var->GetType()); VLOG(3) << "Create Variable " << var->Name() << " locally, which pointer is " << ptr; } @@ -307,7 +307,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } else { for (auto& var : block.AllVars()) { auto* ptr = local_scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); + InitializeVariable(ptr, var->GetType()); VLOG(3) << "Create variable " << var->Name() << ", which pointer is " << ptr; } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index d8dd82469af06..7173c51c95e04 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace framework { +extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc new file mode 100644 index 0000000000000..8a90f231d741b --- /dev/null +++ b/paddle/fluid/framework/parallel_executor.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/parallel_executor.h" + +#include "ThreadPool.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" + +namespace paddle { +namespace framework { + +class ParallelExecutorPrivate { + public: + explicit ParallelExecutorPrivate(const std::vector &places) + : places_(places) {} + + std::vector places_; + std::vector local_scopes_; + Scope *global_scope_; + std::unique_ptr executor_; + +#ifdef PADDLE_WITH_CUDA + std::unique_ptr nccl_ctxs_; +#endif +}; + +ParallelExecutor::ParallelExecutor( + size_t num_threads, bool use_event, + const std::vector &places, + const std::unordered_set ¶ms, + const ProgramDesc &startup_program, const ProgramDesc &main_program, + const std::string &loss_var_name, Scope *scope) + : member_(new ParallelExecutorPrivate(places)) { + member_->global_scope_ = scope; + + // Step 1. RunStartupProgram and Bcast the params to devs. + Executor exe(places[0]); + exe.Run(startup_program, scope, 0); + // Create local scopes + for (size_t i = 0; i < member_->places_.size(); ++i) { + member_->local_scopes_.push_back(&scope->NewScope()); + } + +// Bcast Parameters to all GPUs +#ifdef PADDLE_WITH_CUDA + member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); +#endif + if (platform::is_gpu_place(places[0]) && + member_->local_scopes_.size() != 1) { // Is CUDA + BCastParamsToGPUs(startup_program); + } +// Startup Program has been run. All local scopes has correct parameters. + +// Step 2. Convert main_program to SSA form and dependency graph. Also, insert +// ncclOp +#ifdef PADDLE_WITH_CUDA + details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, + params, member_->local_scopes_, + member_->nccl_ctxs_.get()); +#else + details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, + params, member_->local_scopes_); +#endif + auto graph = builder.Build(main_program); + + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( + num_threads, use_event, member_->local_scopes_, places, + std::move(graph))); + + // Step 3. Create vars in each scope; + for (auto *scope : member_->local_scopes_) { + for (auto *var : main_program.Block(0).AllVars()) { + if (scope->FindVar(var->Name()) != nullptr) { + continue; + } + + InitializeVariable(scope->Var(var->Name()), var->GetType()); + } + } +} + +void ParallelExecutor::BCastParamsToGPUs( + const ProgramDesc &startup_program) const { +#ifdef PADDLE_WITH_CUDA + auto *main_scope = member_->local_scopes_[0]; + + for (auto *var_desc : startup_program.Block(0).AllVars()) { + if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { + auto &main_tensor = + main_scope->FindVar(var_desc->Name())->Get(); + ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); + auto &dims = main_tensor.dims(); + size_t numel = main_tensor.numel(); + + platform::NCCLGroupGuard guard; + + for (size_t i = 0; i < member_->places_.size(); ++i) { + auto place = member_->places_[i]; + void *buffer; + if (i == 0) { + buffer = const_cast(main_tensor.data()); + } else { + auto local_scope = member_->local_scopes_[i]; + auto *t = local_scope->Var(var_desc->Name())->GetMutable(); + t->Resize(dims); + buffer = t->mutable_data(place, main_tensor.type()); + } + + auto &nccl_ctx = member_->nccl_ctxs_->at(place); + platform::dynload::ncclBcast(buffer, numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } + } + member_->nccl_ctxs_->WaitAll(); + } +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif +} + +void ParallelExecutor::Run(const std::vector &fetch_tensors, + const std::string &fetched_var_name) { + auto fetch_data = member_->executor_->Run(fetch_tensors); + *member_->global_scope_->Var(fetched_var_name)->GetMutable() = + fetch_data; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h new file mode 100644 index 0000000000000..503efa2e447b0 --- /dev/null +++ b/paddle/fluid/framework/parallel_executor.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { + +class ParallelExecutorPrivate; + +class ParallelExecutor { + DISABLE_COPY_AND_ASSIGN(ParallelExecutor); + + public: + explicit ParallelExecutor(size_t num_threads, bool use_event, + const std::vector& places, + const std::unordered_set& params, + const ProgramDesc& startup_program, + const ProgramDesc& main_program, + const std::string& loss_var_name, Scope* scope); + + void Run(const std::vector& fetch_tensors, + const std::string& fetched_var_name = "fetched_var"); + + private: + ParallelExecutorPrivate* member_; + + void BCastParamsToGPUs(const ProgramDesc& startup_program) const; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index fa00c08e0d579..56bf00e5f9170 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -29,7 +29,7 @@ void FileReader::ReadNext(std::vector *out) { PADDLE_ENFORCE_EQ(actual.size(), expect.size()); for (int j = 0; j < actual.size(); ++j) { - PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); + // PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index df51fb24a588c..f9dce7105e32f 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -32,6 +32,8 @@ namespace framework { // number of threads. class ThreadPool { public: + explicit ThreadPool(int num_threads); + using Task = std::packaged_task()>; // Returns the singleton of ThreadPool. @@ -103,8 +105,6 @@ class ThreadPool { DISABLE_COPY_AND_ASSIGN(ThreadPool); - explicit ThreadPool(int num_threads); - // If the task queue is empty and avaialbe is equal to the number of // threads, means that all tasks are completed. Note: this function // is not thread-safe. Returns true if all tasks are completed. diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 71d28dcbade1b..22f6f50674873 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -79,7 +79,18 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) { // if size is 0. We just make sure it does. if (size <= 0) return nullptr; void* p; + int prev_id; + cudaGetDevice(&prev_id); + if (prev_id != gpu_id_) { + cudaSetDevice(gpu_id_); + } + cudaError_t result = cudaMalloc(&p, size); + + if (prev_id != gpu_id_) { + cudaSetDevice(prev_id); + } + if (result == cudaSuccess) { index = 0; gpu_alloc_size_ += size; diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index 3e024125fabb8..e8479e73f433f 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -43,6 +43,8 @@ class CPUAllocator : public SystemAllocator { #ifdef PADDLE_WITH_CUDA class GPUAllocator : public SystemAllocator { public: + explicit GPUAllocator(int gpu_id) : gpu_id_(gpu_id) {} + virtual void* Alloc(size_t& index, size_t size); virtual void Free(void* p, size_t size, size_t index); virtual bool UseGpu() const; @@ -50,6 +52,7 @@ class GPUAllocator : public SystemAllocator { private: size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; + int gpu_id_; }; class CUDAPinnedAllocator : public SystemAllocator { diff --git a/paddle/fluid/memory/detail/system_allocator_test.cc b/paddle/fluid/memory/detail/system_allocator_test.cc index d5df9e6897e9e..3e1926f632c57 100644 --- a/paddle/fluid/memory/detail/system_allocator_test.cc +++ b/paddle/fluid/memory/detail/system_allocator_test.cc @@ -58,7 +58,7 @@ TEST(CPUAllocator, LockMem) { #ifdef PADDLE_WITH_CUDA TEST(GPUAllocator, Alloc) { - paddle::memory::detail::GPUAllocator a; + paddle::memory::detail::GPUAllocator a(0); TestAllocator(a, 2048); TestAllocator(a, 0); } diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index f2d5f250bfb56..56593653a622b 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -71,7 +71,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { } platform::SetDeviceId(gpu_id); if (!as[gpu_id]) { - as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator, + as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator(gpu_id), platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()); VLOG(10) << "\n\nNOTE: each GPU device use " diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index 22147d79e4b1e..c0e983e4aa7ab 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 2a5605e0d378a..2925b8a85da1b 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { @@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { framework::ReaderHolder* reader = - scope.FindVar(Input("Reader"))->GetMutable(); + detail::Ref(scope.FindVar(Input("Reader")), + "Cannot find reader variable %s", Input("Reader")) + .GetMutable(); std::vector out_arg_names = Outputs("Out"); std::vector ins; reader->ReadNext(&ins); diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index c4aa29c7206db..adaa0b9e5f1ff 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/recordio/scanner.h" namespace paddle { namespace operators { namespace reader { +template class RecordIOFileReader : public framework::FileReader { public: explicit RecordIOFileReader(const std::string& filename, @@ -25,7 +28,12 @@ class RecordIOFileReader : public framework::FileReader { : FileReader(dims), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( - platform::CPUPlace())) {} + platform::CPUPlace())) { + if (ThreadSafe) { + mutex_.reset(new std::mutex()); + } + LOG(INFO) << "Creating file reader" << filename; + } bool HasNext() const override { return scanner_.HasNext(); } @@ -33,10 +41,16 @@ class RecordIOFileReader : public framework::FileReader { protected: void ReadNextImpl(std::vector* out) override { - *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + if (ThreadSafe) { + std::lock_guard guard(*mutex_); + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } else { + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } } private: + std::unique_ptr mutex_; recordio::Scanner scanner_; const platform::DeviceContext& dev_ctx_; }; @@ -59,8 +73,9 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset( - new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks))); + + out->Reset(new RecordIOFileReader( + filename, RestoreShapes(shape_concat, ranks))); } }; @@ -87,4 +102,4 @@ REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader, reader::CreateRecordIOReaderOp, reader::CreateRecordIOReaderOpMaker); -REGISTER_FILE_READER(recordio, reader::RecordIOFileReader); +REGISTER_FILE_READER(recordio, reader::RecordIOFileReader); diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h new file mode 100644 index 0000000000000..2999004320650 --- /dev/null +++ b/paddle/fluid/platform/nccl_helper.h @@ -0,0 +1,137 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/platform/dynload/nccl.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +inline ncclDataType_t ToNCCLDataType(std::type_index type) { + if (type == typeid(float)) { // NOLINT + return ncclFloat; + } else if (type == typeid(double)) { // NOLINT + return ncclDouble; + } else if (type == typeid(int)) { // NOLINT + return ncclInt; + } else { + PADDLE_THROW("Not supported"); + } +} + +class NCCLGroupGuard { + public: + inline NCCLGroupGuard() { + mutex().lock(); + PADDLE_ENFORCE(dynload::ncclGroupStart()); + } + + inline ~NCCLGroupGuard() { + PADDLE_ENFORCE(dynload::ncclGroupEnd()); + mutex().unlock(); + } + + private: + static std::mutex &mutex() { + static std::mutex mtx; + return mtx; + } +}; + +struct NCCLContext { + std::unique_ptr ctx_; + ncclComm_t comm_; + + explicit NCCLContext(int dev_id) + : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {} + + cudaStream_t stream() const { return ctx_->stream(); } + + int device_id() const { + return boost::get(ctx_->GetPlace()).device; + } + + static void InitNCCLContext(std::unordered_map &contexts, + const std::vector &places) { + std::vector comms; + std::vector devs; + comms.resize(contexts.size()); + devs.reserve(contexts.size()); + + for (auto &p : places) { + devs.push_back(boost::get(p).device); + } + + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(contexts.size()), &devs[0])); + + int i = 0; + for (auto &dev_id : devs) { + contexts.at(dev_id).comm_ = comms[i++]; + } + } +}; + +struct NCCLContextMap { + std::unordered_map contexts_; + std::vector order_; + + NCCLContextMap(const std::vector &places) { + order_.reserve(places.size()); + for (auto &p : places) { + int dev_id = boost::get(p).device; + order_.emplace_back(dev_id); + contexts_.emplace(dev_id, NCCLContext(dev_id)); + } + PADDLE_ENFORCE_EQ( + order_.size(), contexts_.size(), + "NCCL Context Map does not support contain two or more same device"); + + std::vector comms; + comms.resize(order_.size()); + + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(order_.size()), &order_[0])); + + int i = 0; + for (auto &dev_id : order_) { + contexts_.at(dev_id).comm_ = comms[i++]; + } + } + + CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } + + CUDADeviceContext *DevCtx(platform::Place p) const { + return DevCtx(boost::get(p).device); + } + + const NCCLContext &at(platform::Place p) const { + return this->at(boost::get(p).device); + } + + const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); } + + void WaitAll() { + for (auto &p : contexts_) { + p.second.ctx_->Wait(); + } + } +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler_test.cc b/paddle/fluid/platform/profiler_test.cc index fc77e0f3213da..366c82bf96e41 100644 --- a/paddle/fluid/platform/profiler_test.cc +++ b/paddle/fluid/platform/profiler_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/profiler.h" +#include "cuda_runtime.h" #include "gtest/gtest.h" TEST(Event, CpuElapsedTime) { @@ -157,3 +158,11 @@ TEST(RecordEvent, RecordEvent) { // Will remove parsing-related code from test later DisableProfiler(EventSortingKey::kTotal, "/tmp/profiler"); } + +TEST(TMP, stream_wait) { + cudaStream_t stream; + cudaStreamCreate(&stream); + cudaStreamSynchronize(stream); + cudaStreamSynchronize(stream); + cudaStreamSynchronize(stream); +} diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index fe991033dfc2a..ada69ea4a425f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -3,11 +3,13 @@ if(WITH_PYTHON) hip_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + parallel_executor ${GLOB_OP_LIB}) else() cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + parallel_executor ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) target_link_libraries(paddle_pybind rt) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6c05442466f5f..e1b1bbec97985 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" @@ -496,6 +497,20 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("reset_profiler", platform::ResetProfiler); + py::class_(m, "ParallelExecutor") + .def("__init__", + [](ParallelExecutor &self, size_t num_threads, bool use_event, + const std::vector &places, + const std::unordered_set ¶ms, + const ProgramDesc &startup_program, + const ProgramDesc &main_program, const std::string &loss_var_name, + Scope *scope) { + new (&self) ParallelExecutor(num_threads, use_event, places, + params, startup_program, main_program, + loss_var_name, scope); + }) + .def("run", &ParallelExecutor::Run); + BindRecordIOWriter(m); return m.ptr(); } diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index fcea282204850..5ea4d977f4d8d 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -41,6 +41,7 @@ import profiler import unique_name import recordio_writer +from parallel_executor import ParallelExecutor Tensor = LoDTensor @@ -68,6 +69,7 @@ 'profiler', 'unique_name', 'recordio_writer', + 'ParallelExecutor', ] diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py new file mode 100644 index 0000000000000..5e0588fa73241 --- /dev/null +++ b/python/paddle/fluid/parallel_executor.py @@ -0,0 +1,62 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import core +import multiprocessing +import framework +import executor + +__all__ = ['ParallelExecutor'] + + +class ParallelExecutor(object): + def __init__(self, loss_name, use_cuda, num_threads=None): + places = [] + if use_cuda: + for i in xrange(core.get_cuda_device_count()): + p = core.Place() + p.set_place(core.CUDAPlace(i)) + places.append(p) + else: + for i in xrange(multiprocessing.cpu_count()): + p = core.Place() + p.set_place(core.CPUPlace()) + places.append(p) + + if num_threads is None: + num_threads = min(len(places) * 2, multiprocessing.cpu_count()) + + startup = framework.default_startup_program() + main = framework.default_main_program() + scope = executor.global_scope() + + self.executor = core.ParallelExecutor( + num_threads, + True if use_cuda else False, # use_event + places, + set([ + p.name for p in main.global_block().iter_parameters() + if not p.stop_gradient + ]), + startup.desc, + main.desc, + loss_name, + scope) + self.scope = scope + + def run(self, fetch_list): + fetch_var_name = '@FETCHED_VAR_NAME@' + self.executor.run(fetch_list, fetch_var_name) + arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() + return [arr[i] for i in range(len(arr))] diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index ad02bdecf436b..3538a9c2009bb 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -2,3 +2,5 @@ mnist.recordio mnist_0.recordio mnist_1.recordio mnist_2.recordio +flowers.recordio +wmt16.recordio diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py new file mode 100644 index 0000000000000..bbfd03c638dac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -0,0 +1,429 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist +import paddle.v2.dataset.wmt16 as wmt16 + + +def simple_fc_net(): + reader = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(reader) + hidden = img + for _ in xrange(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def fc_with_batchnorm(): + reader = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(reader) + hidden = img + for _ in xrange(1): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def squeeze_excitation(input, num_channels, reduction_ratio): + # pool = fluid.layers.pool2d( + # input=input, pool_size=0, pool_type='avg', global_pooling=True) + conv = input + shape = conv.shape + reshape = fluid.layers.reshape( + x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) + pool = fluid.layers.reduce_mean(input=reshape, dim=2) + + squeeze = fluid.layers.fc(input=pool, + size=num_channels / reduction_ratio, + act='relu') + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act='sigmoid') + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + +def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) / 2, + groups=groups, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=conv, act=act, momentum=0.1) + + +def shortcut(input, ch_out, stride): + ch_in = input.shape[1] + if ch_in != ch_out: + if stride == 1: + filter_size = 1 + else: + filter_size = 3 + return conv_bn_layer(input, ch_out, filter_size, stride) + else: + return input + + +def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): + # The number of first 1x1 convolutional channels for each bottleneck build block + # was halved to reduce the compution cost. + conv0 = conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu') + conv1 = conv_bn_layer( + input=conv0, + num_filters=num_filters * 2, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu') + conv2 = conv_bn_layer( + input=conv1, num_filters=num_filters * 2, filter_size=1, act=None) + scale = squeeze_excitation( + input=conv2, + num_channels=num_filters * 2, + reduction_ratio=reduction_ratio) + + short = shortcut(input, num_filters * 2, stride) + + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + +def SE_ResNeXt152(batch_size=4): + img = fluid.layers.fill_constant( + shape=[batch_size, 3, 224, 224], dtype='float32', value=0.0) + label = fluid.layers.fill_constant( + shape=[batch_size, 1], dtype='int64', value=0.0) + + conv = conv_bn_layer( + input=img, num_filters=64, filter_size=3, stride=2, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=128, filter_size=3, stride=1, act='relu') + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + + for block in range(len(depth)): + for i in range(depth[block]): + conv = bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + reduction_ratio=reduction_ratio) + + shape = conv.shape + reshape = fluid.layers.reshape( + x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) + pool = fluid.layers.reduce_mean(input=reshape, dim=2) + dropout = fluid.layers.dropout(x=pool, dropout_prob=0.2) + # Classifier layer: + prediction = fluid.layers.fc(input=dropout, size=1000, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +import time + + +class TestParallelExecutorBase(unittest.TestCase): + def check_network_convergence(self, + method, + memory_opt=True, + iter=10, + batch_size=None): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = method() + adam = fluid.optimizer.Adam() + adam.minimize(loss) + if memory_opt: + fluid.memory_optimize(main) + + exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) + if batch_size is not None: + batch_size *= fluid.core.get_cuda_device_count() + begin = time.time() + first_loss, = exe.run([loss.name]) + first_loss = numpy.array(first_loss) + + for i in xrange(iter): + exe.run([]) + + last_loss, = exe.run([loss.name]) + end = time.time() + + if batch_size is not None: + print "%.4f Instance per second" % ( + (batch_size * iter + 2) / (end - begin)) + + last_loss = numpy.array(last_loss) + + print first_loss, last_loss + # self.assertGreater(first_loss[0], last_loss[0]) + + +class TestMNIST(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=32) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist.recordio', reader, feeder) + + def test_simple_fc(self): + self.check_network_convergence(simple_fc_net) + + def test_batchnorm_fc(self): + self.check_network_convergence(fc_with_batchnorm) + + +class TestResnet(TestParallelExecutorBase): + # @classmethod + # def setUpClass(cls): + # # import os + # # if os.path.exists('./flowers.recordio'): + # # return + # with fluid.program_guard(fluid.Program(), fluid.Program()): + # reader = paddle.batch(flowers.train(), batch_size=4) + # feeder = fluid.DataFeeder( + # feed_list=[ + # fluid.layers.data( + # name='image', shape=[3, 224, 224]), + # fluid.layers.data( + # name='label', shape=[1], dtype='int64'), + # ], + # place=fluid.CPUPlace()) + # fluid.recordio_writer.convert_reader_to_recordio_file( + # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) + + def test_resnet(self): + import functools + batch_size = 4 + self.check_network_convergence( + functools.partial( + SE_ResNeXt152, batch_size=batch_size), + iter=20, + batch_size=batch_size) + + +class ModelHyperParams(object): + # Dictionary size for source and target language. This model directly uses + # paddle.dataset.wmt16 in which , and token has + # alreay been added, but the token is not added. Transformer requires + # sequences in a mini-batch are padded to have the same length. A token is + # added into the original dictionary in paddle.dateset.wmt16. + + # size of source word dictionary. + src_vocab_size = 10000 + # index for token in source language. + src_pad_idx = src_vocab_size + + # size of target word dictionay + trg_vocab_size = 10000 + # index for token in target language. + trg_pad_idx = trg_vocab_size + + # position value corresponding to the token. + pos_pad_idx = 0 + + # max length of sequences. It should plus 1 to include position + # padding token for position encoding. + max_length = 50 + + # the dimension for word embeddings, which is also the last dimension of + # the input and output of multi-head attention, position-wise feed-forward + # networks, encoder and decoder. + + d_model = 512 + # size of the hidden layer in position-wise feed-forward networks. + d_inner_hid = 1024 + # the dimension that keys are projected to for dot-product attention. + d_key = 64 + # the dimension that values are projected to for dot-product attention. + d_value = 64 + # number of head used in multi-head attention. + n_head = 8 + # number of sub-layers to be stacked in the encoder and decoder. + n_layer = 6 + # dropout rate used by all dropout layers. + dropout = 0.1 + + +import numpy as np + + +def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. Then, convert the numpy + data to tensors and return a dict mapping names to tensors. + """ + + def __pad_batch_data(insts, + pad_idx, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if return_pos: + inst_pos = np.array([[ + pos_i + 1 if w_i != pad_idx else 0 + for pos_i, w_i in enumerate(inst) + ] for inst in inst_data]) + + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, + max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( + [-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * + (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + return return_list if len(return_list) > 1 else return_list[0] + + def data_to_tensor(data_list, name_list, input_dict, place): + assert len(data_list) == len(name_list) + for i in range(len(name_list)): + tensor = fluid.LoDTensor() + tensor.set(data_list[i], place) + input_dict[name_list[i]] = tensor + + src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, is_target=False) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, is_target=True) + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, trg_max_len, 1]).astype("float32") + lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, + False, False, False) + lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) + + return [ + src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + ] + + +import transformer_model + + +def transformer(): + return transformer_model.transformer( + ModelHyperParams.src_vocab_size + 1, + ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, ModelHyperParams.n_head, + ModelHyperParams.d_key, ModelHyperParams.d_value, + ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, + ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) + + +class TestTransformer(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + reader = paddle.batch( + wmt16.train(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + batch_size=transformer_model.batch_size) + + with fluid.recordio_writer.create_recordio_writer( + "./wmt16.recordio") as writer: + for batch in reader(): + for tensor in prepare_batch_input( + batch, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head): + t = fluid.LoDTensor() + t.set(tensor, fluid.CPUPlace()) + writer.append_tensor(t) + writer.complete_append_tensor() + + @unittest.skip("transformer is buggy in multi gpu") + def test_main(self): + self.check_network_convergence(transformer) diff --git a/python/paddle/fluid/tests/unittests/transformer_model.py b/python/paddle/fluid/tests/unittests/transformer_model.py new file mode 100644 index 0000000000000..c62792face3c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/transformer_model.py @@ -0,0 +1,487 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers + +pos_enc_param_names = ( + "src_pos_enc_table", + "trg_pos_enc_table", ) + +batch_size = 64 + + +def position_encoding_init(n_position, d_pos_vec): + """ + Generate the initial values for the sinusoid position encoding table. + """ + position_enc = np.array([[ + pos / np.power(10000, 2 * (j // 2) / d_pos_vec) + for j in range(d_pos_vec) + ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) + position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i + position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 + return position_enc.astype("float32") + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + n_head=1, + dropout_rate=0.): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * n_head, + param_attr=fluid.initializer.Xavier( + uniform=False, + fan_in=d_model * d_key, + fan_out=n_head * d_key), + bias_attr=False, + num_flatten_dims=2) + k = layers.fc(input=keys, + size=d_key * n_head, + param_attr=fluid.initializer.Xavier( + uniform=False, + fan_in=d_model * d_key, + fan_out=n_head * d_key), + bias_attr=False, + num_flatten_dims=2) + v = layers.fc(input=values, + size=d_value * n_head, + param_attr=fluid.initializer.Xavier( + uniform=False, + fan_in=d_model * d_value, + fan_out=n_head * d_value), + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads(x, n_head): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] then output a tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + if n_head == 1: + return x + + hidden_size = x.shape[-1] + # FIXME(guosheng): Decouple the program desc with batch_size. + reshaped = layers.reshape( + x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) + + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # FIXME(guosheng): Decouple the program desc with batch_size. + return layers.reshape( + x=trans_x, + shape=map(int, + [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): + """ + Scaled Dot-Product Attention + """ + + # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. + + # The current implementation of softmax_op only supports 2D tensor, + # consequently it cannot be directly used here. + # If to use the reshape_op, Besides, the shape of product inferred in + # compile-time is not the actual shape in run-time. It cann't be used + # to set the attribute of reshape_op. + # So, here define the softmax for temporary solution. + + def __softmax(x, eps=1e-9): + exp_out = layers.exp(x=x) + sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) + return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) + + scaled_q = layers.scale(x=q, scale=d_model**-0.5) + product = layers.matmul(x=scaled_q, y=k, transpose_y=True) + weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) + if dropout_rate: + weights = layers.dropout( + weights, dropout_prob=dropout_rate, is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + + q = __split_heads(q, n_head) + k = __split_heads(k, n_head) + v = __split_heads(v, n_head) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + param_attr=fluid.initializer.Xavier(uniform=False), + bias_attr=False, + num_flatten_dims=2) + return proj_out + + +def positionwise_feed_forward(x, d_inner_hid, d_hid): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=2, + param_attr=fluid.initializer.Uniform( + low=-(d_hid**-0.5), high=(d_hid**-0.5)), + act="relu") + out = layers.fc(input=hidden, + size=d_hid, + num_flatten_dims=2, + param_attr=fluid.initializer.Uniform( + low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5))) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.initializer.Constant(1.), + bias_attr=fluid.initializer.Constant(0.)) + elif cmd == "d": # add dropout + if dropout: + out = layers.dropout(out, dropout_prob=dropout, is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def prepare_encoder(src_word, + src_pos, + src_vocab_size, + src_emb_dim, + src_pad_idx, + src_max_len, + dropout=0., + pos_pad_idx=0, + pos_enc_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + + This module is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, + size=[src_vocab_size, src_emb_dim], + padding_idx=src_pad_idx, + param_attr=fluid.initializer.Normal(0., 1.)) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + padding_idx=pos_pad_idx, + param_attr=fluid.ParamAttr( + name=pos_enc_param_name, trainable=False)) + enc_input = src_word_emb + src_pos_enc + + # FIXME(guosheng): Decouple the program desc with batch_size. + enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + return layers.dropout( + enc_input, dropout_prob=dropout, + is_test=False) if dropout else enc_input + + +prepare_encoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]) +prepare_decoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) + + +def encoder_layer(enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """The encoder layers that can be stacked to form a deep encoder. + + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output = multi_head_attention(enc_input, enc_input, enc_input, + attn_bias, d_key, d_value, d_model, + n_head, dropout_rate) + attn_output = post_process_layer(enc_input, attn_output, "dan", + dropout_rate) + ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) + return post_process_layer(attn_output, ffd_output, "dan", dropout_rate) + + +def encoder(enc_input, + attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + for i in range(n_layer): + enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, + d_model, d_inner_hid, dropout_rate) + enc_input = enc_output + return enc_output + + +def decoder_layer(dec_input, + enc_output, + slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ The layer to be stacked in decoder part. + + The structure of this module is similar to that in the encoder part except + a multi-head attention is added to implement encoder-decoder attention. + """ + slf_attn_output = multi_head_attention( + dec_input, + dec_input, + dec_input, + slf_attn_bias, + d_key, + d_value, + d_model, + n_head, + dropout_rate, ) + slf_attn_output = post_process_layer( + dec_input, + slf_attn_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + enc_attn_output = multi_head_attention( + slf_attn_output, + enc_output, + enc_output, + dec_enc_attn_bias, + d_key, + d_value, + d_model, + n_head, + dropout_rate, ) + enc_attn_output = post_process_layer( + slf_attn_output, + enc_attn_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + ffd_output = positionwise_feed_forward( + enc_attn_output, + d_inner_hid, + d_model, ) + dec_output = post_process_layer( + enc_attn_output, + ffd_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + return dec_output + + +def decoder(dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ + The decoder is composed of a stack of identical decoder_layer layers. + """ + for i in range(n_layer): + dec_output = decoder_layer( + dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + dec_input = dec_output + return dec_output + + +def transformer( + src_vocab_size, + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + trg_pad_idx, + pos_pad_idx, ): + file_obj = fluid.layers.open_recordio_file( + filename='./wmt16.recordio', + shapes=[ + [batch_size * max_length, 1], + [batch_size * max_length, 1], + [batch_size * max_length, 1], + [batch_size * max_length, 1], + [batch_size, n_head, max_length, max_length], + [batch_size, n_head, max_length, max_length], + [batch_size, n_head, max_length, max_length], + [batch_size * max_length, 1], + [batch_size * max_length, 1], + ], + dtypes=[ + 'int64', + 'int64', + 'int64', + 'int64', + 'float32', + 'float32', + 'float32', + 'int64', + 'float32', + ], + lod_levels=[0] * 9) + + src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, trg_slf_attn_bias, trg_src_attn_bias, gold, weights = fluid.layers.read_file( + file_obj) + + enc_input = prepare_encoder( + src_word, + src_pos, + src_vocab_size, + d_model, + src_pad_idx, + max_length, + dropout_rate, ) + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + dec_input = prepare_decoder( + trg_word, + trg_pos, + trg_vocab_size, + d_model, + trg_pad_idx, + max_length, + dropout_rate, ) + dec_output = decoder( + dec_input, + enc_output, + trg_slf_attn_bias, + trg_src_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + # TODO(guosheng): Share the weight matrix between the embedding layers and + # the pre-softmax linear transformation. + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + param_attr=fluid.initializer.Xavier(uniform=False), + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax") + + cost = layers.cross_entropy(input=predict, label=gold) + weighted_cost = cost * weights + return layers.reduce_sum(weighted_cost)