Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/kirin/dialects/scf/scf2cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
result.replace_by(exit_block.args.append_from(result.type, result.name))

curr_block = node.parent_block
assert isinstance(curr_block, ir.Block), "Node must be inside a block"
assert curr_block.IS_BLOCK, "Node must be inside a block"

Check failure on line 37 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

"IS_BLOCK" is not a known attribute of "None" (reportOptionalMemberAccess)

curr_block.stmts.append(

Check failure on line 39 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

"stmts" is not a known attribute of "None" (reportOptionalMemberAccess)
Branch(arguments=(), successor=(entr_block := ir.Block()))
)

Expand All @@ -47,11 +47,11 @@
curr_block = node.parent_block
region = node.parent_region

assert isinstance(region, ir.Region), "Node must be inside a region"
assert isinstance(curr_block, ir.Block), "Node must be inside a block"
assert region.IS_REGION, "Node must be inside a region"

Check failure on line 50 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

"IS_REGION" is not a known attribute of "None" (reportOptionalMemberAccess)
assert curr_block.IS_BLOCK, "Node must be inside a block"

Check failure on line 51 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

"IS_BLOCK" is not a known attribute of "None" (reportOptionalMemberAccess)

block_idx = region._block_idx[curr_block]

Check failure on line 53 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

"_block_idx" is not a known attribute of "None" (reportOptionalMemberAccess)

Check failure on line 53 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "Block | None" cannot be assigned to parameter "key" of type "Block" in function "__getitem__"   Type "Block | None" is not assignable to type "Block"     "None" is not assignable to "Block" (reportArgumentType)
return region, block_idx

Check failure on line 54 in src/kirin/dialects/scf/scf2cf.py

View workflow job for this annotation

GitHub Actions / pyright

Type "tuple[Region | None, int]" is not assignable to return type "tuple[Region, int]"   Type "Region | None" is not assignable to type "Region"     "None" is not assignable to "Region" (reportReturnType)


class ForRule(ScfRule):
Expand Down
8 changes: 4 additions & 4 deletions src/kirin/dialects/scf/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@
then_body: ir.Region | ir.Block,
else_body: ir.Region | ir.Block | None = None,
):
if isinstance(then_body, ir.Region):
if then_body.IS_REGION:
then_body_region = then_body
if then_body_region.blocks:

Check failure on line 35 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "blocks" for class "Block"   Attribute "blocks" is unknown (reportAttributeAccessIssue)
then_body_block = then_body_region.blocks[-1]

Check failure on line 36 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "blocks" for class "Block"   Attribute "blocks" is unknown (reportAttributeAccessIssue)
else:
then_body_block = None
elif isinstance(then_body, ir.Block):
elif then_body.IS_BLOCK:
then_body_block = then_body
then_body_region = ir.Region(then_body)

Check failure on line 41 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "Region | Block" cannot be assigned to parameter "blocks" of type "Block | Iterable[Block]" in function "__init__"   Type "Region | Block" is not assignable to type "Block | Iterable[Block]"     Type "Region" is not assignable to type "Block | Iterable[Block]"       "Region" is not assignable to "Block"       "Region" is incompatible with protocol "Iterable[Block]"         "__iter__" is not present (reportArgumentType)

if isinstance(else_body, ir.Region):
if else_body.IS_REGION:
if not else_body.blocks: # empty region
else_body_region = else_body
else_body_block = None
Expand All @@ -50,7 +50,7 @@
else:
else_body_region = else_body
else_body_block = else_body_region.blocks[0]
elif isinstance(else_body, ir.Block):
elif else_body.IS_BLOCK:
else_body_region = ir.Region(else_body)
else_body_block = else_body
else:
Expand Down
4 changes: 1 addition & 3 deletions src/kirin/ir/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def attach(self, method: Method):
return
self.method = method

from kirin.ir.nodes.stmt import Statement

