Skip to content

Commit

Permalink
Implement Select OP (#9088)
Browse files Browse the repository at this point in the history
* Fix old documentation for channel_recv

* Initial design of CSP select

* Redesign channel implementation for Select Op

* Remove unecessary header

* Initial checkin of select op, currently will read all the conditional_op in the cases block and also pull out all channels involved in the select.

* Init python select op API

* Python select bug fix when checking op creates block

* Add case_to_execute as (a) input to select, (b) into the passed inputs into the select op

* Add in addition code for select op

* Init fibonacci test from python

* implement fibonnaci sequence test

* update fib unit test

* Improve select test cases

* Shorten non-pep-8-ed lines

* Add methods on channel needed by select op

* Fix compile issues, finish implementation, still need to debug code

* Fix issue with fibonncci test, it works now!

* Change QueueMessage callback to take in an ChannelAction enum, fix select unit test

* Fix case attributes

* Fix issue with select control flow

* Make cases - previously on each selectcase conditional_block - attributes to select

* Use class constants for type of channel

* Change select op to take in "cases" attribute

* return boolean from select callback function to tell Channel if this RECV or SEND should be executed

* Improve attributes and inputs comments on select op

* Fix issues with python unit test

* Assert fibonacci final output

* Fix issue when channel name / channel var is null for "default" case in select op

* Assert base select test output

* Make QueueMessage use shared pointer and modify the order of the callback

* Fixing the order in which the callback is called

* Move channel utility methods to paddle/fluid/operators/concurrency/channel_util

* Create channel_util and move channel util methods

* Fix crash when calling select_op

* Fix deadlock

* Fix issue of channel destructor deadlock

* Fix precommit issues

* Accidentally checked in changes to beam_search_op, reverting change.

* Fix dependency issue in concurrency cmake

* add device_context dependency for concurrency target
  • Loading branch information
cs2be authored and abhinavarora committed Mar 15, 2018
1 parent 45073b7 commit 1e4c504
Show file tree
Hide file tree
Showing 16 changed files with 1,096 additions and 91 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc)
channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
conditional_block_op while_op assign_op print_op executor proto_desc)
30 changes: 16 additions & 14 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,12 @@ class ChannelHolder {
}
}

template <typename T>
void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromSendQ(referrer);
}
}
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
}

template <typename T>
void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromReceiveQ(referrer);
}
}
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
}

