From b25f1b9c249a68cbd0c7fea93c1b2483aa3eb00a Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 1 Apr 2022 18:50:28 +0800 Subject: [PATCH 01/12] Try to speed up sbp collector. However, throughput drop --- oneflow/core/auto_parallel/sbp_collector.cpp | 117 ++++++++++--------- oneflow/core/auto_parallel/sbp_collector.h | 14 +-- 2 files changed, 67 insertions(+), 64 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index d9f0bbbf275..6b71322e874 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -134,14 +134,15 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode* sbp // Initialize copy cost from proxy of producer to consumers void SbpCollector::InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - HashMap, std::unordered_set>& consumer_bn2sbp_set, + const std::vector>& consumer_bns, HashMap*>& op_name2sbp_node) { // Connect sbp proxy and consumers - for (const auto& consumer_bn_group : consumer_bn2sbp_set) { + for (const auto& consumer_bn : consumer_bns) { // consumer in cost model - SbpNode* sbp_node_consumer = op_name2sbp_node[consumer_bn_group.first.first]; + SbpNode* sbp_node_consumer = + op_name2sbp_node[consumer_bn.first->op().op_name()]; // input blob name of logical blob in consumer - const std::string& ibn = consumer_bn_group.first.second; + const std::string& ibn = consumer_bn.second; // check is_mutable in consumer OpNode* consumer = sbp_node_consumer->op_node; @@ -190,6 +191,13 @@ void SbpCollector::ProxySbpCandidate( HashMap, HashMap, std::unordered_set>> producer_lbi2consumer_bn2sbp_set; + // mapping from a logical blob id to index + HashMap lbi2index; + // mapping from the index to producer, consuemr and corresponding input blob name, possible sbp + // sets + std::vector index2producer; + std::vector>> index2consumer_bns; + std::vector> index2sbp_set; for (auto* consumer_sbp_node : sbp_graph.NodeList) { auto* node = consumer_sbp_node->op_node; @@ -209,11 +217,27 @@ void SbpCollector::ProxySbpCandidate( // not building proxy for fixed opertors if (op_name2sbp_node.find(producer.op().op_name()) == op_name2sbp_node.end()) { return; } + // decide the index of a logical blob description + const auto& iterator_lbi = lbi2index.find(lbi); + int32_t index = 0; + if (iterator_lbi == lbi2index.end()) { + index = lbi2index.size(); + lbi2index[lbi] = index; + // map from lbi to the producer + index2producer.push_back(&producer); + // Initialize consumer_bns and the sbp sets + index2consumer_bns.resize(index + 1); + index2sbp_set.resize(index + 1); + } else { + index = iterator_lbi->second; + } + // Add the consumer and corresponding input blob name + index2consumer_bns[index].push_back({consumer_sbp_node->op_node, ibn}); + // a set to store the id of all possible SBP Parallel for a downstream op // should filter out B and other repeated SBP Parallel by pre-storing them into an // unordered_set - std::unordered_set& SbpParallelIDs = producer_lbi2consumer_bn2sbp_set[{ - producer.op().op_name(), lbi}][{node->op().op_name(), ibn}]; + std::unordered_set& SbpParallelIDs = index2sbp_set[index]; // TODO: use SbpSignatureList instead of SbpSignatureObjList for (auto& sbp_sig : consumer_sbp_node->SbpSignatureObjList) { const auto& map = sbp_sig.bn_in_op2nd_sbp(); @@ -227,28 +251,28 @@ void SbpCollector::ProxySbpCandidate( }; // A set of binary set with broadcast only - std::unordered_set ParallelCandidatesInitializer; + // std::unordered_set ParallelCandidatesInitializer; // BinarySet one_broadcast(SbpParallelUniverse.size()); // one_broadcast.AddEntry(0); // ParallelCandidatesInitializer.insert(std::move(one_broadcast)); // Decide if we should insert a proxy for each logical blob - for (auto& lbi7groups : producer_lbi2consumer_bn2sbp_set) { + for (auto& lbi_index : lbi2index) { + int32_t index = lbi_index.second; // Only insert proxy for those blobs with multiple downstream consumers. - if (lbi7groups.second.size() < 2) { continue; } - const std::string& producer_name = lbi7groups.first.first; + if (index2consumer_bns[index].size() < 2) { continue; } + // Maximum number of possible sbp in the proxy + int32_t max_num_sbp_proxy = std::min(max_num_sbp_proxy_, index2consumer_bns[index].size()); // producer in cost model + const std::string& producer_name = index2producer[index]->op().op_name(); SbpNode* sbp_node_producer = op_name2sbp_node[producer_name]; - const LogicalBlobId& lbi = lbi7groups.first.second; - HashMap, std::unordered_set>& consumer_bn2sbp_set = - lbi7groups.second; - HashMap, std::unordered_set>::iterator it_begin = - consumer_bn2sbp_set.begin(); + + const LogicalBlobId& lbi = lbi_index.first; // store all the binary sets of SBP Parallel into an unordered_set. - std::unordered_set ParallelCandidates( - ParallelCandidatesInitializer); + std::unordered_set ParallelCandidates; + + DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], ParallelCandidates); - DfsSbpSet(it_begin, consumer_bn2sbp_set, op_name2sbp_node, ParallelCandidates); // Initialize sbp proxy SbpNode* sbp_proxy = InitializePorxy(sbp_graph, ParallelCandidates); // Might be unnecessary @@ -261,12 +285,13 @@ void SbpCollector::ProxySbpCandidate( InitializeCopyCostFromNode2Proxy(sbp_proxy, lbi); // Build connection and compute copy cost between proxy and consumers - InitializeCopyCostFromProxy2Consumer(sbp_proxy, consumer_bn2sbp_set, op_name2sbp_node); + InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bns[index], op_name2sbp_node); // Unloading - for (const auto& consumer_bn_group : consumer_bn2sbp_set) { + for (const auto& consumer_bn : index2consumer_bns[index]) { // consumer in cost model - SbpNode* sbp_node_consumer = op_name2sbp_node[consumer_bn_group.first.first]; + SbpNode* sbp_node_consumer = + op_name2sbp_node[consumer_bn.first->op().op_name()]; // the sbp edge connecting producer and consumer SbpEdge* edge_found = FindEdgeBetweenNodes(sbp_node_producer, sbp_node_consumer); @@ -282,43 +307,23 @@ void SbpCollector::ProxySbpCandidate( } // Depth first search to collect Sbp Parallel information for different lbis -void SbpCollector::DfsSbpSet( - HashMap, std::unordered_set>::iterator it, - HashMap, std::unordered_set>& consumer_bn2sbp_set, - HashMap*>& op_name2sbp_node, - std::unordered_set& ParallelCandidates) { - if (it == consumer_bn2sbp_set.end()) { +void SbpCollector::DfsSbpSet(int32_t depth, int32_t max_depth, + const std::unordered_set& sbp_sets, + std::unordered_set& ParallelCandidates) { + if (depth > 0) { // store the binary set into an unordered_set ParallelCandidates.insert(bs_buffer); - } else { - const std::string& consumer_name = it->first.first; - const std::string& ibn = it->first.second; - SbpNode* consumer_sbp_node = op_name2sbp_node[consumer_name]; - // a set to store the id of all possible SBP Parallel for a downstream op - // should filter out B and other repeated SBP Parallel by pre-storing them into an - // unordered_set - std::unordered_set SbpParallelIDs; - for (auto& sbp_sig : consumer_sbp_node->SbpSignatureObjList) { - const auto& map = sbp_sig.bn_in_op2nd_sbp(); - const auto& iter = map.find(ibn); - CHECK(iter != map.end()) << "blob_name " << ibn << " not found in sbp signature"; - const NdSbp& consumer_sbp = iter->second; - SbpParallelIDs.insert(SbpParallelUniverse[consumer_sbp]); - } - // next iterator - HashMap, std::unordered_set>::iterator it_next = - it; - ++it_next; - // go through all the sbp parallel of different candidates - for (int32_t SbpParallelNum : SbpParallelIDs) { - if (++accumulator[SbpParallelNum] == 1) { - bs_buffer.AddEntry(SbpParallelNum); - DfsSbpSet(it_next, consumer_bn2sbp_set, op_name2sbp_node, ParallelCandidates); - bs_buffer.DeleteEntry(SbpParallelNum); - } else { - DfsSbpSet(it_next, consumer_bn2sbp_set, op_name2sbp_node, ParallelCandidates); - } - accumulator[SbpParallelNum]--; + } + if (depth >= max_depth) { return; } + + // go through all the sbp parallel of different candidates + for (int32_t SbpParallelNum : sbp_sets) { + if (accumulator[SbpParallelNum] == 0) { + bs_buffer.AddEntry(SbpParallelNum); + ++accumulator[SbpParallelNum]; + DfsSbpSet(depth + 1, max_depth, sbp_sets, ParallelCandidates); + bs_buffer.DeleteEntry(SbpParallelNum); + --accumulator[SbpParallelNum]; } } } diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index 73e424b6e20..68094774b22 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -67,8 +67,7 @@ class SbpCollector { // Initialize copy cost from proxy of producer to consumers void InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - HashMap, std::unordered_set>& - consumer_bn2sbp_set, + const std::vector>& consumer_bns, HashMap*>& op_name2sbp_node); // Export list of possible combination of Sbp Parallels @@ -77,13 +76,12 @@ class SbpCollector { SbpGraph& sbp_graph); private: + // Maximum number of possible sbp in the proxy + unsigned long max_num_sbp_proxy_ = 3; + // Depth first search to collect Sbp Parallel information for different lbis - void DfsSbpSet( - HashMap, std::unordered_set>::iterator it, - HashMap, std::unordered_set>& - consumer_bn2sbp_set, - HashMap*>& op_name2sbp_node, - std::unordered_set& ParallelCandidates); + void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, + std::unordered_set& ParallelCandidates); }; // class SbpCollector } // namespace auto_parallel From ff137c8fe322598d4623b4121f0d1a4dc7fc4027 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 1 Apr 2022 19:05:41 +0800 Subject: [PATCH 02/12] Shrink the parallel candidates for the proxy node --- oneflow/core/auto_parallel/sbp_collector.cpp | 15 +++++++++++---- oneflow/core/auto_parallel/sbp_collector.h | 1 + 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index 6b71322e874..9b6e0647c13 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -271,7 +271,8 @@ void SbpCollector::ProxySbpCandidate( // store all the binary sets of SBP Parallel into an unordered_set. std::unordered_set ParallelCandidates; - DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], ParallelCandidates); + DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(), + ParallelCandidates); // Initialize sbp proxy SbpNode* sbp_proxy = InitializePorxy(sbp_graph, ParallelCandidates); @@ -309,6 +310,7 @@ void SbpCollector::ProxySbpCandidate( // Depth first search to collect Sbp Parallel information for different lbis void SbpCollector::DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, + const std::unordered_set::iterator start_it, std::unordered_set& ParallelCandidates) { if (depth > 0) { // store the binary set into an unordered_set @@ -316,12 +318,17 @@ void SbpCollector::DfsSbpSet(int32_t depth, int32_t max_depth, } if (depth >= max_depth) { return; } - // go through all the sbp parallel of different candidates - for (int32_t SbpParallelNum : sbp_sets) { + // go through the rest of the sbp parallel + std::unordered_set::iterator curr_it = start_it; + while (curr_it != sbp_sets.end()) { + // Take the value out + int32_t SbpParallelNum = *curr_it; + // Then move to the next pointer + ++curr_it; if (accumulator[SbpParallelNum] == 0) { bs_buffer.AddEntry(SbpParallelNum); ++accumulator[SbpParallelNum]; - DfsSbpSet(depth + 1, max_depth, sbp_sets, ParallelCandidates); + DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, ParallelCandidates); bs_buffer.DeleteEntry(SbpParallelNum); --accumulator[SbpParallelNum]; } diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index 68094774b22..53b19283d99 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -81,6 +81,7 @@ class SbpCollector { // Depth first search to collect Sbp Parallel information for different lbis void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, + const std::unordered_set::iterator sbp_set_it, std::unordered_set& ParallelCandidates); }; // class SbpCollector From 15cadfb4ab74298ffba87a95850159e638a53427 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Sat, 2 Apr 2022 22:24:14 +0800 Subject: [PATCH 03/12] Print out some information and then refine --- oneflow/core/auto_parallel/binary_set.cpp | 6 +++--- oneflow/core/auto_parallel/binary_set.h | 6 +++--- oneflow/core/auto_parallel/sbp_collector.cpp | 5 ----- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/oneflow/core/auto_parallel/binary_set.cpp b/oneflow/core/auto_parallel/binary_set.cpp index deed59a24fd..358d2884372 100644 --- a/oneflow/core/auto_parallel/binary_set.cpp +++ b/oneflow/core/auto_parallel/binary_set.cpp @@ -44,7 +44,7 @@ void BinarySet::Initialize(int32_t size_of_set) { } // Check if i-th element in this subset -int32_t BinarySet::CheckExistency(int32_t i) { +int32_t BinarySet::CheckExistency(int32_t i) const { int32_t k = i / bit_of_BinarySetEntryType; int32_t j = i % bit_of_BinarySetEntryType; return BinarySetValues[k] >> j & 1; @@ -91,7 +91,7 @@ int32_t BinarySet::Total() { } // Output all the elements in the subset -void BinarySet::OutPut(std::vector& out) { +void BinarySet::OutPut(std::vector& out) const { out.clear(); for (int32_t i = 0; i < SizeOfSet; i++) { if (CheckExistency(i)) { out.emplace_back(i); } @@ -99,7 +99,7 @@ void BinarySet::OutPut(std::vector& out) { } // Output all the elements in the subset -void BinarySet::QuickOutPut(std::vector& out) { +void BinarySet::QuickOutPut(std::vector& out) const { out.clear(); for (int32_t i = 0; i < BinarySetValues.size(); i++) { BinarySetEntryType x = BinarySetValues[i]; diff --git a/oneflow/core/auto_parallel/binary_set.h b/oneflow/core/auto_parallel/binary_set.h index d616be2b068..9bf4b00ec8f 100644 --- a/oneflow/core/auto_parallel/binary_set.h +++ b/oneflow/core/auto_parallel/binary_set.h @@ -47,7 +47,7 @@ class BinarySet { // Initialization void Initialize(int32_t size_of_set); // Check if i-th element in this subset - int32_t CheckExistency(int32_t i); + int32_t CheckExistency(int32_t i) const; // Add i-th element into this subset void AddEntry(int32_t i); // Take i-th element out from this subset @@ -59,9 +59,9 @@ class BinarySet { // Count number of elements in this subset int32_t Total(); // Output all the elements in the subset - void OutPut(std::vector& out); + void OutPut(std::vector& out) const; // Output all the elements in the subset - void QuickOutPut(std::vector& out); + void QuickOutPut(std::vector& out) const; // Add elements of input into this subset void AddEntrys(std::vector& in); // If two binary sets are equal to each other diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index 9b6e0647c13..c63667b9720 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -186,11 +186,6 @@ void SbpCollector::ProxySbpCandidate( // HashMap*>>& // op_name2lbi2sbp_proxy; - // mapping from a logical blob id to a group of consumers and corresponding input blob names. - // mapping from consumers and input blob names to an unordered_set of SBP Parallel. - HashMap, - HashMap, std::unordered_set>> - producer_lbi2consumer_bn2sbp_set; // mapping from a logical blob id to index HashMap lbi2index; // mapping from the index to producer, consuemr and corresponding input blob name, possible sbp From b6a829d5908fd8b1c19ce047ffd0990747a8f6f5 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 8 Apr 2022 18:33:00 +0800 Subject: [PATCH 04/12] Store the sbp set for each consumer --- oneflow/core/auto_parallel/sbp_collector.cpp | 42 +++++++++++--------- oneflow/core/auto_parallel/sbp_collector.h | 3 +- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index c63667b9720..5e981c89705 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -134,15 +134,14 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode* sbp // Initialize copy cost from proxy of producer to consumers void SbpCollector::InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - const std::vector>& consumer_bns, + HashMap, std::unordered_set>& consumer_bn2sbp_set, HashMap*>& op_name2sbp_node) { // Connect sbp proxy and consumers - for (const auto& consumer_bn : consumer_bns) { + for (const auto& consumer_bn_group : consumer_bn2sbp_set) { // consumer in cost model - SbpNode* sbp_node_consumer = - op_name2sbp_node[consumer_bn.first->op().op_name()]; + SbpNode* sbp_node_consumer = op_name2sbp_node[consumer_bn_group.first.first]; // input blob name of logical blob in consumer - const std::string& ibn = consumer_bn.second; + const std::string& ibn = consumer_bn_group.first.second; // check is_mutable in consumer OpNode* consumer = sbp_node_consumer->op_node; @@ -191,8 +190,10 @@ void SbpCollector::ProxySbpCandidate( // mapping from the index to producer, consuemr and corresponding input blob name, possible sbp // sets std::vector index2producer; - std::vector>> index2consumer_bns; std::vector> index2sbp_set; + // mapping from consumers and input blob names to an unordered_set of SBP Parallel. + std::vector, std::unordered_set>> + index2consumer_bn2sbp_set; for (auto* consumer_sbp_node : sbp_graph.NodeList) { auto* node = consumer_sbp_node->op_node; @@ -221,18 +222,18 @@ void SbpCollector::ProxySbpCandidate( // map from lbi to the producer index2producer.push_back(&producer); // Initialize consumer_bns and the sbp sets - index2consumer_bns.resize(index + 1); + index2consumer_bn2sbp_set.resize(index + 1); index2sbp_set.resize(index + 1); } else { index = iterator_lbi->second; } - // Add the consumer and corresponding input blob name - index2consumer_bns[index].push_back({consumer_sbp_node->op_node, ibn}); // a set to store the id of all possible SBP Parallel for a downstream op - // should filter out B and other repeated SBP Parallel by pre-storing them into an - // unordered_set - std::unordered_set& SbpParallelIDs = index2sbp_set[index]; + // should filter out repeated SBP Parallel by pre-storing them into an unordered_set + std::unordered_set& SbpParallelIDs = + index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}]; + // The union sbp set of all the consumers + std::unordered_set& UnionSbpParallelIDs = index2sbp_set[index]; // TODO: use SbpSignatureList instead of SbpSignatureObjList for (auto& sbp_sig : consumer_sbp_node->SbpSignatureObjList) { const auto& map = sbp_sig.bn_in_op2nd_sbp(); @@ -240,7 +241,9 @@ void SbpCollector::ProxySbpCandidate( CHECK(iter != map.end()) << "blob_name " << ibn << " not found in sbp signature"; const NdSbp& consumer_sbp = iter->second; // filter out repeated SBP - SbpParallelIDs.insert(SbpParallelUniverse[consumer_sbp]); + int32_t sbp_universe_id = SbpParallelUniverse[consumer_sbp]; + SbpParallelIDs.insert(sbp_universe_id); + UnionSbpParallelIDs.insert(sbp_universe_id); } } }; @@ -255,9 +258,10 @@ void SbpCollector::ProxySbpCandidate( for (auto& lbi_index : lbi2index) { int32_t index = lbi_index.second; // Only insert proxy for those blobs with multiple downstream consumers. - if (index2consumer_bns[index].size() < 2) { continue; } + if (index2consumer_bn2sbp_set[index].size() < 2) { continue; } // Maximum number of possible sbp in the proxy - int32_t max_num_sbp_proxy = std::min(max_num_sbp_proxy_, index2consumer_bns[index].size()); + int32_t max_num_sbp_proxy = + std::min(max_num_sbp_proxy_, index2consumer_bn2sbp_set[index].size()); // producer in cost model const std::string& producer_name = index2producer[index]->op().op_name(); SbpNode* sbp_node_producer = op_name2sbp_node[producer_name]; @@ -281,13 +285,13 @@ void SbpCollector::ProxySbpCandidate( InitializeCopyCostFromNode2Proxy(sbp_proxy, lbi); // Build connection and compute copy cost between proxy and consumers - InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bns[index], op_name2sbp_node); + InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bn2sbp_set[index], + op_name2sbp_node); // Unloading - for (const auto& consumer_bn : index2consumer_bns[index]) { + for (const auto& consumer_bn_group : index2consumer_bn2sbp_set[index]) { // consumer in cost model - SbpNode* sbp_node_consumer = - op_name2sbp_node[consumer_bn.first->op().op_name()]; + SbpNode* sbp_node_consumer = op_name2sbp_node[consumer_bn_group.first.first]; // the sbp edge connecting producer and consumer SbpEdge* edge_found = FindEdgeBetweenNodes(sbp_node_producer, sbp_node_consumer); diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index 53b19283d99..b3b11ff62a4 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -67,7 +67,8 @@ class SbpCollector { // Initialize copy cost from proxy of producer to consumers void InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - const std::vector>& consumer_bns, + HashMap, std::unordered_set>& + consumer_bn2sbp_set, HashMap*>& op_name2sbp_node); // Export list of possible combination of Sbp Parallels From 22fb83f6ab300a52c3adaafa0c5ea7e7a4ab57ef Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 8 Apr 2022 19:04:07 +0800 Subject: [PATCH 05/12] Update binary set intersection --- oneflow/core/auto_parallel/binary_set.cpp | 10 ++++++++++ oneflow/core/auto_parallel/binary_set.h | 2 ++ 2 files changed, 12 insertions(+) diff --git a/oneflow/core/auto_parallel/binary_set.cpp b/oneflow/core/auto_parallel/binary_set.cpp index 358d2884372..789c4ef8f01 100644 --- a/oneflow/core/auto_parallel/binary_set.cpp +++ b/oneflow/core/auto_parallel/binary_set.cpp @@ -68,8 +68,18 @@ void BinarySet::UnionTo(BinarySet& bs, BinarySet& u) { u.BinarySetValues[k] = BinarySetValues[k] | bs.BinarySetValues[k]; } } +// If this binary set intersects another one +bool BinarySet::IfIntersect(const BinarySet& bs) const { + int32_t min_bs_size = std::min(BinarySetValues.size(), bs.BinarySetValues.size()); + for (int32_t k = 0; k < min_bs_size; k++) { + if (BinarySetValues[k] & bs.BinarySetValues[k]) { return true; } + } + return false; +} // Get the intersection with another subset and store it into i void BinarySet::IntersectionTo(BinarySet& bs, BinarySet& i) { + int32_t min_bs_size = std::min(BinarySetValues.size(), bs.BinarySetValues.size()); + if (min_bs_size > i.BinarySetValues.size()) { i.BinarySetValues.resize(min_bs_size, 0); } for (int32_t k = 0; k < BinarySetValues.size(); k++) { i.BinarySetValues[k] = BinarySetValues[k] & bs.BinarySetValues[k]; } diff --git a/oneflow/core/auto_parallel/binary_set.h b/oneflow/core/auto_parallel/binary_set.h index 9bf4b00ec8f..f670348d813 100644 --- a/oneflow/core/auto_parallel/binary_set.h +++ b/oneflow/core/auto_parallel/binary_set.h @@ -54,6 +54,8 @@ class BinarySet { void DeleteEntry(int32_t i); // Get the union with another subset and store it into u void UnionTo(BinarySet& bs, BinarySet& u); + // If this binary set intersects another one + bool IfIntersect(const BinarySet& bs) const; // Get the intersection with another subset and store it into i void IntersectionTo(BinarySet& bs, BinarySet& i); // Count number of elements in this subset From 11da2c2ce2562276db5cd55ce797d063edfa8a2d Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 8 Apr 2022 19:05:09 +0800 Subject: [PATCH 06/12] Remove impossible parallel candidates from sbp proxy --- oneflow/core/auto_parallel/sbp_collector.cpp | 70 ++++++++++---------- oneflow/core/auto_parallel/sbp_collector.h | 12 ++-- 2 files changed, 40 insertions(+), 42 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index 5e981c89705..d75a1c5bd23 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "sbp_collector.h" #include +#include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/job/sbp_parallel.cfg.h" #include "sbp_constructor.h" @@ -24,6 +25,19 @@ namespace oneflow { namespace auto_parallel { +namespace { +// Whether the given binary set intersects all the sbp sets of the consumers +bool IfIntersectAll( + const HashMap, BinarySet>& consumer_bn2sbp_set, + BinarySet& bs) { + for (const auto& sbp_set_group : consumer_bn2sbp_set) { + if (!bs.IfIntersect(sbp_set_group.second)) { return false; } + } + + return true; +} +} // namespace + // Default constructor for SbpCollector // Don't allow any special case for broadcast! SbpCollector::SbpCollector() { @@ -55,21 +69,6 @@ void SbpCollector::CollectUniverse(SbpGraph& sbp_graph) { accumulator.resize(SbpParallelUniverse.size(), 0); bs_buffer.Initialize(SbpParallelUniverse.size()); } -// Initialize sbp proxy with given parallel candidates of a blob -SbpNode* SbpCollector::InitializePorxy( - SbpGraph& sbp_graph, - std::unordered_set& ParallelCandidates) { - // Initialize sbp proxy - SbpNode* sbp_proxy = sbp_graph.GenerateNode(); - // move parallel candidates - for (const BinarySet& parallel_candidate : ParallelCandidates) { - sbp_proxy->ParallelCandidates.emplace_back(parallel_candidate); - } - // Initialize computation cost - sbp_proxy->Cost.resize(sbp_proxy->ParallelCandidates.size(), 0); - - return sbp_proxy; -} // TODO: Auto Placement! // It only collect the same sbp with the same parallel description @@ -134,7 +133,7 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode* sbp // Initialize copy cost from proxy of producer to consumers void SbpCollector::InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - HashMap, std::unordered_set>& consumer_bn2sbp_set, + HashMap, BinarySet>& consumer_bn2sbp_set, HashMap*>& op_name2sbp_node) { // Connect sbp proxy and consumers for (const auto& consumer_bn_group : consumer_bn2sbp_set) { @@ -192,8 +191,7 @@ void SbpCollector::ProxySbpCandidate( std::vector index2producer; std::vector> index2sbp_set; // mapping from consumers and input blob names to an unordered_set of SBP Parallel. - std::vector, std::unordered_set>> - index2consumer_bn2sbp_set; + std::vector, BinarySet>> index2consumer_bn2sbp_set; for (auto* consumer_sbp_node : sbp_graph.NodeList) { auto* node = consumer_sbp_node->op_node; @@ -230,8 +228,8 @@ void SbpCollector::ProxySbpCandidate( // a set to store the id of all possible SBP Parallel for a downstream op // should filter out repeated SBP Parallel by pre-storing them into an unordered_set - std::unordered_set& SbpParallelIDs = - index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}]; + BinarySet& SbpParallelIDs = index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}]; + SbpParallelIDs.Initialize(SbpParallelUniverse.size()); // The union sbp set of all the consumers std::unordered_set& UnionSbpParallelIDs = index2sbp_set[index]; // TODO: use SbpSignatureList instead of SbpSignatureObjList @@ -242,7 +240,7 @@ void SbpCollector::ProxySbpCandidate( const NdSbp& consumer_sbp = iter->second; // filter out repeated SBP int32_t sbp_universe_id = SbpParallelUniverse[consumer_sbp]; - SbpParallelIDs.insert(sbp_universe_id); + SbpParallelIDs.AddEntry(sbp_universe_id); UnionSbpParallelIDs.insert(sbp_universe_id); } } @@ -268,15 +266,16 @@ void SbpCollector::ProxySbpCandidate( const LogicalBlobId& lbi = lbi_index.first; // store all the binary sets of SBP Parallel into an unordered_set. - std::unordered_set ParallelCandidates; + // std::vector ParallelCandidates; + // generate sbp proxy + SbpNode* sbp_proxy = sbp_graph.GenerateNode(); + // Depth first search to collect Sbp Parallel information for the whole sbp set DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(), - ParallelCandidates); + index2consumer_bn2sbp_set[index], sbp_proxy->ParallelCandidates); - // Initialize sbp proxy - SbpNode* sbp_proxy = InitializePorxy(sbp_graph, ParallelCandidates); - // Might be unnecessary - // op_name2lbi2sbp_proxy[producer_name][lbi] = sbp_proxy; + // Initialize computation cost + sbp_proxy->Cost.resize(sbp_proxy->ParallelCandidates.size(), 0); // Transfer a logical blob from producer to a sbp proxy of this blob sbp_node_producer->PointTo(sbp_proxy); @@ -307,13 +306,16 @@ void SbpCollector::ProxySbpCandidate( } // Depth first search to collect Sbp Parallel information for different lbis -void SbpCollector::DfsSbpSet(int32_t depth, int32_t max_depth, - const std::unordered_set& sbp_sets, - const std::unordered_set::iterator start_it, - std::unordered_set& ParallelCandidates) { +void SbpCollector::DfsSbpSet( + int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, + const std::unordered_set::iterator start_it, + HashMap, BinarySet>& consumer_bn2sbp_set, + std::vector& ParallelCandidates) { if (depth > 0) { - // store the binary set into an unordered_set - ParallelCandidates.insert(bs_buffer); + if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer)) { + // store the binary set into an unordered_set + ParallelCandidates.push_back(bs_buffer); + } } if (depth >= max_depth) { return; } @@ -327,7 +329,7 @@ void SbpCollector::DfsSbpSet(int32_t depth, int32_t max_depth, if (accumulator[SbpParallelNum] == 0) { bs_buffer.AddEntry(SbpParallelNum); ++accumulator[SbpParallelNum]; - DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, ParallelCandidates); + DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, ParallelCandidates); bs_buffer.DeleteEntry(SbpParallelNum); --accumulator[SbpParallelNum]; } diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index b3b11ff62a4..bb8214eb485 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -55,10 +55,6 @@ class SbpCollector { void CollectUniverse(SbpNode* sbp_node); // Collect all the possible Sbp Parallel from a SbpGraph void CollectUniverse(SbpGraph& sbp_graph); - // Initialize sbp proxy with given parallel candidates of a blob - SbpNode* InitializePorxy( - SbpGraph& sbp_graph, - std::unordered_set& ParallelCandidates); // Initialize copy cost from producer to proxy of producer void InitializeCopyCostFromNode2Proxy(SbpNode* sbp_proxy, @@ -67,8 +63,7 @@ class SbpCollector { // Initialize copy cost from proxy of producer to consumers void InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, - HashMap, std::unordered_set>& - consumer_bn2sbp_set, + HashMap, BinarySet>& consumer_bn2sbp_set, HashMap*>& op_name2sbp_node); // Export list of possible combination of Sbp Parallels @@ -80,10 +75,11 @@ class SbpCollector { // Maximum number of possible sbp in the proxy unsigned long max_num_sbp_proxy_ = 3; - // Depth first search to collect Sbp Parallel information for different lbis + // Depth first search to collect Sbp Parallel information for the whole sbp set void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, const std::unordered_set::iterator sbp_set_it, - std::unordered_set& ParallelCandidates); + HashMap, BinarySet>& consumer_bn2sbp_set, + std::vector& ParallelCandidates); }; // class SbpCollector } // namespace auto_parallel From 2632b8ae818ca204193de30239f2bfa1c7f1a4bf Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 8 Apr 2022 19:13:14 +0800 Subject: [PATCH 07/12] Refine binary set --- oneflow/core/auto_parallel/binary_set.cpp | 4 ++-- oneflow/core/auto_parallel/binary_set.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/core/auto_parallel/binary_set.cpp b/oneflow/core/auto_parallel/binary_set.cpp index 789c4ef8f01..9c03a79b260 100644 --- a/oneflow/core/auto_parallel/binary_set.cpp +++ b/oneflow/core/auto_parallel/binary_set.cpp @@ -77,7 +77,7 @@ bool BinarySet::IfIntersect(const BinarySet& bs) const { return false; } // Get the intersection with another subset and store it into i -void BinarySet::IntersectionTo(BinarySet& bs, BinarySet& i) { +void BinarySet::IntersectionTo(const BinarySet& bs, BinarySet& i) const { int32_t min_bs_size = std::min(BinarySetValues.size(), bs.BinarySetValues.size()); if (min_bs_size > i.BinarySetValues.size()) { i.BinarySetValues.resize(min_bs_size, 0); } for (int32_t k = 0; k < BinarySetValues.size(); k++) { @@ -85,7 +85,7 @@ void BinarySet::IntersectionTo(BinarySet& bs, BinarySet& i) { } } // Count number of elements in this subset -int32_t BinarySet::Total() { +int32_t BinarySet::Total() const { int32_t t = 0; for (int32_t k = 0; k < BinarySetValues.size(); k++) { BinarySetEntryType bsv = BinarySetValues[k]; diff --git a/oneflow/core/auto_parallel/binary_set.h b/oneflow/core/auto_parallel/binary_set.h index f670348d813..0f36bf63faa 100644 --- a/oneflow/core/auto_parallel/binary_set.h +++ b/oneflow/core/auto_parallel/binary_set.h @@ -57,9 +57,9 @@ class BinarySet { // If this binary set intersects another one bool IfIntersect(const BinarySet& bs) const; // Get the intersection with another subset and store it into i - void IntersectionTo(BinarySet& bs, BinarySet& i); + void IntersectionTo(const BinarySet& bs, BinarySet& i) const; // Count number of elements in this subset - int32_t Total(); + int32_t Total() const; // Output all the elements in the subset void OutPut(std::vector& out) const; // Output all the elements in the subset From fadfe2c9687c71636cc668ea38dd71345311f328 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Sat, 9 Apr 2022 00:31:49 +0800 Subject: [PATCH 08/12] Add a Clear() in binary set --- oneflow/core/auto_parallel/binary_set.cpp | 3 +++ oneflow/core/auto_parallel/binary_set.h | 2 ++ 2 files changed, 5 insertions(+) diff --git a/oneflow/core/auto_parallel/binary_set.cpp b/oneflow/core/auto_parallel/binary_set.cpp index 9c03a79b260..64cf582d954 100644 --- a/oneflow/core/auto_parallel/binary_set.cpp +++ b/oneflow/core/auto_parallel/binary_set.cpp @@ -43,6 +43,9 @@ void BinarySet::Initialize(int32_t size_of_set) { BinarySetValues.resize(k, 0); } +// Clear all the elements in the set +void BinarySet::Clear() { BinarySetValues.assign(BinarySetValues.size(), 0); } + // Check if i-th element in this subset int32_t BinarySet::CheckExistency(int32_t i) const { int32_t k = i / bit_of_BinarySetEntryType; diff --git a/oneflow/core/auto_parallel/binary_set.h b/oneflow/core/auto_parallel/binary_set.h index 0f36bf63faa..b8d2c82a8bd 100644 --- a/oneflow/core/auto_parallel/binary_set.h +++ b/oneflow/core/auto_parallel/binary_set.h @@ -46,6 +46,8 @@ class BinarySet { // Initialization void Initialize(int32_t size_of_set); + // Clear all the elements in the set + void Clear(); // Check if i-th element in this subset int32_t CheckExistency(int32_t i) const; // Add i-th element into this subset From 7f1a0dcd2001a638d66c92e20212a6a49ff2b075 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Sat, 9 Apr 2022 00:35:21 +0800 Subject: [PATCH 09/12] Filter out those proxy candidates containing two sbps from the same unique group --- oneflow/core/auto_parallel/sbp_collector.cpp | 70 ++++++++++++++++++-- oneflow/core/auto_parallel/sbp_collector.h | 1 + 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index d75a1c5bd23..5d071df1770 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -36,6 +36,59 @@ bool IfIntersectAll( return true; } + +// Find unique sbp sets +void FindUniqueSbpSets( + const HashMap, BinarySet>& consumer_bn2sbp_set, + const std::unordered_set& all_sbp_set, std::vector& accumulator, + BinarySet& unique_sbps) { + std::vector sbp_ids; + // count the number of sbp + for (const auto& sbp_set_group : consumer_bn2sbp_set) { + sbp_set_group.second.QuickOutPut(sbp_ids); + for (int32_t sbp_id : sbp_ids) { accumulator[sbp_id]++; } + } + // find unique sbp and clear the accumulator + for (const auto& sbp_id : all_sbp_set) { + if (accumulator[sbp_id] == 1) { unique_sbps.AddEntry(sbp_id); } + accumulator[sbp_id] = 0; + } +} + +// Find unique sbp groups +void FindUniqueSbpGroups( + const HashMap, BinarySet>& consumer_bn2sbp_set, + const std::unordered_set& all_sbp_set, std::vector& accumulator, + BinarySet& bs_buffer, std::vector& unique_sbp_groups) { + // find the unique sbp sets + BinarySet unique_sbps(accumulator.size()); + FindUniqueSbpSets(consumer_bn2sbp_set, all_sbp_set, accumulator, unique_sbps); + + // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0} + // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them + for (const auto& sbp_set_group : consumer_bn2sbp_set) { + unique_sbps.IntersectionTo(sbp_set_group.second, bs_buffer); + // Find those unique sbp groups with more than two sbp + // For example {B, S1, S2} is an impossible proxy candidate, + // since {S1, S2} is only contained by A but not contained by C and D. + // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2. + if (bs_buffer.Total() >= 2) { unique_sbp_groups.push_back(bs_buffer); } + } + bs_buffer.Clear(); +} + +// If not contains two sbp from a same unique group +bool No2SbpFromSameUniqueGroup(BinarySet& bs, const std::vector& unique_sbp_groups) { + BinarySet intersection(bs.SizeOfSet); + for (const auto& unique_sbp_group : unique_sbp_groups) { + bs.IntersectionTo(unique_sbp_group, intersection); + // For example {B, S1, S2} is an impossible proxy candidate, + // since {S1, S2} is only contained by A but not contained by C and D. + // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2. + if (intersection.Total() >= 2) { return false; } + } + return true; +} } // namespace // Default constructor for SbpCollector @@ -270,9 +323,16 @@ void SbpCollector::ProxySbpCandidate( // generate sbp proxy SbpNode* sbp_proxy = sbp_graph.GenerateNode(); + + // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0} + // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them + std::vector unique_sbp_groups; + FindUniqueSbpGroups(index2consumer_bn2sbp_set[index], index2sbp_set[index], accumulator, + bs_buffer, unique_sbp_groups); + // Depth first search to collect Sbp Parallel information for the whole sbp set DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(), - index2consumer_bn2sbp_set[index], sbp_proxy->ParallelCandidates); + index2consumer_bn2sbp_set[index], unique_sbp_groups, sbp_proxy->ParallelCandidates); // Initialize computation cost sbp_proxy->Cost.resize(sbp_proxy->ParallelCandidates.size(), 0); @@ -310,9 +370,10 @@ void SbpCollector::DfsSbpSet( int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, const std::unordered_set::iterator start_it, HashMap, BinarySet>& consumer_bn2sbp_set, - std::vector& ParallelCandidates) { + const std::vector& unique_sbp_groups, std::vector& ParallelCandidates) { if (depth > 0) { - if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer)) { + if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer) + && No2SbpFromSameUniqueGroup(bs_buffer, unique_sbp_groups)) { // store the binary set into an unordered_set ParallelCandidates.push_back(bs_buffer); } @@ -329,7 +390,8 @@ void SbpCollector::DfsSbpSet( if (accumulator[SbpParallelNum] == 0) { bs_buffer.AddEntry(SbpParallelNum); ++accumulator[SbpParallelNum]; - DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, ParallelCandidates); + DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, unique_sbp_groups, + ParallelCandidates); bs_buffer.DeleteEntry(SbpParallelNum); --accumulator[SbpParallelNum]; } diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index bb8214eb485..35f383d7d7d 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -79,6 +79,7 @@ class SbpCollector { void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, const std::unordered_set::iterator sbp_set_it, HashMap, BinarySet>& consumer_bn2sbp_set, + const std::vector& unique_sbp_groups, std::vector& ParallelCandidates); }; // class SbpCollector From adf4e88f675194c86b5963c0e5e4ecd5e4bc50ca Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 11 Apr 2022 14:54:59 +0800 Subject: [PATCH 10/12] refine --- oneflow/core/auto_parallel/sbp_collector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.h b/oneflow/core/auto_parallel/sbp_collector.h index 35f383d7d7d..56ba693361e 100644 --- a/oneflow/core/auto_parallel/sbp_collector.h +++ b/oneflow/core/auto_parallel/sbp_collector.h @@ -73,7 +73,7 @@ class SbpCollector { private: // Maximum number of possible sbp in the proxy - unsigned long max_num_sbp_proxy_ = 3; + const unsigned long max_num_sbp_proxy_ = 3; // Depth first search to collect Sbp Parallel information for the whole sbp set void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, From 150d01b76384c74f42af4c660c29a33f3edc9525 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 11 Apr 2022 15:52:36 +0800 Subject: [PATCH 11/12] Check spells --- oneflow/core/auto_parallel/sbp_collector.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index 5d071df1770..d30201201e3 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -239,7 +239,7 @@ void SbpCollector::ProxySbpCandidate( // mapping from a logical blob id to index HashMap lbi2index; - // mapping from the index to producer, consuemr and corresponding input blob name, possible sbp + // mapping from the index to producer, consumer and corresponding input blob name, possible sbp // sets std::vector index2producer; std::vector> index2sbp_set; @@ -253,7 +253,7 @@ void SbpCollector::ProxySbpCandidate( // If not support boxing, just skip it. if (IsClassRegistered(op_type_case)) { return; } for (const std::string& ibn : node->op().input_bns()) { - // Skip those blobs who enforc same SBP. + // Skip those blobs who enforce same SBP. if (IsSameSbp(node, ibn)) { // Enforcing same SBP. Can not collect sbp from this blob. continue; @@ -262,7 +262,7 @@ void SbpCollector::ProxySbpCandidate( const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn); const OpNode& producer = node->ProducerOpNode4Lbi(lbi); - // not building proxy for fixed opertors + // not building proxy for fixed operators if (op_name2sbp_node.find(producer.op().op_name()) == op_name2sbp_node.end()) { return; } // decide the index of a logical blob description const auto& iterator_lbi = lbi2index.find(lbi); @@ -357,7 +357,7 @@ void SbpCollector::ProxySbpCandidate( // unload logical blob from sbp edges edge_found->UnloadLbi(lbi); // Do not clip this edge. Save it for wait time. - // clip this edge if it no longer carrys any blob + // clip this edge if it no longer carries any blob // We don't clip edges now since we have transfer cost // if (edge_found->EmptyLbi() && edge_found->WaitTime <= 0.0 && edge_found->WaitTime > -0.5) // sbp_graph.ClipEdge(edge_found); From bbcb80a1ed742d6a2cc62486b9560da0d3c1c2dc Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 27 Apr 2022 15:58:01 +0800 Subject: [PATCH 12/12] Clip useless edges --- oneflow/core/auto_parallel/sbp_collector.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index d30201201e3..51f1b4d4277 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -358,9 +358,12 @@ void SbpCollector::ProxySbpCandidate( edge_found->UnloadLbi(lbi); // Do not clip this edge. Save it for wait time. // clip this edge if it no longer carries any blob - // We don't clip edges now since we have transfer cost - // if (edge_found->EmptyLbi() && edge_found->WaitTime <= 0.0 && edge_found->WaitTime > -0.5) - // sbp_graph.ClipEdge(edge_found); + // We don't clip edges before since we have transfer cost + // Now we clip edges, which makes the topology simplier + if (edge_found->EmptyLbi() && edge_found->WaitTime <= 0.0 && edge_found->WaitTime > -0.5 + && sbp_graph.transfer_cost <= 0.0) { + sbp_graph.ClipEdge(edge_found); + } } } }