Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/kirin/dialects/scf/scf2cf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import cast
from dataclasses import field, dataclass

from kirin import ir
Expand Down Expand Up @@ -34,7 +35,10 @@ def get_entr_and_exit_blks(self, node: For | IfElse):
result.replace_by(exit_block.args.append_from(result.type, result.name))

curr_block = node.parent_block
assert curr_block.IS_BLOCK, "Node must be inside a block"
assert (
curr_block is not None and curr_block.IS_BLOCK
), "Node must be inside a block"
curr_block = cast(ir.Block, curr_block)

curr_block.stmts.append(
Branch(arguments=(), successor=(entr_block := ir.Block()))
Expand All @@ -47,8 +51,12 @@ def get_curr_blk_info(self, node: For | IfElse) -> tuple[ir.Region, int]:
curr_block = node.parent_block
region = node.parent_region

assert region.IS_REGION, "Node must be inside a region"
assert curr_block.IS_BLOCK, "Node must be inside a block"
assert region is not None and region.IS_REGION, "Node must be inside a region"
region = cast(ir.Region, region)
assert (
curr_block is not None and curr_block.IS_BLOCK
), "Node must be inside a block"
curr_block = cast(ir.Block, curr_block)

block_idx = region._block_idx[curr_block]
return region, block_idx
Expand Down
27 changes: 14 additions & 13 deletions src/kirin/dialects/scf/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import cast

from kirin import ir, types
from kirin.ir import Block, Region
from kirin.decl import info, statement
from kirin.print.printer import Printer

Expand Down Expand Up @@ -31,31 +34,29 @@
else_body: ir.Region | ir.Block | None = None,
):
if then_body.IS_REGION:
then_body_region = then_body
then_body_region = cast(Region, then_body)
if then_body_region.blocks:
then_body_block = then_body_region.blocks[-1]
else:
then_body_block = None
elif then_body.IS_BLOCK:
then_body_block = then_body
then_body_region = ir.Region(then_body)
else: # then_body.IS_BLOCK:
then_body_block = cast(Block, then_body)
then_body_region = Region(then_body)

Check failure on line 44 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "Region | Block" cannot be assigned to parameter "blocks" of type "Block | Iterable[Block]" in function "__init__"   Type "Region | Block" is not assignable to type "Block | Iterable[Block]"     Type "Region" is not assignable to type "Block | Iterable[Block]"       "Region" is not assignable to "Block"       "Region" is incompatible with protocol "Iterable[Block]"         "__iter__" is not present (reportArgumentType)

if else_body.IS_REGION:
if else_body is None:
else_body_region = ir.Region()
else_body_block = None
elif else_body.IS_REGION:
else_body_region = cast(Region, else_body)
if not else_body.blocks: # empty region

Check failure on line 51 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "blocks" for class "Block"   Attribute "blocks" is unknown (reportAttributeAccessIssue)
else_body_region = else_body
else_body_block = None
elif len(else_body.blocks) == 0:

Check failure on line 53 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "blocks" for class "Block"   Attribute "blocks" is unknown (reportAttributeAccessIssue)
else_body_region = else_body
else_body_block = None
else:
else_body_region = else_body
else_body_block = else_body_region.blocks[0]
elif else_body.IS_BLOCK:
else_body_region = ir.Region(else_body)
else: # else_body.IS_BLOCK:
else_body_region = ir.Region(cast(Block, else_body))
else_body_block = else_body
else:
else_body_region = ir.Region()
else_body_block = None

# if either then or else body has yield, we generate results
# we assume if both have yields, they have the same number of results
Expand All @@ -63,7 +64,7 @@
if then_body_block is not None:
then_yield = then_body_block.last_stmt
else_yield = (
else_body_block.last_stmt if else_body_block is not None else None

Check failure on line 67 in src/kirin/dialects/scf/stmts.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "last_stmt" for class "Region"   Attribute "last_stmt" is unknown (reportAttributeAccessIssue)
)
if then_yield is not None and isinstance(then_yield, Yield):
results = then_yield.values
Expand Down
7 changes: 4 additions & 3 deletions src/kirin/ir/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import sys
import inspect
import textwrap
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from rich.console import Console

from kirin.exception import StaticCheckError
from kirin.print.printer import Printer

if TYPE_CHECKING:
from kirin.ir import IRNode, Method
from kirin.ir import IRNode, Method, Statement


class ValidationError(StaticCheckError):
Expand Down Expand Up @@ -39,7 +39,8 @@ def attach(self, method: Method):
map(lambda each_line: " " * 4 + each_line, node_str.splitlines())
)
if self.node.IS_STATEMENT:
dialect = self.node.dialect.name if self.node.dialect else "<no dialect>"
stmt = cast("Statement", self.node)
dialect = stmt.dialect.name if stmt.dialect else "<no dialect>"
self.args += (
"when verifying the following statement",
f" `{dialect}.{type(self.node).__name__}` at\n",
Expand Down
15 changes: 12 additions & 3 deletions src/kirin/lowering/frame.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Optional, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
TypeVar,
Callable,
Optional,
cast,
overload,
)
from dataclasses import field, dataclass

from kirin.ir import Block, Region, SSAValue, Statement
Expand Down Expand Up @@ -54,9 +63,9 @@

def push(self, node: StmtType | Block) -> StmtType | Block:
if node.IS_BLOCK:
return self._push_block(node)
return self._push_block(cast(Block, node))
elif node.IS_STATEMENT:
return self._push_stmt(node)
return self._push_stmt(cast(Statement, node))

Check failure on line 68 in src/kirin/lowering/frame.py

View workflow job for this annotation

GitHub Actions / pyright

Type "Statement" is not assignable to return type "StmtType@push | Block"   Type "Statement" is not assignable to type "StmtType@push | Block"     Type "Statement" is not assignable to type "StmtType@push"     "Statement" is not assignable to "Block" (reportReturnType)
else:
raise BuildError(f"Unsupported type {type(node)} in push()")

Expand Down
8 changes: 4 additions & 4 deletions src/kirin/rewrite/walk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, cast
from dataclasses import field, dataclass

from kirin.ir import Block, Region, Statement
Expand Down Expand Up @@ -49,11 +49,11 @@ def populate_worklist(self, node: IRNode) -> None:
return

if node.IS_STATEMENT:
self.populate_worklist_Statement(node)
self.populate_worklist_Statement(cast(Statement, node))
elif node.IS_REGION:
self.populate_worklist_Region(node)
self.populate_worklist_Region(cast(Region, node))
elif node.IS_BLOCK:
self.populate_worklist_Block(node)
self.populate_worklist_Block(cast(Block, node))
else:
raise NotImplementedError(f"populate_worklist_{node.__class__.__name__}")

Expand Down
Loading