Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7 from Fridge003/cinn_tmp
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
feifei-111 committed Mar 20, 2024
2 parents 4a767ec + a3fb6ff commit 6c64295
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
18 changes: 18 additions & 0 deletions paddle/cinn/frontend/cluster_ops/clustering_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ common::BfsWalker<const StmtPattern*> ClusteringEngine::MakeAcyclicSameClusterBf
return common::BfsWalker<const StmtPattern*>(VisitAcyclicClusterNext);
}

ShardableAxes4ValueT ClusteringEngine::MakeInferedShardableAxes4Value(
const std::vector<const StmtPattern*>& stmt_ptrs) {
const OpSetPtr ops = [&] {
auto ops = std::make_shared<OpSet>();
for (const auto* stmt_ptr : stmt_ptrs) {
VisitStmtOp(*stmt_ptr, [&](const auto* op) { ops->insert(op); });
}
return ops;
}();
auto value2shardable_axes = shardable_axes_inferer_.InferShardableAxes(ops);
return [map = std::move(value2shardable_axes)](
pir::Value value) -> std::optional<const ShardableAxes*> {
const auto& iter = map.find(value);
if (iter == map.end()) return std::nullopt;
return &iter->second;
};
}

IsAcyclicConnectedT ClusteringEngine::MakePredicatorIsAcyclicConnected(
const common::TopoWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns,
Expand Down
11 changes: 8 additions & 3 deletions paddle/cinn/frontend/cluster_ops/clustering_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,19 @@ class ClusteringEngine {
}
}

common::BfsWalker<const StmtPattern*> MakeAcyclicSameClusterBfsWalker(
const std::vector<StmtPattern>& stmt_patterns);

using ShardableAxes4ValueT =
std::function<std::optional<const ShardableAxes*>(pir::Value)>;
using IsAcyclicConnectedT =
std::function<bool(const StmtPattern* src, const StmtPattern* dst)>;
using ClusterRoot4StmtT =
std::function<const StmtPattern*(const StmtPattern*)>;

ShardableAxes4ValueT MakeInferedShardableAxes4Value(
const std::vector<const StmtPattern*>& stmt_ptrs);

common::BfsWalker<const StmtPattern*> MakeAcyclicSameClusterBfsWalker(
const std::vector<StmtPattern>& stmt_patterns);

IsAcyclicConnectedT MakePredicatorIsAcyclicConnected(
const common::TopoWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns,
Expand Down

0 comments on commit 6c64295

Please sign in to comment.