Skip to content

Commit

Permalink
Merge pull request #8181 from reyoung/feature/add_exception_for_threa…
Browse files Browse the repository at this point in the history
…d_pool

Add RunAndGetException in threadpool
  • Loading branch information
reyoung committed Feb 6, 2018
2 parents a07b751 + c966c28 commit 8499460
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions paddle/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ limitations under the License. */
#include <queue>
#include <thread>
#include <vector>

#include "glog/logging.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN

namespace paddle {
Expand All @@ -31,7 +32,7 @@ namespace framework {
// number of threads.
class ThreadPool {
public:
typedef std::packaged_task<void()> Task;
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;

// Returns the singleton of ThreadPool.
static ThreadPool* GetInstance();
Expand All @@ -52,9 +53,28 @@ class ThreadPool {
// std::future::wait().
template <typename Callback>
std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
}

template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task(std::bind(fn));
std::future<void> f = task.get_future();
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try {
fn();
return nullptr;
} catch (platform::EnforceNotMet ex) {
return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL)
<< "Unexpected exception is catched in thread pool. All "
"throwable exception in Fluid should be an EnforceNotMet.";
}
});
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one();
Expand All @@ -65,6 +85,22 @@ class ThreadPool {
void Wait();

private:
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};

DISABLE_COPY_AND_ASSIGN(ThreadPool);

explicit ThreadPool(int num_threads);
Expand Down

0 comments on commit 8499460

Please sign in to comment.