Skip to content

Commit

Permalink
Fix Python handling of kernel list parameters (#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmhowe23 committed Mar 21, 2024
1 parent 9510eda commit 65ea909
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import random
import re
import string
import typing
import sys
from typing import get_origin, List
from .quake_value import QuakeValue
from .kernel_decorator import PyKernelDecorator
from .utils import mlirTypeFromPyType, nvqppPrefix, emitFatalError, mlirTypeToPyType
Expand Down Expand Up @@ -242,7 +243,7 @@ def __processArgType(self, ty):
"""
if ty in [cudaq_runtime.qvector, cudaq_runtime.qubit]:
return ty, None
if typing.get_origin(ty) == list or isinstance(ty(), list):
if get_origin(ty) == list or isinstance(ty(), list):
if '[' in str(ty) and ']' in str(ty):
allowedTypeMap = {'int': int, 'bool': bool, 'float': float}
# Infer the slice type
Expand Down Expand Up @@ -1084,11 +1085,22 @@ def __call__(self, *args):
f"Invalid number of arguments passed to kernel `{self.funcName}` ({len(args)} provided, {len(self.mlirArgTypes)} required"
)

def getListType(eleType: type):
if sys.version_info < (3, 9):
return List[eleType]
else:
return list[eleType]

# validate the argument types
processedArgs = []
for i, arg in enumerate(args):
mlirType = mlirTypeFromPyType(type(arg), self.ctx)
if mlirType != self.mlirArgTypes[i]:
argType = type(arg)
listType = None
if argType == list:
listType = getListType(type(arg[0]))
mlirType = mlirTypeFromPyType(argType, self.ctx)
if mlirType != self.mlirArgTypes[
i] and listType != mlirTypeToPyType(self.mlirArgTypes[i]):
emitFatalError(
f"Invalid runtime argument type ({type(arg)} provided, {mlirTypeToPyType(self.mlirArgTypes[i])} required)"
)
Expand Down
15 changes: 15 additions & 0 deletions python/tests/builder/test_kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,13 @@ def kernelThatTakesIntAndListListFloat(qubits: cudaq.qview, qbit: int,
cudaq.observe(ansatz, hamiltonian).expectation(),
atol=1e-2)

kernelAndArgs = cudaq.make_kernel(List[bool])
cudaq.sample(kernelAndArgs[0], [False, True, False])
kernelAndArgs = cudaq.make_kernel(List[int])
cudaq.sample(kernelAndArgs[0], [1, 2, 3, 4])
kernelAndArgs = cudaq.make_kernel(List[float])
cudaq.sample(kernelAndArgs[0], [5.5, 6.5, 7.5])


@skipIfPythonLessThan39
def test_call_kernel_expressions_list():
Expand Down Expand Up @@ -1113,6 +1120,13 @@ def kernelThatTakesIntAndListListFloat(qubits: cudaq.qview, qbit: int,
cudaq.observe(ansatz, hamiltonian).expectation(),
atol=1e-2)

kernelAndArgs = cudaq.make_kernel(list[bool])
cudaq.sample(kernelAndArgs[0], [False, True, False])
kernelAndArgs = cudaq.make_kernel(list[int])
cudaq.sample(kernelAndArgs[0], [1, 2, 3, 4])
kernelAndArgs = cudaq.make_kernel(list[float])
cudaq.sample(kernelAndArgs[0], [5.5, 6.5, 7.5])


def test_adequate_number_params():

Expand Down Expand Up @@ -1200,6 +1214,7 @@ def test_list_subscript():
kernelAndArgs = cudaq.make_kernel(bool, list[bool], List[int], list[float])
print(kernelAndArgs[0])
assert len(kernelAndArgs) == 5 and len(kernelAndArgs[0].arguments) == 4
cudaq.sample(kernelAndArgs[0], False, [False], [3], [3.5])


# leave for gdb debugging
Expand Down

0 comments on commit 65ea909

Please sign in to comment.