Skip to content

Commit

Permalink
Fix value error bug with validating mlirTypeFromPyType for np.ndarray (
Browse files Browse the repository at this point in the history
…#1406)

* Fix value error bug with validating mlirTypeFromPyType for np.array

Signed-off-by: Alex McCaskey <amccaskey@nvidia.com>

* add the test

Signed-off-by: Alex McCaskey <amccaskey@nvidia.com>

* fix ci failure

Signed-off-by: Alex McCaskey <amccaskey@nvidia.com>

* Move new test due to issue #1400

---------

Signed-off-by: Alex McCaskey <amccaskey@nvidia.com>
Co-authored-by: Ben Howe <bhowe@nvidia.com>
  • Loading branch information
amccaskey and bmhowe23 committed Mar 21, 2024
1 parent 4ec139c commit 81db600
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/cudaq/kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ def mlirTypeFromPyType(argType, ctx, **kwargs):
return ComplexType.get(mlirTypeFromPyType(float, ctx))

if argType in [list, np.ndarray, List]:
if 'argInstance' not in kwargs or kwargs['argInstance'] == None:
if 'argInstance' not in kwargs:
return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx))
if argType != np.ndarray:
if kwargs['argInstance'] == None:
return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx))

argInstance = kwargs['argInstance']
argTypeToCompareTo = kwargs[
Expand Down
11 changes: 11 additions & 0 deletions python/tests/kernel/test_kernel_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,17 @@ def test2() -> int:
assert test2() == 10


def test_no_valueerror_np_array():

@cudaq.kernel
def test(var: np.ndarray):
q = cudaq.qubit()
ry(var[0], q)
mz(q)

test(np.array([1., 2.]))


def test_draw():

@cudaq.kernel
Expand Down

0 comments on commit 81db600

Please sign in to comment.