Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/geometry/dialects/grid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._dialect import dialect as dialect
from ._typeinfer import TypeInferMethods as TypeInferMethods
from .concrete import GridInterpreter as GridInterpreter
from .stmts import (
FromPositions as FromPositions,
Expand Down
29 changes: 29 additions & 0 deletions src/bloqade/geometry/dialects/grid/_typeinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import cast

from kirin import types
from kirin.analysis import TypeInference
from kirin.dialects import ilist
from kirin.interp import Frame, MethodTable, impl

from ._dialect import dialect
from .stmts import New
from .types import GridType


@dialect.register(key="typeinfer")
class TypeInferMethods(MethodTable):

def get_len(self, typ: types.TypeAttribute):
if typ.is_subseteq(ilist.IListType[types.Int, types.Any]):
Copy link

Copilot AI May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_len function checks for a list of integers while grid.New declares spacing as a list of floats. Verify that this discrepancy is intentional or update the types to ensure consistency in type inference.

Suggested change
if typ.is_subseteq(ilist.IListType[types.Int, types.Any]):
if typ.is_subseteq(ilist.IListType[types.Int, types.Any]) or typ.is_subseteq(ilist.IListType[types.Float, types.Any]):

Copilot uses AI. Check for mistakes.
typ = cast(types.Generic, typ)
if isinstance(typ.vars[1], types.Literal):
return types.Literal(typ.vars[1].data + 1)

return types.Any

@impl(New)
def inter_new(self, _: TypeInference, frame: Frame[types.TypeAttribute], node: New):
x_len = self.get_len(frame.get(node.x_spacing))
y_len = self.get_len(frame.get(node.y_spacing))

return (GridType[x_len, y_len],)
12 changes: 9 additions & 3 deletions src/bloqade/geometry/dialects/grid/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ class New(ir.Statement):
name = "new"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})

x_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any])
y_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any])
x_spacing: ir.SSAValue = info.argument(
type=ilist.IListType[types.Float, types.TypeVar("NumXStep")]
)
y_spacing: ir.SSAValue = info.argument(
type=ilist.IListType[types.Float, types.TypeVar("NumYStep")]
)
x_init: ir.SSAValue = info.argument(types.Float)
y_init: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(GridType[types.Any, types.Any])
result: ir.ResultValue = info.result(
GridType[types.TypeVar("NumX"), types.TypeVar("NumY")]
)


# Maybe do this with hints?
Expand Down
44 changes: 44 additions & 0 deletions src/bloqade/geometry/prelude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from kirin import ir
from kirin.ir.method import Method
from kirin.passes.default import Default
from kirin.prelude import structural
from typing_extensions import Annotated, Doc

from bloqade.geometry.dialects import grid


@ir.dialect_group(structural.add(grid))
def geometry(
self,
):
"""Structural kernel with optimization passes."""

def run_pass(
mt: Annotated[Method, Doc("The method to run pass on.")],
*,
verify: Annotated[
bool, Doc("run `verify` before running passes, default is `True`")
] = True,
typeinfer: Annotated[
bool,
Doc(
"run type inference and apply the inferred type to IR, default `False`"
),
] = False,
fold: Annotated[bool, Doc("run folding passes")] = True,
aggressive: Annotated[
bool, Doc("run aggressive folding passes if `fold=True`")
] = False,
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
) -> None:
default_pass = Default(
self,
verify=verify,
fold=fold,
aggressive=aggressive,
typeinfer=typeinfer,
no_raise=no_raise,
)
default_pass.fixpoint(mt)

return run_pass
Empty file.
Empty file.
17 changes: 17 additions & 0 deletions test/grid/test_typeinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from kirin import types

from bloqade.geometry.dialects import grid
from bloqade.geometry.prelude import geometry


def test_typeinfer():

@geometry
def test_method():
return grid.New([1, 2], [1, 2], 0, 0)

test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)])


if __name__ == "__main__":
test_typeinfer()