diff --git a/src/kirin/rewrite/inline.py b/src/kirin/rewrite/inline.py index a15d2fd5a..4b81c4370 100644 --- a/src/kirin/rewrite/inline.py +++ b/src/kirin/rewrite/inline.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, cast from dataclasses import dataclass from kirin import ir @@ -68,6 +68,121 @@ def inline_call_like( args (tuple[ir.SSAValue, ...]): the arguments of the call (first one is the callee) region (ir.Region): the region of the callee """ + if not call_like.parent_block: + return + + if not call_like.parent_region: + return + + # NOTE: we cannot change region because it may be used elsewhere + inline_region: ir.Region = region.clone() + + # Preserve source information by attributing inlined code to the call site + if call_like.source is not None: + for block in inline_region.blocks: + if block.source is None: + block.source = call_like.source + for stmt in block.stmts: + if stmt.source is None: + stmt.source = call_like.source + + if self._can_use_simple_inline(inline_region): + return self._inline_simple(call_like, args, inline_region.blocks[0]) + + return self._inline_complex(call_like, args, inline_region) + + def _can_use_simple_inline(self, inline_region: ir.Region) -> bool: + """Check if we can use the fast path for simple single-block inlining. + + Args: + inline_region: The cloned region to be inlined + + Returns: + True if simple inline is possible (single block with simple return) + """ + if len(inline_region.blocks) != 1: + return False + + block = inline_region.blocks[0] + + # Last statement must be a simple return + if not isinstance(block.last_stmt, func.Return): + return False + + return True + + def _inline_simple( + self, + call_like: ir.Statement, + args: tuple[ir.SSAValue, ...], + func_block: ir.Block, + ): + """Fast path: inline single-block function by splicing statements. + + For simple functions with no control flow, we just clone the function's + statements and insert them before the call site. + No new blocks are created, no statement parent updates are needed. + + Complexity: O(k) where k = number of statements in function (typically small) + + Args: + call_like: The call statement to replace + args: Arguments to the call (first is callee, rest are parameters) + func_block: The single block from the cloned function region + """ + ssa_map: dict[ir.SSAValue, ir.SSAValue] = {} + for func_arg, call_arg in zip(func_block.args, args): + ssa_map[func_arg] = call_arg + if func_arg.name and call_arg.name is None: + call_arg.name = func_arg.name + + for stmt in func_block.stmts: + if isinstance(stmt, func.Return): + return_value = ssa_map.get(stmt.value, stmt.value) + + if call_like.results: + for call_result in call_like.results: + call_result.replace_by(return_value) + + # Don't insert the return statement itself + break + + new_stmt = stmt.from_stmt( + stmt, + args=[ssa_map.get(arg, arg) for arg in stmt.args], + regions=[r.clone(ssa_map) for r in stmt.regions], + successors=stmt.successors, # successors are empty for simple stmts + ) + + new_stmt.insert_before(call_like) + + # Update SSA mapping for newly created results + for old_result, new_result in zip(stmt.results, new_stmt.results): + ssa_map[old_result] = new_result + if old_result.name: + new_result.name = old_result.name + + call_like.delete() + + def _inline_complex( + self, + call_like: ir.Statement, + args: tuple[ir.SSAValue, ...], + inline_region: ir.Region, + ): + """Inline multi-block function with control flow. + + This handles the general case where the function has multiple blocks + + Complexity: O(n+k) where n = statements after call site (due to moving them) + and k = number of statements in function. + + Args: + call_like: The call statement to replace + args: Arguments to the call + inline_region: The cloned function region to inline + """ + # # #
@@ -89,26 +204,8 @@ def inline_call_like( # split the current block into two, and replace the return with # the branch instruction # 4. remove the call - if not call_like.parent_block: - return - - if not call_like.parent_region: - return - - # NOTE: we cannot change region because it may be used elsewhere - inline_region: ir.Region = region.clone() - - # Preserve source information by attributing inlined code to the call site - if call_like.source is not None: - for block in inline_region.blocks: - if block.source is None: - block.source = call_like.source - for stmt in block.stmts: - if stmt.source is None: - stmt.source = call_like.source - - parent_block: ir.Block = call_like.parent_block - parent_region: ir.Region = call_like.parent_region + parent_block: ir.Block = cast(ir.Block, call_like.parent_block) + parent_region: ir.Region = cast(ir.Region, call_like.parent_region) # wrap what's after invoke into a block after_block = ir.Block() @@ -150,4 +247,3 @@ def inline_call_like( successor=entry_block, ).insert_before(call_like) call_like.delete() - return