Skip to content

Commit

Permalink
Add the sort of rank for win_put/accum to avoid collison better
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed May 18, 2020
1 parent 8ed5721 commit 0851737
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions bluefog/common/mpi_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,28 @@ Status MPIController::WinFence(const std::string& name) {
return Status::OK();
}

// Reshuffle the order of destination to avoid the collision of network.
std::vector<std::pair<int, double>> GetSortedDstWeights(
const int self_rank, const int size, const std::unordered_map<int, double> dst_weights) {
std::vector<std::pair<int, double>> sorted_dst_weights;
for (auto kv : dst_weights) {
int target_rank = kv.first;
double weight = kv.second;
sorted_dst_weights.push_back(std::make_pair(target_rank, weight));
}

std::sort(
sorted_dst_weights.begin(), sorted_dst_weights.end(),
[self_rank, size](std::pair<int, double> a, std::pair<int, double> b) {
int distance1 = a.first - self_rank;
int distance2 = b.first - self_rank;
if (a.first < self_rank) distance1 += size;
if (b.first < self_rank) distance2 += size;
return distance1 < distance2;
});
return sorted_dst_weights;
}

void MPIController::WinPut(TensorTableEntry& entry) {
// We need to explicitly set the device here.
with_device device_guard(entry.device);
Expand All @@ -536,7 +558,11 @@ void MPIController::WinPut(TensorTableEntry& entry) {

Timeline* timeline_ptr;
Status timeline_status = GetBluefogTimeline(timeline_ptr);
for (auto kv : entry.dst_weights) {

std::vector<std::pair<int, double>> sorted_dst_weights =
GetSortedDstWeights(rank_, size_, entry.dst_weights);

for (auto kv : sorted_dst_weights) {
int target_rank = kv.first;
double weight = kv.second;

Expand Down Expand Up @@ -601,7 +627,10 @@ void MPIController::WinAccumulate(TensorTableEntry& entry) {
Status timeline_status = GetBluefogTimeline(timeline_ptr);
std::vector<int> mutex_ranks = {}; // used in mutex only.

for (auto kv : entry.dst_weights) {
std::vector<std::pair<int, double>> sorted_dst_weights =
GetSortedDstWeights(rank_, size_, entry.dst_weights);

for (auto kv : sorted_dst_weights) {
int target_rank = kv.first;
double weight = kv.second;
// avoid putting the tensor for itself (NOT valid).
Expand Down

0 comments on commit 0851737

Please sign in to comment.