Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kirin/ir/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def wrapper(py_func: Callable) -> Method:
f"`{py_func.__name__}` is already defined in the current scope and is not a Method."
)

lineno_offset = call_site_frame.f_lineno - 1
lineno_offset = py_func.__code__.co_firstlineno - 1
file = call_site_frame.f_code.co_filename

code = self.lowering.python_function(py_func, lineno_offset=lineno_offset)
Expand Down
7 changes: 7 additions & 0 deletions src/kirin/ir/nodes/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def insert_after(self, stmt: Statement) -> None:
self.parent = stmt.parent
stmt._next_stmt = self

if self.source is None and stmt.source is not None:
self.source = stmt.source

if self.parent:
self.parent._stmt_len += 1

Expand Down Expand Up @@ -302,6 +305,9 @@ def insert_before(self, stmt: Statement) -> None:
self.parent = stmt.parent
stmt._prev_stmt = self

if self.source is None and stmt.source is not None:
self.source = stmt.source

if self.parent:
self.parent._stmt_len += 1

Expand Down Expand Up @@ -506,6 +512,7 @@ def from_stmt(
attributes=attributes or other.attributes,
result_types=[result.type for result in other._results],
args_slice=other._name_args_slice,
source=other.source,
)
# inherit the hint:
for result, other_result in zip(obj._results, other._results):
Expand Down
2 changes: 1 addition & 1 deletion src/kirin/lowering/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def _push_stmt(self, stmt: StmtType) -> StmtType:
raise BuildError(
f"Unsupported dialect `{stmt.dialect.name}` from statement {stmt.name}"
)
self.curr_block.stmts.append(stmt)
if stmt.source is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since stmts.append calls Statement.insert which sets the source info, it's now important to set the source before.

stmt.source = self.state.source
self.curr_block.stmts.append(stmt)
return stmt

def _push_block(self, block: Block):
Expand Down
4 changes: 3 additions & 1 deletion src/kirin/lowering/python/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def lower_global(self, state: State[ast.AST], node: ast.AST) -> LoweringABC.Resu
def visit(self, state: State[ast.AST], node: ast.AST) -> Result:
if hasattr(node, "lineno"):
state.source = SourceInfo.from_ast(node, state.file)
state.source.offset(state.lineno_offset, state.col_offset)
name = node.__class__.__name__
if name in self.registry.ast_table:
return self.registry.ast_table[name].lower(state, node)
Expand All @@ -148,7 +149,8 @@ def generic_visit(self, state: State[ast.AST], node: ast.AST) -> Result:

def visit_Call(self, state: State[ast.AST], node: ast.Call) -> Result:
if hasattr(node.func, "lineno"):
state.source = SourceInfo.from_ast(node.func, state.file)
state.source = SourceInfo.from_ast(node, state.file)
state.source.offset(state.lineno_offset, state.col_offset)

global_callee_result = state.get_global(node.func, no_raise=True)
if global_callee_result is None:
Expand Down
10 changes: 10 additions & 0 deletions src/kirin/rewrite/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def inline_call_like(

# NOTE: we cannot change region because it may be used elsewhere
inline_region: ir.Region = region.clone()

# Preserve source information by attributing inlined code to the call site
if call_like.source is not None:
for block in inline_region.blocks:
if block.source is None:
block.source = call_like.source
for stmt in block.stmts:
if stmt.source is None:
stmt.source = call_like.source

parent_block: ir.Block = call_like.parent_block
parent_region: ir.Region = call_like.parent_region

Expand Down
10 changes: 10 additions & 0 deletions test/ir/test_stmt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from kirin.ir import Block
from kirin.source import SourceInfo
from kirin.dialects import py


Expand Down Expand Up @@ -50,3 +51,12 @@ def test_stmt_from_stmt():
y = x.from_stmt(x)

assert y.result.hints["const"] == py.constant.types.Int


def test_stmt_from_stmt_preserves_source_info():
x = py.Constant(1)
x.source = SourceInfo(lineno=1, col_offset=0, end_lineno=None, end_col_offset=None)

y = x.from_stmt(x)
assert y.source == x.source
assert y.source is x.source
36 changes: 36 additions & 0 deletions test/lowering/test_source_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from kirin.source import SourceInfo
from kirin.prelude import basic_no_opt


def get_line_of(target: str) -> int:
for i, line in enumerate(open(__file__), 1):
if target in line:
return i


@pytest.mark.parametrize("similar", [True, False])
def test_stmt_source_info(similar: bool):
@basic_no_opt
def test(x: int):
y = 2
a = 4**2
return y + 2 + a

if similar:
test = test.similar()

stmts = test.callable_region.blocks[0].stmts

def get_line_from_source_info(source: SourceInfo) -> int:
return source.lineno + source.lineno_begin

for stmt in stmts:
assert stmt.source.file == __file__

assert get_line_from_source_info(stmts.at(0).source) == get_line_of("y = 2")
assert get_line_from_source_info(stmts.at(2).source) == get_line_of("a = 4**2")
assert get_line_from_source_info(stmts.at(4).source) == get_line_of(
"return y + 2 + a"
)
23 changes: 23 additions & 0 deletions test/passes/test_inline_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from kirin.prelude import basic_no_opt
from kirin.passes.inline import InlinePass
from kirin.dialects.py.constant import Constant


@basic_no_opt
Expand Down Expand Up @@ -40,3 +41,25 @@ def main_inline_pass2(x: int):
assert a == b

assert len(main_inline_pass2.callable_region.blocks[0].stmts) == 4


def test_inline_preserves_source_info():
def get_line_of(target: str) -> int:
for i, line in enumerate(open(__file__), 1):
if target in line:
return i

@basic_no_opt
def main_inline_pass(x: int):
y = inline_func(x)
return y + 2

inline = InlinePass(main_inline_pass.dialects)
inline(main_inline_pass)

stmt = main_inline_pass.callable_region.blocks[0].stmts.at(0)
line = stmt.source.lineno + stmt.source.lineno_begin
assert stmt.value.data == 1
assert isinstance(stmt, Constant)

assert get_line_of("return x - 1") == line