Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #68 from tqchen/master
Browse files Browse the repository at this point in the history
Separate Dependency Tracking and Execution Policy in Engine, BugFix
  • Loading branch information
tqchen committed Sep 14, 2015
2 parents 7673db0 + ab64cc3 commit c18ea83
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 135 deletions.
125 changes: 45 additions & 80 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine.cc
* \brief implements base threaded engine.
* \author Yutian Li
*/
#include "threaded_engine.h"
#include <dmlc/logging.h>
Expand Down Expand Up @@ -81,29 +84,35 @@ void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
template <typename Dispatcher>
bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
VersionedVarBlock *old_pending_write, *end_of_dispatch_chain;
int num_reads;
{
// this is lock scope
std::lock_guard<std::mutex> lock{m_};
assert(ready_to_read_ == false);
// detach pending write
old_pending_write = pending_write_;
pending_write_ = nullptr;
// search for chains to trigger
VersionedVarBlock *p = old_pending_write->next;
assert(num_pending_reads_ == 0);
num_reads = 0;
while (p->next != nullptr && p->write == false) {
++num_pending_reads_;
p = p->next;
++num_reads;
}
num_pending_reads_ = num_reads;
// mark end of dispatch chain
end_of_dispatch_chain = p;

if (p->next == nullptr) {
ready_to_read_ = true;
pending_write_ = nullptr;
assert(p->trigger == nullptr);
assert(p->write ==false);
} else {
assert(p->write == true);
pending_write_ = p;
if (num_pending_reads_ == 0) {
if (--pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
}
}
}
}
// this is outside of lock scope
Expand All @@ -122,14 +131,9 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
}
auto prev = cur_head;
cur_head = cur_head->next;
assert(cur_head != nullptr);
VersionedVarBlock::Delete(prev);
}
// trigger pending write, if any
if (pending_write_ != nullptr && num_reads == 0) {
if (--pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
}
}
return false;
}

Expand All @@ -143,33 +147,33 @@ bool ThreadedVar::ready_to_read() {
return ready_to_read_;
}

ThreadedEngine::ThreadedEngine()
: pending_{0},
thread_pool_{[this]() { ThreadWorker(&task_queue_); }},
io_thread_pool_{[this]() { ThreadWorker(&io_task_queue_); }} {}

ThreadedEngine::~ThreadedEngine() noexcept(false) {
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
}

// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
return ThreadedVar::New(VersionedVarBlock::New());
}

ThreadedOpr* ThreadedEngine::NewOperator(
ThreadedEngine::AsyncFn fn, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars, FnProperty prop) {
ThreadedEngine::AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop) {
auto ret = ThreadedOpr::New();
ret->fn = fn;
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(),
ThreadedVar::CastFromBase);
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
ret->mutable_vars.begin(), ThreadedVar::CastFromBase);
#if ENGINE_DEBUG
if (ENGINE_DEBUG != 0) {
CheckDuplicate(const_vars, mutable_vars);
}
return ret;
}

void ThreadedEngine::CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) {
// Check for duplicates.
auto use = const_vars;
auto mutate = mutable_vars;
Expand Down Expand Up @@ -200,12 +204,10 @@ ThreadedOpr* ThreadedEngine::NewOperator(
<< "duplicate items found between `const_vars` and `mutable_vars`";
}
}
#endif // ENGINE_DEBUG
return ret;
}

void ThreadedEngine::DeleteOperator(OprHandle op) {
auto&& threaded_opr = ThreadedOpr::CastFromBase(op);
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
std::vector<VarHandle> deps;
deps.reserve(threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size());
Expand All @@ -221,8 +223,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
}

void ThreadedEngine::Push(OprHandle op, Context exec_ctx) {
auto&& threaded_opr = ThreadedOpr::CastFromBase(op);
auto&& opr_block = OprBlock::New();
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
opr_block->wait.store(threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size() + 1);
Expand All @@ -237,19 +239,15 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) {
i->AppendWriteDependency(opr_block);
}
if (--opr_block->wait == 0) {
if (opr_block->opr->prop == FnProperty::kAsync) {
DoExecute(opr_block);
} else {
DoPushToQueue(opr_block);
}
this->PushToExecute(opr_block, true);
}
}

