Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#10 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
Reorder add_store_in_fusion_op pass and group_cluster pass
  • Loading branch information
feifei-111 committed Apr 24, 2024
2 parents 06fb22c + 770f496 commit fec3505
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ void ApplyDivideGroupOpToFusionOpPass(
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
if (FLAGS_group_schedule_tiling_first) {
pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInGroupOpPass());
pass_manager->AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
} else {
pass_manager->AddPass(
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ class AddYieldStoreInFusionOpPattern
}
};

class AddStoreInFusionOpPass : public pir::Pass {
class AddStoreInGroupOpPass : public pir::Pass {
public:
AddStoreInFusionOpPass()
: pir::Pass("add_store_in_fusion_op", /*opt_level=*/1) {}
AddStoreInGroupOpPass()
: pir::Pass("add_store_in_group_op", /*opt_level=*/1) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand All @@ -76,14 +76,7 @@ class AddStoreInFusionOpPass : public pir::Pass {
for (uint32_t i = 0; i < op->num_regions(); ++i) {
for (auto& block : op->region(i)) {
for (auto& op : block) {
if (op.isa<cinn::dialect::FusionOp>()) {
auto fusion_op = op.dyn_cast<cinn::dialect::FusionOp>();
if (fusion_op.GetOperators().size() == 2 &&
fusion_op.GetOperators()
.front()
->isa<cinn::dialect::ReshapeOp>()) {
continue;
}
if (op.isa<cinn::dialect::GroupOp>()) {
auto [_, num_rewrites] =
pir::ApplyPatternsGreedily(&op, patterns_, cfg);
AddStatistics(num_rewrites);
Expand All @@ -101,8 +94,8 @@ class AddStoreInFusionOpPass : public pir::Pass {
pir::FrozenRewritePatternSet patterns_;
};

std::unique_ptr<pir::Pass> CreateAddStoreInFusionOpPass() {
return std::make_unique<AddStoreInFusionOpPass>();
std::unique_ptr<pir::Pass> CreateAddStoreInGroupOpPass() {
return std::make_unique<AddStoreInGroupOpPass>();
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cinn {
namespace dialect {
namespace ir {

std::unique_ptr<pir::Pass> CreateAddStoreInFusionOpPass();
std::unique_ptr<pir::Pass> CreateAddStoreInGroupOpPass();

} // namespace ir
} // namespace dialect
Expand Down
14 changes: 6 additions & 8 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ void PatternGraph<T>::SinkTrivialPattern() {
GraphTransformer<
NodePattern,
T,
And<And<NonSinkNodeMatcher, StmtPatternGraphMatcher<TrivialPattern<T>>>,
IsNotOutputNodeMatcher>,
And<NonSinkNodeMatcher, StmtPatternGraphMatcher<TrivialPattern<T>>>,
MergeTrivialPatternOperation>(this);
}

Expand Down Expand Up @@ -135,17 +134,16 @@ template <typename T>
void PatternGraph<T>::ReduceTreeGrown() {
GraphTransformer<NodePattern,
T,
And<CanFuseReduceTreeMatcher, IsNotOutputNodeMatcher>,
CanFuseReduceTreeMatcher,
MergeReduceTreeOperation>(this);
}

template <typename T>
void PatternGraph<T>::ReduceTree_Trivial_Fusion() {
GraphTransformer<
NodePattern,
T,
And<CanFuseReduceTreeAndTrivialMatcher, IsNotOutputNodeMatcher>,
MergeReduceTreeAndTrivialOperation>(this);
GraphTransformer<NodePattern,
T,
CanFuseReduceTreeAndTrivialMatcher,
MergeReduceTreeAndTrivialOperation>(this);
}

template <typename T>
Expand Down
10 changes: 0 additions & 10 deletions paddle/cinn/operator_fusion/pattern_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,6 @@ struct IsOutputNodeMatcher {
}
};

struct IsNotOutputNodeMatcher {
// TODO(@wuzhanfei) after move yield_store before group cluster, remove this
// matcher
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
bool res = !IsOutputNodeMatcher()(graph, node);
return res;
}
};

template <int N>
struct DownstreamSmallerThan {
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/operator_fusion/policy/relative_judge_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ static std::optional<ValueDimRelation> CreateOpRelativenessForSpecialOps(
return CreateOpRelativenessForDefault(op);
}
if (op->name() == "cinn_op.yield_store") {
return CreateOpRelativenessForDefault(op);
return CreateOpRelativenessForElementWise(op);
}
return {};
}
Expand Down
40 changes: 20 additions & 20 deletions paddle/cinn/operator_fusion/policy/shardable_axes_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,6 @@ ShardableAxesSignature CreateDefaultSignature(pir::Operation* op) {
return result;
}

std::optional<ShardableAxesSignature> CreateSignatureForSpecialOps(
pir::Operation* op) {
if (op->isa<cinn::dialect::ReshapeOp>()) {
return CreateDefaultSignature(op);
}
if (op->name() == "cinn_op.generate_shape") {
return CreateDefaultSignature(op);
}
if (op->name() == "cinn_op.yield_store") {
return CreateDefaultSignature(op);
}
if (op->name() == "cinn_op.reshape") {
return CreateDefaultSignature(op);
}
if (op->name() == "pd_op.reshape") {
return CreateDefaultSignature(op);
}
return std::nullopt;
}

ShardableAxesSignature CreateSignatureForReduce(pir::Operation* reduce_op) {
CHECK_EQ(reduce_op->num_operands(), 1);
CHECK_EQ(reduce_op->num_results(), 1);
Expand Down Expand Up @@ -178,6 +158,26 @@ ShardableAxesSignature CreateSignatureForBroadcast(
return result;
}

std::optional<ShardableAxesSignature> CreateSignatureForSpecialOps(
pir::Operation* op) {
if (op->isa<cinn::dialect::ReshapeOp>()) {
return CreateDefaultSignature(op);
}
if (op->name() == "cinn_op.generate_shape") {
return CreateDefaultSignature(op);
}
if (op->name() == "cinn_op.yield_store") {
return CreateSignatureForElementWise(op);
}
if (op->name() == "cinn_op.reshape") {
return CreateDefaultSignature(op);
}
if (op->name() == "pd_op.reshape") {
return CreateDefaultSignature(op);
}
return std::nullopt;
}

ShardableAxesSignature ShardableAxesInfoManager::CreateShardableSignature(
pir::Operation* op) {
auto special_result = CreateSignatureForSpecialOps(op);
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ static void RunAndCheckResult(::pir::Program* program,
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.AddPass(pir::CreateBuildCinnPass());
pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
pm.AddPass(cinn::dialect::ir::CreateAddStoreInGroupOpPass());
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass());
pm.EnableIRPrinting();
Expand Down

0 comments on commit fec3505

Please sign in to comment.