Skip to content
19 changes: 15 additions & 4 deletions src/kirin/dialects/ilist/rewrite/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
typ = result.type
data = hint.data
if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)):
has_done_something = self._rewrite_IList_type(result, data)
has_done_something = has_done_something or self._rewrite_IList_type(
result, data
)
elif isinstance(typ, types.Generic) and typ.body.is_subseteq(
types.PyClass(IList)
):
has_done_something = self._rewrite_IList_type(result, data)
has_done_something = has_done_something or self._rewrite_IList_type(
result, data
)
return RewriteResult(has_done_something=has_done_something)

def rewrite_Constant(self, node: Constant) -> RewriteResult:
Expand All @@ -53,6 +57,13 @@ def _rewrite_IList_type(self, result: ir.SSAValue, data):
for elem in data.data[1:]:
elem_type = elem_type.join(types.PyClass(type(elem)))

result.type = IListType[elem_type, types.Literal(len(data.data))]
result.hints["const"] = const.Value(data)
new_type = IListType[elem_type, types.Literal(len(data.data))]
new_hint = const.Value(data)

# Check if type and hint are already correct
if result.type == new_type and result.hints.get("const") == new_hint:
return False

result.type = new_type
result.hints["const"] = new_hint
return True
7 changes: 6 additions & 1 deletion src/kirin/dialects/ilist/rewrite/hint_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if (coll_len := self._get_collection_len(node.value)) is None:
return RewriteResult()

node.result.hints["const"] = const.Value(coll_len)
existing_hint = node.result.hints.get("const")
new_hint = const.Value(coll_len)

if existing_hint is not None and new_hint.is_structurally_equal(existing_hint):
return RewriteResult()

node.result.hints["const"] = new_hint
return RewriteResult(has_done_something=True)
6 changes: 4 additions & 2 deletions src/kirin/dialects/ilist/rewrite/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class List2IList(RewriteRule):
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
has_done_something = False
for arg in node.args:
has_done_something = self._rewrite_SSAValue_type(arg)
has_done_something = has_done_something or self._rewrite_SSAValue_type(arg)
return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
Expand All @@ -25,7 +25,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
)

for result in node.results:
has_done_something = self._rewrite_SSAValue_type(result)
has_done_something = has_done_something or self._rewrite_SSAValue_type(
result
)

return RewriteResult(has_done_something=has_done_something)

Expand Down
2 changes: 1 addition & 1 deletion src/kirin/passes/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
for _ in range(max_iter):
result_ = self.unsafe_run(mt)
result = result_.join(result)
if not result.has_done_something:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _ is important. If the pass does something in the first iteration, result.has_done_something will always be true for all iterations.

if not result_.has_done_something:
break
mt.verify()
return result
Expand Down
12 changes: 8 additions & 4 deletions src/kirin/rewrite/apply_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
has_done_something = False
for arg in node.args:
if arg in self.results:
arg.type = self.results[arg]
has_done_something = True
arg_type = self.results[arg]
if arg.type != arg_type:
arg.type = arg_type
has_done_something = True

return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = False
for result in node._results:
if result in self.results:
result.type = self.results[result]
has_done_something = True
arg_type = self.results[result]
if result.type != arg_type:
result.type = arg_type
has_done_something = True

if (trait := node.get_trait(ir.HasSignature)) is not None and (
callable_trait := node.get_trait(ir.CallableStmtInterface)
Expand Down
5 changes: 3 additions & 2 deletions src/kirin/rewrite/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CommonSubexpressionElimination(RewriteRule):

def rewrite_Block(self, node: ir.Block) -> RewriteResult:
seen: dict[Info, ir.Statement] = {}
has_done_something = False

for stmt in node.stmts:
if not stmt.has_trait(ir.Pure):
Expand All @@ -81,10 +82,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
for result, old_result in zip(stmt._results, old_stmt.results):
result.replace_by(old_result)
stmt.delete()
return RewriteResult(has_done_something=True)
has_done_something = True
else:
seen[info] = stmt
return RewriteResult()
return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not node.regions:
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/rewrite/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = False
for old_result in node.results:
if (value := self.get_const(old_result)) is not None:
if not old_result.uses:
continue
stmt = Constant(value.data)
stmt.insert_before(node)
old_result.replace_by(stmt.result)
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/rewrite/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def rewrite(self, node: IRNode) -> RewriteResult:
# NOTE: because the rewrite pass may mutate the node
# thus we need to save the list of nodes to be processed
# first before we start processing them
assert self.worklist.is_empty()

self.populate_worklist(node)
has_done_something = False
subnode = self.worklist.pop()
Expand Down
6 changes: 4 additions & 2 deletions src/kirin/rewrite/wrap_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = True

if (
trait := node.get_trait(ir.MaybePure)
) and node in self.frame.should_be_pure:
(trait := node.get_trait(ir.MaybePure))
and node in self.frame.should_be_pure
and not trait.is_pure(node)
):
trait.set_pure(node)
has_done_something = True
return RewriteResult(has_done_something=has_done_something)
18 changes: 0 additions & 18 deletions test/analysis/dataflow/typeinfer/test_always_rewrite.py

This file was deleted.

10 changes: 9 additions & 1 deletion test/dialects/test_infer_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

def test():
rule = rewrite.Fixpoint(
rewrite.Walk(rewrite.Chain(ilist.rewrite.HintLen(), rewrite.ConstantFold()))
rewrite.Walk(
rewrite.Chain(
ilist.rewrite.HintLen(),
rewrite.ConstantFold(),
rewrite.DeadCodeElimination(),
)
)
)

@basic
Expand All @@ -24,6 +30,8 @@ def len_func3(xs: ilist.IList[int, Any]):
stmt = len_func.callable_region.blocks[0].stmts.at(0)
assert isinstance(stmt, py.Constant)
assert stmt.value.unwrap() == 3
assert len(len_func.callable_region.blocks[0].stmts) == 2

stmt = len_func3.callable_region.blocks[0].stmts.at(0)
assert isinstance(stmt, py.Len)
assert len(len_func3.callable_region.blocks[0].stmts) == 2
2 changes: 1 addition & 1 deletion test/testing/test_assert_structurally_same.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from kirin.testing import assert_structurally_same
from kirin.prelude import structural_no_opt
from kirin.testing import assert_structurally_same
from kirin.dialects import py, func


Expand Down