From a3fb6ffa89e62ef1b03fa58d0e76ef877e188c98 Mon Sep 17 00:00:00 2001 From: zhangbaizhou Date: Wed, 20 Mar 2024 09:34:42 +0000 Subject: [PATCH] fix --- .../frontend/cluster_ops/clustering_engine.cc | 18 ++++++++++++++++++ .../frontend/cluster_ops/clustering_engine.h | 11 ++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/paddle/cinn/frontend/cluster_ops/clustering_engine.cc b/paddle/cinn/frontend/cluster_ops/clustering_engine.cc index 2e50b1b552a5a..4740a10d48298 100644 --- a/paddle/cinn/frontend/cluster_ops/clustering_engine.cc +++ b/paddle/cinn/frontend/cluster_ops/clustering_engine.cc @@ -95,6 +95,24 @@ common::BfsWalker ClusteringEngine::MakeAcyclicSameClusterBf return common::BfsWalker(VisitAcyclicClusterNext); } +ShardableAxes4ValueT ClusteringEngine::MakeInferedShardableAxes4Value( + const std::vector& stmt_ptrs) { + const OpSetPtr ops = [&] { + auto ops = std::make_shared(); + for (const auto* stmt_ptr : stmt_ptrs) { + VisitStmtOp(*stmt_ptr, [&](const auto* op) { ops->insert(op); }); + } + return ops; + }(); + auto value2shardable_axes = shardable_axes_inferer_.InferShardableAxes(ops); + return [map = std::move(value2shardable_axes)]( + pir::Value value) -> std::optional { + const auto& iter = map.find(value); + if (iter == map.end()) return std::nullopt; + return &iter->second; + }; +} + IsAcyclicConnectedT ClusteringEngine::MakePredicatorIsAcyclicConnected( const common::TopoWalker& walker, const std::vector& stmt_patterns, diff --git a/paddle/cinn/frontend/cluster_ops/clustering_engine.h b/paddle/cinn/frontend/cluster_ops/clustering_engine.h index 5bf88510aa81f..2710583b69475 100644 --- a/paddle/cinn/frontend/cluster_ops/clustering_engine.h +++ b/paddle/cinn/frontend/cluster_ops/clustering_engine.h @@ -52,14 +52,19 @@ class ClusteringEngine { } } - common::BfsWalker MakeAcyclicSameClusterBfsWalker( - const std::vector& stmt_patterns); - + using ShardableAxes4ValueT = + std::function(pir::Value)>; using IsAcyclicConnectedT = std::function; using ClusterRoot4StmtT = std::function; + ShardableAxes4ValueT MakeInferedShardableAxes4Value( + const std::vector& stmt_ptrs); + + common::BfsWalker MakeAcyclicSameClusterBfsWalker( + const std::vector& stmt_patterns); + IsAcyclicConnectedT MakePredicatorIsAcyclicConnected( const common::TopoWalker& walker, const std::vector& stmt_patterns,