diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 15e5574ecfd40..a4ea74a6d2fbc 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -103,4 +103,5 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op - channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc) + channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op + conditional_block_op while_op assign_op print_op executor proto_desc) diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 51e2b03f9cb9a..adfaba26ace78 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -162,24 +162,12 @@ class ChannelHolder { } } - template void RemoveFromSendQ(const void* referrer) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->RemoveFromSendQ(referrer); - } - } + if (IsInitialized()) holder_->RemoveFromSendQ(referrer); } - template void RemoveFromReceiveQ(const void* referrer) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->RemoveFromReceiveQ(referrer); - } - } + if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer); } inline bool IsInitialized() const { return holder_ != nullptr; } @@ -201,6 +189,8 @@ class ChannelHolder { virtual bool IsClosed() = 0; virtual bool CanSend() = 0; virtual bool CanReceive() = 0; + virtual void RemoveFromSendQ(const void* referrer) = 0; + virtual void RemoveFromReceiveQ(const void* referrer) = 0; virtual void Close() = 0; virtual void Lock() = 0; virtual void Unlock() = 0; @@ -238,6 +228,18 @@ class ChannelHolder { return false; } + virtual void RemoveFromSendQ(const void* referrer) { + if (channel_) { + channel_->RemoveFromSendQ(referrer); + } + } + + virtual void RemoveFromReceiveQ(const void* referrer) { + if (channel_) { + channel_->RemoveFromReceiveQ(referrer); + } + } + virtual void Close() { if (channel_) channel_->Close(); } diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h index c194c03e264cc..457abbf373d45 100644 --- a/paddle/fluid/framework/channel_impl.h +++ b/paddle/fluid/framework/channel_impl.h @@ -151,7 +151,7 @@ bool ChannelImpl::Send(T *item) { // We do not care about notifying other // because they would have been notified // by the executed select case. - return Send(item); + return send_return(Send(item)); // Wake up the blocked process and unlock m->Notify(); @@ -214,7 +214,7 @@ bool ChannelImpl::Receive(T *item) { // We do not care about notifying other // because they would have been notified // by the executed select case. - return Receive(item); + return recv_return(Receive(item)); // Wake up the blocked process and unlock m->Notify(); @@ -331,7 +331,6 @@ void ChannelImpl::RemoveFromSendQ(const void *referrer) { if (sendMsg->referrer == referrer) { it = sendq.erase(it); - send_ctr--; } else { ++it; } @@ -347,7 +346,6 @@ void ChannelImpl::RemoveFromReceiveQ(const void *referrer) { if (recvMsg->referrer == referrer) { it = recvq.erase(it); - recv_ctr--; } else { ++it; } diff --git a/paddle/fluid/framework/concurrency_test.cc b/paddle/fluid/framework/concurrency_test.cc index 5770b0a5a1865..25152054eb845 100644 --- a/paddle/fluid/framework/concurrency_test.cc +++ b/paddle/fluid/framework/concurrency_test.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/program_desc.h" USE_NO_KERNEL_OP(go); USE_NO_KERNEL_OP(channel_close); @@ -27,6 +26,12 @@ USE_NO_KERNEL_OP(channel_create); USE_NO_KERNEL_OP(channel_recv); USE_NO_KERNEL_OP(channel_send); USE_NO_KERNEL_OP(elementwise_add); +USE_NO_KERNEL_OP(select); +USE_NO_KERNEL_OP(conditional_block); +USE_NO_KERNEL_OP(equal); +USE_NO_KERNEL_OP(assign); +USE_NO_KERNEL_OP(while); +USE_NO_KERNEL_OP(print); namespace f = paddle::framework; namespace p = paddle::platform; @@ -35,27 +40,15 @@ namespace paddle { namespace framework { template -void CreateIntVariable(Scope &scope, p::CPUPlace &place, std::string name, - T value) { - // Create LoDTensor of dim [1,1] +LoDTensor *CreateVariable(Scope &scope, p::CPUPlace &place, std::string name, + T value) { + // Create LoDTensor of dim [1] auto var = scope.Var(name); auto tensor = var->GetMutable(); - tensor->Resize({1, 1}); + tensor->Resize({1}); T *expect = tensor->mutable_data(place); expect[0] = value; -} - -void InitTensorsInScope(Scope &scope, p::CPUPlace &place) { - p::CPUDeviceContext ctx(place); - - // Create channel variable - scope.Var("Channel"); - - // Create Variables, x0 will be put into channel, - // result will be pulled from channel - CreateIntVariable(scope, place, "Status", false); - CreateIntVariable(scope, place, "x0", 99); - CreateIntVariable(scope, place, "result", 0); + return tensor; } void AddOp(const std::string &type, const VariableNameMap &inputs, @@ -73,12 +66,116 @@ void AddOp(const std::string &type, const VariableNameMap &inputs, op->SetAttrMap(attrs); } +void AddCase(ProgramDesc *program, Scope *scope, p::CPUPlace *place, + BlockDesc *casesBlock, int caseId, int caseType, + std::string caseChannel, std::string caseVarName, + std::function func) { + std::string caseCondName = std::string("caseCond") + std::to_string(caseId); + std::string caseCondXVarName = + std::string("caseCondX") + std::to_string(caseId); + + BlockDesc *caseBlock = program->AppendBlock(*casesBlock); + func(caseBlock, scope); + + CreateVariable(*scope, *place, caseCondName, false); + CreateVariable(*scope, *place, caseCondXVarName, caseId); + CreateVariable(*scope, *place, caseVarName, caseId); + + scope->Var("step_scope"); + + AddOp("equal", {{"X", {caseCondXVarName}}, {"Y", {"caseToExecute"}}}, + {{"Out", {caseCondName}}}, {}, casesBlock); + + AddOp("conditional_block", {{"X", {caseCondName}}, {"Params", {}}}, + {{"Out", {}}, {"Scope", {"step_scope"}}}, + {{"sub_block", caseBlock}, {"is_scalar_condition", true}}, casesBlock); +} + +void AddFibonacciSelect(Scope *scope, p::CPUPlace *place, ProgramDesc *program, + BlockDesc *parentBlock, std::string dataChanName, + std::string quitChanName) { + BlockDesc *whileBlock = program->AppendBlock(*parentBlock); + + CreateVariable(*scope, *place, "whileExitCond", true); + CreateVariable(*scope, *place, "caseToExecute", -1); + CreateVariable(*scope, *place, "case1var", 0); + + CreateVariable(*scope, *place, "xtemp", 0); + + // TODO(thuan): Need to create fibXToSend, since channel send moves the actual + // data, + // which causes the data to be no longer accessible to do the fib calculation + // TODO(abhinav): Change channel send to do a copy instead of a move! + CreateVariable(*scope, *place, "fibXToSend", 0); + + CreateVariable(*scope, *place, "fibX", 0); + CreateVariable(*scope, *place, "fibY", 1); + CreateVariable(*scope, *place, "quitVar", 0); + + BlockDesc *casesBlock = program->AppendBlock(*whileBlock); + std::function f = [](BlockDesc *caseBlock) {}; + + // TODO(thuan): Remove this once we change channel send to do a copy instead + // of move + AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"fibXToSend"}}}, {}, whileBlock); + + // Case 0: Send to dataChanName + std::function case0Func = [&]( + BlockDesc *caseBlock, Scope *scope) { + AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"xtemp"}}}, {}, caseBlock); + AddOp("assign", {{"X", {"fibY"}}}, {{"Out", {"fibX"}}}, {}, caseBlock); + AddOp("elementwise_add", {{"X", {"xtemp"}}, {"Y", {"fibY"}}}, + {{"Out", {"fibY"}}}, {}, caseBlock); + }; + AddCase(program, scope, place, casesBlock, 0, 1, dataChanName, "fibXToSend", + case0Func); + std::string case0Config = + std::string("0,1,") + dataChanName + std::string(",fibXToSend"); + + // Case 1: Receive from quitChanName + std::function case2Func = [&]( + BlockDesc *caseBlock, Scope *scope) { + // Exit the while loop after we receive from quit channel. + // We assign a false to "whileExitCond" variable, which will + // break out of while_op loop + CreateVariable(*scope, *place, "whileFalse", false); + AddOp("assign", {{"X", {"whileFalse"}}}, {{"Out", {"whileExitCond"}}}, {}, + caseBlock); + }; + AddCase(program, scope, place, casesBlock, 1, 2, quitChanName, "quitVar", + case2Func); + std::string case1Config = + std::string("1,2,") + quitChanName + std::string(",quitVar"); + + // Select block + AddOp("select", {{"X", {dataChanName, quitChanName}}, + {"case_to_execute", {"caseToExecute"}}}, + {}, {{"sub_block", casesBlock}, + {"cases", std::vector{case0Config, case1Config}}}, + whileBlock); + + scope->Var("stepScopes"); + AddOp("while", + {{"X", {dataChanName, quitChanName}}, {"Condition", {"whileExitCond"}}}, + {{"Out", {}}, {"StepScopes", {"stepScopes"}}}, + {{"sub_block", whileBlock}}, parentBlock); +} + TEST(Concurrency, Go_Op) { Scope scope; p::CPUPlace place; // Initialize scope variables - InitTensorsInScope(scope, place); + p::CPUDeviceContext ctx(place); + + // Create channel variable + scope.Var("Channel"); + + // Create Variables, x0 will be put into channel, + // result will be pulled from channel + CreateVariable(scope, place, "Status", false); + CreateVariable(scope, place, "x0", 99); + CreateVariable(scope, place, "result", 0); framework::Executor executor(place); ProgramDesc program; @@ -118,5 +215,78 @@ TEST(Concurrency, Go_Op) { auto *finalData = tensor.data(); EXPECT_EQ(finalData[0], 99); } + +/** + * This test implements the fibonacci function using go_op and select_op + */ +TEST(Concurrency, Select) { + Scope scope; + p::CPUPlace place; + + // Initialize scope variables + p::CPUDeviceContext ctx(place); + + CreateVariable(scope, place, "Status", false); + CreateVariable(scope, place, "result", 0); + CreateVariable(scope, place, "currentXFib", 0); + + framework::Executor executor(place); + ProgramDesc program; + BlockDesc *block = program.MutableBlock(0); + + // Create channel OP + std::string dataChanName = "Channel"; + scope.Var(dataChanName); + AddOp("channel_create", {}, {{"Out", {dataChanName}}}, + {{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block); + + std::string quitChanName = "Quit"; + scope.Var(quitChanName); + AddOp("channel_create", {}, {{"Out", {quitChanName}}}, + {{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block); + + // Create Go Op routine, which loops 10 times over fibonacci sequence + CreateVariable(scope, place, "xReceiveVar", 0); + + BlockDesc *goOpBlock = program.AppendBlock(program.Block(0)); + for (int i = 0; i < 10; ++i) { + AddOp("channel_recv", {{"Channel", {dataChanName}}}, + {{"Status", {"Status"}}, {"Out", {"currentXFib"}}}, {}, goOpBlock); + AddOp("print", {{"In", {"currentXFib"}}}, {{"Out", {"currentXFib"}}}, + {{"first_n", 100}, + {"summarize", -1}, + {"print_tensor_name", false}, + {"print_tensor_type", true}, + {"print_tensor_shape", false}, + {"print_tensor_lod", false}, + {"print_phase", std::string("FORWARD")}, + {"message", std::string("X: ")}}, + goOpBlock); + } + + CreateVariable(scope, place, "quitSignal", 0); + AddOp("channel_send", {{"Channel", {quitChanName}}, {"X", {"quitSignal"}}}, + {{"Status", {"Status"}}}, {}, goOpBlock); + + // Create Go Op + AddOp("go", {{"X", {dataChanName, quitChanName}}}, {}, + {{"sub_block", goOpBlock}}, block); + + AddFibonacciSelect(&scope, &place, &program, block, dataChanName, + quitChanName); + + // Create Channel Close Op + AddOp("channel_close", {{"Channel", {dataChanName}}}, {}, {}, block); + AddOp("channel_close", {{"Channel", {quitChanName}}}, {}, {}, block); + + executor.Run(program, &scope, 0, true, true); + + // After we call executor.run, "result" variable should be equal to 34 + // (which is 10 loops through fibonacci sequence) + const LoDTensor &tensor = (scope.FindVar("currentXFib"))->Get(); + auto *finalData = tensor.data(); + EXPECT_EQ(finalData[0], 34); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 625e0f7561899..84dc265575679 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -203,6 +203,11 @@ op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) op_library(concat_op DEPS concat) +# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency +add_subdirectory(concurrency) +op_library(channel_send_op DEPS concurrency) +op_library(channel_recv_op DEPS concurrency) + list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) diff --git a/paddle/fluid/operators/channel_recv_op.cc b/paddle/fluid/operators/channel_recv_op.cc index c12b88e7a91c4..844b3ae3b7bf8 100644 --- a/paddle/fluid/operators/channel_recv_op.cc +++ b/paddle/fluid/operators/channel_recv_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/concurrency/channel_util.h" #include "paddle/fluid/operators/math/math_function.h" static constexpr char Channel[] = "Channel"; @@ -36,25 +37,6 @@ void SetReceiveStatus(const platform::Place &dev_place, status_tensor[0] = status; } -bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var) { - // Get type of channel and use that to call mutable data for Variable - auto type = framework::ToVarType(ch->Type()); - if (type == framework::proto::VarType_Type_LOD_TENSOR) - return ch->Receive(var->GetMutable()); - else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) - return ch->Receive(var->GetMutable()); - else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) - return ch->Receive(var->GetMutable()); - else if (type == framework::proto::VarType_Type_SELECTED_ROWS) - return ch->Receive(var->GetMutable()); - else if (type == framework::proto::VarType_Type_READER) - return ch->Receive(var->GetMutable()); - else if (type == framework::proto::VarType_Type_CHANNEL) - return ch->Receive(var->GetMutable()); - else - PADDLE_THROW("ChannelReceive:Unsupported type"); -} - class ChannelRecvOp : public framework::OperatorBase { public: ChannelRecvOp(const std::string &type, @@ -81,7 +63,7 @@ class ChannelRecvOp : public framework::OperatorBase { scope.FindVar(Input(Channel))->GetMutable(); auto output_var = scope.FindVar(Output(Out)); // Receive the data from the channel. - bool ok = ChannelReceive(ch, output_var); + bool ok = concurrency::ChannelReceive(ch, output_var); // Set the status output of the `ChannelReceive` call. SetReceiveStatus(dev_place, *scope.FindVar(Output(Status)), ok); diff --git a/paddle/fluid/operators/channel_send_op.cc b/paddle/fluid/operators/channel_send_op.cc index 6d7715ad229e8..47cf7d7efc999 100644 --- a/paddle/fluid/operators/channel_send_op.cc +++ b/paddle/fluid/operators/channel_send_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/concurrency/channel_util.h" #include "paddle/fluid/operators/math/math_function.h" static constexpr char Channel[] = "Channel"; @@ -37,24 +38,6 @@ void SetSendStatus(const platform::Place &dev_place, status_tensor[0] = status; } -bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { - auto type = framework::ToVarType(var->Type()); - if (type == framework::proto::VarType_Type_LOD_TENSOR) - return ch->Send(var->GetMutable()); - else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) - return ch->Send(var->GetMutable()); - else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) - return ch->Send(var->GetMutable()); - else if (type == framework::proto::VarType_Type_SELECTED_ROWS) - return ch->Send(var->GetMutable()); - else if (type == framework::proto::VarType_Type_READER) - return ch->Send(var->GetMutable()); - else if (type == framework::proto::VarType_Type_CHANNEL) - return ch->Send(var->GetMutable()); - else - PADDLE_THROW("ChannelSend:Unsupported type"); -} - class ChannelSendOp : public framework::OperatorBase { public: ChannelSendOp(const std::string &type, @@ -82,7 +65,7 @@ class ChannelSendOp : public framework::OperatorBase { auto input_var = scope.FindVar(Input(X)); // Send the input data through the channel. - bool ok = ChannelSend(ch, input_var); + bool ok = concurrency::ChannelSend(ch, input_var); // Set the status output of the `ChannelSend` call. SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok); diff --git a/paddle/fluid/operators/concurrency/CMakeLists.txt b/paddle/fluid/operators/concurrency/CMakeLists.txt new file mode 100644 index 0000000000000..e4617440d152b --- /dev/null +++ b/paddle/fluid/operators/concurrency/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3) diff --git a/paddle/fluid/operators/concurrency/channel_util.cc b/paddle/fluid/operators/concurrency/channel_util.cc new file mode 100644 index 0000000000000..a483af7affd82 --- /dev/null +++ b/paddle/fluid/operators/concurrency/channel_util.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "channel_util.h" +#include "paddle/fluid/framework/var_type.h" + +namespace poc = paddle::operators::concurrency; + +bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { + auto type = framework::ToVarType(var->Type()); + if (type == framework::proto::VarType_Type_LOD_TENSOR) + return ch->Send(var->GetMutable()); + else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) + return ch->Send(var->GetMutable()); + else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) + return ch->Send(var->GetMutable()); + else if (type == framework::proto::VarType_Type_SELECTED_ROWS) + return ch->Send(var->GetMutable()); + else if (type == framework::proto::VarType_Type_READER) + return ch->Send(var->GetMutable()); + else if (type == framework::proto::VarType_Type_CHANNEL) + return ch->Send(var->GetMutable()); + else + PADDLE_THROW("ChannelSend:Unsupported type"); +} + +bool poc::ChannelReceive(framework::ChannelHolder *ch, + framework::Variable *var) { + // Get type of channel and use that to call mutable data for Variable + auto type = framework::ToVarType(ch->Type()); + if (type == framework::proto::VarType_Type_LOD_TENSOR) + return ch->Receive(var->GetMutable()); + else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) + return ch->Receive(var->GetMutable()); + else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) + return ch->Receive(var->GetMutable()); + else if (type == framework::proto::VarType_Type_SELECTED_ROWS) + return ch->Receive(var->GetMutable()); + else if (type == framework::proto::VarType_Type_READER) + return ch->Receive(var->GetMutable()); + else if (type == framework::proto::VarType_Type_CHANNEL) + return ch->Receive(var->GetMutable()); + else + PADDLE_THROW("ChannelReceive:Unsupported type"); +} + +void poc::ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer, + framework::Variable *var, + std::shared_ptr cond, + std::function cb) { + auto type = framework::ToVarType(var->Type()); + if (type == framework::proto::VarType_Type_LOD_TENSOR) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, cb); + } else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, + cb); + } else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, + cb); + } else if (type == framework::proto::VarType_Type_SELECTED_ROWS) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, + cb); + } else if (type == framework::proto::VarType_Type_READER) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, + cb); + } else if (type == framework::proto::VarType_Type_CHANNEL) { + ch->AddToSendQ(referrer, var->GetMutable(), cond, + cb); + } else { + PADDLE_THROW("ChannelAddToSendQ:Unsupported type"); + } +} + +void poc::ChannelAddToReceiveQ( + framework::ChannelHolder *ch, const void *referrer, + framework::Variable *var, std::shared_ptr cond, + std::function cb) { + auto type = framework::ToVarType(var->Type()); + if (type == framework::proto::VarType_Type_LOD_TENSOR) { + ch->AddToReceiveQ(referrer, var->GetMutable(), cond, + cb); + } else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) { + ch->AddToReceiveQ(referrer, var->GetMutable(), + cond, cb); + } else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) { + ch->AddToReceiveQ(referrer, var->GetMutable(), + cond, cb); + } else if (type == framework::proto::VarType_Type_SELECTED_ROWS) { + ch->AddToReceiveQ(referrer, var->GetMutable(), + cond, cb); + } else if (type == framework::proto::VarType_Type_READER) { + ch->AddToReceiveQ(referrer, var->GetMutable(), + cond, cb); + } else if (type == framework::proto::VarType_Type_CHANNEL) { + ch->AddToReceiveQ(referrer, var->GetMutable(), + cond, cb); + } else { + PADDLE_THROW("ChannelAddToReceiveQ:Unsupported type"); + } +} diff --git a/paddle/fluid/operators/concurrency/channel_util.h b/paddle/fluid/operators/concurrency/channel_util.h new file mode 100644 index 0000000000000..c3674bd9815df --- /dev/null +++ b/paddle/fluid/operators/concurrency/channel_util.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/variable.h" + +namespace paddle { +namespace operators { +namespace concurrency { + +bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); +bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var); + +void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer, + framework::Variable *var, + std::shared_ptr cond, + std::function cb); +void ChannelAddToReceiveQ(framework::ChannelHolder *ch, const void *referrer, + framework::Variable *var, + std::shared_ptr cond, + std::function cb); + +} // namespace concurrency +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/select_op.cc b/paddle/fluid/operators/select_op.cc new file mode 100644 index 0000000000000..8344a239df7b3 --- /dev/null +++ b/paddle/fluid/operators/select_op.cc @@ -0,0 +1,414 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/concurrency/channel_util.h" + +namespace paddle { +namespace operators { + +static constexpr char kX[] = "X"; +static constexpr char kCaseToExecute[] = "case_to_execute"; + +static constexpr char kCases[] = "cases"; +static constexpr char kCasesBlock[] = "sub_block"; + +class SelectOp : public framework::OperatorBase { + public: + SelectOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + + private: + enum class SelectOpCaseType { + DEFAULT = 0, + SEND = 1, + RECEIVE = 2, + }; + + struct SelectOpCase { + int caseIndex; + SelectOpCaseType caseType; + std::string channelName; + std::string varName; + + SelectOpCase() {} + + SelectOpCase(int caseIndex, SelectOpCaseType caseType, + std::string channelName, std::string varName) + : caseIndex(caseIndex), + caseType(caseType), + channelName(channelName), + varName(varName) {} + }; + + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + std::vector casesConfigs = + Attr>(kCases); + + framework::BlockDesc *casesBlock = + Attr(kCasesBlock); + + framework::Scope &casesBlockScope = scope.NewScope(); + + std::string caseToExecuteVarName = Input(kCaseToExecute); + framework::Variable *caseToExecuteVar = + casesBlockScope.FindVar(caseToExecuteVarName); + + // Construct cases from "conditional_block_op"(s) in the casesBlock + std::vector> cases = + ParseAndShuffleCases(&casesConfigs); + + // Get all unique channels involved in select + std::set channelsSet; + for (auto c : cases) { + if (!c->channelName.empty()) { + auto channelVar = scope.FindVar(c->channelName); + framework::ChannelHolder *ch = + channelVar->GetMutable(); + + if (channelsSet.find(ch) == channelsSet.end()) { + channelsSet.insert(ch); + } + } + } + + // Order all channels by their pointer address + std::vector channels(channelsSet.begin(), + channelsSet.end()); + std::sort(channels.begin(), channels.end()); + + // Poll all cases + int32_t caseToExecute = pollCases(&scope, &cases, channels); + + // At this point, the case to execute has already been determined, + // so we can proceed with executing the cases block + framework::LoDTensor *caseToExecuteTensor = + caseToExecuteVar->GetMutable(); + caseToExecuteTensor->data()[0] = caseToExecute; + + // Execute the cases block, only one case will be executed since we set the + // case_to_execute value to the index of the case we want to execute + framework::Executor executor(dev_place); + framework::ProgramDesc *program = casesBlock->Program(); + executor.Run(*program, &casesBlockScope, casesBlock->ID(), + false /*create_local_scope*/); + } + + /** + * Goes through all operators in the casesConfigs and processes + * "conditional_block" operators. These operators are mapped to our + * SelectOpCase objects. We randomize the case orders, and set the + * default case (if any exists) as the last case) + * @param casesBlock + * @return + */ + std::vector> ParseAndShuffleCases( + std::vector *casesConfigs) const { + std::vector> cases; + std::shared_ptr defaultCase; + + if (casesConfigs != nullptr) { + boost::char_delimiters_separator sep(false, ",", ""); + for (std::vector::iterator itr = casesConfigs->begin(); + itr < casesConfigs->end(); ++itr) { + std::string caseConfig = *itr; + boost::tokenizer<> tokens(caseConfig, sep); + + boost::tokenizer<>::iterator tok_iter = tokens.begin(); + PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case index"); + std::string caseIndexString = *tok_iter; + int caseIndex = std::stoi(caseIndexString); + + ++tok_iter; + PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case type"); + std::string caseTypeString = *tok_iter; + SelectOpCaseType caseType = (SelectOpCaseType)std::stoi(caseTypeString); + + std::string caseChannel; + std::string caseChannelVar; + + ++tok_iter; + if (caseType != SelectOpCaseType::DEFAULT) { + PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case channel"); + caseChannel = *tok_iter; + + ++tok_iter; + PADDLE_ENFORCE(tok_iter != tokens.end(), + "Cannot get case channel variable"); + caseChannelVar = *tok_iter; + } + + auto c = std::make_shared(caseIndex, caseType, + caseChannel, caseChannelVar); + + if (caseType == SelectOpCaseType::DEFAULT) { + PADDLE_ENFORCE(defaultCase == nullptr, + "Select can only contain one default case."); + defaultCase = c; + } else { + cases.push_back(c); + } + } + } + + // Randomly sort cases, with default case being last + std::random_shuffle(cases.begin(), cases.end()); + if (defaultCase != nullptr) { + cases.push_back(defaultCase); + } + + return cases; + } + + /** + * This method will recursively poll the cases and determines if any case + * condition is true. + * If none of the cases conditions are true (and there is no default case), + * then block + * the thread. The thread may be woken up by a channel operation, at which + * point we + * execute the case. + * @param scope + * @param cases + * @param channels + * @return + */ + int32_t pollCases(const framework::Scope *scope, + std::vector> *cases, + std::vector channels) const { + // Lock all involved channels + lockChannels(channels); + + std::atomic caseToExecute(-1); + + std::vector>::iterator it = cases->begin(); + while (it != cases->end()) { + std::shared_ptr c = *it; + + auto chVar = scope->FindVar(c->channelName); + framework::ChannelHolder *ch = + chVar->GetMutable(); + + switch (c->caseType) { + case SelectOpCaseType::SEND: + PADDLE_ENFORCE(!ch->IsClosed(), "Cannot send to a closed channel"); + if (ch->CanSend()) { + // We can send to channel directly, send the data to channel + // and execute case + auto chVar = scope->FindVar(c->varName); + concurrency::ChannelSend(ch, chVar); + caseToExecute = c->caseIndex; + } + break; + case SelectOpCaseType::RECEIVE: + if (ch->CanReceive()) { + // We can receive from channel directly, send the data to channel + // and execute case + auto chVar = scope->FindVar(c->varName); + concurrency::ChannelReceive(ch, chVar); + caseToExecute = c->caseIndex; + } + break; + case SelectOpCaseType::DEFAULT: + caseToExecute = c->caseIndex; + break; + } + + if (caseToExecute != -1) { + // We found a case to execute, stop looking at other case statements + break; + } + + ++it; + } + + if (caseToExecute == -1) { + // None of the cases are eligible to execute, enqueue current thread + // into all the sending/receiving queue of each involved channel + std::atomic completed(false); + std::recursive_mutex mutex; + std::unique_lock lock{mutex}; + // std::condition_variable_any selectCond; + auto selectCond = std::make_shared(); + + std::recursive_mutex callbackMutex; + pushThreadOnChannelQueues(scope, cases, selectCond, caseToExecute, + completed, callbackMutex); + + // TODO(thuan): Atomically unlock all channels and sleep current thread + unlockChannels(channels); + selectCond->wait(lock, [&completed]() { return completed.load(); }); + + // Select has been woken up by case operation + lockChannels(channels); + removeThreadOnChannelQueues(scope, cases); + + if (caseToExecute == -1) { + // Recursively poll cases, since we were woken up by a channel close + // TODO(thuan): Need to test if this is a valid case + unlockChannels(channels); + return pollCases(scope, cases, channels); + } + } + + // At this point, caseToExecute != -1, and we can proceed with executing + // the case block + unlockChannels(channels); + + return caseToExecute; + } + + void lockChannels(std::vector chs) const { + std::vector::iterator it = chs.begin(); + while (it != chs.end()) { + framework::ChannelHolder *ch = *it; + ch->Lock(); + ++it; + } + } + + void unlockChannels(std::vector chs) const { + std::vector::reverse_iterator it = chs.rbegin(); + while (it != chs.rend()) { + framework::ChannelHolder *ch = *it; + ch->Unlock(); + ++it; + } + } + + void pushThreadOnChannelQueues( + const framework::Scope *scope, + std::vector> *cases, + std::shared_ptr rCond, + std::atomic &caseToExecute, std::atomic &completed, + std::recursive_mutex &callbackMutex) const { + std::vector>::iterator it = cases->begin(); + while (it != cases->end()) { + std::shared_ptr c = *it; + + auto chVar = scope->FindVar(c->channelName); + framework::ChannelHolder *ch = + chVar->GetMutable(); + + std::function cb = + [&caseToExecute, &completed, &callbackMutex, + c](framework::ChannelAction channelAction) { + std::lock_guard lock{callbackMutex}; + + bool canProcess = false; + if (!completed) { + // If the channel wasn't closed, we set the caseToExecute index + // as this current case + if (channelAction != framework::ChannelAction::CLOSE) { + caseToExecute = c->caseIndex; + } + // This will allow our conditional variable to break out of wait + completed = true; + canProcess = true; + } + + return canProcess; + }; + + switch (c->caseType) { + case SelectOpCaseType::SEND: { + auto chOutputVar = scope->FindVar(c->varName); + concurrency::ChannelAddToSendQ(ch, this, chOutputVar, rCond, cb); + break; + } + case SelectOpCaseType::RECEIVE: { + auto chOutputVar = scope->FindVar(c->varName); + concurrency::ChannelAddToReceiveQ(ch, this, chOutputVar, rCond, cb); + break; + } + default: + break; + } + ++it; + } + } + + void removeThreadOnChannelQueues( + const framework::Scope *scope, + std::vector> *cases) const { + std::vector>::iterator it = cases->begin(); + while (it != cases->end()) { + std::shared_ptr c = *it; + + auto chVar = scope->FindVar(c->channelName); + framework::ChannelHolder *ch = + chVar->GetMutable(); + switch (c->caseType) { + case SelectOpCaseType::SEND: { + ch->RemoveFromSendQ(this); + break; + } + case SelectOpCaseType::RECEIVE: { + ch->RemoveFromReceiveQ(this); + break; + } + default: + break; + } + ++it; + } + } +}; + +class SelectOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SelectOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput(kX, + "A set of variables, which are required by operators inside the " + "cases of Select Op") + .AsDuplicable(); + AddInput(kCaseToExecute, + "(Int) The variable the sets the index of the case to execute, " + "after evaluating the channels being sent to and received from") + .AsDuplicable(); + AddAttr>(kCases, + "(String vector) Serialized list of" + "all cases in the select op. Each" + "case is serialized as: " + "',,,'" + "where type is 0 for default, 1 for" + "send, and 2 for receive" + "No channel and values are needed for" + "default cases."); + AddAttr(kCasesBlock, + "The cases block inside select_op"); + AddComment(R"DOC( +)DOC"); + } +}; + +// TODO(thuan): Implement Gradient Operator for SELECT_OP + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(select, paddle::operators::SelectOp, + paddle::framework::EmptyGradOpMaker, + paddle::operators::SelectOpMaker); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index dcde08632a6bb..fcea282204850 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -35,7 +35,7 @@ from distribute_transpiler import DistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler from concurrency import (Go, make_channel, channel_send, channel_recv, - channel_close) + channel_close, Select) import clip from memory_optimization_transpiler import memory_optimize, release_memory import profiler diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index dec224fc886cd..535e881c42f67 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from layers.control_flow import BlockGuard +from layers.control_flow import BlockGuard, Select from layer_helper import LayerHelper, unique_name from layers import fill_constant import core __all__ = [ - 'Go', - 'make_channel', - 'channel_send', - 'channel_recv', - 'channel_close', + 'Go', 'make_channel', 'channel_send', 'channel_recv', 'channel_close', + 'Select' ] @@ -198,7 +195,7 @@ def channel_recv(channel, return_value): ch = fluid.make_channel(dtype='int32', capacity=10) with fluid.Go(): - returned_value = fluid.channel_recv(ch, 'int32') + returned_value, return_status = fluid.channel_recv(ch, 'int32') # Code to send data through the channel. """ diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d14d6349b1bcf..70ecffd910a46 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -487,7 +487,7 @@ def find_name(var_list, name): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', - 'channel_send', 'channel_recv' + 'channel_send', 'channel_recv', 'select' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 1bb1aa30ee101..02cd0a05a11d8 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -16,7 +16,7 @@ from layer_function_generator import autodoc from tensor import assign, fill_constant from .. import core -from ..framework import Program, Variable, Operator +from ..framework import Program, Variable, Operator, Block from ..layer_helper import LayerHelper, unique_name from ops import logical_and, logical_not, logical_or @@ -29,6 +29,7 @@ 'WhileGuard', 'While', 'Switch', + 'Select', 'lod_rank_table', 'max_sequence_len', 'topk', @@ -1211,6 +1212,186 @@ def __exit__(self, exc_type, exc_val, exc_tb): return True +class SelectCase(object): + DEFAULT = 0 + SEND = 1 + RECEIVE = 2 + + def __init__(self, + case_idx, + case_to_execute, + channel_action_fn=None, + channel=None, + value=None): + self.helper = LayerHelper('conditional_block') + self.main_program = self.helper.main_program + self.is_scalar_condition = True + + self.case_to_execute = case_to_execute + self.idx = case_idx + + # Since we aren't going to use the `channel_send` or `channel_recv` + # functions directly, we just need to capture the name. + self.action = (self.SEND + if channel_action_fn.__name__ == ('channel_send') else + self.RECEIVE) if channel_action_fn else (self.DEFAULT) + self.value = value + self.channel = channel + + def __enter__(self): + self.block = self.main_program.create_block() + + def construct_op(self): + main_program = self.helper.main_program + cases_block = main_program.current_block() + + inner_outputs = set() + input_set = set() + params = set() + + for op in self.block.ops: + # Iterate over all operators, get all the inputs + # and add as input to the SelectCase operator. + for iname in op.input_names: + for in_var_name in op.input(iname): + if in_var_name not in inner_outputs: + input_set.add(in_var_name) + + for oname in op.output_names: + for out_var_name in op.output(oname): + inner_outputs.add(out_var_name) + + param_list = [ + cases_block.var(each_name) for each_name in params + if each_name not in input_set + ] + + # Iterate over all operators, get all the outputs + # add to the output list of SelectCase operator only if + # they exist in the parent block. + out_vars = [] + for inner_out_name in inner_outputs: + if inner_out_name in cases_block.vars: + out_vars.append(cases_block.var(inner_out_name)) + + # First, create an op that will determine whether or not this is the + # conditional variable to execute. + should_execute_block = equal( + fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=self.idx), + self.case_to_execute) + + step_scope = cases_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + + cases_block.append_op( + type='conditional_block', + inputs={'X': [should_execute_block], + 'Params': param_list}, + outputs={'Out': out_vars, + 'Scope': [step_scope]}, + attrs={ + 'sub_block': self.block, + 'is_scalar_condition': self.is_scalar_condition + }) + + return '%s,%s,%s,%s' % (self.idx, self.action, self.channel.name + if self.channel else '', self.value.name + if self.value else '') + + def __exit__(self, exc_type, exc_val, exc_tb): + self.main_program.rollback() + if exc_type is not None: + return False # re-raise exception + return True + + +class Select(BlockGuard): + def __init__(self, name=None): + self.helper = LayerHelper('select', name=name) + self.cases = [] + + super(Select, self).__init__(self.helper.main_program) + self.case_to_execute = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=-1) + + def __enter__(self): + super(Select, self).__enter__() + return self + + def case(self, channel_action_fn, channel, value): + """Create a new block for this condition. + """ + select_case = SelectCase( + len(self.cases), self.case_to_execute, channel_action_fn, channel, + value) + + self.cases.append(select_case) + + return select_case + + def default(self): + """Create a default case block for this condition. + """ + default_case = SelectCase(len(self.cases), self.case_to_execute) + + self.cases.append(default_case) + + return default_case + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + return False + + # Create a select op and another block to wrap its + # case blocks. + select_block = self.helper.main_program.current_block() + parent_block = self.helper.main_program.block(select_block.parent_idx) + + # Construct each case op, inside the newly created select block. + serialized_cases = [] + for case in self.cases: + serialized_cases.append(case.construct_op()) + + intermediate = set() + params = set() + + for case_block in select_block.ops: + if case_block.attrs and 'sub_block' in case_block.attrs: + for each_op in case_block.attrs['sub_block'].ops: + assert isinstance(each_op, Operator) + for iname in each_op.input_names: + for in_var_name in each_op.input(iname): + if in_var_name not in intermediate: + params.add(in_var_name) + + for oname in each_op.output_names: + for out_var_name in each_op.output(oname): + intermediate.add(out_var_name) + + # TODO(varunarora): Figure out if defining output is needed. + out_list = [ + parent_block.var(var_name) for var_name in parent_block.vars + if var_name in intermediate + ] + + X = [select_block.var_recursive(x_name) for x_name in params] + + # Needs to be used by `equal` inside the cases block. + X.append(self.case_to_execute) + + # Construct the select op. + parent_block.append_op( + type='select', + inputs={'X': X, + 'case_to_execute': self.case_to_execute}, + attrs={'sub_block': select_block, + 'cases': serialized_cases}, + outputs={}) + + return super(Select, self).__exit__(exc_type, exc_val, exc_tb) + + class IfElseBlockGuard(object): def __init__(self, is_true, ifelse): if not isinstance(ifelse, IfElse): diff --git a/python/paddle/fluid/tests/test_concurrency.py b/python/paddle/fluid/tests/test_concurrency.py index 9f7bf63c5e017..3aa51610cd954 100644 --- a/python/paddle/fluid/tests/test_concurrency.py +++ b/python/paddle/fluid/tests/test_concurrency.py @@ -15,9 +15,9 @@ import unittest import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid import framework, unique_name +from paddle.fluid import framework, unique_name, layer_helper from paddle.fluid.executor import Executor -from paddle.fluid.layers import fill_constant +from paddle.fluid.layers import fill_constant, assign, While, elementwise_add, Print class TestRoutineOp(unittest.TestCase): @@ -86,8 +86,7 @@ def test_daisy_chain(self): self.assertEqual(leftmost_data[0][0], n + 1) def _create_one_dim_tensor(self, value): - one_dim_tensor = fill_constant( - shape=[1], dtype=core.VarDesc.VarType.INT64, value=value) + one_dim_tensor = fill_constant(shape=[1], dtype='int', value=value) one_dim_tensor.stop_gradient = True return one_dim_tensor @@ -95,6 +94,129 @@ def _create_tensor(self, name, type, dtype): return framework.default_main_program().current_block().create_var( name=unique_name.generate(name), type=type, dtype=dtype) + def _create_persistable_tensor(self, name, type, dtype): + return framework.default_main_program().current_block().create_var( + name=unique_name.generate(name), + type=type, + dtype=dtype, + persistable=True) + + def test_select(self): + with framework.program_guard(framework.Program()): + ch1 = fluid.make_channel( + dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1) + + result1 = self._create_tensor('return_value', + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.FP64) + + input_value = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.FP64, value=10) + + with fluid.Select() as select: + with select.case(fluid.channel_send, ch1, input_value): + # Execute something. + pass + + with select.default(): + pass + + # This should not block because we are using a buffered channel. + result1, status = fluid.channel_recv(ch1, result1) + fluid.channel_close(ch1) + + cpu = core.CPUPlace() + exe = Executor(cpu) + + result = exe.run(fetch_list=[result1]) + self.assertEqual(result[0][0], 10) + + def test_fibonacci(self): + """ + Mimics Fibonacci Go example: https://tour.golang.org/concurrency/5 + """ + with framework.program_guard(framework.Program()): + quit_ch_input_var = self._create_persistable_tensor( + 'quit_ch_input', core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.INT32) + quit_ch_input = fill_constant( + shape=[1], + dtype=core.VarDesc.VarType.INT32, + value=0, + out=quit_ch_input_var) + + result = self._create_persistable_tensor( + 'result', core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.INT32) + fill_constant( + shape=[1], + dtype=core.VarDesc.VarType.INT32, + value=0, + out=result) + + x = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) + y = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=1) + + while_cond = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.BOOL, value=True) + + while_false = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.BOOL, value=False) + + x_tmp = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) + + def fibonacci(channel, quit_channel): + while_op = While(cond=while_cond) + with while_op.block(): + result2 = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) + x_to_send_tmp = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) + + # TODO(abhinav): Need to perform copy when doing a channel send. + # Once this is complete, we can remove these lines + assign(input=x, output=x_to_send_tmp) + + with fluid.Select() as select: + with select.case(fluid.channel_send, channel, + x_to_send_tmp): + assign(input=x, output=x_tmp) + assign(input=y, output=x) + assign(elementwise_add(x=x_tmp, y=y), output=y) + + with select.case(fluid.channel_recv, quit_channel, + result2): + # Quit + helper = layer_helper.LayerHelper('assign') + helper.append_op( + type='assign', + inputs={'X': [while_false]}, + outputs={'Out': [while_cond]}) + + ch1 = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR) + quit_ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR) + + with fluid.Go(): + for i in xrange(10): + fluid.channel_recv(ch1, result) + Print(result) + + fluid.channel_send(quit_ch, quit_ch_input) + + fibonacci(ch1, quit_ch) + + fluid.channel_close(ch1) + fluid.channel_close(quit_ch) + + cpu = core.CPUPlace() + exe = Executor(cpu) + + exe_result = exe.run(fetch_list=[result]) + self.assertEqual(exe_result[0][0], 34) + if __name__ == '__main__': unittest.main()