Skip to content

Commit

Permalink
fix #5428 redux
Browse files Browse the repository at this point in the history
  • Loading branch information
rtri committed Jan 14, 2017
1 parent 5a97ea9 commit 681fe52
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions rts/System/Threading/ThreadPool.h
Expand Up @@ -85,16 +85,19 @@ class ITaskGroup
{
public:
ITaskGroup(const bool getid = true) : id(getid ? lastId.fetch_add(1) : -1) {
remainingTasks = 0;
wantedThread = 0;
inTaskQueue = 1;
ResetState();
}

virtual ~ITaskGroup() {
assert(IsFinished());
assert(!IsInQueue());
}

virtual void ResetState() {
remainingTasks = 0;
wantedThread = 0;
inTaskQueue = 1;
}
virtual bool IsSingleTask() const { return false; }
virtual bool ExecuteStep() = 0;
virtual bool SelfDelete() const { return false; }
Expand All @@ -108,6 +111,7 @@ class ITaskGroup
const spring_time dt = t1 - t0;

if (!wffCall) {
// do not set this from WFF, defeats the purpose
assert(inTaskQueue.load() == 1);
inTaskQueue.store(0);
}
Expand Down Expand Up @@ -483,23 +487,26 @@ class ForTaskGroup : public ITaskGroup

template <template<typename> class TG, typename F>
struct TaskPool {
typedef TG<F> I;
typedef std::shared_ptr<I> P;
typedef TG<F> FuncTaskGroup;
typedef std::shared_ptr<FuncTaskGroup> FuncTaskGroupPtr;

std::array<P, 256> cp;
// more than 256 nested for_mt's or parallel's should be uncommon
std::array<FuncTaskGroupPtr, 256> tgPool;
std::atomic_int pos = {0};

TaskPool() {
for (int i = 0; i<256; ++i) {
cp[i] = P(new I);
for (int i = 0; i < tgPool.size(); ++i) {
tgPool[i] = FuncTaskGroupPtr(new FuncTaskGroup());
}
}


P Get() {
auto v = cp[pos.fetch_add(1) % 256];
assert(!v || v->IsFinished());
return v;
FuncTaskGroupPtr GetTaskGroup() {
auto tg = tgPool[pos.fetch_add(1) % tgPool.size()];

assert(tg->IsFinished());
tg->ResetState();
return tg;
}
};

Expand All @@ -521,8 +528,11 @@ static inline void for_mt(int start, int end, int step, F&& f)
}

SCOPED_MT_TIMER("::ThreadWorkers (real)");
static TaskPool<ForTaskGroup,F> pool;
auto taskGroup = pool.Get();

// static, so TaskGroup's are recycled
static TaskPool<ForTaskGroup, F> pool;
auto taskGroup = pool.GetTaskGroup();

taskGroup->Enqueue(start, end, step, f);
taskGroup->UpdateId();
ThreadPool::PushTaskGroup(taskGroup);
Expand All @@ -543,8 +553,11 @@ static inline void parallel(F&& f)
return f();

SCOPED_MT_TIMER("::ThreadWorkers (real)");

// static, so TaskGroup's are recycled
static TaskPool<Parallel2TaskGroup, F> pool;
auto taskGroup = pool.Get();
auto taskGroup = pool.GetTaskGroup();

taskGroup->Enqueue(f);
taskGroup->UpdateId();
ThreadPool::PushTaskGroup(taskGroup);
Expand Down

0 comments on commit 681fe52

Please sign in to comment.