diff --git a/docs/cookbook/foodlang/cf_rewrite.md b/docs/cookbook/foodlang/cf_rewrite.md index a8581c7a1..b0e75ce7f 100644 --- a/docs/cookbook/foodlang/cf_rewrite.md +++ b/docs/cookbook/foodlang/cf_rewrite.md @@ -132,7 +132,7 @@ def food(self): if fold: fold_pass(mt) - + if hungry: Walk(NewFoodAndNap()).rewrite(mt.code) diff --git a/example/food/script.py b/example/food/script.py index 54de42116..47a48c608 100644 --- a/example/food/script.py +++ b/example/food/script.py @@ -1,9 +1,9 @@ # type: ignore -from emit import EmitReceptMain from group import food from stmts import Eat, Nap, Cook, NewFood from recept import FeeAnalysis +from emit import EmitReceptMain from interp import FoodMethods as FoodMethods from lattice import AtLeastXItem from rewrite import NewFoodAndNap diff --git a/example/quantum/script.py b/example/quantum/script.py index 93f0d225c..3831d0d0e 100644 --- a/example/quantum/script.py +++ b/example/quantum/script.py @@ -4,6 +4,7 @@ from enum import Enum from typing import ClassVar from dataclasses import dataclass + from qulacs import QuantumState @@ -27,7 +28,7 @@ class Basis(Enum): # [section] from kirin import ir, types, lowering -from kirin.decl import statement, info +from kirin.decl import info, statement from kirin.prelude import basic # our language definitions and compiler begins @@ -161,8 +162,9 @@ def main(state: QuantumState): # we need to implement the runtime for the quantum circuit # let's just import qulacs a quantum circuit simulator +from qulacs import QuantumState, gate + from kirin import interp -from qulacs import gate, QuantumState @dialect.register diff --git a/src/kirin/analysis/typeinfer/solve.py b/src/kirin/analysis/typeinfer/solve.py index b08d576a1..7899d00e6 100644 --- a/src/kirin/analysis/typeinfer/solve.py +++ b/src/kirin/analysis/typeinfer/solve.py @@ -69,6 +69,13 @@ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute: ) elif isinstance(typ, types.Union): return types.Union(self.substitute(t) for t in typ.types) + elif isinstance(typ, types.FunctionType): + return types.FunctionType( + params_type=tuple(self.substitute(t) for t in typ.params_type), + return_type=( + self.substitute(typ.return_type) if typ.return_type else None + ), + ) return typ def solve( @@ -94,6 +101,8 @@ def solve( return self.solve_Generic(annot, value) elif isinstance(annot, types.Union): return self.solve_Union(annot, value) + elif isinstance(annot, types.FunctionType): + return self.solve_FunctionType(annot, value) if annot.is_subseteq(value): return Ok @@ -133,6 +142,24 @@ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute): return result return Ok + def solve_FunctionType(self, annot: types.FunctionType, value: types.TypeAttribute): + if not isinstance(value, types.FunctionType): + return ResolutionError(annot, value) + + for var, val in zip(annot.params_type, value.params_type): + result = self.solve(var, val) + if not result: + return result + + if not annot.return_type or not value.return_type: + return Ok + + result = self.solve(annot.return_type, value.return_type) + if not result: + return result + + return Ok + def solve_Union(self, annot: types.Union, value: types.TypeAttribute): for typ in annot.types: result = self.solve(typ, value) diff --git a/src/kirin/dialects/func/__init__.py b/src/kirin/dialects/func/__init__.py index 0aeecc34f..913385f3f 100644 --- a/src/kirin/dialects/func/__init__.py +++ b/src/kirin/dialects/func/__init__.py @@ -5,7 +5,7 @@ constprop as constprop, typeinfer as typeinfer, ) -from kirin.dialects.func.attrs import Signature as Signature, MethodType as MethodType +from kirin.dialects.func.attrs import Signature as Signature from kirin.dialects.func.stmts import ( Call as Call, Invoke as Invoke, diff --git a/src/kirin/dialects/func/attrs.py b/src/kirin/dialects/func/attrs.py index 3d6ec46ae..52d557104 100644 --- a/src/kirin/dialects/func/attrs.py +++ b/src/kirin/dialects/func/attrs.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from kirin import types -from kirin.ir import Method, Attribute +from kirin.ir import Attribute from kirin.print.printer import Printer from kirin.serialization.core.serializationunit import SerializationUnit @@ -12,10 +12,6 @@ from ._dialect import dialect -TypeofMethodType = types.PyClass[Method] -MethodType = types.Generic( - Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret") -) TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute") diff --git a/src/kirin/dialects/func/stmts.py b/src/kirin/dialects/func/stmts.py index f026f21b2..309ea26ff 100644 --- a/src/kirin/dialects/func/stmts.py +++ b/src/kirin/dialects/func/stmts.py @@ -1,13 +1,13 @@ from __future__ import annotations -from types import MethodType as ClassMethodType, FunctionType +from types import MethodType as PyClassMethodType, FunctionType as PyFunctionType from typing import TypeVar from kirin import ir, types from kirin.decl import info, statement from kirin.print.printer import Printer -from .attrs import Signature, MethodType +from .attrs import Signature from ._dialect import dialect @@ -58,7 +58,7 @@ class Function(ir.Statement): """The signature of the function at declaration.""" body: ir.Region = info.region(multi=True) """The body of the function.""" - result: ir.ResultValue = info.result(MethodType) + result: ir.ResultValue = info.result(types.MethodType) """The result of the function.""" def print_impl(self, printer: Printer) -> None: @@ -115,7 +115,7 @@ class Lambda(ir.Statement): """The signature of the function at declaration.""" captured: tuple[ir.SSAValue, ...] = info.argument() body: ir.Region = info.region(multi=True) - result: ir.ResultValue = info.result(MethodType) + result: ir.ResultValue = info.result(types.MethodType) def check(self) -> None: assert self.body.blocks, "lambda body must not be empty" @@ -145,7 +145,7 @@ def print_impl(self, printer: Printer) -> None: class GetField(ir.Statement): name = "getfield" traits = frozenset({ir.Pure()}) - obj: ir.SSAValue = info.argument(MethodType) + obj: ir.SSAValue = info.argument(types.MethodType) field: int = info.attribute() # NOTE: mypy somehow doesn't understand default init=False result: ir.ResultValue = info.result(init=False) @@ -249,7 +249,7 @@ def print_impl(self, printer: Printer) -> None: def check_type(self) -> None: if not self.callee.type.is_subseteq(types.MethodType): - if self.callee.type.is_subseteq(types.PyClass(FunctionType)): + if self.callee.type.is_subseteq(types.PyClass(PyFunctionType)): raise ir.TypeCheckError( self, f"callee must be a method type, got {self.callee.type}", @@ -257,7 +257,7 @@ def check_type(self) -> None: "consider decorating it with kernel decorator", ) - if self.callee.type.is_subseteq(types.PyClass(ClassMethodType)): + if self.callee.type.is_subseteq(types.PyClass(PyClassMethodType)): raise ir.TypeCheckError( self, "callee must be a method type, got class method", diff --git a/src/kirin/dialects/func/typeinfer.py b/src/kirin/dialects/func/typeinfer.py index eba287567..862fdcf45 100644 --- a/src/kirin/dialects/func/typeinfer.py +++ b/src/kirin/dialects/func/typeinfer.py @@ -3,7 +3,7 @@ from kirin import ir, types from kirin.interp import Frame, MethodTable, ReturnValue, impl from kirin.analysis import const -from kirin.analysis.typeinfer import TypeInference, TypeResolution +from kirin.analysis.typeinfer import TypeInference from kirin.dialects.func.stmts import ( Call, Invoke, @@ -54,20 +54,11 @@ def call(self, interp_: TypeInference, frame: Frame, stmt: Call): def _solve_method_type(self, interp: TypeInference, frame: Frame, stmt: Call): mt_inferred = frame.get(stmt.callee) - if not isinstance(mt_inferred, types.Generic): - return (types.Bottom,) - if len(mt_inferred.vars) != 2: - return (types.Bottom,) - args = mt_inferred.vars[0] - result = mt_inferred.vars[1] - if not args.is_subseteq(types.Tuple): + if not isinstance(mt_inferred, types.FunctionType): return (types.Bottom,) - resolve = TypeResolution() - # NOTE: we are not using [...] below to be compatible with 3.10 - resolve.solve(args, types.Tuple.where(frame.get_values(stmt.inputs))) - return (resolve.substitute(result),) + return (mt_inferred.return_type,) @impl(Invoke) def invoke(self, interp_: TypeInference, frame: Frame, stmt: Invoke): diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index b6d078210..39e523b7f 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -75,9 +75,7 @@ class Map(ir.Statement): class Foldr(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [ElemT, OutElemT], OutElemT) - ) + fn: ir.SSAValue = info.argument(types.MethodType[[ElemT, OutElemT], OutElemT]) collection: ir.SSAValue = info.argument(IListType[ElemT]) init: ir.SSAValue = info.argument(OutElemT) result: ir.ResultValue = info.result(OutElemT) @@ -87,9 +85,8 @@ class Foldr(ir.Statement): class Foldl(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [OutElemT, ElemT], OutElemT) - ) + fn: ir.SSAValue = info.argument(types.MethodType[[OutElemT, ElemT], OutElemT]) + collection: ir.SSAValue = info.argument(IListType[ElemT]) init: ir.SSAValue = info.argument(OutElemT) result: ir.ResultValue = info.result(OutElemT) @@ -104,7 +101,7 @@ class Scan(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) fn: ir.SSAValue = info.argument( - types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT]) + types.MethodType[[OutElemT, ElemT], types.Tuple[OutElemT, ResultT]] ) collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen]) init: ir.SSAValue = info.argument(OutElemT) @@ -117,7 +114,7 @@ class Scan(ir.Statement): class ForEach(ir.Statement): traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()}) purity: bool = info.attribute(default=False) - fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType)) + fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], types.NoneType]) collection: ir.SSAValue = info.argument(IListType[ElemT]) @@ -141,7 +138,7 @@ class Sorted(ir.Statement): purity: bool = info.attribute(default=False) collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen]) key: ir.SSAValue = info.argument( - types.Union((types.Generic(ir.Method, [ElemT], ElemT), types.NoneType)) + types.Union((types.MethodType[[ElemT], ElemT], types.NoneType)) ) reverse: ir.SSAValue = info.argument(types.Bool) result: ir.ResultValue = info.result(IListType[ElemT, ListLen]) diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index 6b98efecb..19b57fc49 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -35,9 +35,7 @@ def lower_FunctionDef( entries: dict[str, ir.SSAValue] = {} entr_block = ir.Block() fn_self = entr_block.args.append_from( - types.Generic( - ir.Method, types.Tuple.where(signature.inputs), signature.output - ), + types.MethodType[list(signature.inputs), signature.output], node.name + "_self", ) entries[node.name] = fn_self diff --git a/src/kirin/ir/attrs/_types.pyi b/src/kirin/ir/attrs/_types.pyi index aebf867ba..ebde8e27f 100644 --- a/src/kirin/ir/attrs/_types.pyi +++ b/src/kirin/ir/attrs/_types.pyi @@ -1,7 +1,16 @@ from dataclasses import dataclass from .abc import Attribute -from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute +from .types import ( + Union, + Generic, + Literal, + PyClass, + TypeVar, + FunctionType, + TypeAttribute, + TypeofMethodType, +) @dataclass class _TypeAttribute(Attribute): @@ -11,3 +20,5 @@ class _TypeAttribute(Attribute): def is_subseteq_PyClass(self, other: PyClass) -> bool: ... def is_subseteq_Generic(self, other: Generic) -> bool: ... def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ... + def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ... + def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ... diff --git a/src/kirin/ir/attrs/types.py b/src/kirin/ir/attrs/types.py index 5dfcbe8ee..51113ff46 100644 --- a/src/kirin/ir/attrs/types.py +++ b/src/kirin/ir/attrs/types.py @@ -716,6 +716,161 @@ def deserialize( return out +TypeArg: typing.TypeAlias = TypeAttribute | TypeVar + + +@typing.final +@dataclass(eq=False) +class TypeofMethodType(TypeAttribute, metaclass=SingletonTypeMeta): + name = "TypeofMethodType" + + def __hash__(self) -> int: + return hash((TypeofMethodType,)) + + def is_structurally_equal( + self, other: Attribute, context: dict | None = None + ) -> bool: + return isinstance(other, TypeofMethodType) + + def is_subseteq_TypeofMethodType(self, other: "TypeofMethodType") -> bool: + return True + + def serialize(self, serializer: "Serializer") -> "SerializationUnit": + return SerializationUnit( + kind="type-attribute", + module_name=self.__module__, + class_name=self.__class__.__name__, + data=dict(), + ) + + @classmethod + def deserialize( + cls, serUnit: "SerializationUnit", deserializer: "Deserializer" + ) -> "TypeofMethodType": + return TypeofMethodType() + + def __getitem__( + self, + typ: tuple[typing.Sequence[TypeArg], TypeArg] | tuple[typing.Sequence[TypeArg]], + ) -> "FunctionType": + if isinstance(typ, tuple) and len(typ) == 2: + return FunctionType(tuple(typ[0]), typ[1]) + elif isinstance(typ, tuple) and len(typ) == 1: + return FunctionType(tuple(typ[0])) + else: + raise TypeError("Invalid type arguments for TypeofMethodType") + + +@typing.final +@dataclass(eq=False) +class FunctionType(TypeAttribute): + name = "MethodType" + params_type: tuple[TypeAttribute, ...] + return_type: TypeAttribute | None = None + + def __init__( + self, + params_type: tuple[TypeAttribute, ...], + return_type: TypeAttribute | None = None, + ): + self.params_type = params_type + self.return_type = return_type + + def __hash__(self) -> int: + return hash((FunctionType, self.params_type, self.return_type)) + + def __repr__(self) -> str: + params = ", ".join(map(repr, self.params_type)) + if self.return_type is not None: + return f"({params}) -> {repr(self.return_type)}" + else: + return f"({params}) -> None" + + def print_impl(self, printer: Printer) -> None: + printer.plain_print("(") + printer.print_seq(self.params_type, delim=", ") + printer.plain_print(") -> ") + if self.return_type is not None: + printer.print(self.return_type) + else: + printer.plain_print("None") + + def __getitem__( + self, + typ: tuple[tuple[TypeArg, ...], TypeArg] | tuple[tuple[TypeArg, ...]], + ) -> "FunctionType": + if isinstance(typ, tuple) and len(typ) == 2: + return self.where(typ[0], typ[1]) + elif isinstance(typ, tuple) and len(typ) == 1: + return self.where(typ[0]) + else: + raise TypeError("Invalid type arguments for MethodType") + + def where( + self, typ: tuple[TypeAttribute, ...], return_type: TypeAttribute | None = None + ) -> "FunctionType": + if len(typ) != len(self.params_type): + raise TypeError("Number of type arguments does not match") + if all(v.is_subseteq(bound) for v, bound in zip(typ, self.params_type)): + if return_type is None: + return FunctionType(typ, self.return_type) + elif self.return_type is not None and return_type.is_subseteq( + self.return_type + ): + return FunctionType(typ, return_type) + raise TypeError("Type arguments do not match") + + def is_structurally_equal( + self, other: Attribute, context: dict | None = None + ) -> bool: + return ( + isinstance(other, FunctionType) + and self.params_type == other.params_type + and self.return_type == other.return_type + ) + + def is_subseteq_FunctionType(self, other: "FunctionType") -> bool: + if len(self.params_type) != len(other.params_type): + return False + for s_param, o_param in zip(self.params_type, other.params_type): + if not s_param.is_subseteq(o_param): + return False + if self.return_type is None: + return True + elif other.return_type is None: + return False + else: + return self.return_type.is_subseteq(other.return_type) + + def is_subseteq_TypeofMethodType(self, other: "TypeofMethodType") -> bool: + return True + + def is_subseteq_Union(self, other: Union) -> bool: + return any(self.is_subseteq(t) for t in other.types) + + def is_subseteq_fallback(self, other: TypeAttribute) -> bool: + return False + + def serialize(self, serializer: "Serializer") -> "SerializationUnit": + return SerializationUnit( + kind="type-attribute", + module_name=self.__module__, + class_name=self.__class__.__name__, + data={ + "params_type": serializer.serialize_tuple(self.params_type), + "return_type": serializer.serialize(self.return_type), + }, + ) + + @classmethod + def deserialize( + cls, serUnit: "SerializationUnit", deserializer: "Deserializer" + ) -> "FunctionType": + params_type = deserializer.deserialize_tuple(serUnit.data["params_type"]) + return_type = deserializer.deserialize(serUnit.data["return_type"]) + return FunctionType(params_type, return_type) + + def _typeparams_list2tuple(args: tuple[TypeVarValue, ...]) -> tuple[TypeOrVararg, ...]: "provides the syntax sugar [A, B, C] type Generic(tuple, A, B, C)" return tuple(Generic(tuple, *arg) if isinstance(arg, list) else arg for arg in args) @@ -771,7 +926,6 @@ def hint2type(hint) -> TypeAttribute: if origin is None: # non-generic return PyClass(hint) - body = PyClass(origin) args = typing.get_args(hint) params = [] for arg in args: @@ -779,4 +933,12 @@ def hint2type(hint) -> TypeAttribute: params.append([hint2type(elem) for elem in arg]) else: params.append(hint2type(arg)) - return Generic(body, *params) + + if origin.__name__ == "Method": + assert ( + len(params) == 2 + ), "method type hint should be ir.Method[[params], return_type]" + return FunctionType(tuple(params[0]), params[1]) + else: + body = PyClass(origin) + return Generic(body, *params) diff --git a/src/kirin/ir/method.py b/src/kirin/ir/method.py index cd9da5d6a..e44bb41ee 100644 --- a/src/kirin/ir/method.py +++ b/src/kirin/ir/method.py @@ -25,7 +25,7 @@ ) from .exception import ValidationError from .nodes.stmt import Statement -from .attrs.types import Generic +from .attrs.types import FunctionType if typing.TYPE_CHECKING: from kirin.ir.group import DialectGroup @@ -141,7 +141,7 @@ def self_type(self): """Return the type of the self argument of the method.""" trait = self.code.get_present_trait(HasSignature) signature = trait.get_signature(self.code) - return Generic(Method, Generic(tuple, *signature.inputs), signature.output) + return FunctionType(params_type=signature.inputs, return_type=signature.output) @property def callable_region(self): diff --git a/src/kirin/types.py b/src/kirin/types.py index 2f57e173d..12ee7d1d7 100644 --- a/src/kirin/types.py +++ b/src/kirin/types.py @@ -2,7 +2,6 @@ import numbers -from kirin.ir.method import Method from kirin.ir.attrs.types import ( Union as Union, Vararg as Vararg, @@ -12,7 +11,9 @@ PyClass as PyClass, TypeVar as TypeVar, BottomType as BottomType, + FunctionType as FunctionType, TypeAttribute as TypeAttribute, + TypeofMethodType as TypeofMethodType, hint2type as hint2type, is_tuple_of as is_tuple_of, ) @@ -32,6 +33,4 @@ Dict = Generic(dict, TypeVar("K"), TypeVar("V")) Set = Generic(set, TypeVar("T")) FrozenSet = Generic(frozenset, TypeVar("T")) -TypeofFunctionType = Generic[type(lambda: None)] -FunctionType = Generic(type(lambda: None), Tuple, Vararg(Any)) -MethodType = Generic(Method, TypeVar("Params", Tuple), TypeVar("Ret")) +MethodType = TypeofMethodType() diff --git a/test/analysis/dataflow/typeinfer/test_infer_lambda.py b/test/analysis/dataflow/typeinfer/test_infer_lambda.py new file mode 100644 index 000000000..3a12fa076 --- /dev/null +++ b/test/analysis/dataflow/typeinfer/test_infer_lambda.py @@ -0,0 +1,37 @@ +from kirin import ir, types +from kirin.prelude import structural +from kirin.dialects import ilist + + +def test_infer_lambda(): + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n): + def map_func(i): + return n + 1 + + return ilist.map(map_func, ilist.range(4)) + + map_stmt = main.callable_region.blocks[0].stmts.at(-2) + assert isinstance(map_stmt, ilist.Map) + assert map_stmt.result.type == ilist.IListType[types.Int, types.Literal(4)] + + +def test_infer_method_type_hint_call(): + + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n, fx: ir.Method[[int], int]): + return fx(n) + + assert main.return_type == types.Int + + +def test_infer_method_type_hint(): + + @structural(typeinfer=True, fold=False, no_raise=False) + def main(n, fx: ir.Method[[int], int]): + def map_func(i): + return n + 1 + fx(i) + + return ilist.map(map_func, ilist.range(4)) + + assert main.return_type == ilist.IListType[types.Int, types.Literal(4)] diff --git a/test/dialects/py/__init__.py b/test/dialects/py_dialect/__init__.py similarity index 100% rename from test/dialects/py/__init__.py rename to test/dialects/py_dialect/__init__.py diff --git a/test/dialects/py/test_assign.py b/test/dialects/py_dialect/test_assign.py similarity index 100% rename from test/dialects/py/test_assign.py rename to test/dialects/py_dialect/test_assign.py diff --git a/test/dialects/py/test_iter.py b/test/dialects/py_dialect/test_iter.py similarity index 100% rename from test/dialects/py/test_iter.py rename to test/dialects/py_dialect/test_iter.py diff --git a/test/dialects/py/test_tuple_infer.py b/test/dialects/py_dialect/test_tuple_infer.py similarity index 100% rename from test/dialects/py/test_tuple_infer.py rename to test/dialects/py_dialect/test_tuple_infer.py diff --git a/test/dialects/test_pytypes.py b/test/dialects/test_pytypes.py index 61dfd61ba..ccc9280ff 100644 --- a/test/dialects/test_pytypes.py +++ b/test/dialects/test_pytypes.py @@ -16,6 +16,7 @@ TypeVar, NoneType, BottomType, + MethodType, TypeAttribute, ) @@ -112,3 +113,20 @@ def test_generic_topbottom(): assert t.meet(TypeAttribute.bottom()).is_subseteq(TypeAttribute.bottom()) assert t.join(TypeAttribute.top()).is_structurally_equal(TypeAttribute.top()) assert t.meet(TypeAttribute.top()).is_structurally_equal(t) + + +def test_method_type(): + t1 = MethodType[[Int, Float], Bool] + t2 = MethodType[[Int, Float], Bool] + + assert t1.is_subseteq(t2) + + t3 = MethodType[[Int, Float], AnyType()] + assert t1.is_subseteq(t3) + + t4 = MethodType[[Int, Float], String] + assert not t1.is_subseteq(t4) + + Var = TypeVar("Var") + t5 = MethodType[[Int, Var], Bool] + assert t1.is_subseteq(t5) diff --git a/test/ir/test_isequal.py b/test/ir/test_isequal.py index 5e552b9f7..e0d6cac01 100644 --- a/test/ir/test_isequal.py +++ b/test/ir/test_isequal.py @@ -4,7 +4,7 @@ def test_is_structurally_equal_ignoring_hint(): block = ir.Block() - block.args.append_from(types.PyClass(ir.Method), "self") + block.args.append_from(types.MethodType, "self") source_func = func.Function( sym_name="main", signature=func.Signature( @@ -15,7 +15,7 @@ def test_is_structurally_equal_ignoring_hint(): ) block = ir.Block() - block.args.append_from(types.PyClass(ir.Method), "self") + block.args.append_from(types.MethodType, "self") expected_func = func.Function( sym_name="main", signature=func.Signature( diff --git a/test/lowering/test_func.py b/test/lowering/test_func.py index 678c1b104..eaeddf5ac 100644 --- a/test/lowering/test_func.py +++ b/test/lowering/test_func.py @@ -39,7 +39,7 @@ def recursive(n): assert isinstance(code.body.blocks[2].stmts.at(2), func.Call) stmt: func.Call = code.body.blocks[2].stmts.at(2) # type: ignore assert isinstance(stmt.callee, ir.BlockArgument) - assert stmt.callee.type.is_subseteq(func.MethodType) + assert stmt.callee.type.is_subseteq(types.MethodType) def test_invalid_func_call(): diff --git a/test/lowering/test_method_hint.py b/test/lowering/test_method_hint.py index caefcef5d..48ec31c9b 100644 --- a/test/lowering/test_method_hint.py +++ b/test/lowering/test_method_hint.py @@ -11,6 +11,4 @@ def test(x: int, y: int) -> float: return test - assert main.return_type == types.Generic( - ir.Method, [types.Int, types.Int], types.Float - ) + assert main.return_type == types.MethodType[[types.Int, types.Int], types.Float]