console = Console(force_terminal=True, force_jupyter=False, file=sys.stderr)
printer = Printer(console=console)
# NOTE: populate the printer with the method body
Expand All @@ -40,7 +38,7 @@ def attach(self, method: Method):
node_str = "\n".join(
map(lambda each_line: " " * 4 + each_line, node_str.splitlines())
)
if isinstance(self.node, Statement):
if self.node.IS_STATEMENT:
dialect = self.node.dialect.name if self.node.dialect else "<no dialect>"
self.args += (
"when verifying the following statement",
Expand Down
6 changes: 5 additions & 1 deletion src/kirin/ir/nodes/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar, Iterator
from typing import TYPE_CHECKING, Generic, TypeVar, ClassVar, Iterator
from dataclasses import field, dataclass

from typing_extensions import Self
Expand Down Expand Up @@ -29,6 +29,10 @@ class IRNode(Generic[ParentType], ABC, Printable):

source: SourceInfo | None = field(default=None, init=False, repr=False)

IS_REGION: ClassVar[bool] = False
IS_BLOCK: ClassVar[bool] = False
IS_STATEMENT: ClassVar[bool] = False

def assert_parent(self, type_: type[IRNode], parent) -> None:
assert (
isinstance(parent, type_) or parent is None
Expand Down
4 changes: 3 additions & 1 deletion src/kirin/ir/nodes/block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Iterator, cast
from typing import TYPE_CHECKING, ClassVar, Iterable, Iterator, cast
from dataclasses import field, dataclass
from collections.abc import Sequence

Expand Down Expand Up @@ -244,6 +244,8 @@ class Block(IRNode["Region"]):
[`.print()`][kirin.print.printable.Printable.print] method.
"""

IS_BLOCK: ClassVar[bool] = True

_args: tuple[BlockArgument, ...]

# NOTE: we need linked list since stmts are inserted frequently
Expand Down
4 changes: 3 additions & 1 deletion src/kirin/ir/nodes/region.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Iterator, cast
from typing import TYPE_CHECKING, ClassVar, Iterable, Iterator, cast
from dataclasses import field, dataclass

from typing_extensions import Self
Expand Down Expand Up @@ -94,6 +94,8 @@ class Region(IRNode["Statement"]):
[`.print()`][kirin.print.printable.Printable.print] method.
"""

IS_REGION: ClassVar[bool] = True

_blocks: list[Block] = field(default_factory=list, repr=False)
_block_idx: dict[Block, int] = field(default_factory=dict, repr=False)
_parent: Statement | None = field(default=None, repr=False)
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/ir/nodes/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class Statement(IRNode["Block"]):
[`.print()`][kirin.print.printable.Printable.print] method.
"""

IS_STATEMENT: ClassVar[bool] = True

name: ClassVar[str]
dialect: ClassVar[Dialect | None] = field(default=None, init=False, repr=False)
traits: ClassVar[frozenset[Trait["Statement"]]] = frozenset()
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/ir/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
class SSAValue(ABC, Printable):
"""Base class for all SSA values in the IR."""

IS_SSA_VALUE: ClassVar[bool] = True

type: TypeAttribute = field(default_factory=AnyType, init=False, repr=True)
"""The type of this SSA value."""
hints: dict[str, Attribute] = field(default_factory=dict, init=False, repr=False)
Expand Down
4 changes: 2 additions & 2 deletions src/kirin/lowering/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def push(self, node: StmtType) -> StmtType: ...
def push(self, node: Block) -> Block: ...

def push(self, node: StmtType | Block) -> StmtType | Block:
if isinstance(node, Block):
if node.IS_BLOCK:
return self._push_block(node)
elif isinstance(node, Statement):
elif node.IS_STATEMENT:
return self._push_stmt(node)
else:
raise BuildError(f"Unsupported type {type(node)} in push()")
Expand Down
13 changes: 7 additions & 6 deletions src/kirin/rewrite/abc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from typing import cast
from dataclasses import field, dataclass

from kirin.ir import Pure, Block, IRNode, Region, MaybePure, Statement
Expand Down Expand Up @@ -29,12 +30,12 @@ class RewriteRule(ABC):
"""

def rewrite(self, node: IRNode) -> RewriteResult:
if isinstance(node, Region):
return self.rewrite_Region(node)
elif isinstance(node, Block):
return self.rewrite_Block(node)
elif isinstance(node, Statement):
return self.rewrite_Statement(node)
if node.IS_REGION:
return self.rewrite_Region(cast(Region, node))
elif node.IS_BLOCK:
return self.rewrite_Block(cast(Block, node))
elif node.IS_STATEMENT:
return self.rewrite_Statement(cast(Statement, node))
else:
return RewriteResult()

Expand Down
6 changes: 3 additions & 3 deletions src/kirin/rewrite/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def populate_worklist(self, node: IRNode) -> None:
if self.skip(node):
return

if isinstance(node, Statement):
if node.IS_STATEMENT:
self.populate_worklist_Statement(node)
elif isinstance(node, Region):
elif node.IS_REGION:
self.populate_worklist_Region(node)
elif isinstance(node, Block):
elif node.IS_BLOCK:
self.populate_worklist_Block(node)
else:
raise NotImplementedError(f"populate_worklist_{node.__class__.__name__}")
Expand Down