Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fluid channels should match the semantics of Go Channels #9265

Merged
merged 9 commits into from
Mar 27, 2018
Merged
69 changes: 34 additions & 35 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Channel {
public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0;
virtual void Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;
Expand Down Expand Up @@ -84,90 +84,89 @@ class ChannelHolder {
}

template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
void Send(T* data) {
PADDLE_ENFORCE_EQ(IsInitialized(), true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe It is better that adding exception information for PADDLE_ENFORCE_EQ,
e.g.
PADDLE_ENFORCE_EQ(IsInitialized(), true, "The channel hasn't been initialized.");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true);
channel->Send(data);
}

template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true);
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true);
return channel->Receive(data);
}

bool IsClosed() {
if (IsInitialized()) {
return holder_->IsClosed();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true);
return holder_->IsClosed();
}

bool CanSend() {
if (IsInitialized()) {
return holder_->CanSend();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true);
return holder_->CanSend();
}

bool CanReceive() {
if (IsInitialized()) {
return holder_->CanReceive();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true);
return holder_->CanReceive();
}

void close() {
if (IsInitialized()) holder_->Close();
PADDLE_ENFORCE_EQ(IsInitialized(), true);
holder_->Close();
}

size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
PADDLE_ENFORCE_EQ(IsInitialized(), true);
return holder_->Cap();
}

void Lock() {
if (IsInitialized()) holder_->Lock();
PADDLE_ENFORCE_EQ(IsInitialized(), true);
holder_->Lock();
}

void Unlock() {
if (IsInitialized()) holder_->Unlock();
PADDLE_ENFORCE_EQ(IsInitialized(), true);
holder_->Unlock();
}

template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true);
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}

template <typename T>
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true);
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}

void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true);
holder_->RemoveFromSendQ(referrer);
}

void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true);
holder_->RemoveFromReceiveQ(referrer);
}

inline bool IsInitialized() const { return holder_ != nullptr; }
Expand Down
33 changes: 20 additions & 13 deletions paddle/fluid/framework/channel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
public:
virtual bool CanSend();
virtual bool CanReceive();
virtual bool Send(T *);
virtual void Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
Expand Down Expand Up @@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel<T> {
}
};

bool send_return(bool value) {
void send_return() {
send_ctr--;
destructor_cond_.notify_all();
return value;
}

bool recv_return(bool value) {
Expand Down Expand Up @@ -118,15 +117,15 @@ bool ChannelImpl<T>::CanReceive() {
}

template <typename T>
bool ChannelImpl<T>::Send(T *item) {
void ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to explicitly lock after constructor? lock->lock() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we don't need to do that. The unique lock constructor automatically does that.


// If channel is closed, do nothing
// If channel is closed, throw exception
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
send_return();
PADDLE_THROW("Cannot send on closed channel");
}

// If there is a receiver, directly pass the value we want
Expand All @@ -143,20 +142,24 @@ bool ChannelImpl<T>::Send(T *item) {
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item);
else
else {
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return send_return(Send(item));
Send(item);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to "lock.unlock();" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, it makes sense to release the lock there. With the new semantics if the nested method call leads to an exception then the outer lock will be held forever.

send_return();
return;
}

// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
send_return();
return;
}

// Unbuffered channel will always bypass this
Expand All @@ -167,16 +170,20 @@ bool ChannelImpl<T>::Send(T *item) {
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
send_return();
return;
}

// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
if (m->chan_closed) {
send_return();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should unlock before throwing exception

PADDLE_THROW("Cannot send on closed channel");
}
send_return();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to unlock here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, the lock needs to be unlocked here. Thank you for pointing this out.

}

template <typename T>
Expand Down
Loading