inline bool IsInitialized() const { return holder_ != nullptr; }
Expand All @@ -201,6 +189,8 @@ class ChannelHolder {
virtual bool IsClosed() = 0;
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual void RemoveFromSendQ(const void* referrer) = 0;
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
Expand Down Expand Up @@ -238,6 +228,18 @@ class ChannelHolder {
return false;
}

virtual void RemoveFromSendQ(const void* referrer) {
if (channel_) {
channel_->RemoveFromSendQ(referrer);
}
}

virtual void RemoveFromReceiveQ(const void* referrer) {
if (channel_) {
channel_->RemoveFromReceiveQ(referrer);
}
}

virtual void Close() {
if (channel_) channel_->Close();
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/channel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool ChannelImpl<T>::Send(T *item) {
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Send(item);
return send_return(Send(item));

// Wake up the blocked process and unlock
m->Notify();
Expand Down Expand Up @@ -214,7 +214,7 @@ bool ChannelImpl<T>::Receive(T *item) {
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Receive(item);
return recv_return(Receive(item));

// Wake up the blocked process and unlock
m->Notify();
Expand Down Expand Up @@ -331,7 +331,6 @@ void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {

if (sendMsg->referrer == referrer) {
it = sendq.erase(it);
send_ctr--;
} else {
++it;
}
Expand All @@ -347,7 +346,6 @@ void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {

if (recvMsg->referrer == referrer) {
it = recvq.erase(it);
recv_ctr--;
} else {
++it;
}
Expand Down
208 changes: 189 additions & 19 deletions paddle/fluid/framework/concurrency_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"

USE_NO_KERNEL_OP(go);
USE_NO_KERNEL_OP(channel_close);
USE_NO_KERNEL_OP(channel_create);
USE_NO_KERNEL_OP(channel_recv);
USE_NO_KERNEL_OP(channel_send);
USE_NO_KERNEL_OP(elementwise_add);
USE_NO_KERNEL_OP(select);
USE_NO_KERNEL_OP(conditional_block);
USE_NO_KERNEL_OP(equal);
USE_NO_KERNEL_OP(assign);
USE_NO_KERNEL_OP(while);
USE_NO_KERNEL_OP(print);

namespace f = paddle::framework;
namespace p = paddle::platform;
Expand All @@ -35,27 +40,15 @@ namespace paddle {
namespace framework {

template <typename T>
void CreateIntVariable(Scope &scope, p::CPUPlace &place, std::string name,
T value) {
// Create LoDTensor<int> of dim [1,1]
LoDTensor *CreateVariable(Scope &scope, p::CPUPlace &place, std::string name,
T value) {
// Create LoDTensor<int> of dim [1]
auto var = scope.Var(name);
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize({1, 1});
tensor->Resize({1});
T *expect = tensor->mutable_data<T>(place);
expect[0] = value;
}

void InitTensorsInScope(Scope &scope, p::CPUPlace &place) {
p::CPUDeviceContext ctx(place);

// Create channel variable
scope.Var("Channel");

// Create Variables, x0 will be put into channel,
// result will be pulled from channel
CreateIntVariable(scope, place, "Status", false);
CreateIntVariable(scope, place, "x0", 99);
CreateIntVariable(scope, place, "result", 0);
return tensor;
}

void AddOp(const std::string &type, const VariableNameMap &inputs,
Expand All @@ -73,12 +66,116 @@ void AddOp(const std::string &type, const VariableNameMap &inputs,
op->SetAttrMap(attrs);
}

void AddCase(ProgramDesc *program, Scope *scope, p::CPUPlace *place,
BlockDesc *casesBlock, int caseId, int caseType,
std::string caseChannel, std::string caseVarName,
std::function<void(BlockDesc *, Scope *)> func) {
std::string caseCondName = std::string("caseCond") + std::to_string(caseId);
std::string caseCondXVarName =
std::string("caseCondX") + std::to_string(caseId);

BlockDesc *caseBlock = program->AppendBlock(*casesBlock);
func(caseBlock, scope);

CreateVariable(*scope, *place, caseCondName, false);
CreateVariable(*scope, *place, caseCondXVarName, caseId);
CreateVariable(*scope, *place, caseVarName, caseId);

scope->Var("step_scope");

AddOp("equal", {{"X", {caseCondXVarName}}, {"Y", {"caseToExecute"}}},
{{"Out", {caseCondName}}}, {}, casesBlock);

AddOp("conditional_block", {{"X", {caseCondName}}, {"Params", {}}},
{{"Out", {}}, {"Scope", {"step_scope"}}},
{{"sub_block", caseBlock}, {"is_scalar_condition", true}}, casesBlock);
}

void AddFibonacciSelect(Scope *scope, p::CPUPlace *place, ProgramDesc *program,
BlockDesc *parentBlock, std::string dataChanName,
std::string quitChanName) {
BlockDesc *whileBlock = program->AppendBlock(*parentBlock);

CreateVariable(*scope, *place, "whileExitCond", true);
CreateVariable(*scope, *place, "caseToExecute", -1);
CreateVariable(*scope, *place, "case1var", 0);

CreateVariable(*scope, *place, "xtemp", 0);

// TODO(thuan): Need to create fibXToSend, since channel send moves the actual
// data,
// which causes the data to be no longer accessible to do the fib calculation
// TODO(abhinav): Change channel send to do a copy instead of a move!
CreateVariable(*scope, *place, "fibXToSend", 0);

CreateVariable(*scope, *place, "fibX", 0);
CreateVariable(*scope, *place, "fibY", 1);
CreateVariable(*scope, *place, "quitVar", 0);

BlockDesc *casesBlock = program->AppendBlock(*whileBlock);
std::function<void(BlockDesc * caseBlock)> f = [](BlockDesc *caseBlock) {};

// TODO(thuan): Remove this once we change channel send to do a copy instead
// of move
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"fibXToSend"}}}, {}, whileBlock);

// Case 0: Send to dataChanName
std::function<void(BlockDesc * caseBlock, Scope * scope)> case0Func = [&](
BlockDesc *caseBlock, Scope *scope) {
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"xtemp"}}}, {}, caseBlock);
AddOp("assign", {{"X", {"fibY"}}}, {{"Out", {"fibX"}}}, {}, caseBlock);
AddOp("elementwise_add", {{"X", {"xtemp"}}, {"Y", {"fibY"}}},
{{"Out", {"fibY"}}}, {}, caseBlock);
};
AddCase(program, scope, place, casesBlock, 0, 1, dataChanName, "fibXToSend",
case0Func);
std::string case0Config =
std::string("0,1,") + dataChanName + std::string(",fibXToSend");

// Case 1: Receive from quitChanName
std::function<void(BlockDesc * caseBlock, Scope * scope)> case2Func = [&](
BlockDesc *caseBlock, Scope *scope) {
// Exit the while loop after we receive from quit channel.
// We assign a false to "whileExitCond" variable, which will
// break out of while_op loop
CreateVariable(*scope, *place, "whileFalse", false);
AddOp("assign", {{"X", {"whileFalse"}}}, {{"Out", {"whileExitCond"}}}, {},
caseBlock);
};
AddCase(program, scope, place, casesBlock, 1, 2, quitChanName, "quitVar",
case2Func);
std::string case1Config =
std::string("1,2,") + quitChanName + std::string(",quitVar");

// Select block
AddOp("select", {{"X", {dataChanName, quitChanName}},
{"case_to_execute", {"caseToExecute"}}},
{}, {{"sub_block", casesBlock},
{"cases", std::vector<std::string>{case0Config, case1Config}}},
whileBlock);

scope->Var("stepScopes");
AddOp("while",
{{"X", {dataChanName, quitChanName}}, {"Condition", {"whileExitCond"}}},
{{"Out", {}}, {"StepScopes", {"stepScopes"}}},
{{"sub_block", whileBlock}}, parentBlock);
}

TEST(Concurrency, Go_Op) {
Scope scope;
p::CPUPlace place;

// Initialize scope variables
InitTensorsInScope(scope, place);
p::CPUDeviceContext ctx(place);

// Create channel variable
scope.Var("Channel");

// Create Variables, x0 will be put into channel,
// result will be pulled from channel
CreateVariable(scope, place, "Status", false);
CreateVariable(scope, place, "x0", 99);
CreateVariable(scope, place, "result", 0);

framework::Executor executor(place);
ProgramDesc program;
Expand Down Expand Up @@ -118,5 +215,78 @@ TEST(Concurrency, Go_Op) {
auto *finalData = tensor.data<int>();
EXPECT_EQ(finalData[0], 99);
}

/**
* This test implements the fibonacci function using go_op and select_op
*/
TEST(Concurrency, Select) {
Scope scope;
p::CPUPlace place;

// Initialize scope variables
p::CPUDeviceContext ctx(place);

CreateVariable(scope, place, "Status", false);
CreateVariable(scope, place, "result", 0);
CreateVariable(scope, place, "currentXFib", 0);

framework::Executor executor(place);
ProgramDesc program;
BlockDesc *block = program.MutableBlock(0);

// Create channel OP
std::string dataChanName = "Channel";
scope.Var(dataChanName);
AddOp("channel_create", {}, {{"Out", {dataChanName}}},
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);

std::string quitChanName = "Quit";
scope.Var(quitChanName);
AddOp("channel_create", {}, {{"Out", {quitChanName}}},
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);

// Create Go Op routine, which loops 10 times over fibonacci sequence
CreateVariable(scope, place, "xReceiveVar", 0);

BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
for (int i = 0; i < 10; ++i) {
AddOp("channel_recv", {{"Channel", {dataChanName}}},
{{"Status", {"Status"}}, {"Out", {"currentXFib"}}}, {}, goOpBlock);
AddOp("print", {{"In", {"currentXFib"}}}, {{"Out", {"currentXFib"}}},
{{"first_n", 100},
{"summarize", -1},
{"print_tensor_name", false},
{"print_tensor_type", true},
{"print_tensor_shape", false},
{"print_tensor_lod", false},
{"print_phase", std::string("FORWARD")},
{"message", std::string("X: ")}},
goOpBlock);
}

CreateVariable(scope, place, "quitSignal", 0);
AddOp("channel_send", {{"Channel", {quitChanName}}, {"X", {"quitSignal"}}},
{{"Status", {"Status"}}}, {}, goOpBlock);

// Create Go Op
AddOp("go", {{"X", {dataChanName, quitChanName}}}, {},
{{"sub_block", goOpBlock}}, block);

AddFibonacciSelect(&scope, &place, &program, block, dataChanName,
quitChanName);

// Create Channel Close Op
AddOp("channel_close", {{"Channel", {dataChanName}}}, {}, {}, block);
AddOp("channel_close", {{"Channel", {quitChanName}}}, {}, {}, block);

executor.Run(program, &scope, 0, true, true);

// After we call executor.run, "result" variable should be equal to 34
// (which is 10 loops through fibonacci sequence)
const LoDTensor &tensor = (scope.FindVar("currentXFib"))->Get<LoDTensor>();
auto *finalData = tensor.data<int>();
EXPECT_EQ(finalData[0], 34);
}

} // namespace framework
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat)

# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
add_subdirectory(concurrency)
op_library(channel_send_op DEPS concurrency)
op_library(channel_recv_op DEPS concurrency)

list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
Expand Down
Loading

0 comments on commit 1e4c504

Please sign in to comment.