From 5ad19e5e0ec2c860d81d3dd6d761e9525893e6f6 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:30:12 +0800 Subject: [PATCH] [PIR] Add num_ops() func for block and program (#61046) * add GetOpNum func for block and program * update * open test * update * add IR_API & delete detail namespace * fix * fix * Update visitors.h --- .gitignore | 2 ++ paddle/fluid/pybind/pir.cc | 4 +++- paddle/pir/core/block.h | 8 ++++++- paddle/pir/core/operation.h | 2 +- paddle/pir/core/program.h | 1 + paddle/pir/core/visitors.cc | 4 ++-- paddle/pir/core/visitors.h | 20 +++++++--------- test/dygraph_to_static/test_tensor_shape.py | 26 ++++++--------------- 8 files changed, 32 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 3d3bc709a75e9..b9ddfda1fd925 100644 --- a/.gitignore +++ b/.gitignore @@ -109,7 +109,9 @@ paddle/fluid/pir/dialect/operator/ir/op_decomp.cc paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc paddle/fluid/pir/dialect/operator/ir/pd_op.* paddle/fluid/pir/dialect/operator/ir/onednn_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.* paddle/fluid/pir/dialect/operator/ir/onednn_op_info.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.* paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.* diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index b550d5a9b35e9..316811923dc5c 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -299,7 +299,8 @@ void BindProgram(py::module *m) { }, [](std::shared_ptr self, int64_t random_seed) { SetProgramInt64Attr(self, "random_seed", random_seed); - }); + }) + .def("num_ops", [](Program &self) { return self.num_ops(); }); } std::shared_ptr ParseProgram(const std::string &program_str) { @@ -361,6 +362,7 @@ void BindBlock(py::module *m) { } return op_list; }) + .def("num_ops", [](Block &self) { return self.num_ops(); }) .def( "__enter__", [](Block &self) -> Block & { diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index eef1f60b7aaf1..c9d0c1f4bfc1a 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -127,7 +127,7 @@ class IR_API Block { template void Walk(Block::Iterator begin, Block::Iterator end, FuncT &&callback) { for (auto &op = begin; op != end; ++op) { - detail::Walk(&*op, callback); + pir::Walk(&*op, callback); } } @@ -138,6 +138,12 @@ class IR_API Block { return Walk(begin(), end(), std::forward(callback)); } + uint32_t num_ops() { + uint32_t num = 0; + Walk([&num](Operation *) { ++num; }); + return num; + } + private: Block(Block &) = delete; Block &operator=(const Block &) = delete; diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index b148347d4d846..7ea3555b3a9a2 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -185,7 +185,7 @@ class IR_API alignas(8) Operation final /// template void Walk(FuncT &&callback) { - return detail::Walk(this, std::forward(callback)); + return pir::Walk(this, std::forward(callback)); } /// diff --git a/paddle/pir/core/program.h b/paddle/pir/core/program.h index dabd871e519a5..2b232d91d88b7 100644 --- a/paddle/pir/core/program.h +++ b/paddle/pir/core/program.h @@ -60,6 +60,7 @@ class IR_API Program { Block* block() { return &module_.block(); } const Block* block() const { return &module_op().block(); } + uint32_t num_ops() { return block()->num_ops(); } Parameter* GetParameter(const std::string& name) const; void SetParameter(const std::string& name, std::unique_ptr&& parameter); diff --git a/paddle/pir/core/visitors.cc b/paddle/pir/core/visitors.cc index a5d9706518c10..f3498facb6f1f 100644 --- a/paddle/pir/core/visitors.cc +++ b/paddle/pir/core/visitors.cc @@ -15,7 +15,7 @@ #include "paddle/pir/core/visitors.h" #include "paddle/pir/core/operation.h" -namespace pir::detail { +namespace pir { // Defines utilities for walking and visiting operations. void Walk(Operation *op, @@ -66,4 +66,4 @@ void Walk(Operation *op, if (order == WalkOrder::PostOrder) callback(op); } -} // namespace pir::detail +} // namespace pir diff --git a/paddle/pir/core/visitors.h b/paddle/pir/core/visitors.h index 7d9e9eacf4394..3d43138a7769e 100644 --- a/paddle/pir/core/visitors.h +++ b/paddle/pir/core/visitors.h @@ -25,24 +25,22 @@ class Block; // Traversal order. enum class WalkOrder { PreOrder, PostOrder }; -namespace detail { // Defines utilities for walking and visiting operations. -void Walk(Operation *op, - const std::function &callback, - WalkOrder order); +IR_API void Walk(Operation *op, + const std::function &callback, + WalkOrder order); -void Walk(Operation *op, - const std::function &callback, - WalkOrder order); +IR_API void Walk(Operation *op, + const std::function &callback, + WalkOrder order); -void Walk(Operation *op, - const std::function &callback, - WalkOrder order); +IR_API void Walk(Operation *op, + const std::function &callback, + WalkOrder order); template void Walk(Operation *op, FuncTy &&callback) { return Walk(op, callback, Order); } -} // namespace detail } // namespace pir diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index 88b2425f45d12..d8755b842525a 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -296,7 +296,7 @@ def _compute_op_num(self, program): return op_num, shape_op_num, slice_op_num def _compute_pir_op_num(self, program): - op_num = len(program.global_block().ops) + op_num = program.global_block().num_ops() shape_op_num = 0 slice_op_num = 0 @@ -651,7 +651,7 @@ def _compute_op_num(self, program): ) def _compute_pir_op_num(self, program): - op_num = len(program.global_block().ops) + op_num = program.global_block().num_ops() shape_op_num = 0 slice_op_num = 0 @@ -732,15 +732,9 @@ def _set_expected_op_num(self): self.expected_slice_op_num = 4 def _set_pir_expected_op_num(self): - self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 0 - self.pir_expected_slice_op_num = 0 - - @test_ast_only - @test_pir_only - def test_pir_op_num(self): - # Remove this after we support control flow - pass + self.pir_expected_op_num = 42 + self.pir_expected_shape_op_num = 1 + self.pir_expected_slice_op_num = 1 class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): @@ -774,15 +768,9 @@ def _set_expected_op_num(self): self.expected_slice_op_num = 3 def _set_pir_expected_op_num(self): - self.pir_expected_op_num = 3 + self.pir_expected_op_num = 28 self.pir_expected_shape_op_num = 0 - self.pir_expected_slice_op_num = 0 - - @test_ast_only - @test_pir_only - def test_pir_op_num(self): - # Remove this after we support control flow - pass + self.pir_expected_slice_op_num = 2 class TestChangeShapeAfterAssign(TestTensorShapeBasic):