Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/cookbook/foodlang/cf_rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def food(self):

if fold:
fold_pass(mt)

if hungry:
Walk(NewFoodAndNap()).rewrite(mt.code)

Expand Down
2 changes: 1 addition & 1 deletion example/food/script.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# type: ignore
from emit import EmitReceptMain
from group import food
from stmts import Eat, Nap, Cook, NewFood
from recept import FeeAnalysis

from emit import EmitReceptMain
from interp import FoodMethods as FoodMethods
from lattice import AtLeastXItem
from rewrite import NewFoodAndNap
Expand Down
6 changes: 4 additions & 2 deletions example/quantum/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import ClassVar
from dataclasses import dataclass

from qulacs import QuantumState


Expand All @@ -27,7 +28,7 @@ class Basis(Enum):

# [section]
from kirin import ir, types, lowering
from kirin.decl import statement, info
from kirin.decl import info, statement
from kirin.prelude import basic

# our language definitions and compiler begins
Expand Down Expand Up @@ -161,8 +162,9 @@ def main(state: QuantumState):
# we need to implement the runtime for the quantum circuit
# let's just import qulacs a quantum circuit simulator

from qulacs import QuantumState, gate

from kirin import interp
from qulacs import gate, QuantumState


@dialect.register
Expand Down
27 changes: 27 additions & 0 deletions src/kirin/analysis/typeinfer/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute:
)
elif isinstance(typ, types.Union):
return types.Union(self.substitute(t) for t in typ.types)
elif isinstance(typ, types.FunctionType):
return types.FunctionType(
params_type=tuple(self.substitute(t) for t in typ.params_type),
return_type=(
self.substitute(typ.return_type) if typ.return_type else None
),
)
return typ

def solve(
Expand All @@ -94,6 +101,8 @@ def solve(
return self.solve_Generic(annot, value)
elif isinstance(annot, types.Union):
return self.solve_Union(annot, value)
elif isinstance(annot, types.FunctionType):
return self.solve_FunctionType(annot, value)

if annot.is_subseteq(value):
return Ok
Expand Down Expand Up @@ -133,6 +142,24 @@ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute):
return result
return Ok

def solve_FunctionType(self, annot: types.FunctionType, value: types.TypeAttribute):
if not isinstance(value, types.FunctionType):
return ResolutionError(annot, value)

for var, val in zip(annot.params_type, value.params_type):
result = self.solve(var, val)
if not result:
return result

if not annot.return_type or not value.return_type:
return Ok

result = self.solve(annot.return_type, value.return_type)
if not result:
return result

return Ok

def solve_Union(self, annot: types.Union, value: types.TypeAttribute):
for typ in annot.types:
result = self.solve(typ, value)
Expand Down
2 changes: 1 addition & 1 deletion src/kirin/dialects/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
constprop as constprop,
typeinfer as typeinfer,
)
from kirin.dialects.func.attrs import Signature as Signature, MethodType as MethodType
from kirin.dialects.func.attrs import Signature as Signature
from kirin.dialects.func.stmts import (
Call as Call,
Invoke as Invoke,
Expand Down
6 changes: 1 addition & 5 deletions src/kirin/dialects/func/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from kirin import types
from kirin.ir import Method, Attribute
from kirin.ir import Attribute
from kirin.print.printer import Printer
from kirin.serialization.core.serializationunit import SerializationUnit

Expand All @@ -12,10 +12,6 @@

from ._dialect import dialect

TypeofMethodType = types.PyClass[Method]
MethodType = types.Generic(
Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret")
)
TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute")


Expand Down
14 changes: 7 additions & 7 deletions src/kirin/dialects/func/stmts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from types import MethodType as ClassMethodType, FunctionType
from types import MethodType as PyClassMethodType, FunctionType as PyFunctionType
from typing import TypeVar

from kirin import ir, types
from kirin.decl import info, statement
from kirin.print.printer import Printer

from .attrs import Signature, MethodType
from .attrs import Signature
from ._dialect import dialect


Expand Down Expand Up @@ -58,7 +58,7 @@ class Function(ir.Statement):
"""The signature of the function at declaration."""
body: ir.Region = info.region(multi=True)
"""The body of the function."""
result: ir.ResultValue = info.result(MethodType)
result: ir.ResultValue = info.result(types.MethodType)
"""The result of the function."""

def print_impl(self, printer: Printer) -> None:
Expand Down Expand Up @@ -115,7 +115,7 @@ class Lambda(ir.Statement):
"""The signature of the function at declaration."""
captured: tuple[ir.SSAValue, ...] = info.argument()
body: ir.Region = info.region(multi=True)
result: ir.ResultValue = info.result(MethodType)
result: ir.ResultValue = info.result(types.MethodType)

def check(self) -> None:
assert self.body.blocks, "lambda body must not be empty"
Expand Down Expand Up @@ -145,7 +145,7 @@ def print_impl(self, printer: Printer) -> None:
class GetField(ir.Statement):
name = "getfield"
traits = frozenset({ir.Pure()})
obj: ir.SSAValue = info.argument(MethodType)
obj: ir.SSAValue = info.argument(types.MethodType)
field: int = info.attribute()
# NOTE: mypy somehow doesn't understand default init=False
result: ir.ResultValue = info.result(init=False)
Expand Down Expand Up @@ -249,15 +249,15 @@ def print_impl(self, printer: Printer) -> None:

def check_type(self) -> None:
if not self.callee.type.is_subseteq(types.MethodType):
if self.callee.type.is_subseteq(types.PyClass(FunctionType)):
if self.callee.type.is_subseteq(types.PyClass(PyFunctionType)):
raise ir.TypeCheckError(
self,
f"callee must be a method type, got {self.callee.type}",
help="did you call a Python function directly? "
"consider decorating it with kernel decorator",
)

if self.callee.type.is_subseteq(types.PyClass(ClassMethodType)):
if self.callee.type.is_subseteq(types.PyClass(PyClassMethodType)):
raise ir.TypeCheckError(
self,
"callee must be a method type, got class method",
Expand Down
15 changes: 3 additions & 12 deletions src/kirin/dialects/func/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kirin import ir, types
from kirin.interp import Frame, MethodTable, ReturnValue, impl
from kirin.analysis import const
from kirin.analysis.typeinfer import TypeInference, TypeResolution
from kirin.analysis.typeinfer import TypeInference
from kirin.dialects.func.stmts import (
Call,
Invoke,
Expand Down Expand Up @@ -54,20 +54,11 @@ def call(self, interp_: TypeInference, frame: Frame, stmt: Call):

def _solve_method_type(self, interp: TypeInference, frame: Frame, stmt: Call):
mt_inferred = frame.get(stmt.callee)
if not isinstance(mt_inferred, types.Generic):
return (types.Bottom,)

if len(mt_inferred.vars) != 2:
return (types.Bottom,)
args = mt_inferred.vars[0]
result = mt_inferred.vars[1]
if not args.is_subseteq(types.Tuple):
if not isinstance(mt_inferred, types.FunctionType):
return (types.Bottom,)

resolve = TypeResolution()
# NOTE: we are not using [...] below to be compatible with 3.10
resolve.solve(args, types.Tuple.where(frame.get_values(stmt.inputs)))
return (resolve.substitute(result),)
return (mt_inferred.return_type,)

@impl(Invoke)
def invoke(self, interp_: TypeInference, frame: Frame, stmt: Invoke):
Expand Down
15 changes: 6 additions & 9 deletions src/kirin/dialects/ilist/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ class Map(ir.Statement):
class Foldr(ir.Statement):
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
purity: bool = info.attribute(default=False)
fn: ir.SSAValue = info.argument(
types.Generic(ir.Method, [ElemT, OutElemT], OutElemT)
)
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT, OutElemT], OutElemT])
collection: ir.SSAValue = info.argument(IListType[ElemT])
init: ir.SSAValue = info.argument(OutElemT)
result: ir.ResultValue = info.result(OutElemT)
Expand All @@ -87,9 +85,8 @@ class Foldr(ir.Statement):
class Foldl(ir.Statement):
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
purity: bool = info.attribute(default=False)
fn: ir.SSAValue = info.argument(
types.Generic(ir.Method, [OutElemT, ElemT], OutElemT)
)
fn: ir.SSAValue = info.argument(types.MethodType[[OutElemT, ElemT], OutElemT])

