diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 932058b232..6ab73cea17 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -8,7 +8,9 @@ from functools import partialmethod import random +import re import string +import typing from .quake_value import QuakeValue from .kernel_decorator import PyKernelDecorator from .utils import mlirTypeFromPyType, nvqppPrefix, emitFatalError, mlirTypeToPyType @@ -215,7 +217,9 @@ def __init__(self, argTypeList): with self.ctx, InsertionPoint(self.module.body), self.loc: self.mlirArgTypes = [ - mlirTypeFromPyType(argType, self.ctx) for argType in argTypeList + mlirTypeFromPyType(argType[0], self.ctx, argInstance=argType[1]) + for argType in + [self.__processArgType(ty) for ty in argTypeList] ] self.funcOp = func.FuncOp(self.funcName, (self.mlirArgTypes, []), @@ -231,6 +235,25 @@ def __init__(self, argTypeList): self.insertPoint = InsertionPoint.at_block_begin(e) + def __processArgType(self, ty): + """ + Process input argument type. Specifically, try to infer the + element type for a list, e.g. list[float]. + """ + if ty in [cudaq_runtime.qvector, cudaq_runtime.qubit]: + return ty, None + if typing.get_origin(ty) == list or isinstance(ty(), list): + if '[' in str(ty) and ']' in str(ty): + allowedTypeMap = {'int': int, 'bool': bool, 'float': float} + # Infer the slice type + result = re.search(r'ist\[(.*)\]', str(ty)) + eleTyName = result.group(1) + pyType = allowedTypeMap[eleTyName] + if eleTyName != None and eleTyName in allowedTypeMap: + return list, [allowedTypeMap[eleTyName]()] + emitFatalError(f'Invalid type for kernel builder {ty}') + return ty, None + def getIntegerAttr(self, type, value): """ Return an MLIR Integer Attribute of the given IntegerType. diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index 679f8045b5..daa1892ee9 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -173,21 +173,26 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): return ComplexType.get(mlirTypeFromPyType(float, ctx)) if argType in [list, np.ndarray, List]: - if 'argInstance' not in kwargs: + if 'argInstance' not in kwargs or kwargs['argInstance'] == None: return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) argInstance = kwargs['argInstance'] - argTypeToCompareTo = kwargs['argTypeToCompareTo'] + argTypeToCompareTo = kwargs[ + 'argTypeToCompareTo'] if 'argTypeToCompareTo' in kwargs else None + + if isinstance(argInstance[0], bool): + return cc.StdvecType.get(ctx, mlirTypeFromPyType(bool, ctx)) if isinstance(argInstance[0], int): return cc.StdvecType.get(ctx, mlirTypeFromPyType(int, ctx)) if isinstance(argInstance[0], float): - # check if we are comparing to a complex... - eleTy = cc.StdvecType.getElementType(argTypeToCompareTo) - if ComplexType.isinstance(eleTy): - emitFatalError( - "Invalid runtime argument to kernel. list[complex] required, but list[float] provided." - ) + if argTypeToCompareTo != None: + # check if we are comparing to a complex... + eleTy = cc.StdvecType.getElementType(argTypeToCompareTo) + if ComplexType.isinstance(eleTy): + emitFatalError( + "Invalid runtime argument to kernel. list[complex] required, but list[float] provided." + ) return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) if isinstance(argInstance[0], complex): return cc.StdvecType.get(ctx, mlirTypeFromPyType(complex, ctx)) diff --git a/python/tests/builder/test_kernel_builder.py b/python/tests/builder/test_kernel_builder.py index 43e40ee392..3a1a24b24d 100644 --- a/python/tests/builder/test_kernel_builder.py +++ b/python/tests/builder/test_kernel_builder.py @@ -1195,6 +1195,12 @@ def test_draw(): assert circuit == expected_str +def test_list_subscript(): + kernelAndArgs = cudaq.make_kernel(bool, list[bool], List[int], list[float]) + print(kernelAndArgs[0]) + assert len(kernelAndArgs) == 5 and len(kernelAndArgs[0].arguments) == 4 + + # leave for gdb debugging if __name__ == "__main__": loc = os.path.abspath(__file__)