Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Separate Dependency Tracking and Execution Policy in Engine, BugFix #68

Merged
merged 3 commits into from
Sep 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 45 additions & 80 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine.cc
* \brief implements base threaded engine.
* \author Yutian Li
*/
#include "threaded_engine.h"
#include <dmlc/logging.h>
Expand Down Expand Up @@ -81,29 +84,35 @@ void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
template <typename Dispatcher>
bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
VersionedVarBlock *old_pending_write, *end_of_dispatch_chain;
int num_reads;
{
// this is lock scope
std::lock_guard<std::mutex> lock{m_};
assert(ready_to_read_ == false);
// detach pending write
old_pending_write = pending_write_;
pending_write_ = nullptr;
// search for chains to trigger
VersionedVarBlock *p = old_pending_write->next;
assert(num_pending_reads_ == 0);
num_reads = 0;
while (p->next != nullptr && p->write == false) {
++num_pending_reads_;
p = p->next;
++num_reads;
}
num_pending_reads_ = num_reads;
// mark end of dispatch chain
end_of_dispatch_chain = p;

if (p->next == nullptr) {
ready_to_read_ = true;
pending_write_ = nullptr;
assert(p->trigger == nullptr);
assert(p->write ==false);
} else {
assert(p->write == true);
pending_write_ = p;
if (num_pending_reads_ == 0) {
if (--pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
}
}
}
}
// this is outside of lock scope
Expand All @@ -122,14 +131,9 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
}
auto prev = cur_head;
cur_head = cur_head->next;
assert(cur_head != nullptr);
VersionedVarBlock::Delete(prev);
}
// trigger pending write, if any
if (pending_write_ != nullptr && num_reads == 0) {
if (--pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
}
}
return false;
}

Expand All @@ -143,33 +147,33 @@ bool ThreadedVar::ready_to_read() {
return ready_to_read_;
}

ThreadedEngine::ThreadedEngine()
: pending_{0},
thread_pool_{[this]() { ThreadWorker(&task_queue_); }},
io_thread_pool_{[this]() { ThreadWorker(&io_task_queue_); }} {}

ThreadedEngine::~ThreadedEngine() noexcept(false) {
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
}

// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
return ThreadedVar::New(VersionedVarBlock::New());
}

ThreadedOpr* ThreadedEngine::NewOperator(
ThreadedEngine::AsyncFn fn, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars, FnProperty prop) {
ThreadedEngine::AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop) {
auto ret = ThreadedOpr::New();
ret->fn = fn;
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(),
ThreadedVar::CastFromBase);
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
ret->mutable_vars.begin(), ThreadedVar::CastFromBase);
#if ENGINE_DEBUG
if (ENGINE_DEBUG != 0) {
CheckDuplicate(const_vars, mutable_vars);
}
return ret;
}

void ThreadedEngine::CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) {
// Check for duplicates.
auto use = const_vars;
auto mutate = mutable_vars;
Expand Down Expand Up @@ -200,12 +204,10 @@ ThreadedOpr* ThreadedEngine::NewOperator(
<< "duplicate items found between `const_vars` and `mutable_vars`";
}
}
#endif // ENGINE_DEBUG
return ret;
}

void ThreadedEngine::DeleteOperator(OprHandle op) {
auto&& threaded_opr = ThreadedOpr::CastFromBase(op);
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
std::vector<VarHandle> deps;
deps.reserve(threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size());
Expand All @@ -221,8 +223,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
}

void ThreadedEngine::Push(OprHandle op, Context exec_ctx) {
auto&& threaded_opr = ThreadedOpr::CastFromBase(op);
auto&& opr_block = OprBlock::New();
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
opr_block->wait.store(threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size() + 1);
Expand All @@ -237,19 +239,15 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) {
i->AppendWriteDependency(opr_block);
}
if (--opr_block->wait == 0) {
if (opr_block->opr->prop == FnProperty::kAsync) {
DoExecute(opr_block);
} else {
DoPushToQueue(opr_block);
}
this->PushToExecute(opr_block, true);
}
}