void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop) {
auto&& opr = NewOperator(fn, const_vars, mutable_vars, prop);
ThreadedOpr *opr = NewOperator(fn, const_vars, mutable_vars, prop);
opr->temporary = true;
Push(opr, exec_ctx);
}
Expand Down Expand Up @@ -289,12 +287,16 @@ void ThreadedEngine::WaitForAll() {
inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) { DoPushToQueue(opr); });
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
bool to_delete = i->CompleteWriteDependency(
[this](OprBlock* opr) { DoPushToQueue(opr); });
[this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
if (to_delete) {
ThreadedVar::Delete(i);
}
Expand All @@ -311,48 +313,11 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

void ThreadedEngine::ThreadWorker(
dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) {
OprBlock* opr_block;
while (task_queue->Pop(&opr_block)) {
DoExecute(opr_block);
}
}

void ThreadedEngine::DoPushToQueue(OprBlock* opr_block) {
switch (opr_block->opr->prop) {
case FnProperty::kCopy: {
io_task_queue_.Push(opr_block);
break;
}
default: {
task_queue_.Push(opr_block);
break;
}
}
}

void ThreadedEngine::DoExecute(OprBlock* opr_block) {
assert(opr_block->wait.load() == 0);
ThreadedOpr* threaded_opr = opr_block->opr;
if (opr_block->ctx.dev_mask == gpu::kDevMask) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id));
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
}
auto&& rctx = opr_block->opr->prop == FnProperty::kCopy
? streams_.GetIORunContext(opr_block->ctx)
: streams_.GetRunContext(opr_block->ctx);
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnComplete_, threaded_opr);
threaded_opr->fn(rctx, callback);
OprBlock::Delete(opr_block);
void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *threaded_opr) {
static_cast<ThreadedEngine*>(engine)->OnComplete(
static_cast<ThreadedOpr*>(threaded_opr));
}

Engine *CreateThreadedEngine() {
return new ThreadedEngine();
}
} // namespace engine
} // namespace mxnet
101 changes: 46 additions & 55 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine.h
* \brief Implementation of threaded engine that tracks the dependency
* and pushes actions to execute.
* \brief Implements base class of threaded engine
* that tracks the dependency and pushes actions to execute.
* \author Yutian Li
*/
#ifndef MXNET_ENGINE_THREADED_ENGINE_H_
#define MXNET_ENGINE_THREADED_ENGINE_H_

#include <dmlc/base.h>
#include <dmlc/concurrency.h>
#include <dmlc/logging.h>
#include <vector>
#include <functional>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "./engine_impl.h"
#include "./thread_pool.h"
#include "./stream_manager.h"
#include "../common/object_pool.h"

namespace mxnet {
Expand Down Expand Up @@ -200,18 +198,19 @@ struct ThreadedOpr final : public Opr,
}; // struct ThreadedOpr

/*!
* \brief Engine implementation.
* \brief Base class of all ThreadedEngine.
* This class implements a thread safe version of engine.
* The engine tracks the dependencies, and will call PushToExecute
* to execute a specific task.
*
* Subclass can implement PushToExecute to design specific
* execution policy for the tasks.
*/
class ThreadedEngine : public Engine {
public:
/*!
* \brief Constructor and destructor.
*/
ThreadedEngine();
~ThreadedEngine() noexcept(false);
/*!
* \brief Overriding methods.
*/
// constructor
ThreadedEngine() : pending_(0) {}
// implementing all the functions from Engine.
ThreadedVar* NewVariable() override;
ThreadedOpr* NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
Expand All @@ -226,65 +225,57 @@ class ThreadedEngine : public Engine {
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;

protected:
/*!
* \brief Worker.
* \param task_queue Queue to work on.
* \brief Push the opr block to execution queue to be executed.
* This function is implemented by the corresponding subclass
* for specific policy.
*
* The method to pass to thread pool to parallelize.
* \param opr_block The operator block.
* \param pusher_thread whether the caller is the thread that calls push
*/
virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
/*!
* \brief Call this function to actually execute an opr_block
* This function also deletes the opr_block after execution.
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue);
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, threaded_opr);
threaded_opr->fn(run_ctx, callback);
OprBlock::Delete(opr_block);
}

private:
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
* \param mutable_vars the variables to mutate.
*/
void CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars);
/*!
* \brief Callback on operation completion.
*
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
inline static void OnComplete_(Engine *engine, void *threaded_opr) {
static_cast<ThreadedEngine*>(engine)->OnComplete(
static_cast<ThreadedOpr*>(threaded_opr));
}

private:
/*! \brief Concurrency for thread pool */
static constexpr std::size_t kNumWorkingThreads = 16;
/*! \brief Maximum number of GPUs */
static constexpr std::size_t kMaxNumGpus = 16;
/*!\brief number of streams allocated for each GPU */
static constexpr std::size_t kNumStreamsPerGpu = 16;
static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*!
* \brief Number of pending operations.
*/
std::atomic<std::size_t> pending_;
/*!
* \brief Notify waits for single or all variables.
* \brief Mutex and condition_variable,
* used to Notify waits for single or all variables.
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*!
* \brief Streams.
*/
StreamManager<kMaxNumGpus, kNumStreamsPerGpu> streams_;
/*!
* \brief Task queues.
*/
dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_;
dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_;
/*!
* \brief Thread pools.
*/
ThreadPool<kNumWorkingThreads> thread_pool_;
ThreadPool<1> io_thread_pool_;
/*!
* \brief Push to corresponding task queue.
* \param opr_block The operator block.
*/
void DoPushToQueue(OprBlock* opr_block);
/*!
* \brief Execute an operation.
* \param opr_block The operator block.
*/
void DoExecute(OprBlock* opr_block);
/*!
* \brief Disallow copy construction and assignment.
*/
Expand Down
Loading

0 comments on commit c18ea83

Please sign in to comment.