Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 118 additions & 22 deletions src/kirin/rewrite/inline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, cast
from dataclasses import dataclass

from kirin import ir
Expand Down Expand Up @@ -68,6 +68,121 @@ def inline_call_like(
args (tuple[ir.SSAValue, ...]): the arguments of the call (first one is the callee)
region (ir.Region): the region of the callee
"""
if not call_like.parent_block:
return

if not call_like.parent_region:
return

# NOTE: we cannot change region because it may be used elsewhere
inline_region: ir.Region = region.clone()

# Preserve source information by attributing inlined code to the call site
if call_like.source is not None:
for block in inline_region.blocks:
if block.source is None:
block.source = call_like.source
for stmt in block.stmts:
if stmt.source is None:
stmt.source = call_like.source

if self._can_use_simple_inline(inline_region):
return self._inline_simple(call_like, args, inline_region.blocks[0])

return self._inline_complex(call_like, args, inline_region)

def _can_use_simple_inline(self, inline_region: ir.Region) -> bool:
"""Check if we can use the fast path for simple single-block inlining.

Args:
inline_region: The cloned region to be inlined

Returns:
True if simple inline is possible (single block with simple return)
"""
if len(inline_region.blocks) != 1:
return False

block = inline_region.blocks[0]

# Last statement must be a simple return
if not isinstance(block.last_stmt, func.Return):
return False

return True

def _inline_simple(
self,
call_like: ir.Statement,
args: tuple[ir.SSAValue, ...],
func_block: ir.Block,
):
"""Fast path: inline single-block function by splicing statements.

For simple functions with no control flow, we just clone the function's
statements and insert them before the call site.
No new blocks are created, no statement parent updates are needed.

Complexity: O(k) where k = number of statements in function (typically small)

Args:
call_like: The call statement to replace
args: Arguments to the call (first is callee, rest are parameters)
func_block: The single block from the cloned function region
"""
ssa_map: dict[ir.SSAValue, ir.SSAValue] = {}
for func_arg, call_arg in zip(func_block.args, args):
ssa_map[func_arg] = call_arg
if func_arg.name and call_arg.name is None:
call_arg.name = func_arg.name

for stmt in func_block.stmts:
if isinstance(stmt, func.Return):
return_value = ssa_map.get(stmt.value, stmt.value)

if call_like.results:
for call_result in call_like.results:
call_result.replace_by(return_value)

# Don't insert the return statement itself
break

new_stmt = stmt.from_stmt(
stmt,
args=[ssa_map.get(arg, arg) for arg in stmt.args],
regions=[r.clone(ssa_map) for r in stmt.regions],
successors=stmt.successors, # successors are empty for simple stmts
)

new_stmt.insert_before(call_like)

# Update SSA mapping for newly created results
for old_result, new_result in zip(stmt.results, new_stmt.results):
ssa_map[old_result] = new_result
if old_result.name:
new_result.name = old_result.name

call_like.delete()

def _inline_complex(
self,
call_like: ir.Statement,
args: tuple[ir.SSAValue, ...],
inline_region: ir.Region,
):
"""Inline multi-block function with control flow.

This handles the general case where the function has multiple blocks

Complexity: O(n+k) where n = statements after call site (due to moving them)
and k = number of statements in function.

Args:
call_like: The call statement to replace
args: Arguments to the call
inline_region: The cloned function region to inline
"""

# <stmt>
# <stmt>
# <br (a, b, c)>
Expand All @@ -89,26 +204,8 @@ def inline_call_like(
# split the current block into two, and replace the return with
# the branch instruction
# 4. remove the call
if not call_like.parent_block:
return

if not call_like.parent_region:
return

# NOTE: we cannot change region because it may be used elsewhere
inline_region: ir.Region = region.clone()

# Preserve source information by attributing inlined code to the call site
if call_like.source is not None:
for block in inline_region.blocks:
if block.source is None:
block.source = call_like.source
for stmt in block.stmts:
if stmt.source is None:
stmt.source = call_like.source

parent_block: ir.Block = call_like.parent_block
parent_region: ir.Region = call_like.parent_region
parent_block: ir.Block = cast(ir.Block, call_like.parent_block)
parent_region: ir.Region = cast(ir.Region, call_like.parent_region)

# wrap what's after invoke into a block
after_block = ir.Block()
Expand Down Expand Up @@ -150,4 +247,3 @@ def inline_call_like(
successor=entry_block,
).insert_before(call_like)
call_like.delete()
return