From f8293f837c2fa96017b34f11c3e30cf9bcf1637e Mon Sep 17 00:00:00 2001 From: kaihsin Date: Tue, 18 Nov 2025 19:42:56 -0500 Subject: [PATCH 01/11] simplify --- src/kirin/ir/attrs/_types.pyi | 4 +- src/kirin/ir/attrs/types.py | 159 ++++++++++++++++++ src/kirin/ir/method.py | 4 +- src/kirin/types.py | 6 +- .../dataflow/typeinfer/test_inter_method.py | 2 + test/dialects/{py => py_dialect}/__init__.py | 0 .../{py => py_dialect}/test_assign.py | 0 test/dialects/{py => py_dialect}/test_iter.py | 0 .../{py => py_dialect}/test_tuple_infer.py | 0 test/dialects/test_pytypes.py | 20 +++ 10 files changed, 189 insertions(+), 6 deletions(-) rename test/dialects/{py => py_dialect}/__init__.py (100%) rename test/dialects/{py => py_dialect}/test_assign.py (100%) rename test/dialects/{py => py_dialect}/test_iter.py (100%) rename test/dialects/{py => py_dialect}/test_tuple_infer.py (100%) diff --git a/src/kirin/ir/attrs/_types.pyi b/src/kirin/ir/attrs/_types.pyi index aebf867ba..1f08bcd49 100644 --- a/src/kirin/ir/attrs/_types.pyi +++ b/src/kirin/ir/attrs/_types.pyi @@ -1,7 +1,7 @@ from dataclasses import dataclass from .abc import Attribute -from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute +from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute, FunctionType, TypeofMethodType @dataclass class _TypeAttribute(Attribute): @@ -11,3 +11,5 @@ class _TypeAttribute(Attribute): def is_subseteq_PyClass(self, other: PyClass) -> bool: ... def is_subseteq_Generic(self, other: Generic) -> bool: ... def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ... + def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ... + def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ... \ No newline at end of file diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index 5dfcbe8ee..2ebaffeb5 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -715,6 +715,165 @@ def deserialize( out.vararg = vararg return out +@typing.final +@dataclass(eq=False) +class TypeofMethodType(TypeAttribute, metaclass=SingletonTypeMeta): + name = "TypeofMethodType" + + def __hash__(self) -> int: + return hash((TypeofMethodType,)) + + def is_structurally_equal( + self, other: Attribute, context: dict | None = None + ) -> bool: + return isinstance(other, TypeofMethodType) + + def is_subseteq_TypeofMethodType(self, other: "TypeofMethodType") -> bool: + return True + + def serialize(self, serializer: "Serializer") -> "SerializationUnit": + return SerializationUnit( + kind="type-attribute", + module_name=self.__module__, + class_name=self.__class__.__name__, + data=dict(), + ) + + @classmethod + def deserialize( + cls, serUnit: "SerializationUnit", deserializer: "Deserializer" + ) -> "TypeofMethodType": + return TypeofMethodType() + + def __getitem__( + self, + typ: ( + tuple[list[TypeAttribute], TypeAttribute] + | tuple[list[TypeAttribute]] + ), + ) -> "FunctionType": + if isinstance(typ, tuple) and len(typ) == 2: + return FunctionType(tuple(typ[0]), typ[1]) + elif isinstance(typ, tuple) and len(typ) == 1: + return FunctionType(tuple(typ[0])) + else: + raise TypeError("Invalid type arguments for TypeofMethodType") + + + +@typing.final +@dataclass(eq=False) +class FunctionType(TypeAttribute): + name = "MethodType" + params_type: tuple[TypeAttribute, ...] + return_type: TypeAttribute | None = None + + def __init__( + self, + params_type: tuple[TypeAttribute, ...], + return_type: TypeAttribute | None = None, + ): + self.params_type = params_type + self.return_type = return_type + + def __hash__(self) -> int: + return hash((FunctionType, self.params_type, self.return_type)) + + def __repr__(self) -> str: + params = ", ".join(map(repr, self.params_type)) + if self.return_type is not None: + return f"({params}) -> {repr(self.return_type)}" + else: + return f"({params}) -> None" + + def print_impl(self, printer: Printer) -> None: + printer.plain_print("(") + printer.print_seq(self.params_type, delim=", ") + printer.plain_print(") -> ") + if self.return_type is not None: + printer.print(self.return_type) + else: + printer.plain_print("None") + + def __getitem__( + self, + typ: ( + tuple[tuple[TypeAttribute, ...], TypeAttribute] + | tuple[tuple[TypeAttribute, ...]] + ), + ) -> "FunctionType": + if isinstance(typ, tuple) and len(typ) == 2: + return self.where(typ[0], typ[1]) + elif isinstance(typ, tuple) and len(typ) == 1: + return self.where(typ[0]) + else: + raise TypeError("Invalid type arguments for MethodType") + + def where( + self, typ: tuple[TypeAttribute, ...], return_type: TypeAttribute | None = None + ) -> "FunctionType": + if len(typ) != len(self.params_type): + raise TypeError("Number of type arguments does not match") + if all(v.is_subseteq(bound) for v, bound in zip(typ, self.params_type)): + if return_type is None: + return FunctionType(typ, self.return_type) + elif self.return_type is not None and return_type.is_subseteq( + self.return_type + ): + return FunctionType(typ, return_type) + raise TypeError("Type arguments do not match") + + def is_structurally_equal( + self, other: Attribute, context: dict | None = None + ) -> bool: + return ( + isinstance(other, FunctionType) + and self.params_type == other.params_type + and self.return_type == other.return_type + ) + + def is_subseteq_FunctionType(self, other: "FunctionType") -> bool: + if len(self.params_type) != len(other.params_type): + return False + for s_param, o_param in zip(self.params_type, other.params_type): + if not s_param.is_subseteq(o_param): + return False + if self.return_type is None: + return True + elif other.return_type is None: + return False + else: + return self.return_type.is_subseteq(other.return_type) + + def is_subseteq_TypeofMethodType(self, other: "TypeofMethodType") -> bool: + return True + + def is_subseteq_Union(self, other: Union) -> bool: + return any(self.is_subseteq(t) for t in other.types) + + def is_subseteq_fallback(self, other: TypeAttribute) -> bool: + return False + + def serialize(self, serializer: "Serializer") -> "SerializationUnit": + return SerializationUnit( + kind="type-attribute", + module_name=self.__module__, + class_name=self.__class__.__name__, + data={ + "params_type": serializer.serialize_tuple(self.params_type), + "return_type": serializer.serialize(self.return_type), + }, + ) + + @classmethod + def deserialize( + cls, serUnit: "SerializationUnit", deserializer: "Deserializer" + ) -> "FunctionType": + params_type = deserializer.deserialize_tuple(serUnit.data["params_type"]) + return_type = deserializer.deserialize(serUnit.data["return_type"]) + return FunctionType(params_type, return_type) + + def _typeparams_list2tuple(args: tuple[TypeVarValue, ...]) -> tuple[TypeOrVararg, ...]: "provides the syntax sugar [A, B, C] type Generic(tuple, A, B, C)" diff --git a/src/kirin/ir/method.py b/src/kirin/ir/method.py index cd9da5d6a..e44bb41ee 100644 --- a/src/kirin/ir/method.py +++ b/src/kirin/ir/method.py @@ -25,7 +25,7 @@ ) from .exception import ValidationError from .nodes.stmt import Statement -from .attrs.types import Generic +from .attrs.types import FunctionType if typing.TYPE_CHECKING: from kirin.ir.group import DialectGroup @@ -141,7 +141,7 @@ def self_type(self): """Return the type of the self argument of the method.""" trait = self.code.get_present_trait(HasSignature) signature = trait.get_signature(self.code) - return Generic(Method, Generic(tuple, *signature.inputs), signature.output) + return FunctionType(params_type=signature.inputs, return_type=signature.output) @property def callable_region(self): diff --git a/src/kirin/types.py b/src/kirin/types.py index 2f57e173d..0f41bb7c4 100644 --- a/src/kirin/types.py +++ b/src/kirin/types.py @@ -13,6 +13,8 @@ TypeVar as TypeVar, BottomType as BottomType, TypeAttribute as TypeAttribute, + TypeofMethodType as TypeofMethodType, + FunctionType as FunctionType, hint2type as hint2type, is_tuple_of as is_tuple_of, ) @@ -32,6 +34,4 @@ Dict = Generic(dict, TypeVar("K"), TypeVar("V")) Set = Generic(set, TypeVar("T")) FrozenSet = Generic(frozenset, TypeVar("T")) -TypeofFunctionType = Generic[type(lambda: None)] -FunctionType = Generic(type(lambda: None), Tuple, Vararg(Any)) -MethodType = Generic(Method, TypeVar("Params", Tuple), TypeVar("Ret")) +MethodType = TypeofMethodType() \ No newline at end of file diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index bc4bccf23..d8373172b 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -57,4 +57,6 @@ def _new(qid: int): def alloc(n_iter: int): return ilist.map(_new, ilist.range(n_iter)) + alloc.print() assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) +test_method_constant_type_infer() \ No newline at end of file diff --git a/test/dialects/py/__init__.py b/test/dialects/py_dialect/__init__.py similarity index 100% rename from test/dialects/py/__init__.py rename to test/dialects/py_dialect/__init__.py diff --git a/test/dialects/py/test_assign.py b/test/dialects/py_dialect/test_assign.py similarity index 100% rename from test/dialects/py/test_assign.py rename to test/dialects/py_dialect/test_assign.py diff --git a/test/dialects/py/test_iter.py b/test/dialects/py_dialect/test_iter.py similarity index 100% rename from test/dialects/py/test_iter.py rename to test/dialects/py_dialect/test_iter.py diff --git a/test/dialects/py/test_tuple_infer.py b/test/dialects/py_dialect/test_tuple_infer.py similarity index 100% rename from test/dialects/py/test_tuple_infer.py rename to test/dialects/py_dialect/test_tuple_infer.py diff --git a/test/dialects/test_pytypes.py b/test/dialects/test_pytypes.py index 61dfd61ba..e97107527 100644 --- a/test/dialects/test_pytypes.py +++ b/test/dialects/test_pytypes.py @@ -17,6 +17,7 @@ NoneType, BottomType, TypeAttribute, + MethodType, ) @@ -112,3 +113,22 @@ def test_generic_topbottom(): assert t.meet(TypeAttribute.bottom()).is_subseteq(TypeAttribute.bottom()) assert t.join(TypeAttribute.top()).is_structurally_equal(TypeAttribute.top()) assert t.meet(TypeAttribute.top()).is_structurally_equal(t) + + +def test_method_type(): + t1 = MethodType[[Int, Float], Bool] + t2 = MethodType[[Int, Float], Bool] + + assert t1.is_subseteq(t2) + + t3 = MethodType[[Int, Float], AnyType()] + assert t1.is_subseteq(t3) + + t4 = MethodType[[Int, Float], String] + assert not t1.is_subseteq(t4) + + Var = TypeVar("Var") + t5 = MethodType[[Int, Var], Bool] + assert t1.is_subseteq(t5) + + From d7adde5c4dd513840e3308e2986eeed23fd32793 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Tue, 18 Nov 2025 19:53:19 -0500 Subject: [PATCH 02/11] tmp --- test/analysis/dataflow/typeinfer/test_inter_method.py | 3 +-- test/dialects/test_pytypes.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index d8373172b..8375f0735 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -57,6 +57,5 @@ def _new(qid: int): def alloc(n_iter: int): return ilist.map(_new, ilist.range(n_iter)) - alloc.print() assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) -test_method_constant_type_infer() \ No newline at end of file + diff --git a/test/dialects/test_pytypes.py b/test/dialects/test_pytypes.py index e97107527..9a640abd2 100644 --- a/test/dialects/test_pytypes.py +++ b/test/dialects/test_pytypes.py @@ -130,5 +130,3 @@ def test_method_type(): Var = TypeVar("Var") t5 = MethodType[[Int, Var], Bool] assert t1.is_subseteq(t5) - - From 6f2a2e1fd2d76966a5e6aea4b12bb5d0c105a423 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 16:02:21 -0500 Subject: [PATCH 03/11] fix type hint --- src/kirin/analysis/typeinfer/solve.py | 31 +++++++++++++++++++++++++++ src/kirin/dialects/func/__init__.py | 2 +- src/kirin/dialects/func/attrs.py | 10 ++++----- src/kirin/dialects/func/stmts.py | 14 ++++++------ src/kirin/dialects/ilist/stmts.py | 15 ++++++------- src/kirin/dialects/lowering/func.py | 4 +--- test/ir/test_isequal.py | 4 ++-- test/lowering/test_func.py | 6 ++++-- 8 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/kirin/analysis/typeinfer/solve.py b/src/kirin/analysis/typeinfer/solve.py index b08d576a1..213a24561 100644 --- a/src/kirin/analysis/typeinfer/solve.py +++ b/src/kirin/analysis/typeinfer/solve.py @@ -69,6 +69,13 @@ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute: ) elif isinstance(typ, types.Union): return types.Union(self.substitute(t) for t in typ.types) + elif isinstance(typ, types.FunctionType): + return types.FunctionType( + params_type=tuple(self.substitute(t) for t in typ.params_type), + return_type=self.substitute(typ.return_type) + if typ.return_type + else None, + ) return typ def solve( @@ -94,6 +101,8 @@ def solve( return self.solve_Generic(annot, value) elif isinstance(annot, types.Union): return self.solve_Union(annot, value) + elif isinstance(annot, types.FunctionType): + return self.solve_FunctionType(annot, value) if annot.is_subseteq(value): return Ok @@ -133,6 +142,28 @@ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute): return result return Ok + def solve_FunctionType(self, annot: types.FunctionType, value: types.TypeAttribute): + if not isinstance(value, types.FunctionType): + return ResolutionError(annot, value) + + for var, val in zip(annot.params_type, value.params_type): + result = self.solve(var, val) + if not result: + return result + + if not annot.return_type or not value.return_type: + return Ok + + result = self.solve(annot.return_type, value.return_type) + if not result: + return result + + return Ok + + + + + def solve_Union(self, annot: types.Union, value: types.TypeAttribute): for typ in annot.types: result = self.solve(typ, value) diff --git a/src/kirin/dialects/func/__init__.py b/src/kirin/dialects/func/__init__.py index 0aeecc34f..913385f3f 100644 --- a/src/kirin/dialects/func/__init__.py +++ b/src/kirin/dialects/func/__init__.py @@ -5,7 +5,7 @@ constprop as constprop, typeinfer as typeinfer, ) -from kirin.dialects.func.attrs import Signature as Signature, MethodType as MethodType +from kirin.dialects.func.attrs import Signature as Signature from kirin.dialects.func.stmts import ( Call as Call, Invoke as Invoke, diff --git a/src/kirin/dialects/func/attrs.py b/src/kirin/dialects/func/attrs.py index 3d6ec46ae..017ab2b1c 100644 --- a/src/kirin/dialects/func/attrs.py +++ b/src/kirin/dialects/func/attrs.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from kirin import types -from kirin.ir import Method, Attribute +from kirin.ir import Attribute from kirin.print.printer import Printer from kirin.serialization.core.serializationunit import SerializationUnit @@ -12,10 +12,10 @@ from ._dialect import dialect -TypeofMethodType = types.PyClass[Method] -MethodType = types.Generic( - Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret") -) +# TypeofMethodType = types.PyClass[Method] +# MethodType = types.Generic( +# Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret") +# ) TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute") diff --git a/src/kirin/dialects/func/stmts.py b/src/kirin/dialects/func/stmts.py index f026f21b2..309ea26ff 100644 --- a/src/kirin/dialects/func/stmts.py +++ b/src/kirin/dialects/func/stmts.py @@ -1,13 +1,13 @@ from __future__ import annotations -from types import MethodType as ClassMethodType, FunctionType +from types import MethodType as PyClassMethodType, FunctionType as PyFunctionType from typing import TypeVar from kirin import ir, types from kirin.decl import info, statement from kirin.print.printer import Printer -from .attrs import Signature, MethodType +from .attrs import Signature from ._dialect import dialect @@ -58,7 +58,7 @@ class Function(ir.Statement): """The signature of the function at declaration.""" body: ir.Region = info.region(multi=True) """The body of the function.""" - result: ir.ResultValue = info.result(MethodType) + result: ir.ResultValue = info.result(types.MethodType) """The result of the function.""" def print_impl(self, printer: Printer) -> None: @@ -115,7 +115,7 @@ class Lambda(ir.Statement): """The signature of the function at declaration.""" captured: tuple[ir.SSAValue, ...] = info.argument() body: ir.Region = info.region(multi=True) - result: ir.ResultValue = info.result(MethodType) + result: ir.ResultValue = info.result(types.MethodType) def check(self) -> None: assert self.body.blocks, "lambda body must not be empty" @@ -145,7 +145,7 @@ def print_impl(self, printer: Printer) -> None: class GetField(ir.Statement): name = "getfield" traits = frozenset({ir.Pure()}) - obj: ir.SSAValue = info.argument(MethodType) + obj: ir.SSAValue = info.argument(types.MethodType) field: int = info.attribute() # NOTE: mypy somehow doesn't understand default init=False result: ir.ResultValue = info.result(init=False) @@ -249,7 +249,7 @@ def print_impl(self, printer: Printer) -> None: def check_type(self) -> None: if not self.callee.type.is_subseteq(types.MethodType): - if self.callee.type.is_subseteq(types.PyClass(FunctionType)): + if self.callee.type.is_subseteq(types.PyClass(PyFunctionType)): raise ir.TypeCheckError( self, f"callee must be a method type, got {self.callee.type}", @@ -257,7 +257,7 @@ def check_type(self) -> None: "consider decorating it with kernel decorator", ) - if self.callee.type.is_subseteq(types.PyClass(ClassMethodType)): + if self.callee.type.is_subseteq(types.PyClass(PyClassMethodType)): raise ir.TypeCheckError( self, "callee must be a method type, got class method", diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index b6d078210..f3a4bd81e 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -75,9 +75,7 @@ class Map(ir.Statement): class Foldr(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [ElemT, OutElemT], OutElemT) - ) + fn: ir.SSAValue = info.argument(types.MethodType[[ElemT, OutElemT], OutElemT]) collection: ir.SSAValue = info.argument(IListType[ElemT]) init: ir.SSAValue = info.argument(OutElemT) result: ir.ResultValue = info.result(OutElemT) @@ -87,9 +85,8 @@ class Foldr(ir.Statement): class Foldl(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [OutElemT, ElemT], OutElemT) - ) + fn: ir.SSAValue = info.argument(types.MethodType[[OutElemT, ElemT], OutElemT]) + collection: ir.SSAValue = info.argument(IListType[ElemT]) init: ir.SSAValue = info.argument(OutElemT) result: ir.ResultValue = info.result(OutElemT) @@ -104,7 +101,7 @@ class Scan(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT]) + types.MethodType[[OutElemT, ElemT], types.Tuple[OutElemT, ResultT]] ) collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen]) init: ir.SSAValue = info.argument(OutElemT) @@ -117,7 +114,7 @@ class Scan(ir.Statement): class ForEach(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType)) + fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], types.NoneType]) collection: ir.SSAValue = info.argument(IListType[ElemT]) @@ -141,7 +138,7 @@ class Sorted(ir.Statement): purity: bool = info.attribute(default=False) collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen]) key: ir.SSAValue = info.argument( - types.Union((types.Generic(ir.Method, [ElemT], ElemT), types.NoneType)) + types.Union((types.MethodType[[ElemT], ElemT], types.NoneType)) ) reverse: ir.SSAValue = info.argument(types.Bool) result: ir.ResultValue = info.result(IListType[ElemT, ListLen]) diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index 6b98efecb..19b57fc49 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -35,9 +35,7 @@ def lower_FunctionDef( entries: dict[str, ir.SSAValue] = {} entr_block = ir.Block() fn_self = entr_block.args.append_from( - types.Generic( - ir.Method, types.Tuple.where(signature.inputs), signature.output - ), + types.MethodType[list(signature.inputs), signature.output], node.name + "_self", ) entries[node.name] = fn_self diff --git a/test/ir/test_isequal.py b/test/ir/test_isequal.py index 5e552b9f7..e0d6cac01 100644 --- a/test/ir/test_isequal.py +++ b/test/ir/test_isequal.py @@ -4,7 +4,7 @@ def test_is_structurally_equal_ignoring_hint(): block = ir.Block() - block.args.append_from(types.PyClass(ir.Method), "self") + block.args.append_from(types.MethodType, "self") source_func = func.Function( sym_name="main", signature=func.Signature( @@ -15,7 +15,7 @@ def test_is_structurally_equal_ignoring_hint(): ) block = ir.Block() - block.args.append_from(types.PyClass(ir.Method), "self") + block.args.append_from(types.MethodType, "self") expected_func = func.Function( sym_name="main", signature=func.Signature( diff --git a/test/lowering/test_func.py b/test/lowering/test_func.py index 678c1b104..282518dfd 100644 --- a/test/lowering/test_func.py +++ b/test/lowering/test_func.py @@ -33,14 +33,16 @@ def recursive(n): return recursive(n - 1) code = lower.python_function(recursive) + code.print() assert isinstance(code, func.Function) assert len(code.body.blocks) == 3 assert isinstance(code.body.blocks[0].last_stmt, cf.ConditionalBranch) assert isinstance(code.body.blocks[2].stmts.at(2), func.Call) stmt: func.Call = code.body.blocks[2].stmts.at(2) # type: ignore assert isinstance(stmt.callee, ir.BlockArgument) - assert stmt.callee.type.is_subseteq(func.MethodType) - + print(stmt.callee.type) + assert stmt.callee.type.is_subseteq(types.MethodType) +test_recursive_func() def test_invalid_func_call(): From ec8a29e4275054657d32272d247268f41839213c Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 16:04:09 -0500 Subject: [PATCH 04/11] fix linit --- docs/cookbook/foodlang/cf_rewrite.md | 2 +- example/food/script.py | 2 +- example/quantum/script.py | 6 ++++-- src/kirin/analysis/typeinfer/solve.py | 18 +++++++----------- src/kirin/dialects/ilist/stmts.py | 2 +- src/kirin/ir/attrs/_types.pyi | 13 +++++++++++-- src/kirin/ir/attrs/types.py | 8 ++------ src/kirin/types.py | 5 ++--- .../dataflow/typeinfer/test_inter_method.py | 1 - test/dialects/test_pytypes.py | 2 +- test/lowering/test_func.py | 3 +++ 11 files changed, 33 insertions(+), 29 deletions(-) diff --git a/docs/cookbook/foodlang/cf_rewrite.md b/docs/cookbook/foodlang/cf_rewrite.md index a8581c7a1..b0e75ce7f 100644 --- a/docs/cookbook/foodlang/cf_rewrite.md +++ b/docs/cookbook/foodlang/cf_rewrite.md @@ -132,7 +132,7 @@ def food(self): if fold: fold_pass(mt) - + if hungry: Walk(NewFoodAndNap()).rewrite(mt.code) diff --git a/example/food/script.py b/example/food/script.py index 54de42116..47a48c608 100644 --- a/example/food/script.py +++ b/example/food/script.py @@ -1,9 +1,9 @@ # type: ignore -from emit import EmitReceptMain from group import food from stmts import Eat, Nap, Cook, NewFood from recept import FeeAnalysis +from emit import EmitReceptMain from interp import FoodMethods as FoodMethods from lattice import AtLeastXItem from rewrite import NewFoodAndNap diff --git a/example/quantum/script.py b/example/quantum/script.py index 93f0d225c..3831d0d0e 100644 --- a/example/quantum/script.py +++ b/example/quantum/script.py @@ -4,6 +4,7 @@ from enum import Enum from typing import ClassVar from dataclasses import dataclass + from qulacs import QuantumState @@ -27,7 +28,7 @@ class Basis(Enum): # [section] from kirin import ir, types, lowering -from kirin.decl import statement, info +from kirin.decl import info, statement from kirin.prelude import basic # our language definitions and compiler begins @@ -161,8 +162,9 @@ def main(state: QuantumState): # we need to implement the runtime for the quantum circuit # let's just import qulacs a quantum circuit simulator +from qulacs import QuantumState, gate + from kirin import interp -from qulacs import gate, QuantumState @dialect.register diff --git a/src/kirin/analysis/typeinfer/solve.py b/src/kirin/analysis/typeinfer/solve.py index 213a24561..7899d00e6 100644 --- a/src/kirin/analysis/typeinfer/solve.py +++ b/src/kirin/analysis/typeinfer/solve.py @@ -72,9 +72,9 @@ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute: elif isinstance(typ, types.FunctionType): return types.FunctionType( params_type=tuple(self.substitute(t) for t in typ.params_type), - return_type=self.substitute(typ.return_type) - if typ.return_type - else None, + return_type=( + self.substitute(typ.return_type) if typ.return_type else None + ), ) return typ @@ -145,24 +145,20 @@ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute): def solve_FunctionType(self, annot: types.FunctionType, value: types.TypeAttribute): if not isinstance(value, types.FunctionType): return ResolutionError(annot, value) - + for var, val in zip(annot.params_type, value.params_type): result = self.solve(var, val) if not result: return result - + if not annot.return_type or not value.return_type: return Ok - + result = self.solve(annot.return_type, value.return_type) if not result: return result - - return Ok - - - + return Ok def solve_Union(self, annot: types.Union, value: types.TypeAttribute): for typ in annot.types: diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index f3a4bd81e..39e523b7f 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -86,7 +86,7 @@ class Foldl(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) fn: ir.SSAValue = info.argument(types.MethodType[[OutElemT, ElemT], OutElemT]) - + collection: ir.SSAValue = info.argument(IListType[ElemT]) init: ir.SSAValue = info.argument(OutElemT) result: ir.ResultValue = info.result(OutElemT) diff --git a/src/kirin/ir/attrs/_types.pyi b/src/kirin/ir/attrs/_types.pyi index 1f08bcd49..ebde8e27f 100644 --- a/src/kirin/ir/attrs/_types.pyi +++ b/src/kirin/ir/attrs/_types.pyi @@ -1,7 +1,16 @@ from dataclasses import dataclass from .abc import Attribute -from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute, FunctionType, TypeofMethodType +from .types import ( + Union, + Generic, + Literal, + PyClass, + TypeVar, + FunctionType, + TypeAttribute, + TypeofMethodType, +) @dataclass class _TypeAttribute(Attribute): @@ -12,4 +21,4 @@ class _TypeAttribute(Attribute): def is_subseteq_Generic(self, other: Generic) -> bool: ... def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ... def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ... - def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ... \ No newline at end of file + def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ... diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index 2ebaffeb5..5f33b41d3 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -715,6 +715,7 @@ def deserialize( out.vararg = vararg return out + @typing.final @dataclass(eq=False) class TypeofMethodType(TypeAttribute, metaclass=SingletonTypeMeta): @@ -747,10 +748,7 @@ def deserialize( def __getitem__( self, - typ: ( - tuple[list[TypeAttribute], TypeAttribute] - | tuple[list[TypeAttribute]] - ), + typ: tuple[list[TypeAttribute], TypeAttribute] | tuple[list[TypeAttribute]], ) -> "FunctionType": if isinstance(typ, tuple) and len(typ) == 2: return FunctionType(tuple(typ[0]), typ[1]) @@ -760,7 +758,6 @@ def __getitem__( raise TypeError("Invalid type arguments for TypeofMethodType") - @typing.final @dataclass(eq=False) class FunctionType(TypeAttribute): @@ -874,7 +871,6 @@ def deserialize( return FunctionType(params_type, return_type) - def _typeparams_list2tuple(args: tuple[TypeVarValue, ...]) -> tuple[TypeOrVararg, ...]: "provides the syntax sugar [A, B, C] type Generic(tuple, A, B, C)" return tuple(Generic(tuple, *arg) if isinstance(arg, list) else arg for arg in args) diff --git a/src/kirin/types.py b/src/kirin/types.py index 0f41bb7c4..12ee7d1d7 100644 --- a/src/kirin/types.py +++ b/src/kirin/types.py @@ -2,7 +2,6 @@ import numbers -from kirin.ir.method import Method from kirin.ir.attrs.types import ( Union as Union, Vararg as Vararg, @@ -12,9 +11,9 @@ PyClass as PyClass, TypeVar as TypeVar, BottomType as BottomType, + FunctionType as FunctionType, TypeAttribute as TypeAttribute, TypeofMethodType as TypeofMethodType, - FunctionType as FunctionType, hint2type as hint2type, is_tuple_of as is_tuple_of, ) @@ -34,4 +33,4 @@ Dict = Generic(dict, TypeVar("K"), TypeVar("V")) Set = Generic(set, TypeVar("T")) FrozenSet = Generic(frozenset, TypeVar("T")) -MethodType = TypeofMethodType() \ No newline at end of file +MethodType = TypeofMethodType() diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index 8375f0735..bc4bccf23 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -58,4 +58,3 @@ def alloc(n_iter: int): return ilist.map(_new, ilist.range(n_iter)) assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) - diff --git a/test/dialects/test_pytypes.py b/test/dialects/test_pytypes.py index 9a640abd2..ccc9280ff 100644 --- a/test/dialects/test_pytypes.py +++ b/test/dialects/test_pytypes.py @@ -16,8 +16,8 @@ TypeVar, NoneType, BottomType, - TypeAttribute, MethodType, + TypeAttribute, ) diff --git a/test/lowering/test_func.py b/test/lowering/test_func.py index 282518dfd..af9d64333 100644 --- a/test/lowering/test_func.py +++ b/test/lowering/test_func.py @@ -42,8 +42,11 @@ def recursive(n): assert isinstance(stmt.callee, ir.BlockArgument) print(stmt.callee.type) assert stmt.callee.type.is_subseteq(types.MethodType) + + test_recursive_func() + def test_invalid_func_call(): def undefined(n): From 8f6d0da55729308af3c6525bb8c1de558e1c0ed8 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 16:10:26 -0500 Subject: [PATCH 05/11] add test --- .../dataflow/typeinfer/test_infer_lambda.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 test/analysis/dataflow/typeinfer/test_infer_lambda.py diff --git a/test/analysis/dataflow/typeinfer/test_infer_lambda.py b/test/analysis/dataflow/typeinfer/test_infer_lambda.py new file mode 100644 index 000000000..c9303bdf6 --- /dev/null +++ b/test/analysis/dataflow/typeinfer/test_infer_lambda.py @@ -0,0 +1,16 @@ +from kirin.prelude import structural +from kirin.dialects import ilist +from kirin import types + +def test_infer_lambda(): + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n): + def map_func(i): + return n + 1 + + return ilist.map(map_func, ilist.range(4)) + + map_stmt = main.callable_region.blocks[0].stmts.at(-2) + assert isinstance(map_stmt, ilist.Map) + assert map_stmt.result.type == ilist.IListType[types.Int, types.Literal(4)] + From c73a843badfb0ccca25e429209d95ef1b4c2f563 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 16:44:08 -0500 Subject: [PATCH 06/11] fix call typeinfer --- src/kirin/dialects/func/typeinfer.py | 13 ++---------- src/kirin/ir/attrs/types.py | 9 +++++++-- .../dataflow/typeinfer/test_infer_lambda.py | 20 +++++++++++++++++++ .../dataflow/typeinfer/test_inter_method.py | 2 ++ test/lowering/test_method_hint.py | 4 +--- 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/kirin/dialects/func/typeinfer.py b/src/kirin/dialects/func/typeinfer.py index eba287567..932ba6989 100644 --- a/src/kirin/dialects/func/typeinfer.py +++ b/src/kirin/dialects/func/typeinfer.py @@ -54,20 +54,11 @@ def call(self, interp_: TypeInference, frame: Frame, stmt: Call): def _solve_method_type(self, interp: TypeInference, frame: Frame, stmt: Call): mt_inferred = frame.get(stmt.callee) - if not isinstance(mt_inferred, types.Generic): - return (types.Bottom,) - if len(mt_inferred.vars) != 2: - return (types.Bottom,) - args = mt_inferred.vars[0] - result = mt_inferred.vars[1] - if not args.is_subseteq(types.Tuple): + if not isinstance(mt_inferred, types.FunctionType): return (types.Bottom,) - resolve = TypeResolution() - # NOTE: we are not using [...] below to be compatible with 3.10 - resolve.solve(args, types.Tuple.where(frame.get_values(stmt.inputs))) - return (resolve.substitute(result),) + return (mt_inferred.return_type,) @impl(Invoke) def invoke(self, interp_: TypeInference, frame: Frame, stmt: Invoke): diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index 5f33b41d3..ad562d871 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -926,7 +926,6 @@ def hint2type(hint) -> TypeAttribute: if origin is None: # non-generic return PyClass(hint) - body = PyClass(origin) args = typing.get_args(hint) params = [] for arg in args: @@ -934,4 +933,10 @@ def hint2type(hint) -> TypeAttribute: params.append([hint2type(elem) for elem in arg]) else: params.append(hint2type(arg)) - return Generic(body, *params) + + if origin.__name__ == "Method": + assert len(params) == 2, "method type hint should be ir.Method[[params], return_type]" + return FunctionType(tuple(params[0]), params[1]) + else: + body = PyClass(origin) + return Generic(body, *params) diff --git a/test/analysis/dataflow/typeinfer/test_infer_lambda.py b/test/analysis/dataflow/typeinfer/test_infer_lambda.py index c9303bdf6..b3c1f361b 100644 --- a/test/analysis/dataflow/typeinfer/test_infer_lambda.py +++ b/test/analysis/dataflow/typeinfer/test_infer_lambda.py @@ -1,6 +1,7 @@ from kirin.prelude import structural from kirin.dialects import ilist from kirin import types +from kirin import ir def test_infer_lambda(): @structural(typeinfer=True, fold=False, no_raise=False) @@ -14,3 +15,22 @@ def map_func(i): assert isinstance(map_stmt, ilist.Map) assert map_stmt.result.type == ilist.IListType[types.Int, types.Literal(4)] + +def test_infer_method_type_hint_call(): + + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n, fx: ir.Method[[int], int]): + return fx(n) + + assert main.return_type == types.Int + +def test_infer_method_type_hint(): + + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n, fx: ir.Method[[int], int]): + def map_func(i): + return n + 1 + fx(i) + + return ilist.map(map_func, ilist.range(4)) + + assert main.return_type == ilist.IListType[types.Int, types.Literal(4)] \ No newline at end of file diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index bc4bccf23..6a18ac073 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -58,3 +58,5 @@ def alloc(n_iter: int): return ilist.map(_new, ilist.range(n_iter)) assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) + + diff --git a/test/lowering/test_method_hint.py b/test/lowering/test_method_hint.py index caefcef5d..48ec31c9b 100644 --- a/test/lowering/test_method_hint.py +++ b/test/lowering/test_method_hint.py @@ -11,6 +11,4 @@ def test(x: int, y: int) -> float: return test - assert main.return_type == types.Generic( - ir.Method, [types.Int, types.Int], types.Float - ) + assert main.return_type == types.MethodType[[types.Int, types.Int], types.Float] From e822504699bf7703977ffccfcadf93389a003f62 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 16:59:56 -0500 Subject: [PATCH 07/11] fix lint --- src/kirin/dialects/func/typeinfer.py | 2 +- src/kirin/ir/attrs/types.py | 4 +++- .../dataflow/typeinfer/test_infer_lambda.py | 15 ++++++++------- .../dataflow/typeinfer/test_inter_method.py | 2 -- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/kirin/dialects/func/typeinfer.py b/src/kirin/dialects/func/typeinfer.py index 932ba6989..862fdcf45 100644 --- a/src/kirin/dialects/func/typeinfer.py +++ b/src/kirin/dialects/func/typeinfer.py @@ -3,7 +3,7 @@ from kirin import ir, types from kirin.interp import Frame, MethodTable, ReturnValue, impl from kirin.analysis import const -from kirin.analysis.typeinfer import TypeInference, TypeResolution +from kirin.analysis.typeinfer import TypeInference from kirin.dialects.func.stmts import ( Call, Invoke, diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index ad562d871..69cb45396 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -935,7 +935,9 @@ def hint2type(hint) -> TypeAttribute: params.append(hint2type(arg)) if origin.__name__ == "Method": - assert len(params) == 2, "method type hint should be ir.Method[[params], return_type]" + assert ( + len(params) == 2 + ), "method type hint should be ir.Method[[params], return_type]" return FunctionType(tuple(params[0]), params[1]) else: body = PyClass(origin) diff --git a/test/analysis/dataflow/typeinfer/test_infer_lambda.py b/test/analysis/dataflow/typeinfer/test_infer_lambda.py index b3c1f361b..3a12fa076 100644 --- a/test/analysis/dataflow/typeinfer/test_infer_lambda.py +++ b/test/analysis/dataflow/typeinfer/test_infer_lambda.py @@ -1,14 +1,14 @@ +from kirin import ir, types from kirin.prelude import structural from kirin.dialects import ilist -from kirin import types -from kirin import ir + def test_infer_lambda(): @structural(typeinfer=True, fold=False, no_raise=False) def main(n): def map_func(i): return n + 1 - + return ilist.map(map_func, ilist.range(4)) map_stmt = main.callable_region.blocks[0].stmts.at(-2) @@ -21,16 +21,17 @@ def test_infer_method_type_hint_call(): @structural(typeinfer=True, fold=False, no_raise=False) def main(n, fx: ir.Method[[int], int]): return fx(n) - + assert main.return_type == types.Int + def test_infer_method_type_hint(): @structural(typeinfer=True, fold=False, no_raise=False) def main(n, fx: ir.Method[[int], int]): def map_func(i): return n + 1 + fx(i) - + return ilist.map(map_func, ilist.range(4)) - - assert main.return_type == ilist.IListType[types.Int, types.Literal(4)] \ No newline at end of file + + assert main.return_type == ilist.IListType[types.Int, types.Literal(4)] diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index 6a18ac073..bc4bccf23 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -58,5 +58,3 @@ def alloc(n_iter: int): return ilist.map(_new, ilist.range(n_iter)) assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) - - From 1c0d98bbfb06c86524b804ec10b24b2edd85c6e8 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 17:10:17 -0500 Subject: [PATCH 08/11] remove comment --- src/kirin/dialects/func/attrs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/kirin/dialects/func/attrs.py b/src/kirin/dialects/func/attrs.py index 017ab2b1c..52d557104 100644 --- a/src/kirin/dialects/func/attrs.py +++ b/src/kirin/dialects/func/attrs.py @@ -12,10 +12,6 @@ from ._dialect import dialect -# TypeofMethodType = types.PyClass[Method] -# MethodType = types.Generic( -# Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret") -# ) TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute") From 984fc94544dcb8e431684cd06e0861b4f718ec11 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 18:02:20 -0500 Subject: [PATCH 09/11] fix pyright --- src/kirin/ir/attrs/types.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index 69cb45396..51113ff46 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -716,6 +716,9 @@ def deserialize( return out +TypeArg: typing.TypeAlias = TypeAttribute | TypeVar + + @typing.final @dataclass(eq=False) class TypeofMethodType(TypeAttribute, metaclass=SingletonTypeMeta): @@ -748,7 +751,7 @@ def deserialize( def __getitem__( self, - typ: tuple[list[TypeAttribute], TypeAttribute] | tuple[list[TypeAttribute]], + typ: tuple[typing.Sequence[TypeArg], TypeArg] | tuple[typing.Sequence[TypeArg]], ) -> "FunctionType": if isinstance(typ, tuple) and len(typ) == 2: return FunctionType(tuple(typ[0]), typ[1]) @@ -794,10 +797,7 @@ def print_impl(self, printer: Printer) -> None: def __getitem__( self, - typ: ( - tuple[tuple[TypeAttribute, ...], TypeAttribute] - | tuple[tuple[TypeAttribute, ...]] - ), + typ: tuple[tuple[TypeArg, ...], TypeArg] | tuple[tuple[TypeArg, ...]], ) -> "FunctionType": if isinstance(typ, tuple) and len(typ) == 2: return self.where(typ[0], typ[1]) From aed5b074a877dbdef7afd419c04016684cf21504 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 18:03:55 -0500 Subject: [PATCH 10/11] remove print --- test/lowering/test_func.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/lowering/test_func.py b/test/lowering/test_func.py index af9d64333..317480df6 100644 --- a/test/lowering/test_func.py +++ b/test/lowering/test_func.py @@ -33,7 +33,6 @@ def recursive(n): return recursive(n - 1) code = lower.python_function(recursive) - code.print() assert isinstance(code, func.Function) assert len(code.body.blocks) == 3 assert isinstance(code.body.blocks[0].last_stmt, cf.ConditionalBranch) @@ -44,9 +43,6 @@ def recursive(n): assert stmt.callee.type.is_subseteq(types.MethodType) -test_recursive_func() - - def test_invalid_func_call(): def undefined(n): From 069d7431f9339dccd30d1c37ee3a0c51e964386b Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 19 Nov 2025 18:04:13 -0500 Subject: [PATCH 11/11] remove print --- test/lowering/test_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/lowering/test_func.py b/test/lowering/test_func.py index 317480df6..eaeddf5ac 100644 --- a/test/lowering/test_func.py +++ b/test/lowering/test_func.py @@ -39,7 +39,6 @@ def recursive(n): assert isinstance(code.body.blocks[2].stmts.at(2), func.Call) stmt: func.Call = code.body.blocks[2].stmts.at(2) # type: ignore assert isinstance(stmt.callee, ir.BlockArgument) - print(stmt.callee.type) assert stmt.callee.type.is_subseteq(types.MethodType)