diff --git a/src/kirin/dialects/scf/stmts.py b/src/kirin/dialects/scf/stmts.py index 43dd38cda..3c4ea008f 100644 --- a/src/kirin/dialects/scf/stmts.py +++ b/src/kirin/dialects/scf/stmts.py @@ -41,16 +41,16 @@ def __init__( then_body_block = None else: # then_body.IS_BLOCK: then_body_block = cast(Block, then_body) - then_body_region = Region(then_body) + then_body_region = cast(Region, then_body) if else_body is None: else_body_region = ir.Region() else_body_block = None elif else_body.IS_REGION: else_body_region = cast(Region, else_body) - if not else_body.blocks: # empty region + if not else_body_region.blocks: # empty region else_body_block = None - elif len(else_body.blocks) == 0: + elif len(else_body_region.blocks) == 0: else_body_block = None else: else_body_block = else_body_region.blocks[0] @@ -63,6 +63,7 @@ def __init__( results = () if then_body_block is not None: then_yield = then_body_block.last_stmt + else_body_block = cast(Block, else_body_block) else_yield = ( else_body_block.last_stmt if else_body_block is not None else None ) diff --git a/src/kirin/ir/nodes/block.py b/src/kirin/ir/nodes/block.py index 19ef1bd7c..cbd657f2e 100644 --- a/src/kirin/ir/nodes/block.py +++ b/src/kirin/ir/nodes/block.py @@ -391,6 +391,10 @@ def is_structurally_equal( # pyright: ignore[reportIncompatibleMethodOverride] if context is None: context = {} + if self in context: + return context[self] is other + context[self] = other + if len(self._args) != len(other._args) or len(self.stmts) != len(other.stmts): return False diff --git a/src/kirin/lowering/frame.py b/src/kirin/lowering/frame.py index 57c2587b5..6eb88ef2e 100644 --- a/src/kirin/lowering/frame.py +++ b/src/kirin/lowering/frame.py @@ -21,6 +21,7 @@ from .state import State CallbackFn = Callable[["Frame", SSAValue], SSAValue] +StmtType = TypeVar("StmtType", bound=Statement) @dataclass @@ -53,8 +54,6 @@ class Frame(Generic[Stmt]): def __repr__(self): return f"Frame({len(self.defs)} defs, {len(self.globals)} globals)" - StmtType = TypeVar("StmtType", bound=Statement) - @overload def push(self, node: StmtType) -> StmtType: ... @@ -65,7 +64,7 @@ def push(self, node: StmtType | Block) -> StmtType | Block: if node.IS_BLOCK: return self._push_block(cast(Block, node)) elif node.IS_STATEMENT: - return self._push_stmt(cast(Statement, node)) + return self._push_stmt(cast(StmtType, node)) else: raise BuildError(f"Unsupported type {type(node)} in push()") diff --git a/src/kirin/passes/fold.py b/src/kirin/passes/fold.py index ce4d44f4a..170bca6dd 100644 --- a/src/kirin/passes/fold.py +++ b/src/kirin/passes/fold.py @@ -10,6 +10,7 @@ CFGCompactify, InlineGetItem, DeadCodeElimination, + ClosureFieldLowering, ) from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult @@ -28,6 +29,7 @@ class Fold(Pass): - `InlineGetItem` - `Call2Invoke` - `DeadCodeElimination` + - `ClosureFieldLowering` """ hint_const: HintConst = field(init=False) @@ -46,6 +48,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult: InlineGetItem(), Call2Invoke(), DeadCodeElimination(), + ClosureFieldLowering(), ) ) ) diff --git a/src/kirin/rewrite/__init__.py b/src/kirin/rewrite/__init__.py index a82d05c9d..d9261d79e 100644 --- a/src/kirin/rewrite/__init__.py +++ b/src/kirin/rewrite/__init__.py @@ -13,3 +13,4 @@ from .wrap_const import WrapConst as WrapConst from .call2invoke import Call2Invoke as Call2Invoke from .type_assert import InlineTypeAssert as InlineTypeAssert +from .closurefieldlowering import ClosureFieldLowering as ClosureFieldLowering diff --git a/src/kirin/rewrite/closurefieldlowering.py b/src/kirin/rewrite/closurefieldlowering.py new file mode 100644 index 000000000..ca9198037 --- /dev/null +++ b/src/kirin/rewrite/closurefieldlowering.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + + +@dataclass +class ClosureFieldLowering(RewriteRule): + """Lowers captured closure fields into py.Constants. + - Trigger on func.Invoke + - If the callee Method has non-empty .fields, lower its func.GetField to py.Constant + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, func.Invoke): + return RewriteResult(has_done_something=False) + + method = node.callee + if not method.fields: + return RewriteResult(has_done_something=False) + # Replace func.GetField with py.Constant. + changed = self._lower_captured_fields(method) + if changed: + method.fields = () + return RewriteResult(has_done_something=changed) + + def _get_field_index(self, getfield_stmt: func.GetField) -> int | None: + fld = getfield_stmt.attributes.get("field") + if fld: + return getfield_stmt.field + else: + return None + + def _lower_captured_fields(self, method: ir.Method) -> bool: + changed = False + fields = method.fields + if not fields: + return False + + for region in method.code.regions: + for block in region.blocks: + for stmt in list(block.stmts): + if not isinstance(stmt, func.GetField): + continue + idx = self._get_field_index(stmt) + if idx is None: + continue + captured = fields[idx] + # Skip Methods. + if isinstance(captured, ir.Method): + continue + # Replace GetField with Constant. + const_stmt = py.Constant(captured) + const_stmt.insert_before(stmt) + if stmt.results and const_stmt.results: + stmt.results[0].replace_by(const_stmt.results[0]) + stmt.delete() + changed = True + return changed diff --git a/test/serialization/test_jsonserializer.py b/test/serialization/test_jsonserializer.py index d9159b5a9..3ac196f6f 100644 --- a/test/serialization/test_jsonserializer.py +++ b/test/serialization/test_jsonserializer.py @@ -9,7 +9,7 @@ def foo(x: int, y: float, z: bool): c = [[(200.0, 200.0), (210.0, 200.0)]] if z: - c.append([(222.0, 333.0)]) + c = [(222.0, 333.0)] else: return [1, 2, 3, 4] return c @@ -47,6 +47,23 @@ def my_kernel2(y: int): return my_kernel1(y) * 10 +@basic +def foo2(y: int): + + def inner(x: int): + return x * y + 1 + + return inner + + +inner_ker = foo2(y=10) + + +@basic +def main_lambda(z: int): + return inner_ker(z) + + @basic def slicing(): in1 = ("a", "b", "c", "d", "e", "f", "g", "h") @@ -94,6 +111,10 @@ def test_round_trip5(): round_trip(slicing) +def test_round_trip6(): + round_trip(main_lambda) + + def test_deterministic(): serializer = Serializer() s1 = serializer.encode(loop_ilist)