Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#3 from Fridge003/cinn_tmp
Browse files Browse the repository at this point in the history
pattern utils
  • Loading branch information
feifei-111 committed Mar 20, 2024
2 parents bcbf191 + 1f343bf commit 3f16743
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 96 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/cluster_ops/common_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#prgama once

#include "paddle/cinn/common/bfs_walker.h"
#include "paddle/cinn/common/topo_walker.h"
Expand All @@ -21,7 +22,6 @@
#include <typeinfo>
#include <variant>


#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/op.h"
Expand Down
109 changes: 18 additions & 91 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.cc
Original file line number Diff line number Diff line change
@@ -1,96 +1,21 @@
bool IsISPattern(const StmtPattern& pattern) {
return std::holds_alternative<IS>(pattern);
}

bool IsPSPattern(const StmtPattern& pattern) {
return std::holds_alternative<PS>(pattern);
}

bool IsRPattern(const StmtPattern& pattern) {
return std::holds_alternative<R>(pattern);
}


template <typename DoEachT>
void VisitStmtOpImpl(const IS& injective_source, const DoEachT& DoEach) {
for (const auto* op : injective_source.ops) {
DoEach(op);
}
}

template <typename DoEachT>
void VisitStmtOpImpl(const PS& partial_shardable, const DoEachT& DoEach) {
for (const auto* op : partial_shardable.ops) {
DoEach(op);
}
}

template <typename DoEachT>
void VisitStmtOpImpl(const R& reduce, const DoEachT& DoEach) {
std::visit(adt::match{
[](const std::monostate&) {
// do nothing.
},
[&](const IS& injective_source) {
VisitStmtOpImpl(injective_source, DoEach);
},
[&](const PS& partial_shardable) {
VisitStmtOpImpl(partial_shardable, DoEach);
},
},
reduce.input);
DoEach(reduce.reduce_op_pattern.reduce_op);
}

template <typename DoEachT>
void VisitStmtOp(const StmtPattern& stmt, const DoEachT& DoEach) {
std::visit([&](const auto& impl) { VisitStmtOpImpl(impl, DoEach); }, stmt);
}

int GetOutputShardableAxesResultIdx(const pir::Operation* op) { return 0; }

pir::Value GetStmtBigestShapeValueImpl(const IS& injective_source) {
const auto* sink_op = injective_source.sole_sink;
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValueImpl(const R& reduce_pattern) {
const auto* sink_op = reduce_pattern.reduce_op_pattern.reduce_op;
CHECK_EQ(sink_op->num_operands(), 1);
return sink_op->operand_source(0);
}

pir::Value GetStmtBigestShapeValueImpl(const PS& partial_shardable) {
const auto* sink_op = partial_shardable.sole_sink;
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValue(const StmtPattern& stmt) {
return std::visit(
[&](const auto& impl) { return GetStmtBigestShapeValueImpl(impl); },
stmt);
}

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

const pir::Operation* GetStmtSoleSinkImpl(const IS& injective_source) {
return injective_source.sole_sink;
}

const pir::Operation* GetStmtSoleSinkImpl(const PS& partial_shardable) {
return partial_shardable.sole_sink;
}

const pir::Operation* GetStmtSoleSinkImpl(const R& reduce) {
return reduce.reduce_op_pattern.reduce_op;
}

const pir::Operation* GetStmtSoleSinkOp(const StmtPattern& stmt) {
return std::visit([](const auto& impl) { return GetStmtSoleSinkImpl(impl); },
stmt);
}
#include "paddle/cinn/frontend/cluster_ops/pattern_utils.h"

namespace cinn::frontend::cluster_ops {

void SortStmtPtrs(
std::vector<const StmtPattern*>* stmt_ptrs,
Expand Down Expand Up @@ -232,4 +157,6 @@ std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
CHECK(iter != map.end());
return iter->second;
};
}
}

} // namespace cinn::frontend::cluster_ops
128 changes: 124 additions & 4 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,125 @@
common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns);
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo);
#prgama once

#include "paddle/cinn/frontend/cluster_ops/common_utils.h"
#include "paddle/cinn/frontend/cluster_ops/group_pattern.h"
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"

namespace cinn::frontend::cluster_ops {

bool IsISPattern(const StmtPattern& pattern) {
return std::holds_alternative<IS>(pattern);
}

bool IsPSPattern(const StmtPattern& pattern) {
return std::holds_alternative<PS>(pattern);
}

bool IsRPattern(const StmtPattern& pattern) {
return std::holds_alternative<R>(pattern);
}

template <typename DoEachT>
void VisitStmtOpImpl(const IS& injective_source, const DoEachT& DoEach) {
for (const auto* op : injective_source.ops) {
DoEach(op);
}
}

template <typename DoEachT>
void VisitStmtOpImpl(const PS& partial_shardable, const DoEachT& DoEach) {
for (const auto* op : partial_shardable.ops) {
DoEach(op);
}
}

template <typename DoEachT>
void VisitStmtOpImpl(const R& reduce, const DoEachT& DoEach) {
std::visit(adt::match{
[](const std::monostate&) {
// do nothing.
},
[&](const IS& injective_source) {
VisitStmtOpImpl(injective_source, DoEach);
},
[&](const PS& partial_shardable) {
VisitStmtOpImpl(partial_shardable, DoEach);
},
},
reduce.input);
DoEach(reduce.reduce_op_pattern.reduce_op);
}

template <typename DoEachT>
void VisitStmtOp(const StmtPattern& stmt, const DoEachT& DoEach) {
std::visit([&](const auto& impl) { VisitStmtOpImpl(impl, DoEach); }, stmt);
}

int GetOutputShardableAxesResultIdx(const pir::Operation* op) { return 0; }

pir::Value GetStmtBigestShapeValueImpl(const IS& injective_source) {
const auto* sink_op = injective_source.sole_sink;
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValueImpl(const R& reduce_pattern) {
const auto* sink_op = reduce_pattern.reduce_op_pattern.reduce_op;
CHECK_EQ(sink_op->num_operands(), 1);
return sink_op->operand_source(0);
}

pir::Value GetStmtBigestShapeValueImpl(const PS& partial_shardable) {
const auto* sink_op = partial_shardable.sole_sink;
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValue(const StmtPattern& stmt) {
return std::visit(
[&](const auto& impl) { return GetStmtBigestShapeValueImpl(impl); },
stmt);
}


const pir::Operation* GetStmtSoleSinkImpl(const IS& injective_source) {
return injective_source.sole_sink;
}

const pir::Operation* GetStmtSoleSinkImpl(const PS& partial_shardable) {
return partial_shardable.sole_sink;
}

const pir::Operation* GetStmtSoleSinkImpl(const R& reduce) {
return reduce.reduce_op_pattern.reduce_op;
}

const pir::Operation* GetStmtSoleSinkOp(const StmtPattern& stmt) {
return std::visit([](const auto& impl) { return GetStmtSoleSinkImpl(impl); },
stmt);
}

void SortStmtPtrs(
std::vector<const StmtPattern*>* stmt_ptrs,
const std::function<size_t(const pir::Operation*)>& OrderValue4Op);

common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns);

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo);

} // namespace cinn::frontend::cluster_ops

0 comments on commit 3f16743

Please sign in to comment.