/
MPIWrapper.h
317 lines (278 loc) · 10.3 KB
/
MPIWrapper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#pragma once
// This uses mpi.h which requires the Microsoft MPI SDK to be installed on Windows
// [cf. https://msdn.microsoft.com/en-us/library/bb524831(v=vs.85).aspx]
// download msmpisdk.msi at https://www.microsoft.com/en-us/download/details.aspx?id=49926 and run it
// and the MPI dev package on Linux (sudo apt-get install libopenmpi-dev openmpi-bin openmpi-doc)
#include "mpi.h"
#pragma comment(lib, "msmpi.lib")
#include <string>
#include <array>
#include <vector>
namespace Microsoft { namespace MSR { namespace CNTK {
struct MpiFail : public std::string
{
MpiFail(const std::string &what)
: std::string(what)
{
}
};
static int operator||(int rc, const MpiFail &what)
{
if (rc == MPI_SUCCESS)
{
return rc;
}
fprintf(stderr, "%s, MPI error %d\n", what.c_str(), rc);
fflush(stderr);
// (special case: we use that code to indicate a missing msmpi.dll...)
if (rc != MPI_ERR_INTERN)
{
char errbuf[MPI_MAX_ERROR_STRING + 1] = {0};
int len;
MPI_Error_string(rc, &errbuf[0], &len);
fprintf(stderr, "%s, MPI error %d: %s\n", what.c_str(), rc, errbuf);
fflush(stderr);
// we abort through this, so that the MPI system gets the memo
MPI_Abort(MPI_COMM_WORLD, rc);
// TODO: or does that only signal an issue, and we should still terminate ourselves?
// BUGBUG: We'd also need to Abort through the other sub-set communicator
}
RuntimeError("%s", what.c_str());
}
class MPIWrapper
{
int m_myRank;
int m_numMPINodes;
size_t m_numNodesInUse;
// MPI communicator that reflects the current subset selection
MPI_Comm m_currentComm;
// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
int MPI_Init_DL()
{
#ifdef WIN32
__try
#endif
{
int argc = 0;
char **argv = NULL;
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
int provided;
int ret = MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
if (provided != requiredThreadLevelSupport)
LogicError("Failed to initialize MPI with the desired level of thread support");
return ret;
}
#ifdef WIN32
__except (EXCEPTION_EXECUTE_HANDLER)
{
fprintf(stderr, "mpihelper: msmpi.dll missing\n");
return MPI_ERR_INTERN;
}
#endif
}
// Workaround for the issue with MPI hanging when we have non-0 exit codes from CNTK processes
// OpenMPI has a confirmed race condition on killing child process vs. handling their non-zero exit statuses, resulting
// in a deadlock, where all processes killed but MPI is still waiting.
// This happens when several perfectly synchronized processes (for example on MPI barrier)
// simulatenously exit with non-0 exit code.
// As a workaround, we simply sleep 50*rank miliseconds, effectively "de-synchronizing processes" at exit,
// allowing MPI to sequentially handle terminations
static int s_myRank;
static void MPIWorkaroundAtExit()
{
// Note: we can't use g_mpi, since MPI stack is already down at this point
Sleep(s_myRank * 50);
}
public:
MPIWrapper()
: m_currentComm(MPI_COMM_WORLD)
{
static bool initialized = false;
if (initialized)
{
LogicError("MPIWrapper: this is a singleton class that can only be instantiated once per process");
}
initialized = true;
fprintf(stderr, "MPIWrapper: initializing MPI\n");
fflush(stderr);
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init_DL() || MpiFail("mpiaggregator: MPI_Init");
}
MPI_Comm_rank(MPI_COMM_WORLD, &m_myRank);
MPI_Comm_size(MPI_COMM_WORLD, &m_numMPINodes);
m_numNodesInUse = m_numMPINodes;
// Applying MPI workaround
s_myRank = m_myRank;
atexit(&MPIWrapper::MPIWorkaroundAtExit);
// by default we use all of them
RequestNodes("MPIWrapper");
if (m_numMPINodes > 1)
fprintf(stderr, "mpihelper: we are cog %d in a gearbox of %d\n", (int) m_myRank, (int) m_numMPINodes);
else
fprintf(stderr, "mpihelper: only one MPI process: MPI operation will be boring\n");
fflush(stderr);
// do an initial handshake
Ping("mpihelper");
// stagger the jobs just a little to get a sort-of deterministic order e.g. in GPU allocation when running on one machine
// continue 0.5 seconds apart
::Sleep((DWORD)(500 * CurrentNodeRank()));
}
// Note: we don't clear the sub-communication here although we should, because in case of a crash, this prevents the EXE from terminating.
// It's OK since this class is a singleton anyway that gets instantiated exactly once at program startup.
~MPIWrapper()
{
fprintf(stderr, "~MPIWrapper\n");
fflush(stderr);
MPI_Finalize();
}
void Ping(const char *msg) const
{
#undef USE2NDCOMM
#ifndef USE2NDCOMM
if (NumNodesInUse() != m_numMPINodes)
{
fprintf(stderr, "ping [%s]: cannot be applied to subset (%d) of nodes, skipping\n", msg, (int) NumNodesInUse());
fflush(stderr);
return;
}
#endif
std::array<int, 1> handshake;
handshake[0] = 1;
fprintf(stderr, "ping [%s]: %d nodes pinging each other\n", msg, (int) NumNodesInUse());
fflush(stderr);
AllReduce(handshake);
fprintf(stderr, "ping [%s]: all %d nodes responded\n", msg, handshake[0]);
fflush(stderr);
}
void RequestNodes(const char *msg, size_t requestednodes = SIZE_MAX /*default: all*/)
{
Ping("requestnodes (before change)");
// undo current split
#ifdef USE2NDCOMM
if (m_currentComm != MPI_COMM_WORLD /*no subset*/ && m_currentComm != MPI_COMM_NULL /*idle nodes*/)
{
fprintf(stderr, "requestnodes: MPI_Comm_free %x\n", (int) m_currentComm);
fflush(stderr);
MPI_Comm_free(&m_currentComm) || MpiFail("requestnodes: MPI_Comm_free"); // will leave MPI_COMM_NULL here
}
#endif
// reset to MPI_COMM_WORLD
m_currentComm = MPI_COMM_WORLD;
// create a new split (unless all nodes were requested)
if (requestednodes < (size_t) m_numMPINodes)
{
#ifdef USE2NDCOMM
fprintf(stderr, "requestnodes: MPI_Comm_split %d\n", (node() < requestednodes) ? 1 : MPI_UNDEFINED);
fflush(stderr);
MPI_Comm_split(communicator(), (node() < requestednodes) ? 1 : MPI_UNDEFINED, 0, &m_currentComm) || MpiFail("requestnodes: MPI_Comm_split");
fprintf(stderr, "requestnodes: MPI_Comm_split -> %x\n", (int) m_currentComm);
fflush(stderr);
#endif
}
else
{
// leave m_currentComm as MPI_COMM_WORLD
// and clip to #nodes
requestednodes = m_numMPINodes;
}
m_numNodesInUse = requestednodes;
fprintf(stderr, "requestnodes [%s]: using %d out of %d MPI nodes (%d requested); we (%d) are %s\n",
msg, (int) m_numNodesInUse, (int) m_numMPINodes, (int) requestednodes,
(int) CurrentNodeRank(), IsIdle() ? "out (idle)" : "in (participating)");
fflush(stderr);
Ping("requestnodes (after change)");
}
MPI_Comm Communicator() const
{
return m_currentComm;
}
size_t NumNodesInUse() const
{
return m_numNodesInUse;
}
size_t CurrentNodeRank() const
{
return m_myRank;
}
bool IsMainNode() const
{
return m_myRank == 0;
} // we are the chosen one--do extra stuff like saving the model to disk
bool IsIdle() const
{
return CurrentNodeRank() >= NumNodesInUse();
} // user had requested to not use this many nodes
bool UsingAllNodes() const
{
return NumNodesInUse() == m_numMPINodes;
} // all nodes participate (used to check whether we can use MPI_Allreduce directly)
size_t MainNodeRank() const
{
return 0;
}
// -----------------------------------------------------------------------
// data-exchange functions (wrappers around MPI functions)
// -----------------------------------------------------------------------
// helpers to determine the MPI_Datatype of a pointer
static MPI_Datatype GetDataType(char *)
{
return MPI_CHAR;
}
static MPI_Datatype GetDataType(int *)
{
return MPI_INT;
}
static MPI_Datatype GetDataType(float *)
{
return MPI_FLOAT;
}
static MPI_Datatype GetDataType(double *)
{
return MPI_DOUBLE;
}
static MPI_Datatype GetDataType(size_t *)
{
return sizeof(size_t) == 4 ? MPI_UNSIGNED : MPI_LONG_LONG_INT;
}
// allreduce of a vector
template <typename VECTORLIKEOBJECT>
void AllReduce(VECTORLIKEOBJECT &accumulator) const
{
auto *dataptr = accumulator.data();
size_t totalnumelements = accumulator.size();
// use MPI to compute the sum over all elements in (dataptr, totalnumelements) and redistribute to all nodes
if ((NumNodesInUse() > 1) && (Communicator() != MPI_COMM_NULL))
{
MPI_Allreduce(MPI_IN_PLACE, dataptr, (int) totalnumelements, GetDataType(dataptr), MPI_SUM, Communicator()) || MpiFail("allreduce: MPI_Allreduce");
}
}
// for raw pointer
template <class ElemType>
void AllReduce(ElemType *pData, size_t nData)
{
if ((NumNodesInUse() > 1 && (Communicator() != MPI_COMM_NULL)))
{
MPI_Allreduce(MPI_IN_PLACE, pData, (int) nData, GetDataType(pData), MPI_SUM, Communicator()) || MpiFail("Allreduce: MPI_Allreduce");
}
}
template <class ElemType>
void Bcast(ElemType *pData, size_t nData, size_t srcRank)
{
if ((NumNodesInUse() > 1) && (Communicator() != MPI_COMM_NULL))
{
MPI_Bcast(pData, (int) nData, GetDataType(pData), (int) srcRank, Communicator()) || MpiFail("Bcast: MPI_Bcast");
}
}
// wait for all ranks to reach here
void WaitAll()
{
MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier");
}
};
}
}
}
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;