Skip to content

Commit

Permalink
[PIR] Add num_ops() func for block and program (#61046)
Browse files Browse the repository at this point in the history
* add GetOpNum func for block and program

* update

* open test

* update

* add IR_API & delete detail namespace

* fix

* fix

* Update visitors.h
  • Loading branch information
chen2016013 committed Jan 25, 2024
1 parent 4f53595 commit 5ad19e5
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ void BindProgram(py::module *m) {
},
[](std::shared_ptr<Program> self, int64_t random_seed) {
SetProgramInt64Attr(self, "random_seed", random_seed);
});
})
.def("num_ops", [](Program &self) { return self.num_ops(); });
}

std::shared_ptr<Program> ParseProgram(const std::string &program_str) {
Expand Down Expand Up @@ -361,6 +362,7 @@ void BindBlock(py::module *m) {
}
return op_list;
})
.def("num_ops", [](Block &self) { return self.num_ops(); })
.def(
"__enter__",
[](Block &self) -> Block & {
Expand Down
8 changes: 7 additions & 1 deletion paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class IR_API Block {
template <WalkOrder Order = WalkOrder::PostOrder, typename FuncT>
void Walk(Block::Iterator begin, Block::Iterator end, FuncT &&callback) {
for (auto &op = begin; op != end; ++op) {
detail::Walk<Order>(&*op, callback);
pir::Walk<Order>(&*op, callback);
}
}

Expand All @@ -138,6 +138,12 @@ class IR_API Block {
return Walk<Order>(begin(), end(), std::forward<FuncT>(callback));
}

uint32_t num_ops() {
uint32_t num = 0;
Walk([&num](Operation *) { ++num; });
return num;
}

private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class IR_API alignas(8) Operation final
///
template <WalkOrder Order = WalkOrder::PostOrder, typename FuncT>
void Walk(FuncT &&callback) {
return detail::Walk<Order>(this, std::forward<FuncT>(callback));
return pir::Walk<Order>(this, std::forward<FuncT>(callback));
}

///
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class IR_API Program {
Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }

uint32_t num_ops() { return block()->num_ops(); }
Parameter* GetParameter(const std::string& name) const;
void SetParameter(const std::string& name,
std::unique_ptr<Parameter>&& parameter);
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/visitors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "paddle/pir/core/visitors.h"
#include "paddle/pir/core/operation.h"

namespace pir::detail {
namespace pir {

// Defines utilities for walking and visiting operations.
void Walk(Operation *op,
Expand Down Expand Up @@ -66,4 +66,4 @@ void Walk(Operation *op,
if (order == WalkOrder::PostOrder) callback(op);
}

} // namespace pir::detail
} // namespace pir
20 changes: 9 additions & 11 deletions paddle/pir/core/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,22 @@ class Block;
// Traversal order.
enum class WalkOrder { PreOrder, PostOrder };

namespace detail {
// Defines utilities for walking and visiting operations.
void Walk(Operation *op,
const std::function<void(Region *)> &callback,
WalkOrder order);
IR_API void Walk(Operation *op,
const std::function<void(Region *)> &callback,
WalkOrder order);

void Walk(Operation *op,
const std::function<void(Block *)> &callback,
WalkOrder order);
IR_API void Walk(Operation *op,
const std::function<void(Block *)> &callback,
WalkOrder order);

void Walk(Operation *op,
const std::function<void(Operation *)> &callback,
WalkOrder order);
IR_API void Walk(Operation *op,
const std::function<void(Operation *)> &callback,
WalkOrder order);

template <WalkOrder Order = WalkOrder::PostOrder, typename FuncTy>
void Walk(Operation *op, FuncTy &&callback) {
return Walk(op, callback, Order);
}
} // namespace detail

} // namespace pir
26 changes: 7 additions & 19 deletions test/dygraph_to_static/test_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _compute_op_num(self, program):
return op_num, shape_op_num, slice_op_num

def _compute_pir_op_num(self, program):
op_num = len(program.global_block().ops)
op_num = program.global_block().num_ops()
shape_op_num = 0
slice_op_num = 0

Expand Down Expand Up @@ -651,7 +651,7 @@ def _compute_op_num(self, program):
)

def _compute_pir_op_num(self, program):
op_num = len(program.global_block().ops)
op_num = program.global_block().num_ops()
shape_op_num = 0
slice_op_num = 0

Expand Down Expand Up @@ -732,15 +732,9 @@ def _set_expected_op_num(self):
self.expected_slice_op_num = 4

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0

@test_ast_only
@test_pir_only
def test_pir_op_num(self):
# Remove this after we support control flow
pass
self.pir_expected_op_num = 42
self.pir_expected_shape_op_num = 1
self.pir_expected_slice_op_num = 1


class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
Expand Down Expand Up @@ -774,15 +768,9 @@ def _set_expected_op_num(self):
self.expected_slice_op_num = 3

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_op_num = 28
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0

@test_ast_only
@test_pir_only
def test_pir_op_num(self):
# Remove this after we support control flow
pass
self.pir_expected_slice_op_num = 2


class TestChangeShapeAfterAssign(TestTensorShapeBasic):
Expand Down

0 comments on commit 5ad19e5

Please sign in to comment.