From 8a71216997761fb9f3cf0849ef4ded155e0cf925 Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Mon, 11 Mar 2024 23:59:06 +0000 Subject: [PATCH 1/6] Enable list[T] specification with kernel builder. Signed-off-by: Alex McCaskey --- python/cudaq/kernel/kernel_builder.py | 22 +++++++++++++++++++++- python/cudaq/kernel/utils.py | 19 ++++++++++++------- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 932058b232..cf4eee0ee9 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -7,8 +7,11 @@ # ============================================================================ # from functools import partialmethod +from pydoc import locate import random +import re import string +import typing from .quake_value import QuakeValue from .kernel_decorator import PyKernelDecorator from .utils import mlirTypeFromPyType, nvqppPrefix, emitFatalError, mlirTypeToPyType @@ -215,7 +218,9 @@ def __init__(self, argTypeList): with self.ctx, InsertionPoint(self.module.body), self.loc: self.mlirArgTypes = [ - mlirTypeFromPyType(argType, self.ctx) for argType in argTypeList + mlirTypeFromPyType(argType[0], self.ctx, argInstance=argType[1]) + for argType in + [self.__processArgType(ty) for ty in argTypeList] ] self.funcOp = func.FuncOp(self.funcName, (self.mlirArgTypes, []), @@ -231,6 +236,21 @@ def __init__(self, argTypeList): self.insertPoint = InsertionPoint.at_block_begin(e) + def __processArgType(self, ty): + """ + Process input argument type. Specifically, try to infer the + element type for a list, e.g. list[float]. + """ + if typing.get_origin(ty) == list or isinstance(ty(), list): + if '[' in str(ty) and ']' in str(ty): + # Infer the slice type + result = re.search('ist\[(.*)\]', str(ty)) + eleTyName = result.group(1) + if eleTyName != None and locate(eleTyName) != None: + return list, [locate(eleTyName)()] + emitFatalError(f'Invalid type for kernel builder {ty}') + return ty, None, None + def getIntegerAttr(self, type, value): """ Return an MLIR Integer Attribute of the given IntegerType. diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index 679f8045b5..c6d79ed8da 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -177,17 +177,22 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) argInstance = kwargs['argInstance'] - argTypeToCompareTo = kwargs['argTypeToCompareTo'] + argTypeToCompareTo = kwargs[ + 'argTypeToCompareTo'] if 'argTypeToCompareTo' in kwargs else None + + if isinstance(argInstance[0], bool): + return cc.StdvecType.get(ctx, mlirTypeFromPyType(bool, ctx)) if isinstance(argInstance[0], int): return cc.StdvecType.get(ctx, mlirTypeFromPyType(int, ctx)) if isinstance(argInstance[0], float): - # check if we are comparing to a complex... - eleTy = cc.StdvecType.getElementType(argTypeToCompareTo) - if ComplexType.isinstance(eleTy): - emitFatalError( - "Invalid runtime argument to kernel. list[complex] required, but list[float] provided." - ) + if argTypeToCompareTo != None: + # check if we are comparing to a complex... + eleTy = cc.StdvecType.getElementType(argTypeToCompareTo) + if ComplexType.isinstance(eleTy): + emitFatalError( + "Invalid runtime argument to kernel. list[complex] required, but list[float] provided." + ) return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) if isinstance(argInstance[0], complex): return cc.StdvecType.get(ctx, mlirTypeFromPyType(complex, ctx)) From de5aa465029c961e6de317ca892820b85704807f Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Tue, 12 Mar 2024 00:07:42 +0000 Subject: [PATCH 2/6] small fixes Signed-off-by: Alex McCaskey --- python/cudaq/kernel/kernel_builder.py | 6 ++++-- python/cudaq/kernel/utils.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index cf4eee0ee9..b8f69a1908 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -241,15 +241,17 @@ def __processArgType(self, ty): Process input argument type. Specifically, try to infer the element type for a list, e.g. list[float]. """ + if ty in [cudaq_runtime.qvector, cudaq_runtime.qubit]: + return ty, None if typing.get_origin(ty) == list or isinstance(ty(), list): if '[' in str(ty) and ']' in str(ty): # Infer the slice type - result = re.search('ist\[(.*)\]', str(ty)) + result = re.search(r'ist\[(.*)\]', str(ty)) eleTyName = result.group(1) if eleTyName != None and locate(eleTyName) != None: return list, [locate(eleTyName)()] emitFatalError(f'Invalid type for kernel builder {ty}') - return ty, None, None + return ty, None def getIntegerAttr(self, type, value): """ diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index c6d79ed8da..daa1892ee9 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -173,7 +173,7 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): return ComplexType.get(mlirTypeFromPyType(float, ctx)) if argType in [list, np.ndarray, List]: - if 'argInstance' not in kwargs: + if 'argInstance' not in kwargs or kwargs['argInstance'] == None: return cc.StdvecType.get(ctx, mlirTypeFromPyType(float, ctx)) argInstance = kwargs['argInstance'] From 49e1753193b9de190e74f4cc15c894b6200a21c9 Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Tue, 12 Mar 2024 23:01:16 +0000 Subject: [PATCH 3/6] add a test Signed-off-by: Alex McCaskey --- python/tests/builder/test_kernel_builder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tests/builder/test_kernel_builder.py b/python/tests/builder/test_kernel_builder.py index 43e40ee392..3a1a24b24d 100644 --- a/python/tests/builder/test_kernel_builder.py +++ b/python/tests/builder/test_kernel_builder.py @@ -1195,6 +1195,12 @@ def test_draw(): assert circuit == expected_str +def test_list_subscript(): + kernelAndArgs = cudaq.make_kernel(bool, list[bool], List[int], list[float]) + print(kernelAndArgs[0]) + assert len(kernelAndArgs) == 5 and len(kernelAndArgs[0].arguments) == 4 + + # leave for gdb debugging if __name__ == "__main__": loc = os.path.abspath(__file__) From 812a83203811f6e9ff2511a4624a04a752b243db Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Wed, 13 Mar 2024 13:52:57 +0000 Subject: [PATCH 4/6] Attempt to fix module not callable error Signed-off-by: Alex McCaskey --- python/cudaq/kernel/kernel_builder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index b8f69a1908..a8c037d670 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -248,8 +248,9 @@ def __processArgType(self, ty): # Infer the slice type result = re.search(r'ist\[(.*)\]', str(ty)) eleTyName = result.group(1) - if eleTyName != None and locate(eleTyName) != None: - return list, [locate(eleTyName)()] + pyType = locate(eleTyName) + if eleTyName != None and pyType != None: + return list, [pyType()] emitFatalError(f'Invalid type for kernel builder {ty}') return ty, None From baf9224005fa722aeb519d769b8c8950cc92c4ee Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Wed, 13 Mar 2024 14:35:52 +0000 Subject: [PATCH 5/6] don't use pydoc.locate Signed-off-by: Alex McCaskey --- python/cudaq/kernel/kernel_builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index a8c037d670..8e81d58819 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -7,7 +7,6 @@ # ============================================================================ # from functools import partialmethod -from pydoc import locate import random import re import string @@ -245,12 +244,13 @@ def __processArgType(self, ty): return ty, None if typing.get_origin(ty) == list or isinstance(ty(), list): if '[' in str(ty) and ']' in str(ty): + allowedTypeMap = {'int':int, 'bool':bool, 'float':float} # Infer the slice type result = re.search(r'ist\[(.*)\]', str(ty)) eleTyName = result.group(1) - pyType = locate(eleTyName) - if eleTyName != None and pyType != None: - return list, [pyType()] + pyType = allowedTypeMap[eleTyName] + if eleTyName != None and eleTyName in allowedTypeMap: + return list, [allowedTypeMap[eleTyName]()] emitFatalError(f'Invalid type for kernel builder {ty}') return ty, None From 51b42b425e4255c040acbfbf8c4450e52db533cc Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Wed, 13 Mar 2024 14:38:22 +0000 Subject: [PATCH 6/6] format Signed-off-by: Alex McCaskey --- python/cudaq/kernel/kernel_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 8e81d58819..6ab73cea17 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -244,7 +244,7 @@ def __processArgType(self, ty): return ty, None if typing.get_origin(ty) == list or isinstance(ty(), list): if '[' in str(ty) and ']' in str(ty): - allowedTypeMap = {'int':int, 'bool':bool, 'float':float} + allowedTypeMap = {'int': int, 'bool': bool, 'float': float} # Infer the slice type result = re.search(r'ist\[(.*)\]', str(ty)) eleTyName = result.group(1)