Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,12 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.coverage.run]
include = ["src/bloqade/*"]

[tool.coverage.report]
exclude_also = [
'@overload',
'@_wraps'
]


[tool.pytest.ini_options]
testpaths = "test/"
3 changes: 1 addition & 2 deletions src/bloqade/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
def main() -> None:
print("Hello from bloqade-analog!")
from .dialects import grid as grid
14 changes: 14 additions & 0 deletions src/bloqade/geometry/dialects/grid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
from ._dialect import dialect as dialect
from ._interface import (
from_positions as from_positions,
get as get,
get_xpos as get_xpos,
get_ypos as get_ypos,
new as new,
repeat as repeat,
scale as scale,
shape as shape,
shift as shift,
sub_grid as sub_grid,
x_bounds as x_bounds,
y_bounds as y_bounds,
)
from ._typeinfer import TypeInferMethods as TypeInferMethods
from .concrete import GridInterpreter as GridInterpreter
from .stmts import (
Expand Down
222 changes: 222 additions & 0 deletions src/bloqade/geometry/dialects/grid/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import typing

from kirin.dialects import ilist
from kirin.lowering import wraps as _wraps

from .stmts import (
FromPositions,
Get,
GetSubGrid,
GetXBounds,
GetXPos,
GetYBounds,
GetYPos,
New,
Repeat,
Scale,
Shape,
Shift,
)
from .types import Grid


@_wraps(New)
def new(
x_spacing: ilist.IList[float, typing.Any] | list[float],
y_spacing: ilist.IList[float, typing.Any] | list[float],
x_init: float,
y_init: float,
) -> Grid[typing.Any, typing.Any]:
"""
Create a new grid with the given spacing and initial position.

Args:
x_spacing (IList[float] | list[float]): The spacing in the x direction.
y_spacing (IList[float] | list[float]): The spacing in the y direction.
x_init (float): The initial position in the x direction.
y_init (float): The initial position in the y direction.

Returns:
Grid: A new grid object.
"""
...


Nx = typing.TypeVar("Nx")
Ny = typing.TypeVar("Ny")


@typing.overload
def from_positions(
x_positions: ilist.IList[float, Nx], y_positions: ilist.IList[float, Ny]
) -> Grid[Nx, Ny]: ...
@typing.overload
def from_positions(
x_positions: ilist.IList[float, Nx], y_positions: list[float]
) -> Grid[Nx, typing.Any]: ...
@typing.overload
def from_positions(
x_positions: list[float], y_positions: ilist.IList[float, Ny]
) -> Grid[typing.Any, Ny]: ...
@typing.overload
def from_positions(
x_positions: list[float], y_positions: list[float]
) -> Grid[typing.Any, typing.Any]: ...
@_wraps(FromPositions)
def from_positions(x_positions, y_positions):
"""Construct a grid from the given x and y positions.

Args:
x_positions (IList[float] | list[float]): A list or ilist of floats representing the x-coordinates of grid points.
y_positions (IList[float] | list[float]): A list or ilist of floats representing the y-coordinates of grid points.

Returns:
Grid: a grid object
"""


@_wraps(Get)
def get(grid: Grid, idx: tuple[int, int]) -> tuple[float, float]:
"""Get the coordinate (x, y) of a grid at the given index.

Args:
grid (Grid): a grid object
idx (tuple[int, int]): a tuple of (x, y) indices
Returns:
tuple[float, float]: a tuple of (x, y) positions
tuple[None, None]: if the grid has no initial x or y position
"""
...


@_wraps(GetXPos)
def get_xpos(grid: Grid[Nx, typing.Any]) -> ilist.IList[float, Nx]:
"""Get the x positions of a grid.

Args:
grid: a grid object
Returns:
ilist.IList[float, typing.Any]: a list of x positions
"""
...


@_wraps(GetYPos)
def get_ypos(grid: Grid[typing.Any, Ny]) -> ilist.IList[float, Ny]:
"""Get the y positions of a grid.

Args:
grid: a grid object
Returns:
ilist.IList[float, typing.Any]: a list of y positions
"""
...


@typing.overload
def sub_grid(
grid: Grid, x_indices: ilist.IList[int, Nx], y_indices: list[int]
) -> Grid[Nx, typing.Any]: ...
@typing.overload
def sub_grid(
grid: Grid, x_indices: list[int], y_indices: ilist.IList[int, Ny]
) -> Grid[typing.Any, Ny]: ...
@typing.overload
def sub_grid(
grid: Grid, x_indices: list[int], y_indices: list[int]
) -> Grid[typing.Any, typing.Any]: ...
@_wraps(GetSubGrid)
def sub_grid(grid, x_indices, y_indices):
"""Get a subgrid from the given grid.

Args:
grid (Grid): a grid object
x_indices: a list/ilist of x indices
y_indices: a list/ilist of y indices
Returns:
Grid: a subgrid object
"""
...


@_wraps(GetXBounds)
def x_bounds(grid: Grid[typing.Any, typing.Any]) -> tuple[float, float]:
"""Get the x bounds of a grid.

Args:
grid (Grid): a grid object
Returns:
tuple[float, float]: a tuple of (min_x, max_x)
"""
...


@_wraps(GetYBounds)
def y_bounds(grid: Grid[typing.Any, typing.Any]) -> tuple[float, float]:
"""Get the y bounds of a grid.

Args:
grid (Grid): a grid object
Returns:
tuple[float, float]: a tuple of (min_y, max_y)
tuple[None, None]: if the grid has no initial y position
"""
...


@_wraps(Repeat)
def repeat(
grid: Grid, x_times: int, y_times: int, x_spacing: float, y_spacing: float
) -> Grid:
"""Repeat a grid in the x and y directions.

Args:
grid (Grid): a grid object
x_times (int): number of times to repeat in the x direction
y_times (int): number of times to repeat in the y direction
x_spacing (float): spacing in the x direction
y_spacing (float): spacing in the y direction
Returns:
Grid: a new grid object with the repeated pattern
"""
...


@_wraps(Scale)
def scale(grid: Grid[Nx, Ny], x_scale: float, y_scale: float) -> Grid[Nx, Ny]:
"""Scale a grid in the x and y directions.

Args:
grid (Grid): a grid object
x_scale (float): scaling factor in the x direction
y_scale (float): scaling factor in the y direction
Returns:
Grid: a new grid object that has been scaled
"""
...


@_wraps(Shift)
def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]:
"""Shift a grid in the x and y directions.

Args:
grid (Grid): a grid object
x_shift (float): shift in the x direction
y_shift (float): shift in the y direction
Returns:
Grid: a new grid object that has been shifted
"""
...


