From 9a0ea1e19fe20005b049a535289564f1fb1591e5 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Tue, 28 Oct 2025 11:51:33 -0700 Subject: [PATCH 1/3] Add simple O(1) inlining --- src/kirin/rewrite/inline.py | 125 ++++++++++++++++++++++++++++++------ 1 file changed, 106 insertions(+), 19 deletions(-) diff --git a/src/kirin/rewrite/inline.py b/src/kirin/rewrite/inline.py index a15d2fd5a..fc6fa1c11 100644 --- a/src/kirin/rewrite/inline.py +++ b/src/kirin/rewrite/inline.py @@ -68,6 +68,112 @@ def inline_call_like( args (tuple[ir.SSAValue, ...]): the arguments of the call (first one is the callee) region (ir.Region): the region of the callee """ + if not call_like.parent_block: + return + + if not call_like.parent_region: + return + + # NOTE: we cannot change region because it may be used elsewhere + inline_region: ir.Region = region.clone() + + if self._can_use_simple_inline(inline_region): + return self._inline_simple(call_like, args, inline_region.blocks[0]) + + return self._inline_complex(call_like, args, inline_region) + + def _can_use_simple_inline(self, inline_region: ir.Region) -> bool: + """Check if we can use the fast path for simple single-block inlining. + + Args: + inline_region: The cloned region to be inlined + + Returns: + True if simple inline is possible (single block with simple return) + """ + if len(inline_region.blocks) != 1: + return False + + block = inline_region.blocks[0] + + # Last statement must be a simple return + if not block.last_stmt or not isinstance(block.last_stmt, func.Return): + return False + + return True + + def _inline_simple( + self, + call_like: ir.Statement, + args: tuple[ir.SSAValue, ...], + func_block: ir.Block, + ): + """Fast path: inline single-block function by splicing statements. + + For simple functions with no control flow, we just clone the function's + statements and insert them before the call site. + No new blocks are created, no statement parent updates are needed. + + Complexity: O(k) where k = number of statements in function (typically small) + + Args: + call_like: The call statement to replace + args: Arguments to the call (first is callee, rest are parameters) + func_block: The single block from the cloned function region + """ + ssa_map: dict[ir.SSAValue, ir.SSAValue] = {} + for func_arg, call_arg in zip(func_block.args, args): + ssa_map[func_arg] = call_arg + if func_arg.name and call_arg.name is None: + call_arg.name = func_arg.name + + for stmt in func_block.stmts: + if isinstance(stmt, func.Return): + return_value = ssa_map.get(stmt.value, stmt.value) + + if call_like.results: + for call_result in call_like.results: + call_result.replace_by(return_value) + + # Don't insert the return statement itself + continue + + new_stmt = stmt.from_stmt( + stmt, + args=[ssa_map.get(arg, arg) for arg in stmt.args], + regions=[r.clone(ssa_map) for r in stmt.regions], + successors=stmt.successors, # successors are empty for simple stmts + ) + + new_stmt.insert_before(call_like) + + # Update SSA mapping for newly created results + for old_result, new_result in zip(stmt.results, new_stmt.results): + ssa_map[old_result] = new_result + if old_result.name: + new_result.name = old_result.name + + call_like.delete() + + def _inline_complex( + self, + call_like: ir.Statement, + args: tuple[ir.SSAValue, ...], + inline_region: ir.Region, + ): + """Inline multi-block function with control flow. + + This handles the general case where the function has multiple blocks + + Complexity: O(n k) where n = statements after call site (due to moving them) + and k = number of statements in function. + + Args: + call_like: The call statement to replace + args: Arguments to the call + inline_region: The cloned function region to inline + """ + # # #
@@ -89,24 +195,6 @@ def inline_call_like( # split the current block into two, and replace the return with # the branch instruction # 4. remove the call - if not call_like.parent_block: - return - - if not call_like.parent_region: - return - - # NOTE: we cannot change region because it may be used elsewhere - inline_region: ir.Region = region.clone() - - # Preserve source information by attributing inlined code to the call site - if call_like.source is not None: - for block in inline_region.blocks: - if block.source is None: - block.source = call_like.source - for stmt in block.stmts: - if stmt.source is None: - stmt.source = call_like.source - parent_block: ir.Block = call_like.parent_block parent_region: ir.Region = call_like.parent_region @@ -150,4 +238,3 @@ def inline_call_like( successor=entry_block, ).insert_before(call_like) call_like.delete() - return From a436cc5610e2d5580235c10366378c7123ff3464 Mon Sep 17 00:00:00 2001 From: rafaelha <34611791+rafaelha@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:37:11 -0400 Subject: [PATCH 2/3] Apply suggestion from @cduck Co-authored-by: Casey Duckering --- src/kirin/rewrite/inline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kirin/rewrite/inline.py b/src/kirin/rewrite/inline.py index fc6fa1c11..b90e69c9b 100644 --- a/src/kirin/rewrite/inline.py +++ b/src/kirin/rewrite/inline.py @@ -97,7 +97,7 @@ def _can_use_simple_inline(self, inline_region: ir.Region) -> bool: block = inline_region.blocks[0] # Last statement must be a simple return - if not block.last_stmt or not isinstance(block.last_stmt, func.Return): + if not isinstance(block.last_stmt, func.Return): return False return True From bf2c0258914be707a475e305bdedc8072ab116e0 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Tue, 28 Oct 2025 14:03:55 -0700 Subject: [PATCH 3/3] Address comments --- src/kirin/rewrite/inline.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/kirin/rewrite/inline.py b/src/kirin/rewrite/inline.py index b90e69c9b..4b81c4370 100644 --- a/src/kirin/rewrite/inline.py +++ b/src/kirin/rewrite/inline.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, cast from dataclasses import dataclass from kirin import ir @@ -77,6 +77,15 @@ def inline_call_like( # NOTE: we cannot change region because it may be used elsewhere inline_region: ir.Region = region.clone() + # Preserve source information by attributing inlined code to the call site + if call_like.source is not None: + for block in inline_region.blocks: + if block.source is None: + block.source = call_like.source + for stmt in block.stmts: + if stmt.source is None: + stmt.source = call_like.source + if self._can_use_simple_inline(inline_region): return self._inline_simple(call_like, args, inline_region.blocks[0]) @@ -136,7 +145,7 @@ def _inline_simple( call_result.replace_by(return_value) # Don't insert the return statement itself - continue + break new_stmt = stmt.from_stmt( stmt, @@ -165,7 +174,7 @@ def _inline_complex( This handles the general case where the function has multiple blocks - Complexity: O(n k) where n = statements after call site (due to moving them) + Complexity: O(n+k) where n = statements after call site (due to moving them) and k = number of statements in function. Args: @@ -195,8 +204,8 @@ def _inline_complex( # split the current block into two, and replace the return with # the branch instruction # 4. remove the call - parent_block: ir.Block = call_like.parent_block - parent_region: ir.Region = call_like.parent_region + parent_block: ir.Block = cast(ir.Block, call_like.parent_block) + parent_region: ir.Region = cast(ir.Region, call_like.parent_region) # wrap what's after invoke into a block after_block = ir.Block()