diff --git a/src/bloqade/geometry/dialects/grid/__init__.py b/src/bloqade/geometry/dialects/grid/__init__.py index a86ac1b..498952e 100644 --- a/src/bloqade/geometry/dialects/grid/__init__.py +++ b/src/bloqade/geometry/dialects/grid/__init__.py @@ -1,4 +1,5 @@ from ._dialect import dialect as dialect +from ._typeinfer import TypeInferMethods as TypeInferMethods from .concrete import GridInterpreter as GridInterpreter from .stmts import ( FromPositions as FromPositions, diff --git a/src/bloqade/geometry/dialects/grid/_typeinfer.py b/src/bloqade/geometry/dialects/grid/_typeinfer.py new file mode 100644 index 0000000..111aac2 --- /dev/null +++ b/src/bloqade/geometry/dialects/grid/_typeinfer.py @@ -0,0 +1,29 @@ +from typing import cast + +from kirin import types +from kirin.analysis import TypeInference +from kirin.dialects import ilist +from kirin.interp import Frame, MethodTable, impl + +from ._dialect import dialect +from .stmts import New +from .types import GridType + + +@dialect.register(key="typeinfer") +class TypeInferMethods(MethodTable): + + def get_len(self, typ: types.TypeAttribute): + if typ.is_subseteq(ilist.IListType[types.Int, types.Any]): + typ = cast(types.Generic, typ) + if isinstance(typ.vars[1], types.Literal): + return types.Literal(typ.vars[1].data + 1) + + return types.Any + + @impl(New) + def inter_new(self, _: TypeInference, frame: Frame[types.TypeAttribute], node: New): + x_len = self.get_len(frame.get(node.x_spacing)) + y_len = self.get_len(frame.get(node.y_spacing)) + + return (GridType[x_len, y_len],) diff --git a/src/bloqade/geometry/dialects/grid/stmts.py b/src/bloqade/geometry/dialects/grid/stmts.py index b999688..6778364 100644 --- a/src/bloqade/geometry/dialects/grid/stmts.py +++ b/src/bloqade/geometry/dialects/grid/stmts.py @@ -26,11 +26,17 @@ class New(ir.Statement): name = "new" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any]) - y_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any]) + x_spacing: ir.SSAValue = info.argument( + type=ilist.IListType[types.Float, types.TypeVar("NumXStep")] + ) + y_spacing: ir.SSAValue = info.argument( + type=ilist.IListType[types.Float, types.TypeVar("NumYStep")] + ) x_init: ir.SSAValue = info.argument(types.Float) y_init: ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(GridType[types.Any, types.Any]) + result: ir.ResultValue = info.result( + GridType[types.TypeVar("NumX"), types.TypeVar("NumY")] + ) # Maybe do this with hints? diff --git a/src/bloqade/geometry/prelude.py b/src/bloqade/geometry/prelude.py new file mode 100644 index 0000000..e7e5eb0 --- /dev/null +++ b/src/bloqade/geometry/prelude.py @@ -0,0 +1,44 @@ +from kirin import ir +from kirin.ir.method import Method +from kirin.passes.default import Default +from kirin.prelude import structural +from typing_extensions import Annotated, Doc + +from bloqade.geometry.dialects import grid + + +@ir.dialect_group(structural.add(grid)) +def geometry( + self, +): + """Structural kernel with optimization passes.""" + + def run_pass( + mt: Annotated[Method, Doc("The method to run pass on.")], + *, + verify: Annotated[ + bool, Doc("run `verify` before running passes, default is `True`") + ] = True, + typeinfer: Annotated[ + bool, + Doc( + "run type inference and apply the inferred type to IR, default `False`" + ), + ] = False, + fold: Annotated[bool, Doc("run folding passes")] = True, + aggressive: Annotated[ + bool, Doc("run aggressive folding passes if `fold=True`") + ] = False, + no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, + ) -> None: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + default_pass.fixpoint(mt) + + return run_pass diff --git a/src/bloqade/geometry/rewrite/__init__.py b/src/bloqade/geometry/rewrite/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bloqade/geometry/rewrite/desugar.py b/src/bloqade/geometry/rewrite/desugar.py new file mode 100644 index 0000000..e69de29 diff --git a/test/grid/test_typeinfer.py b/test/grid/test_typeinfer.py new file mode 100644 index 0000000..770ab82 --- /dev/null +++ b/test/grid/test_typeinfer.py @@ -0,0 +1,17 @@ +from kirin import types + +from bloqade.geometry.dialects import grid +from bloqade.geometry.prelude import geometry + + +def test_typeinfer(): + + @geometry + def test_method(): + return grid.New([1, 2], [1, 2], 0, 0) + + test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)]) + + +if __name__ == "__main__": + test_typeinfer()