@_wraps(Shape)
def shape(grid: Grid) -> tuple[int, int]:
"""Get the shape of a grid.

Args:
grid (Grid): a grid object
Returns:
tuple[int, int]: a tuple of (num_x, num_y)
"""
...
2 changes: 1 addition & 1 deletion test/grid/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kirin import interp, ir
from kirin.dialects import ilist

from bloqade.geometry.dialects import grid
from bloqade.geometry import grid


class TestGridInterpreter:
Expand Down
21 changes: 14 additions & 7 deletions test/grid/test_typeinfer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import Any, Literal

from kirin import types
from kirin.dialects import ilist

from bloqade.geometry.dialects import grid
from bloqade.geometry import grid
from bloqade.geometry.prelude import geometry


def test_typeinfer():

@geometry
def test_method():
return grid.New([1, 2], [1, 2], 0, 0)
@geometry(typeinfer=True)
def test_1(spacing: ilist.IList[float, Literal[2]]):
return grid.new(spacing, [1.0, 2.0], 0.0, 0.0)

test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)])
assert test_1.return_type.is_equal(
grid.GridType[types.Literal(3), types.Literal(3)]
)

@geometry(typeinfer=True)
def test_2(spacing: ilist.IList[float, Any]):
return grid.new(spacing, [1.0, 2.0], 0.0, 0.0)

if __name__ == "__main__":
test_typeinfer()
assert test_2.return_type.is_equal(grid.GridType[types.Any, types.Literal(3)])