From 4741a8dbe15a095e66b198fab0666f73c8c0ab9f Mon Sep 17 00:00:00 2001 From: Casey Duckering Date: Thu, 16 Oct 2025 15:58:05 -0400 Subject: [PATCH] [smaller change] Replace isinstance with attribute checks for performance-critical code --- src/kirin/dialects/scf/scf2cf.py | 6 +++--- src/kirin/dialects/scf/stmts.py | 8 ++++---- src/kirin/ir/exception.py | 4 +--- src/kirin/ir/nodes/base.py | 6 +++++- src/kirin/ir/nodes/block.py | 4 +++- src/kirin/ir/nodes/region.py | 4 +++- src/kirin/ir/nodes/stmt.py | 2 ++ src/kirin/ir/ssa.py | 2 ++ src/kirin/lowering/frame.py | 4 ++-- src/kirin/rewrite/abc.py | 13 +++++++------ src/kirin/rewrite/walk.py | 6 +++--- 11 files changed, 35 insertions(+), 24 deletions(-) diff --git a/src/kirin/dialects/scf/scf2cf.py b/src/kirin/dialects/scf/scf2cf.py index 0d72c56bc..fd15add53 100644 --- a/src/kirin/dialects/scf/scf2cf.py +++ b/src/kirin/dialects/scf/scf2cf.py @@ -34,7 +34,7 @@ def get_entr_and_exit_blks(self, node: For | IfElse): result.replace_by(exit_block.args.append_from(result.type, result.name)) curr_block = node.parent_block - assert isinstance(curr_block, ir.Block), "Node must be inside a block" + assert curr_block.IS_BLOCK, "Node must be inside a block" curr_block.stmts.append( Branch(arguments=(), successor=(entr_block := ir.Block())) @@ -47,8 +47,8 @@ def get_curr_blk_info(self, node: For | IfElse) -> tuple[ir.Region, int]: curr_block = node.parent_block region = node.parent_region - assert isinstance(region, ir.Region), "Node must be inside a region" - assert isinstance(curr_block, ir.Block), "Node must be inside a block" + assert region.IS_REGION, "Node must be inside a region" + assert curr_block.IS_BLOCK, "Node must be inside a block" block_idx = region._block_idx[curr_block] return region, block_idx diff --git a/src/kirin/dialects/scf/stmts.py b/src/kirin/dialects/scf/stmts.py index d5f506a36..d909f125a 100644 --- a/src/kirin/dialects/scf/stmts.py +++ b/src/kirin/dialects/scf/stmts.py @@ -30,17 +30,17 @@ def __init__( then_body: ir.Region | ir.Block, else_body: ir.Region | ir.Block | None = None, ): - if isinstance(then_body, ir.Region): + if then_body.IS_REGION: then_body_region = then_body if then_body_region.blocks: then_body_block = then_body_region.blocks[-1] else: then_body_block = None - elif isinstance(then_body, ir.Block): + elif then_body.IS_BLOCK: then_body_block = then_body then_body_region = ir.Region(then_body) - if isinstance(else_body, ir.Region): + if else_body.IS_REGION: if not else_body.blocks: # empty region else_body_region = else_body else_body_block = None @@ -50,7 +50,7 @@ def __init__( else: else_body_region = else_body else_body_block = else_body_region.blocks[0] - elif isinstance(else_body, ir.Block): + elif else_body.IS_BLOCK: else_body_region = ir.Region(else_body) else_body_block = else_body else: diff --git a/src/kirin/ir/exception.py b/src/kirin/ir/exception.py index 326f8e6ef..550836135 100644 --- a/src/kirin/ir/exception.py +++ b/src/kirin/ir/exception.py @@ -26,8 +26,6 @@ def attach(self, method: Method): return self.method = method - from kirin.ir.nodes.stmt import Statement - console = Console(force_terminal=True, force_jupyter=False, file=sys.stderr) printer = Printer(console=console) # NOTE: populate the printer with the method body @@ -40,7 +38,7 @@ def attach(self, method: Method): node_str = "\n".join( map(lambda each_line: " " * 4 + each_line, node_str.splitlines()) ) - if isinstance(self.node, Statement): + if self.node.IS_STATEMENT: dialect = self.node.dialect.name if self.node.dialect else "" self.args += ( "when verifying the following statement", diff --git a/src/kirin/ir/nodes/base.py b/src/kirin/ir/nodes/base.py index 333d94422..baf7bd2f1 100644 --- a/src/kirin/ir/nodes/base.py +++ b/src/kirin/ir/nodes/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar, Iterator +from typing import TYPE_CHECKING, Generic, TypeVar, ClassVar, Iterator from dataclasses import field, dataclass from typing_extensions import Self @@ -29,6 +29,10 @@ class IRNode(Generic[ParentType], ABC, Printable): source: SourceInfo | None = field(default=None, init=False, repr=False) + IS_REGION: ClassVar[bool] = False + IS_BLOCK: ClassVar[bool] = False + IS_STATEMENT: ClassVar[bool] = False + def assert_parent(self, type_: type[IRNode], parent) -> None: assert ( isinstance(parent, type_) or parent is None diff --git a/src/kirin/ir/nodes/block.py b/src/kirin/ir/nodes/block.py index 616ec032b..19ef1bd7c 100644 --- a/src/kirin/ir/nodes/block.py +++ b/src/kirin/ir/nodes/block.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Iterator, cast +from typing import TYPE_CHECKING, ClassVar, Iterable, Iterator, cast from dataclasses import field, dataclass from collections.abc import Sequence @@ -244,6 +244,8 @@ class Block(IRNode["Region"]): [`.print()`][kirin.print.printable.Printable.print] method. """ + IS_BLOCK: ClassVar[bool] = True + _args: tuple[BlockArgument, ...] # NOTE: we need linked list since stmts are inserted frequently diff --git a/src/kirin/ir/nodes/region.py b/src/kirin/ir/nodes/region.py index 730477219..3a8b49152 100644 --- a/src/kirin/ir/nodes/region.py +++ b/src/kirin/ir/nodes/region.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Iterator, cast +from typing import TYPE_CHECKING, ClassVar, Iterable, Iterator, cast from dataclasses import field, dataclass from typing_extensions import Self @@ -94,6 +94,8 @@ class Region(IRNode["Statement"]): [`.print()`][kirin.print.printable.Printable.print] method. """ + IS_REGION: ClassVar[bool] = True + _blocks: list[Block] = field(default_factory=list, repr=False) _block_idx: dict[Block, int] = field(default_factory=dict, repr=False) _parent: Statement | None = field(default=None, repr=False) diff --git a/src/kirin/ir/nodes/stmt.py b/src/kirin/ir/nodes/stmt.py index edf8d119c..0581b6649 100644 --- a/src/kirin/ir/nodes/stmt.py +++ b/src/kirin/ir/nodes/stmt.py @@ -131,6 +131,8 @@ class Statement(IRNode["Block"]): [`.print()`][kirin.print.printable.Printable.print] method. """ + IS_STATEMENT: ClassVar[bool] = True + name: ClassVar[str] dialect: ClassVar[Dialect | None] = field(default=None, init=False, repr=False) traits: ClassVar[frozenset[Trait["Statement"]]] = frozenset() diff --git a/src/kirin/ir/ssa.py b/src/kirin/ir/ssa.py index f15c28597..5134dcc6e 100644 --- a/src/kirin/ir/ssa.py +++ b/src/kirin/ir/ssa.py @@ -24,6 +24,8 @@ class SSAValue(ABC, Printable): """Base class for all SSA values in the IR.""" + IS_SSA_VALUE: ClassVar[bool] = True + type: TypeAttribute = field(default_factory=AnyType, init=False, repr=True) """The type of this SSA value.""" hints: dict[str, Attribute] = field(default_factory=dict, init=False, repr=False) diff --git a/src/kirin/lowering/frame.py b/src/kirin/lowering/frame.py index 3e4c4b0b8..66e560ccb 100644 --- a/src/kirin/lowering/frame.py +++ b/src/kirin/lowering/frame.py @@ -53,9 +53,9 @@ def push(self, node: StmtType) -> StmtType: ... def push(self, node: Block) -> Block: ... def push(self, node: StmtType | Block) -> StmtType | Block: - if isinstance(node, Block): + if node.IS_BLOCK: return self._push_block(node) - elif isinstance(node, Statement): + elif node.IS_STATEMENT: return self._push_stmt(node) else: raise BuildError(f"Unsupported type {type(node)} in push()") diff --git a/src/kirin/rewrite/abc.py b/src/kirin/rewrite/abc.py index d76d995fd..46517ea10 100644 --- a/src/kirin/rewrite/abc.py +++ b/src/kirin/rewrite/abc.py @@ -1,4 +1,5 @@ from abc import ABC +from typing import cast from dataclasses import field, dataclass from kirin.ir import Pure, Block, IRNode, Region, MaybePure, Statement @@ -29,12 +30,12 @@ class RewriteRule(ABC): """ def rewrite(self, node: IRNode) -> RewriteResult: - if isinstance(node, Region): - return self.rewrite_Region(node) - elif isinstance(node, Block): - return self.rewrite_Block(node) - elif isinstance(node, Statement): - return self.rewrite_Statement(node) + if node.IS_REGION: + return self.rewrite_Region(cast(Region, node)) + elif node.IS_BLOCK: + return self.rewrite_Block(cast(Block, node)) + elif node.IS_STATEMENT: + return self.rewrite_Statement(cast(Statement, node)) else: return RewriteResult() diff --git a/src/kirin/rewrite/walk.py b/src/kirin/rewrite/walk.py index 20479b423..2115f456c 100644 --- a/src/kirin/rewrite/walk.py +++ b/src/kirin/rewrite/walk.py @@ -48,11 +48,11 @@ def populate_worklist(self, node: IRNode) -> None: if self.skip(node): return - if isinstance(node, Statement): + if node.IS_STATEMENT: self.populate_worklist_Statement(node) - elif isinstance(node, Region): + elif node.IS_REGION: self.populate_worklist_Region(node) - elif isinstance(node, Block): + elif node.IS_BLOCK: self.populate_worklist_Block(node) else: raise NotImplementedError(f"populate_worklist_{node.__class__.__name__}")