From 107c5ce82b9518c5869e189b3d4bd9030c582080 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:19:46 -0700 Subject: [PATCH 1/9] Fix bug in Pass.fixpoint that lead to unreachable break condition Since result was always recomputed from the join of result and result_, as soon as result_.has_done_something was set to True in the first loop iteration, result.has_done_something would never become false anymore. --- src/kirin/passes/abc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/kirin/passes/abc.py b/src/kirin/passes/abc.py index a43a28787..9c76e56c0 100644 --- a/src/kirin/passes/abc.py +++ b/src/kirin/passes/abc.py @@ -32,10 +32,8 @@ def __call__(self, mt: Method) -> RewriteResult: return result def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult: - result = RewriteResult() for _ in range(max_iter): - result_ = self.unsafe_run(mt) - result = result_.join(result) + result = self.unsafe_run(mt) if not result.has_done_something: break mt.verify() From 8cf550bffa7ffe95836901cc2cd0b24966a4e3d9 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:21:21 -0700 Subject: [PATCH 2/9] Fux bug where has_done_something was overwritten --- src/kirin/dialects/ilist/rewrite/const.py | 8 ++++++-- src/kirin/dialects/ilist/rewrite/list.py | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/kirin/dialects/ilist/rewrite/const.py b/src/kirin/dialects/ilist/rewrite/const.py index 12b95614e..d49177088 100644 --- a/src/kirin/dialects/ilist/rewrite/const.py +++ b/src/kirin/dialects/ilist/rewrite/const.py @@ -27,11 +27,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: typ = result.type data = hint.data if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)): - has_done_something = self._rewrite_IList_type(result, data) + has_done_something = has_done_something or self._rewrite_IList_type( + result, data + ) elif isinstance(typ, types.Generic) and typ.body.is_subseteq( types.PyClass(IList) ): - has_done_something = self._rewrite_IList_type(result, data) + has_done_something = has_done_something or self._rewrite_IList_type( + result, data + ) return RewriteResult(has_done_something=has_done_something) def rewrite_Constant(self, node: Constant) -> RewriteResult: diff --git a/src/kirin/dialects/ilist/rewrite/list.py b/src/kirin/dialects/ilist/rewrite/list.py index 1b9963c98..4de60beef 100644 --- a/src/kirin/dialects/ilist/rewrite/list.py +++ b/src/kirin/dialects/ilist/rewrite/list.py @@ -13,7 +13,7 @@ class List2IList(RewriteRule): def rewrite_Block(self, node: ir.Block) -> RewriteResult: has_done_something = False for arg in node.args: - has_done_something = self._rewrite_SSAValue_type(arg) + has_done_something = has_done_something or self._rewrite_SSAValue_type(arg) return RewriteResult(has_done_something=has_done_something) def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @@ -25,7 +25,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: ) for result in node.results: - has_done_something = self._rewrite_SSAValue_type(result) + has_done_something = has_done_something or self._rewrite_SSAValue_type( + result + ) return RewriteResult(has_done_something=has_done_something) From bbaa056b86095808e6f2686fae2e0b0e41aead81 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:24:28 -0700 Subject: [PATCH 3/9] Fix bug where passes would wrongly return `has_done_something=True`. This is important for fixed-point iterations because in that case, the fixed-point loop would just iterate until it reaches max iteration. --- src/kirin/dialects/ilist/rewrite/const.py | 11 +++++++++-- src/kirin/dialects/ilist/rewrite/hint_len.py | 7 ++++++- src/kirin/rewrite/apply_type.py | 12 ++++++++---- src/kirin/rewrite/wrap_const.py | 2 +- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/kirin/dialects/ilist/rewrite/const.py b/src/kirin/dialects/ilist/rewrite/const.py index d49177088..dedf58181 100644 --- a/src/kirin/dialects/ilist/rewrite/const.py +++ b/src/kirin/dialects/ilist/rewrite/const.py @@ -57,6 +57,13 @@ def _rewrite_IList_type(self, result: ir.SSAValue, data): for elem in data.data[1:]: elem_type = elem_type.join(types.PyClass(type(elem))) - result.type = IListType[elem_type, types.Literal(len(data.data))] - result.hints["const"] = const.Value(data) + new_type = IListType[elem_type, types.Literal(len(data.data))] + new_hint = const.Value(data) + + # Check if type and hint are already correct + if result.type == new_type and result.hints.get("const") == new_hint: + return False + + result.type = new_type + result.hints["const"] = new_hint return True diff --git a/src/kirin/dialects/ilist/rewrite/hint_len.py b/src/kirin/dialects/ilist/rewrite/hint_len.py index d0d0d5d26..c34bbc14d 100644 --- a/src/kirin/dialects/ilist/rewrite/hint_len.py +++ b/src/kirin/dialects/ilist/rewrite/hint_len.py @@ -32,6 +32,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if (coll_len := self._get_collection_len(node.value)) is None: return RewriteResult() - node.result.hints["const"] = const.Value(coll_len) + existing_hint = node.result.hints.get("const") + new_hint = const.Value(coll_len) + if new_hint.is_structurally_equal(existing_hint): + return RewriteResult() + + node.result.hints["const"] = new_hint return RewriteResult(has_done_something=True) diff --git a/src/kirin/rewrite/apply_type.py b/src/kirin/rewrite/apply_type.py index 1be7c3ada..9004da021 100644 --- a/src/kirin/rewrite/apply_type.py +++ b/src/kirin/rewrite/apply_type.py @@ -13,8 +13,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult: has_done_something = False for arg in node.args: if arg in self.results: - arg.type = self.results[arg] - has_done_something = True + arg_type = self.results[arg] + if arg.type != arg_type: + arg.type = arg_type + has_done_something = True return RewriteResult(has_done_something=has_done_something) @@ -22,8 +24,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = False for result in node._results: if result in self.results: - result.type = self.results[result] - has_done_something = True + arg_type = self.results[result] + if result.type != arg_type: + result.type = arg_type + has_done_something = True if (trait := node.get_trait(ir.HasSignature)) is not None and ( callable_trait := node.get_trait(ir.CallableStmtInterface) diff --git a/src/kirin/rewrite/wrap_const.py b/src/kirin/rewrite/wrap_const.py index 06a40a335..2015ab6c3 100644 --- a/src/kirin/rewrite/wrap_const.py +++ b/src/kirin/rewrite/wrap_const.py @@ -49,7 +49,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if ( trait := node.get_trait(ir.MaybePure) - ) and node in self.frame.should_be_pure: + ) and node in self.frame.should_be_pure and not trait.is_pure(node): trait.set_pure(node) has_done_something = True return RewriteResult(has_done_something=has_done_something) From 8d716200975e41ca6362aabb1395978d75ebcdd2 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:25:14 -0700 Subject: [PATCH 4/9] Fix bug where CSE rewrite would prematurely exit --- src/kirin/rewrite/cse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/kirin/rewrite/cse.py b/src/kirin/rewrite/cse.py index 1740bcf54..d154eaf44 100644 --- a/src/kirin/rewrite/cse.py +++ b/src/kirin/rewrite/cse.py @@ -61,6 +61,7 @@ class CommonSubexpressionElimination(RewriteRule): def rewrite_Block(self, node: ir.Block) -> RewriteResult: seen: dict[Info, ir.Statement] = {} + has_done_something = False for stmt in node.stmts: if not stmt.has_trait(ir.Pure): @@ -81,10 +82,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult: for result, old_result in zip(stmt._results, old_stmt.results): result.replace_by(old_result) stmt.delete() - return RewriteResult(has_done_something=True) + has_done_something = True else: seen[info] = stmt - return RewriteResult() + return RewriteResult(has_done_something=has_done_something) def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if not node.regions: From cd15068e48e57fe9439ef0727bbd7a0621e255ff Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:28:52 -0700 Subject: [PATCH 5/9] Fix issue that led to repeated addition of const SSA statements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this fix, the following code would produce the following IR: rule = rewrite.Fixpoint( rewrite.Walk( rewrite.Chain( ilist.rewrite.HintLen(), rewrite.ConstantFold(), rewrite.DeadCodeElimination(), ) ) ) @basic def len_func(xs: ilist.IList[int, Literal[3]]): return len(xs) @basic def len_func3(xs: ilist.IList[int, Any]): return len(xs) rule.rewrite(len_func.code) len_func.print() func.func len_func(!py.IList[!py.int, Literal(3,int)]) -> !Any { ^0(%len_func_self, %xs): │ %0 = py.constant.constant 3 : !py.int │ %1 = py.constant.constant 3 : !py.int │ %2 = py.constant.constant 3 : !py.int │ %3 = py.constant.constant 3 : !py.int │ %4 = py.constant.constant 3 : !py.int │ %5 = py.constant.constant 3 : !py.int │ %6 = py.constant.constant 3 : !py.int │ %7 = py.constant.constant 3 : !py.int │ %8 = py.constant.constant 3 : !py.int │ %9 = py.constant.constant 3 : !py.int │ %10 = py.constant.constant 3 : !py.int │ %11 = py.constant.constant 3 : !py.int │ %12 = py.constant.constant 3 : !py.int │ %13 = py.constant.constant 3 : !py.int │ %14 = py.constant.constant 3 : !py.int │ %15 = py.constant.constant 3 : !py.int │ %16 = py.constant.constant 3 : !py.int │ %17 = py.constant.constant 3 : !py.int │ %18 = py.constant.constant 3 : !py.int │ %19 = py.constant.constant 3 : !py.int │ %20 = py.constant.constant 3 : !py.int │ %21 = py.constant.constant 3 : !py.int │ %22 = py.constant.constant 3 : !py.int │ %23 = py.constant.constant 3 : !py.int │ %24 = py.constant.constant 3 : !py.int │ %25 = py.constant.constant 3 : !py.int │ %26 = py.constant.constant 3 : !py.int │ %27 = py.constant.constant 3 : !py.int │ %28 = py.constant.constant 3 : !py.int │ %29 = py.constant.constant 3 : !py.int │ %30 = py.constant.constant 3 : !py.int │ %31 = py.len.len(value=%xs : !py.IList[!py.int, Literal(3,int)]) : !py.int │ func.return %0 } // func.func len_func --- src/kirin/rewrite/fold.py | 2 ++ test/dialects/test_infer_len.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/kirin/rewrite/fold.py b/src/kirin/rewrite/fold.py index e2d9d65d8..b7a74eb7f 100644 --- a/src/kirin/rewrite/fold.py +++ b/src/kirin/rewrite/fold.py @@ -29,6 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = False for old_result in node.results: if (value := self.get_const(old_result)) is not None: + if not old_result.uses: + continue stmt = Constant(value.data) stmt.insert_before(node) old_result.replace_by(stmt.result) diff --git a/test/dialects/test_infer_len.py b/test/dialects/test_infer_len.py index 0576a1ac0..f76553657 100644 --- a/test/dialects/test_infer_len.py +++ b/test/dialects/test_infer_len.py @@ -7,7 +7,13 @@ def test(): rule = rewrite.Fixpoint( - rewrite.Walk(rewrite.Chain(ilist.rewrite.HintLen(), rewrite.ConstantFold())) + rewrite.Walk( + rewrite.Chain( + ilist.rewrite.HintLen(), + rewrite.ConstantFold(), + rewrite.DeadCodeElimination(), + ) + ) ) @basic @@ -24,6 +30,8 @@ def len_func3(xs: ilist.IList[int, Any]): stmt = len_func.callable_region.blocks[0].stmts.at(0) assert isinstance(stmt, py.Constant) assert stmt.value.unwrap() == 3 + assert len(len_func.callable_region.blocks[0].stmts) == 2 stmt = len_func3.callable_region.blocks[0].stmts.at(0) assert isinstance(stmt, py.Len) + assert len(len_func3.callable_region.blocks[0].stmts) == 2 From 8231e91007fd19f6da8f4f17b8259d47adf13f68 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:30:28 -0700 Subject: [PATCH 6/9] Remove failing test, since TypeInfer no longer always does something If the type-in for a pass was to always do something, this would lead to infinite fixpoint loops (or loops would reach max iterations) --- .../dataflow/typeinfer/test_always_rewrite.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 test/analysis/dataflow/typeinfer/test_always_rewrite.py diff --git a/test/analysis/dataflow/typeinfer/test_always_rewrite.py b/test/analysis/dataflow/typeinfer/test_always_rewrite.py deleted file mode 100644 index 72c0822e1..000000000 --- a/test/analysis/dataflow/typeinfer/test_always_rewrite.py +++ /dev/null @@ -1,18 +0,0 @@ -from kirin.passes import TypeInfer -from kirin.prelude import basic_no_opt - - -def test_always_rewrites(): - @basic_no_opt - def unstable(x: int): # type: ignore - y = x + 1 - if y > 10: - z = y - else: - z = y + 1.2 - return z - - result = TypeInfer(dialects=unstable.dialects, no_raise=False).fixpoint(unstable) - assert ( - result.has_done_something - ) # this will always be true because TypeInfer always rewrites type From f88535da3062a4629d26924c6778e0e31fea91a3 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 12:30:49 -0700 Subject: [PATCH 7/9] Asser worklist is empty before rewrite --- src/kirin/rewrite/walk.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/kirin/rewrite/walk.py b/src/kirin/rewrite/walk.py index 4d9a23638..9b8b7b51c 100644 --- a/src/kirin/rewrite/walk.py +++ b/src/kirin/rewrite/walk.py @@ -31,6 +31,8 @@ def rewrite(self, node: IRNode) -> RewriteResult: # NOTE: because the rewrite pass may mutate the node # thus we need to save the list of nodes to be processed # first before we start processing them + assert self.worklist.is_empty() + self.populate_worklist(node) has_done_something = False subnode = self.worklist.pop() From f22dfd7565634502d21e80a0c1529eaf9b3eec27 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 14:54:30 -0700 Subject: [PATCH 8/9] Pyright --- src/kirin/dialects/ilist/rewrite/hint_len.py | 2 +- src/kirin/rewrite/wrap_const.py | 6 ++++-- test/testing/test_assert_structurally_same.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/kirin/dialects/ilist/rewrite/hint_len.py b/src/kirin/dialects/ilist/rewrite/hint_len.py index c34bbc14d..b20229ec7 100644 --- a/src/kirin/dialects/ilist/rewrite/hint_len.py +++ b/src/kirin/dialects/ilist/rewrite/hint_len.py @@ -35,7 +35,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: existing_hint = node.result.hints.get("const") new_hint = const.Value(coll_len) - if new_hint.is_structurally_equal(existing_hint): + if existing_hint is not None and new_hint.is_structurally_equal(existing_hint): return RewriteResult() node.result.hints["const"] = new_hint diff --git a/src/kirin/rewrite/wrap_const.py b/src/kirin/rewrite/wrap_const.py index 2015ab6c3..93d5a189d 100644 --- a/src/kirin/rewrite/wrap_const.py +++ b/src/kirin/rewrite/wrap_const.py @@ -48,8 +48,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = True if ( - trait := node.get_trait(ir.MaybePure) - ) and node in self.frame.should_be_pure and not trait.is_pure(node): + (trait := node.get_trait(ir.MaybePure)) + and node in self.frame.should_be_pure + and not trait.is_pure(node) + ): trait.set_pure(node) has_done_something = True return RewriteResult(has_done_something=has_done_something) diff --git a/test/testing/test_assert_structurally_same.py b/test/testing/test_assert_structurally_same.py index 1de3380e2..186f109f6 100644 --- a/test/testing/test_assert_structurally_same.py +++ b/test/testing/test_assert_structurally_same.py @@ -1,7 +1,7 @@ import pytest -from kirin.testing import assert_structurally_same from kirin.prelude import structural_no_opt +from kirin.testing import assert_structurally_same from kirin.dialects import py, func From 1c791a5b807f54b24ecbbee1765eee07a7d1818a Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Sat, 25 Oct 2025 14:58:04 -0700 Subject: [PATCH 9/9] Fix correct return of has_done_something that was broken in previous commit --- src/kirin/passes/abc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/kirin/passes/abc.py b/src/kirin/passes/abc.py index 9c76e56c0..40309de9e 100644 --- a/src/kirin/passes/abc.py +++ b/src/kirin/passes/abc.py @@ -32,9 +32,11 @@ def __call__(self, mt: Method) -> RewriteResult: return result def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult: + result = RewriteResult() for _ in range(max_iter): - result = self.unsafe_run(mt) - if not result.has_done_something: + result_ = self.unsafe_run(mt) + result = result_.join(result) + if not result_.has_done_something: break mt.verify() return result