diff --git a/pyproject.toml b/pyproject.toml index adb4ff6..0a6c50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,5 +88,12 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.coverage.run] include = ["src/bloqade/*"] +[tool.coverage.report] +exclude_also = [ + '@overload', + '@_wraps' +] + + [tool.pytest.ini_options] testpaths = "test/" diff --git a/src/bloqade/geometry/__init__.py b/src/bloqade/geometry/__init__.py index d202e14..1f46a44 100644 --- a/src/bloqade/geometry/__init__.py +++ b/src/bloqade/geometry/__init__.py @@ -1,2 +1 @@ -def main() -> None: - print("Hello from bloqade-analog!") +from .dialects import grid as grid diff --git a/src/bloqade/geometry/dialects/grid/__init__.py b/src/bloqade/geometry/dialects/grid/__init__.py index 498952e..fcfb207 100644 --- a/src/bloqade/geometry/dialects/grid/__init__.py +++ b/src/bloqade/geometry/dialects/grid/__init__.py @@ -1,4 +1,18 @@ from ._dialect import dialect as dialect +from ._interface import ( + from_positions as from_positions, + get as get, + get_xpos as get_xpos, + get_ypos as get_ypos, + new as new, + repeat as repeat, + scale as scale, + shape as shape, + shift as shift, + sub_grid as sub_grid, + x_bounds as x_bounds, + y_bounds as y_bounds, +) from ._typeinfer import TypeInferMethods as TypeInferMethods from .concrete import GridInterpreter as GridInterpreter from .stmts import ( diff --git a/src/bloqade/geometry/dialects/grid/_interface.py b/src/bloqade/geometry/dialects/grid/_interface.py new file mode 100644 index 0000000..3cf5446 --- /dev/null +++ b/src/bloqade/geometry/dialects/grid/_interface.py @@ -0,0 +1,222 @@ +import typing + +from kirin.dialects import ilist +from kirin.lowering import wraps as _wraps + +from .stmts import ( + FromPositions, + Get, + GetSubGrid, + GetXBounds, + GetXPos, + GetYBounds, + GetYPos, + New, + Repeat, + Scale, + Shape, + Shift, +) +from .types import Grid + + +@_wraps(New) +def new( + x_spacing: ilist.IList[float, typing.Any] | list[float], + y_spacing: ilist.IList[float, typing.Any] | list[float], + x_init: float, + y_init: float, +) -> Grid[typing.Any, typing.Any]: + """ + Create a new grid with the given spacing and initial position. + + Args: + x_spacing (IList[float] | list[float]): The spacing in the x direction. + y_spacing (IList[float] | list[float]): The spacing in the y direction. + x_init (float): The initial position in the x direction. + y_init (float): The initial position in the y direction. + + Returns: + Grid: A new grid object. + """ + ... + + +Nx = typing.TypeVar("Nx") +Ny = typing.TypeVar("Ny") + + +@typing.overload +def from_positions( + x_positions: ilist.IList[float, Nx], y_positions: ilist.IList[float, Ny] +) -> Grid[Nx, Ny]: ... +@typing.overload +def from_positions( + x_positions: ilist.IList[float, Nx], y_positions: list[float] +) -> Grid[Nx, typing.Any]: ... +@typing.overload +def from_positions( + x_positions: list[float], y_positions: ilist.IList[float, Ny] +) -> Grid[typing.Any, Ny]: ... +@typing.overload +def from_positions( + x_positions: list[float], y_positions: list[float] +) -> Grid[typing.Any, typing.Any]: ... +@_wraps(FromPositions) +def from_positions(x_positions, y_positions): + """Construct a grid from the given x and y positions. + + Args: + x_positions (IList[float] | list[float]): A list or ilist of floats representing the x-coordinates of grid points. + y_positions (IList[float] | list[float]): A list or ilist of floats representing the y-coordinates of grid points. + + Returns: + Grid: a grid object + """ + + +@_wraps(Get) +def get(grid: Grid, idx: tuple[int, int]) -> tuple[float, float]: + """Get the coordinate (x, y) of a grid at the given index. + + Args: + grid (Grid): a grid object + idx (tuple[int, int]): a tuple of (x, y) indices + Returns: + tuple[float, float]: a tuple of (x, y) positions + tuple[None, None]: if the grid has no initial x or y position + """ + ... + + +@_wraps(GetXPos) +def get_xpos(grid: Grid[Nx, typing.Any]) -> ilist.IList[float, Nx]: + """Get the x positions of a grid. + + Args: + grid: a grid object + Returns: + ilist.IList[float, typing.Any]: a list of x positions + """ + ... + + +@_wraps(GetYPos) +def get_ypos(grid: Grid[typing.Any, Ny]) -> ilist.IList[float, Ny]: + """Get the y positions of a grid. + + Args: + grid: a grid object + Returns: + ilist.IList[float, typing.Any]: a list of y positions + """ + ... + + +@typing.overload +def sub_grid( + grid: Grid, x_indices: ilist.IList[int, Nx], y_indices: list[int] +) -> Grid[Nx, typing.Any]: ... +@typing.overload +def sub_grid( + grid: Grid, x_indices: list[int], y_indices: ilist.IList[int, Ny] +) -> Grid[typing.Any, Ny]: ... +@typing.overload +def sub_grid( + grid: Grid, x_indices: list[int], y_indices: list[int] +) -> Grid[typing.Any, typing.Any]: ... +@_wraps(GetSubGrid) +def sub_grid(grid, x_indices, y_indices): + """Get a subgrid from the given grid. + + Args: + grid (Grid): a grid object + x_indices: a list/ilist of x indices + y_indices: a list/ilist of y indices + Returns: + Grid: a subgrid object + """ + ... + + +@_wraps(GetXBounds) +def x_bounds(grid: Grid[typing.Any, typing.Any]) -> tuple[float, float]: + """Get the x bounds of a grid. + + Args: + grid (Grid): a grid object + Returns: + tuple[float, float]: a tuple of (min_x, max_x) + """ + ... + + +@_wraps(GetYBounds) +def y_bounds(grid: Grid[typing.Any, typing.Any]) -> tuple[float, float]: + """Get the y bounds of a grid. + + Args: + grid (Grid): a grid object + Returns: + tuple[float, float]: a tuple of (min_y, max_y) + tuple[None, None]: if the grid has no initial y position + """ + ... + + +@_wraps(Repeat) +def repeat( + grid: Grid, x_times: int, y_times: int, x_spacing: float, y_spacing: float +) -> Grid: + """Repeat a grid in the x and y directions. + + Args: + grid (Grid): a grid object + x_times (int): number of times to repeat in the x direction + y_times (int): number of times to repeat in the y direction + x_spacing (float): spacing in the x direction + y_spacing (float): spacing in the y direction + Returns: + Grid: a new grid object with the repeated pattern + """ + ... + + +@_wraps(Scale) +def scale(grid: Grid[Nx, Ny], x_scale: float, y_scale: float) -> Grid[Nx, Ny]: + """Scale a grid in the x and y directions. + + Args: + grid (Grid): a grid object + x_scale (float): scaling factor in the x direction + y_scale (float): scaling factor in the y direction + Returns: + Grid: a new grid object that has been scaled + """ + ... + + +@_wraps(Shift) +def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]: + """Shift a grid in the x and y directions. + + Args: + grid (Grid): a grid object + x_shift (float): shift in the x direction + y_shift (float): shift in the y direction + Returns: + Grid: a new grid object that has been shifted + """ + ... + + +@_wraps(Shape) +def shape(grid: Grid) -> tuple[int, int]: + """Get the shape of a grid. + + Args: + grid (Grid): a grid object + Returns: + tuple[int, int]: a tuple of (num_x, num_y) + """ + ... diff --git a/test/grid/test_concrete.py b/test/grid/test_concrete.py index ca91a09..9526973 100644 --- a/test/grid/test_concrete.py +++ b/test/grid/test_concrete.py @@ -4,7 +4,7 @@ from kirin import interp, ir from kirin.dialects import ilist -from bloqade.geometry.dialects import grid +from bloqade.geometry import grid class TestGridInterpreter: diff --git a/test/grid/test_typeinfer.py b/test/grid/test_typeinfer.py index 770ab82..0068420 100644 --- a/test/grid/test_typeinfer.py +++ b/test/grid/test_typeinfer.py @@ -1,17 +1,24 @@ +from typing import Any, Literal + from kirin import types +from kirin.dialects import ilist -from bloqade.geometry.dialects import grid +from bloqade.geometry import grid from bloqade.geometry.prelude import geometry def test_typeinfer(): - @geometry - def test_method(): - return grid.New([1, 2], [1, 2], 0, 0) + @geometry(typeinfer=True) + def test_1(spacing: ilist.IList[float, Literal[2]]): + return grid.new(spacing, [1.0, 2.0], 0.0, 0.0) - test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)]) + assert test_1.return_type.is_equal( + grid.GridType[types.Literal(3), types.Literal(3)] + ) + @geometry(typeinfer=True) + def test_2(spacing: ilist.IList[float, Any]): + return grid.new(spacing, [1.0, 2.0], 0.0, 0.0) -if __name__ == "__main__": - test_typeinfer() + assert test_2.return_type.is_equal(grid.GridType[types.Any, types.Literal(3)])