From 07cfb5f9851d2ca20fbe6e1f4b4eae954e4d2280 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 9 May 2025 14:50:43 -0400 Subject: [PATCH 1/2] adding type infer for --- .../geometry/dialects/grid/__init__.py | 1 + .../geometry/dialects/grid/_typeinfer.py | 29 ++++++++++++ src/bloqade/geometry/dialects/grid/stmts.py | 12 +++-- src/bloqade/geometry/prelude.py | 44 +++++++++++++++++++ src/bloqade/geometry/rewrite/__init__.py | 0 src/bloqade/geometry/rewrite/desugar.py | 0 test/grid/test_typeinfer.py | 17 +++++++ 7 files changed, 100 insertions(+), 3 deletions(-) create mode 100644 src/bloqade/geometry/dialects/grid/_typeinfer.py create mode 100644 src/bloqade/geometry/prelude.py create mode 100644 src/bloqade/geometry/rewrite/__init__.py create mode 100644 src/bloqade/geometry/rewrite/desugar.py create mode 100644 test/grid/test_typeinfer.py diff --git a/src/bloqade/geometry/dialects/grid/__init__.py b/src/bloqade/geometry/dialects/grid/__init__.py index a86ac1b..498952e 100644 --- a/src/bloqade/geometry/dialects/grid/__init__.py +++ b/src/bloqade/geometry/dialects/grid/__init__.py @@ -1,4 +1,5 @@ from ._dialect import dialect as dialect +from ._typeinfer import TypeInferMethods as TypeInferMethods from .concrete import GridInterpreter as GridInterpreter from .stmts import ( FromPositions as FromPositions, diff --git a/src/bloqade/geometry/dialects/grid/_typeinfer.py b/src/bloqade/geometry/dialects/grid/_typeinfer.py new file mode 100644 index 0000000..54625ee --- /dev/null +++ b/src/bloqade/geometry/dialects/grid/_typeinfer.py @@ -0,0 +1,29 @@ +from kirin import types +from kirin.analysis import TypeInference +from kirin.interp import Frame, MethodTable, impl + +from ._dialect import dialect +from .stmts import New +from .types import GridType + + +@dialect.register(key="typeinfer") +class TypeInferMethods(MethodTable): + + def get_len(self, typ: types.TypeAttribute): + if isinstance(typ, types.Generic) and isinstance(typ.vars[1], types.Literal): + return types.Literal(typ.vars[1].data + 1) + else: + return types.Any + + @impl(New) + def inter_new( + self, interp_: TypeInference, frame: Frame[types.TypeAttribute], node: New + ): + x_spacing_type = frame.get(node.x_spacing) + y_spacing_type = frame.get(node.y_spacing) + + x_len = self.get_len(x_spacing_type) + y_len = self.get_len(y_spacing_type) + + return (GridType[x_len, y_len],) diff --git a/src/bloqade/geometry/dialects/grid/stmts.py b/src/bloqade/geometry/dialects/grid/stmts.py index b999688..6778364 100644 --- a/src/bloqade/geometry/dialects/grid/stmts.py +++ b/src/bloqade/geometry/dialects/grid/stmts.py @@ -26,11 +26,17 @@ class New(ir.Statement): name = "new" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any]) - y_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any]) + x_spacing: ir.SSAValue = info.argument( + type=ilist.IListType[types.Float, types.TypeVar("NumXStep")] + ) + y_spacing: ir.SSAValue = info.argument( + type=ilist.IListType[types.Float, types.TypeVar("NumYStep")] + ) x_init: ir.SSAValue = info.argument(types.Float) y_init: ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(GridType[types.Any, types.Any]) + result: ir.ResultValue = info.result( + GridType[types.TypeVar("NumX"), types.TypeVar("NumY")] + ) # Maybe do this with hints? diff --git a/src/bloqade/geometry/prelude.py b/src/bloqade/geometry/prelude.py new file mode 100644 index 0000000..e7e5eb0 --- /dev/null +++ b/src/bloqade/geometry/prelude.py @@ -0,0 +1,44 @@ +from kirin import ir +from kirin.ir.method import Method +from kirin.passes.default import Default +from kirin.prelude import structural +from typing_extensions import Annotated, Doc + +from bloqade.geometry.dialects import grid + + +@ir.dialect_group(structural.add(grid)) +def geometry( + self, +): + """Structural kernel with optimization passes.""" + + def run_pass( + mt: Annotated[Method, Doc("The method to run pass on.")], + *, + verify: Annotated[ + bool, Doc("run `verify` before running passes, default is `True`") + ] = True, + typeinfer: Annotated[ + bool, + Doc( + "run type inference and apply the inferred type to IR, default `False`" + ), + ] = False, + fold: Annotated[bool, Doc("run folding passes")] = True, + aggressive: Annotated[ + bool, Doc("run aggressive folding passes if `fold=True`") + ] = False, + no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, + ) -> None: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + default_pass.fixpoint(mt) + + return run_pass diff --git a/src/bloqade/geometry/rewrite/__init__.py b/src/bloqade/geometry/rewrite/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bloqade/geometry/rewrite/desugar.py b/src/bloqade/geometry/rewrite/desugar.py new file mode 100644 index 0000000..e69de29 diff --git a/test/grid/test_typeinfer.py b/test/grid/test_typeinfer.py new file mode 100644 index 0000000..770ab82 --- /dev/null +++ b/test/grid/test_typeinfer.py @@ -0,0 +1,17 @@ +from kirin import types + +from bloqade.geometry.dialects import grid +from bloqade.geometry.prelude import geometry + + +def test_typeinfer(): + + @geometry + def test_method(): + return grid.New([1, 2], [1, 2], 0, 0) + + test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)]) + + +if __name__ == "__main__": + test_typeinfer() From 588ee204fdd2df106817eeb9d70e38b66e11b0e2 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 9 May 2025 14:56:56 -0400 Subject: [PATCH 2/2] making type check a bit tighter bound --- .../geometry/dialects/grid/_typeinfer.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/bloqade/geometry/dialects/grid/_typeinfer.py b/src/bloqade/geometry/dialects/grid/_typeinfer.py index 54625ee..111aac2 100644 --- a/src/bloqade/geometry/dialects/grid/_typeinfer.py +++ b/src/bloqade/geometry/dialects/grid/_typeinfer.py @@ -1,5 +1,8 @@ +from typing import cast + from kirin import types from kirin.analysis import TypeInference +from kirin.dialects import ilist from kirin.interp import Frame, MethodTable, impl from ._dialect import dialect @@ -11,19 +14,16 @@ class TypeInferMethods(MethodTable): def get_len(self, typ: types.TypeAttribute): - if isinstance(typ, types.Generic) and isinstance(typ.vars[1], types.Literal): - return types.Literal(typ.vars[1].data + 1) - else: - return types.Any + if typ.is_subseteq(ilist.IListType[types.Int, types.Any]): + typ = cast(types.Generic, typ) + if isinstance(typ.vars[1], types.Literal): + return types.Literal(typ.vars[1].data + 1) + + return types.Any @impl(New) - def inter_new( - self, interp_: TypeInference, frame: Frame[types.TypeAttribute], node: New - ): - x_spacing_type = frame.get(node.x_spacing) - y_spacing_type = frame.get(node.y_spacing) - - x_len = self.get_len(x_spacing_type) - y_len = self.get_len(y_spacing_type) + def inter_new(self, _: TypeInference, frame: Frame[types.TypeAttribute], node: New): + x_len = self.get_len(frame.get(node.x_spacing)) + y_len = self.get_len(frame.get(node.y_spacing)) return (GridType[x_len, y_len],)