-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
allreduce_threads.cc
64 lines (53 loc) · 1.59 KB
/
allreduce_threads.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
/*
This implements the allreduce function using threads.
*/
#include "allreduce.h"
#include <future>
using namespace std;
AllReduceSync::AllReduceSync(const size_t total) : m_total(total), m_count(0), m_run(true)
{ m_mutex = new mutex;
m_cv = new condition_variable;
buffers = new void*[total];
}
AllReduceSync::~AllReduceSync()
{ delete m_mutex;
delete m_cv;
delete buffers;
}
void AllReduceSync::waitForSynchronization()
{ unique_lock<mutex> l(*m_mutex);
m_count++;
if (m_count >= m_total)
{ assert(m_count == m_total);
m_cv->notify_all();
// order of m_count before or after notify_all doesn't matter
// since the lock is still hold at this point in time.
m_count = 0;
// flip for the next run
m_run = !m_run;
}
else
{ bool current_run = m_run;
// this predicate cannot depend on m_count, as somebody can race ahead and m_count++
// FYI just wait can spuriously wake-up
m_cv->wait(l, [this, current_run] { return m_run != current_run; });
}
}
AllReduceThreads::AllReduceThreads(AllReduceThreads* root, const size_t ptotal, const size_t pnode)
: AllReduce(ptotal, pnode), m_sync(root->m_sync), m_syncOwner(false)
{
}
AllReduceThreads::AllReduceThreads(const size_t ptotal, const size_t pnode)
: AllReduce(ptotal, pnode), m_sync(new AllReduceSync(ptotal)), m_syncOwner(true)
{
}
AllReduceThreads::~AllReduceThreads()
{ if (m_syncOwner)
{ delete m_sync;
}
}