forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mpi_common.h
159 lines (140 loc) · 4.65 KB
/
mpi_common.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#ifndef CAFFE2_MPI_MPI_COMMON_H_
#define CAFFE2_MPI_MPI_COMMON_H_
#include <mpi.h>
#include <mutex>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
namespace caffe2 {
inline void CheckInitializedMPI() {
int flag;
MPI_Initialized(&flag);
CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
}
template <typename T>
class MPIDataTypeWrapper;
#define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
template <> \
class MPIDataTypeWrapper<c_type> { \
public: \
inline static MPI_Datatype type() { \
return mpi_type; \
} \
};
MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
// Note(Yangqing): as necessary, add more specializations.
#undef MPI_DATATYPE_WRAPPER
// For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
CAFFE2_API std::mutex& MPIMutex();
#define MPI_CHECK(condition) \
do { \
std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
int error = (condition); \
CAFFE_ENFORCE( \
error == MPI_SUCCESS, \
"Caffe2 MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
/**
* @brief Gets the global MPI communicator used by Caffe2. In default, this
* is MPI_COMM_WORLD unless you call SetGlobalMPIComm().
*/
CAFFE2_API MPI_Comm GlobalMPIComm();
/**
* @brief Sets the global MPI communicator. Caffe2 takes over the ownership
* of the passed in communicator.
*/
CAFFE2_API void SetGlobalMPIComm(MPI_Comm new_comm);
/**
* @brief A helper function to return the size of the given communicator.
*/
CAFFE2_API int MPICommSize(MPI_Comm comm);
/**
* @brief A helper function to return the rank of the given communicator.
*/
CAFFE2_API int MPICommRank(MPI_Comm comm);
/**
* @brief A simple wrapper over an MPI common world.
*/
class MPICommonWorldWrapper {
public:
/**
* @brief Creates a common world wrapper.
*
* The new common world is created by taking the existing communicator
* passed in as src_comm, and splitting it using the color and the rank
* specified. In default, we will split from Caffe2's global communicator,
* and use color 0 as well as rank implicitly given by src_comm. As a result,
* the default constructor basically creates a comm identical to the source
* comm world.
*/
explicit MPICommonWorldWrapper(
MPI_Comm src_comm = MPI_COMM_NULL,
int color = 0,
int rank = -1) {
if (src_comm == MPI_COMM_NULL) {
src_comm = GlobalMPIComm();
}
if (rank == -1) {
MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
}
MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
MPI_CHECK(MPI_Comm_size(comm_, &size_));
MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
}
~MPICommonWorldWrapper() {
int ret;
MPI_CHECK(MPI_Finalized(&ret));
if (!ret) {
MPI_Comm_free(&comm_);
}
}
/**
* @brief Returns the common world held by the wrapper.
*/
inline MPI_Comm comm() const {
return comm_;
}
/**
* @brief Returns the size of the world.
*/
inline int size() const {
return size_;
}
/**
* @brief Returns the rank of this process in the world.
*/
inline int rank() const {
return rank_;
}
private:
MPI_Comm comm_;
int size_;
int rank_;
};
/**
* A function used to perform peer setup so one does not need to use
* mpirun / mpiexec to run the binary. Note that if you use mpirun or mpiexec
* to set up the common world, do not use this function - MPI_Init would have
* already set that up.
*
* This also assumes that you have a common path (like NFS) that multiple
* instances can read from.
*
* Inputs:
* replicas (int): the number of replicas that mpi will run with.
* role (string): the role of this process, "server" or "client".
* job_path (string): a file name that the server will write its port into
* and the clients will read the server's port from.
*/
void MPISetupPeers(
const int replicas,
const string& role,
const string& job_path);
} // namespace caffe2
#endif // CAFFE2_MPI_MPI_COMMON_H_