Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/kirin/dialects/scf/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def __init__(
then_body_block = None
else: # then_body.IS_BLOCK:
then_body_block = cast(Block, then_body)
then_body_region = Region(then_body)
then_body_region = cast(Region, then_body)

if else_body is None:
else_body_region = ir.Region()
else_body_block = None
elif else_body.IS_REGION:
else_body_region = cast(Region, else_body)
if not else_body.blocks: # empty region
if not else_body_region.blocks: # empty region
else_body_block = None
elif len(else_body.blocks) == 0:
elif len(else_body_region.blocks) == 0:
else_body_block = None
else:
else_body_block = else_body_region.blocks[0]
Expand All @@ -63,6 +63,7 @@ def __init__(
results = ()
if then_body_block is not None:
then_yield = then_body_block.last_stmt
else_body_block = cast(Block, else_body_block)
else_yield = (
else_body_block.last_stmt if else_body_block is not None else None
)
Expand Down
4 changes: 4 additions & 0 deletions src/kirin/ir/nodes/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ def is_structurally_equal( # pyright: ignore[reportIncompatibleMethodOverride]
if context is None:
context = {}

if self in context:
return context[self] is other
context[self] = other

if len(self._args) != len(other._args) or len(self.stmts) != len(other.stmts):
return False

Expand Down
5 changes: 2 additions & 3 deletions src/kirin/lowering/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .state import State

CallbackFn = Callable[["Frame", SSAValue], SSAValue]
StmtType = TypeVar("StmtType", bound=Statement)


@dataclass
Expand Down Expand Up @@ -53,8 +54,6 @@ class Frame(Generic[Stmt]):
def __repr__(self):
return f"Frame({len(self.defs)} defs, {len(self.globals)} globals)"

StmtType = TypeVar("StmtType", bound=Statement)

@overload
def push(self, node: StmtType) -> StmtType: ...

Expand All @@ -65,7 +64,7 @@ def push(self, node: StmtType | Block) -> StmtType | Block:
if node.IS_BLOCK:
return self._push_block(cast(Block, node))
elif node.IS_STATEMENT:
return self._push_stmt(cast(Statement, node))
return self._push_stmt(cast(StmtType, node))
else:
raise BuildError(f"Unsupported type {type(node)} in push()")

Expand Down
3 changes: 3 additions & 0 deletions src/kirin/passes/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CFGCompactify,
InlineGetItem,
DeadCodeElimination,
ClosureFieldLowering,
)
from kirin.passes.abc import Pass
from kirin.rewrite.abc import RewriteResult
Expand All @@ -28,6 +29,7 @@ class Fold(Pass):
- `InlineGetItem`
- `Call2Invoke`
- `DeadCodeElimination`
- `ClosureFieldLowering`
"""

hint_const: HintConst = field(init=False)
Expand All @@ -46,6 +48,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
InlineGetItem(),
Call2Invoke(),
DeadCodeElimination(),
ClosureFieldLowering(),
)
)
)
Expand Down
1 change: 1 addition & 0 deletions src/kirin/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .wrap_const import WrapConst as WrapConst
from .call2invoke import Call2Invoke as Call2Invoke
from .type_assert import InlineTypeAssert as InlineTypeAssert
from .closurefieldlowering import ClosureFieldLowering as ClosureFieldLowering
60 changes: 60 additions & 0 deletions src/kirin/rewrite/closurefieldlowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass

from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult


@dataclass
class ClosureFieldLowering(RewriteRule):
"""Lowers captured closure fields into py.Constants.
- Trigger on func.Invoke
- If the callee Method has non-empty .fields, lower its func.GetField to py.Constant
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, func.Invoke):
return RewriteResult(has_done_something=False)

method = node.callee
if not method.fields:
return RewriteResult(has_done_something=False)
# Replace func.GetField with py.Constant.
changed = self._lower_captured_fields(method)
if changed:
method.fields = ()
return RewriteResult(has_done_something=changed)

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
else:
return None

def _lower_captured_fields(self, method: ir.Method) -> bool:
changed = False
fields = method.fields
if not fields:
return False

for region in method.code.regions:
for block in region.blocks:
for stmt in list(block.stmts):
if not isinstance(stmt, func.GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
continue
captured = fields[idx]
# Skip Methods.
if isinstance(captured, ir.Method):
continue
# Replace GetField with Constant.
const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
stmt.results[0].replace_by(const_stmt.results[0])
stmt.delete()
changed = True
return changed
23 changes: 22 additions & 1 deletion test/serialization/test_jsonserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def foo(x: int, y: float, z: bool):
c = [[(200.0, 200.0), (210.0, 200.0)]]
if z:
c.append([(222.0, 333.0)])
c = [(222.0, 333.0)]
else:
return [1, 2, 3, 4]
return c
Expand Down Expand Up @@ -47,6 +47,23 @@ def my_kernel2(y: int):
return my_kernel1(y) * 10


@basic
def foo2(y: int):

def inner(x: int):
return x * y + 1

return inner


inner_ker = foo2(y=10)


@basic
def main_lambda(z: int):
return inner_ker(z)


@basic
def slicing():
in1 = ("a", "b", "c", "d", "e", "f", "g", "h")
Expand Down Expand Up @@ -94,6 +111,10 @@ def test_round_trip5():
round_trip(slicing)


def test_round_trip6():
round_trip(main_lambda)


def test_deterministic():
serializer = Serializer()
s1 = serializer.encode(loop_ilist)
Expand Down