Skip to content

Commit

Permalink
Merge pull request #74 from Bluefog-Lib/dynamic_neighbor_allgather
Browse files Browse the repository at this point in the history
Dynamic neighbor allgather
  • Loading branch information
bichengying committed Feb 21, 2021
2 parents fb30a9c + 4a95892 commit ceb5073
Show file tree
Hide file tree
Showing 12 changed files with 538 additions and 241 deletions.
1 change: 1 addition & 0 deletions bluefog/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ enum class Communicator {
LOCAL = 1,
CROSS = 2,
GRAPH = 3,
DYNAMIC = 4,
};

enum class DataType {
Expand Down
123 changes: 122 additions & 1 deletion bluefog/common/mpi_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,120 @@ bool MPIContext::UnregisterAllWindowName() {
return true;
}

std::string GenerateNeighborExchangeErrorMessage(const std::vector<MPI_Status>& statuses,
int nsend, int nrecv) {
std::string error_message = "";
bool error_encountered = false;
for (int i = 0; i < nsend; ++i) {
const auto& status = statuses[i];
error_message += "MPI_Isend to Process " + std::to_string(status.MPI_SOURCE);
error_message += "; with tag " + std::to_string(status.MPI_TAG);
error_message += "; with error code " + std::to_string(status.MPI_ERROR) + "\n";
if(status.MPI_ERROR != MPI_SUCCESS) error_encountered = true;
}
for (int i = 0; i < nrecv; ++i) {
const auto& status = statuses[i+nsend];
error_message += "MPI_Irecv from Process " + std::to_string(status.MPI_SOURCE);
error_message += "; with tag " + std::to_string(status.MPI_TAG);
error_message += "; with error code " + std::to_string(status.MPI_ERROR) + "\n";
if(status.MPI_ERROR != MPI_SUCCESS) error_encountered = true;
}
if (!error_encountered) error_message = "";
return error_message;
}

std::string MPIContext::NeighborValueExchangeWithConstantElements(
const void* input_ptr, void* output_ptr, int num_elements, DataType dtype,
const std::vector<int>* dst_ranks, const std::vector<int>* src_ranks) {
int nsend = dst_ranks->size();
int nrecv = src_ranks->size();
std::vector<MPI_Request> requests(nsend + nrecv);
std::vector<MPI_Status> statuses(nsend + nrecv);
int element_size = GetMPITypeSize(dtype);
for (int i = 0; i < nrecv; ++i) {
void* recvbuf = (void*)(static_cast<const char*>(output_ptr) +
num_elements * i * element_size);
int ret_code = MPI_Irecv(
recvbuf, num_elements, GetMPIDataType(dtype), src_ranks->at(i),
/*tag=*/rank_ + src_ranks->at(i),
GetMPICommunicator(Communicator::GLOBAL),
&requests[i + nsend]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Irecv (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
for (int i = 0; i < nsend; ++i) {
int ret_code = MPI_Isend(
input_ptr, num_elements, GetMPIDataType(dtype), dst_ranks->at(i),
/*tag=*/rank_ + dst_ranks->at(i),
GetMPICommunicator(Communicator::GLOBAL), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
MPI_Waitall(nsend + nrecv, requests.data(), statuses.data());
return GenerateNeighborExchangeErrorMessage(statuses, nsend, nrecv);
}

std::string MPIContext::NeighborValueExchangeWithVaryingElements(
const void* input_ptr, void* output_ptr, const int sendcount,
const int* recvcounts, const int* displcmnts, DataType dtype,
const std::vector<int>* dst_ranks,
const std::vector<int>* src_ranks
) {
int nsend = dst_ranks->size();
int nrecv = src_ranks->size();
std::vector<MPI_Request> requests(nsend + nrecv);
std::vector<MPI_Status> statuses(nsend + nrecv);
int element_size = GetMPITypeSize(dtype);
for (int i = 0; i < nrecv; ++i) {
void* recvbuf = (void*)(static_cast<const char*>(output_ptr) +
displcmnts[i] * element_size);
int ret_code = MPI_Irecv(
recvbuf, recvcounts[i], GetMPIDataType(dtype), src_ranks->at(i),
/*tag=*/rank_ + src_ranks->at(i),
GetMPICommunicator(Communicator::GLOBAL),
&requests[i + nsend]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Irecv (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
for (int i = 0; i < nsend; ++i) {
int ret_code = MPI_Isend(
input_ptr, sendcount, GetMPIDataType(dtype), dst_ranks->at(i),
/*tag=*/rank_ + dst_ranks->at(i),
GetMPICommunicator(Communicator::GLOBAL), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
MPI_Waitall(nsend + nrecv, requests.data(), statuses.data());
return GenerateNeighborExchangeErrorMessage(statuses, nsend, nrecv);
}

Status MPIContext::AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
Communicator comm_type) {
if (comm_type == Communicator::DYNAMIC) {
return Status::InvalidArgument(
"Try to allocate the output with dynamic topology but do not provide "
"the source and receive ranks.");
} else {
return AllocateOutput(entry, recvcounts, comm_type, nullptr, nullptr);
}
}

Status MPIContext::AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
Communicator comm_type,
const std::vector<int>* dst_ranks,
const std::vector<int>* src_ranks) {
Timeline* timeline_ptr;
Status timeline_status = GetBluefogTimeline(timeline_ptr);
timeline_ptr->ActivityStart(entry.tensor_name, "ALLOCATE_OUTPUT");
Expand All @@ -462,6 +574,8 @@ Status MPIContext::AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
cnt_size = size_;
} else if (comm_type == Communicator::GRAPH) {
cnt_size = neighbor_indgree_;
} else if (comm_type == Communicator::DYNAMIC) {
cnt_size = src_ranks->size();
}

int* send_count = new int[1];
Expand All @@ -475,6 +589,11 @@ Status MPIContext::AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
ret_code =
MPI_Neighbor_allgather(send_count, 1, MPI_INT, gather_count, 1, MPI_INT,
GetMPICommunicator(Communicator::GRAPH));
} else if (comm_type == Communicator::DYNAMIC) {
std::string error_message = NeighborValueExchangeWithConstantElements(
send_count, gather_count, 1, DataType::BLUEFOG_INT32, dst_ranks, src_ranks
);
ret_code = error_message == "" ? MPI_SUCCESS : -1;
}

if (ret_code != MPI_SUCCESS) {
Expand Down Expand Up @@ -508,12 +627,14 @@ Status MPIContext::AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
}

void MPIContext::SetDisplacements(const int* recvcounts, int*& displcmnts,
Communicator comm_type) {
Communicator comm_type, int source_neighbor_cnt) {
int cnt_size = 0;
if (comm_type == Communicator::GLOBAL) {
cnt_size = size_;
} else if (comm_type == Communicator::GRAPH) {
cnt_size = neighbor_indgree_;
} else if (comm_type == Communicator::DYNAMIC) {
cnt_size = source_neighbor_cnt;
}
for (int rc = 0; rc < cnt_size; ++rc) {
if (rc == 0) {
Expand Down
21 changes: 19 additions & 2 deletions bluefog/common/mpi_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,25 @@ class MPIContext {
bool UnregisterWindowName(const std::string& name);
bool UnregisterAllWindowName();

Status AllocateOutput(TensorTableEntry& entries, int*& recvcounts, Communicator comm_type);
void SetDisplacements(const int* recvcounts, int*& displcmnts, Communicator comm_type);
std::string NeighborValueExchangeWithConstantElements(
const void* input_ptr, void* output_ptr, int num_elements, DataType dtype,
const std::vector<int>* dst_ranks,
const std::vector<int>* src_ranks);
std::string NeighborValueExchangeWithVaryingElements(
const void* input_ptr, void* output_ptr, const int sendcount,
const int* recvcounts, const int* displcmnts, DataType dtype,
const std::vector<int>* dst_ranks,
const std::vector<int>* src_ranks);

Status AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
Communicator comm_type);
Status AllocateOutput(TensorTableEntry& entry, int*& recvcounts,
Communicator comm_type,
const std::vector<int>* dst_ranks,
const std::vector<int>* src_ranks);
// source_neighbor_cnt is required only when Communicator is dynamic.
void SetDisplacements(const int* recvcounts, int*& displcmnts,
Communicator comm_type, int source_neighbor_cnt = -1);

// Flag indicating whether mpi is enabled.
bool enabled_ = false;
Expand Down

0 comments on commit ceb5073

Please sign in to comment.