void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop) {
auto&& opr = NewOperator(fn, const_vars, mutable_vars, prop);
ThreadedOpr *opr = NewOperator(fn, const_vars, mutable_vars, prop);
opr->temporary = true;
Push(opr, exec_ctx);
}
Expand Down Expand Up @@ -289,12 +287,16 @@ void ThreadedEngine::WaitForAll() {
inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) { DoPushToQueue(opr); });
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
bool to_delete = i->CompleteWriteDependency(
[this](OprBlock* opr) { DoPushToQueue(opr); });
[this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
if (to_delete) {
ThreadedVar::Delete(i);
}
Expand All @@ -311,48 +313,11 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

void ThreadedEngine::ThreadWorker(
dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) {
OprBlock* opr_block;
while (task_queue->Pop(&opr_block)) {
DoExecute(opr_block);
}
}

void ThreadedEngine::DoPushToQueue(OprBlock* opr_block) {
switch (opr_block->opr->prop) {
case FnProperty::kCopy: {
io_task_queue_.Push(opr_block);
break;
}
default: {
task_queue_.Push(opr_block);
break;
}
}
}

void ThreadedEngine::DoExecute(OprBlock* opr_block) {
assert(opr_block->wait.load() == 0);
ThreadedOpr* threaded_opr = opr_block->opr;
if (opr_block->ctx.dev_mask == gpu::kDevMask) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id));
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
}
auto&& rctx = opr_block->opr->prop == FnProperty::kCopy
? streams_.GetIORunContext(opr_block->ctx)
: streams_.GetRunContext(opr_block->ctx);
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnComplete_, threaded_opr);
threaded_opr->fn(rctx, callback);
OprBlock::Delete(opr_block);
void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *threaded_opr) {
static_cast<ThreadedEngine*>(engine)->OnComplete(
static_cast<ThreadedOpr*>(threaded_opr));
}

Engine *CreateThreadedEngine() {
return new ThreadedEngine();
}
} // namespace engine
} // namespace mxnet
101 changes: 46 additions & 55 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine.h
* \brief Implementation of threaded engine that tracks the dependency
* and pushes actions to execute.
* \brief Implements base class of threaded engine
* that tracks the dependency and pushes actions to execute.
* \author Yutian Li
*/
#ifndef MXNET_ENGINE_THREADED_ENGINE_H_
#define MXNET_ENGINE_THREADED_ENGINE_H_

#include <dmlc/base.h>
#include <dmlc/concurrency.h>
#include <dmlc/logging.h>
#include <vector>
#include <functional>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "./engine_impl.h"
#include "./thread_pool.h"
#include "./stream_manager.h"
#include "../common/object_pool.h"

namespace mxnet {
Expand Down Expand Up @@ -200,18 +198,19 @@ struct ThreadedOpr final : public Opr,
}; // struct ThreadedOpr

/*!
* \brief Engine implementation.
* \brief Base class of all ThreadedEngine.
* This class implements a thread safe version of engine.
* The engine tracks the dependencies, and will call PushToExecute
* to execute a specific task.
*
* Subclass can implement PushToExecute to design specific
* execution policy for the tasks.
*/
class ThreadedEngine : public Engine {
public:
/*!
* \brief Constructor and destructor.
*/
ThreadedEngine();
~ThreadedEngine() noexcept(false);
/*!
* \brief Overriding methods.
*/
// constructor
ThreadedEngine() : pending_(0) {}
// implementing all the functions from Engine.
ThreadedVar* NewVariable() override;
ThreadedOpr* NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
Expand All @@ -226,65 +225,57 @@ class ThreadedEngine : public Engine {
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;

protected:
/*!
* \brief Worker.
* \param task_queue Queue to work on.
* \brief Push the opr block to execution queue to be executed.
* This function is implemented by the corresponding subclass
* for specific policy.
*
* The method to pass to thread pool to parallelize.
* \param opr_block The operator block.
* \param pusher_thread whether the caller is the thread that calls push
*/
virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
/*!
* \brief Call this function to actually execute an opr_block
* This function also deletes the opr_block after execution.
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue);
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, threaded_opr);
threaded_opr->fn(run_ctx, callback);
OprBlock::Delete(opr_block);
}

private:
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
* \param mutable_vars the variables to mutate.
*/
void CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars);
/*!
* \brief Callback on operation completion.
*
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
inline static void OnComplete_(Engine *engine, void *threaded_opr) {
static_cast<ThreadedEngine*>(engine)->OnComplete(
static_cast<ThreadedOpr*>(threaded_opr));
}

private:
/*! \brief Concurrency for thread pool */
static constexpr std::size_t kNumWorkingThreads = 16;
/*! \brief Maximum number of GPUs */
static constexpr std::size_t kMaxNumGpus = 16;
/*!\brief number of streams allocated for each GPU */
static constexpr std::size_t kNumStreamsPerGpu = 16;
static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*!
* \brief Number of pending operations.
*/
std::atomic<std::size_t> pending_;
/*!
* \brief Notify waits for single or all variables.
* \brief Mutex and condition_variable,
* used to Notify waits for single or all variables.
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*!
* \brief Streams.
*/
StreamManager<kMaxNumGpus, kNumStreamsPerGpu> streams_;
/*!
* \brief Task queues.
*/
dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_;
dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_;
/*!
* \brief Thread pools.
*/
ThreadPool<kNumWorkingThreads> thread_pool_;
ThreadPool<1> io_thread_pool_;
/*!
* \brief Push to corresponding task queue.
* \param opr_block The operator block.
*/
void DoPushToQueue(OprBlock* opr_block);
/*!
* \brief Execute an operation.
* \param opr_block The operator block.
*/
void DoExecute(OprBlock* opr_block);
/*!
* \brief Disallow copy construction and assignment.
*/
Expand Down
Loading