collection: ir.SSAValue = info.argument(IListType[ElemT])
init: ir.SSAValue = info.argument(OutElemT)
result: ir.ResultValue = info.result(OutElemT)
Expand All @@ -104,7 +101,7 @@ class Scan(ir.Statement):
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
purity: bool = info.attribute(default=False)
fn: ir.SSAValue = info.argument(
types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT])
types.MethodType[[OutElemT, ElemT], types.Tuple[OutElemT, ResultT]]
)
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
init: ir.SSAValue = info.argument(OutElemT)
Expand All @@ -117,7 +114,7 @@ class Scan(ir.Statement):
class ForEach(ir.Statement):
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
purity: bool = info.attribute(default=False)
fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType))
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], types.NoneType])
collection: ir.SSAValue = info.argument(IListType[ElemT])


Expand All @@ -141,7 +138,7 @@ class Sorted(ir.Statement):
purity: bool = info.attribute(default=False)
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
key: ir.SSAValue = info.argument(
types.Union((types.Generic(ir.Method, [ElemT], ElemT), types.NoneType))
types.Union((types.MethodType[[ElemT], ElemT], types.NoneType))
)
reverse: ir.SSAValue = info.argument(types.Bool)
result: ir.ResultValue = info.result(IListType[ElemT, ListLen])
4 changes: 1 addition & 3 deletions src/kirin/dialects/lowering/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def lower_FunctionDef(
entries: dict[str, ir.SSAValue] = {}
entr_block = ir.Block()
fn_self = entr_block.args.append_from(
types.Generic(
ir.Method, types.Tuple.where(signature.inputs), signature.output
),
types.MethodType[list(signature.inputs), signature.output],
node.name + "_self",
)
entries[node.name] = fn_self
Expand Down
13 changes: 12 additions & 1 deletion src/kirin/ir/attrs/_types.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from dataclasses import dataclass

from .abc import Attribute
from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute
from .types import (
Union,
Generic,
Literal,
PyClass,
TypeVar,
FunctionType,
TypeAttribute,
TypeofMethodType,
)

@dataclass
class _TypeAttribute(Attribute):
Expand All @@ -11,3 +20,5 @@ class _TypeAttribute(Attribute):
def is_subseteq_PyClass(self, other: PyClass) -> bool: ...
def is_subseteq_Generic(self, other: Generic) -> bool: ...
def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ...
def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ...
def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ...
Loading