Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Qvector init from state #1713

Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ebd7574
Add support for initialization of qvector from lists
annagrin Apr 25, 2024
aa542b1
Fix formatting changes
annagrin May 7, 2024
138174f
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 7, 2024
495ce33
Support nd.array creation with dtypes and creating qvector from it
annagrin May 9, 2024
a65fd11
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 10, 2024
b96ea19
temp
annagrin May 13, 2024
45bab2a
Add automatic conversion to the simulation precision data type
annagrin May 14, 2024
8964dfa
Use simluation dtype for alloca operation, add length checking for sc…
annagrin May 14, 2024
06b405f
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 14, 2024
4abd9c2
Support creating nd arrays and initializing vectors. Add a vector cop…
annagrin May 15, 2024
5bf877f
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 15, 2024
38eb835
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 15, 2024
742e7ff
Fixed spelling
annagrin May 15, 2024
f84a810
Remove dictionary.dic
annagrin May 15, 2024
d938e29
Support cudaq.amplitudes inside kernels
annagrin May 16, 2024
4444e2d
Fixed test failures
annagrin May 16, 2024
d95794f
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 16, 2024
6f06de7
Updated test
annagrin May 16, 2024
b7a6eaa
Fixed more failing tests
annagrin May 16, 2024
03c9f8b
Remove creating qvector of const length to be handled later
annagrin May 16, 2024
8dca23a
Temp
annagrin May 17, 2024
aa9f8e4
Try a couple of ways to get state data
annagrin May 22, 2024
078dffa
Merge with experimental/stateHandling
annagrin May 22, 2024
c3c8580
Made it work e2e for testing
annagrin May 22, 2024
9cd4c69
Cleanup
annagrin May 22, 2024
a6dee4a
Support captured cudaq states in kernels
annagrin May 24, 2024
94db953
Merge branch 'experimental/stateHandling' of https://github.com/NVIDI…
annagrin May 24, 2024
6575bea
Added error tests for trying to create a state from incorrect precisi…
annagrin May 24, 2024
ab7eeaa
Remove unneded comments
annagrin May 24, 2024
bd60230
Use a counter ID for captured states instead of hashes
annagrin May 24, 2024
138ea58
Cleanup
annagrin May 24, 2024
1e6a28f
Merge branch 'experimental/stateHandling' into qvector-init-from-stat…
annagrin May 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added dictionary.dic
Binary file not shown.
6 changes: 6 additions & 0 deletions include/cudaq/Optimizer/Builder/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ inline mlir::LLVM::LLVMStructType stdVectorImplType(mlir::Type eleTy) {
return mlir::LLVM::LLVMStructType::getLiteral(ctx, eleTys);
}

inline mlir::Type stateImplType(mlir::Type eleTy) {
auto *ctx = eleTy.getContext();
auto eTy = cudaq::opt::factory::getCharType(ctx);
return cudaq::opt::factory::getPointerType(eTy);
Comment on lines +123 to +125
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified to just:

return cudaq::opt::factory::getPointerType(eleTy.getContext());

On the other hand, we don't use the eleTy, so we could erase this function and just use getPointerType at the call site.

}

// Host side types for std::string and std::vector

cudaq::cc::StructType stlStringType(mlir::MLIRContext *ctx);
Expand Down
2 changes: 2 additions & 0 deletions include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ static constexpr const char QIRArrayQubitAllocateArrayWithStateComplex32[] =
"__quantum__rt__qubit_allocate_array_with_state_complex32";
static constexpr const char QIRArrayQubitAllocateArrayWithStatePtr[] =
"__quantum__rt__qubit_allocate_array_with_state_ptr";
static constexpr const char QIRArrayQubitAllocateArrayWithCudaqStatePtr[] =
"__quantum__rt__qubit_allocate_array_with_cudaq_state_ptr";
static constexpr const char QIRQubitAllocate[] =
"__quantum__rt__qubit_allocate";
static constexpr const char QIRArrayQubitReleaseArray[] =
Expand Down
10 changes: 10 additions & 0 deletions lib/Optimizer/CodeGen/ConvertToQIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,17 @@ class QmemRAIIOpRewrite
StringRef functionName;
if (Type eleTy = dyn_cast<LLVM::LLVMPointerType>(ccState.getType())
.getElementType()) {
if (auto elePtrTy = dyn_cast<LLVM::LLVMPointerType>(eleTy))
eleTy = elePtrTy.getElementType();
if (auto arrayTy = dyn_cast<LLVM::LLVMArrayType>(eleTy))
eleTy = arrayTy.getElementType();
bool fromComplex = false;
if (auto complexTy = dyn_cast<LLVM::LLVMStructType>(eleTy)) {
fromComplex = true;
eleTy = complexTy.getBody()[0];
}
if (eleTy == rewriter.getI8Type())
functionName = cudaq::opt::QIRArrayQubitAllocateArrayWithCudaqStatePtr;
if (eleTy == rewriter.getF64Type())
functionName =
fromComplex
Expand Down Expand Up @@ -187,6 +191,7 @@ class QmemRAIIOpRewrite
// Create QIR allocation with initializer function.
auto *ctx = rewriter.getContext();
auto ptrTy = cudaq::opt::factory::getPointerType(ctx);

FlatSymbolRefAttr raiiSymbolRef =
cudaq::opt::factory::createLLVMFunctionSymbol(
functionName, array_qbit_type, {i64Ty, ptrTy}, parentModule);
Expand Down Expand Up @@ -2011,6 +2016,8 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) {
[](quake::VeqType type) { return getArrayType(type.getContext()); });
typeConverter.addConversion(
[](quake::RefType type) { return getQubitType(type.getContext()); });
typeConverter.addConversion(
[](cc::StateType type) { return factory::stateImplType(type); });
typeConverter.addConversion([](cc::CallableType type) {
return lambdaAsPairOfPointers(type.getContext());
});
Expand All @@ -2026,6 +2033,9 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) {
if (isa<NoneType>(eleTy))
return factory::getPointerType(type.getContext());
eleTy = typeConverter.convertType(eleTy);
if (isa<NoneType>(eleTy))
return factory::getPointerType(type.getContext());

Comment on lines +2036 to +2038
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't the recursion handle this? The recursive call on line 2035 ought to handle pointers to pointers.

if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy)) {
// If array has a static size, it becomes an LLVMArrayType.
assert(arrTy.isUnknownSize());
Expand Down
17 changes: 14 additions & 3 deletions lib/Optimizer/Transforms/GenKernelExecution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,10 @@ class GenerateKernelExecution
hasTrailingData = true;
continue;
}
if (isa<cudaq::cc::PointerType>(currEleTy))
if (isa<cudaq::cc::PointerType>(currEleTy) &&
!isStatePointerType(currEleTy)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't need braces.

continue;
}

// cast to the struct element type, void* -> TYPE *
argPtr = builder.create<cudaq::cc::CastOp>(
Expand Down Expand Up @@ -933,6 +935,13 @@ class GenerateKernelExecution
builder.create<cudaq::cc::StoreOp>(loc, endPtr, sret2);
}

static bool isStatePointerType(mlir::Type ty) {
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(ty)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no braces

return isa<cudaq::cc::StateType>(ptrTy.getElementType());
}
return false;
}

static MutableArrayRef<BlockArgument>
dropAnyHiddenArguments(MutableArrayRef<BlockArgument> args,
FunctionType funcTy, bool hasThisPointer) {
Expand All @@ -941,7 +950,8 @@ class GenerateKernelExecution
cudaq::cc::numberOfHiddenArgs(hasThisPointer, hiddenSRet);
if (count > 0 && args.size() >= count &&
std::all_of(args.begin(), args.begin() + count, [](auto i) {
return isa<cudaq::cc::PointerType>(i.getType());
return isa<cudaq::cc::PointerType>(i.getType()) &&
!isStatePointerType(i.getType());
Comment on lines +953 to +954
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this one. We're dropping hidden arguments. Why would a cudaq::state* appear as a this or sret? Do we have an example?

}))
return args.drop_front(count);
return args;
Expand Down Expand Up @@ -1207,8 +1217,9 @@ class GenerateKernelExecution
hasTrailingData = true;
continue;
}
if (isa<cudaq::cc::PointerType>(inTy))
if (isa<cudaq::cc::PointerType>(inTy) && !isStatePointerType(inTy))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct. We cannot pass a cudaq::state* as a pointer-free value.

continue;

stVal = builder.create<cudaq::cc::InsertValueOp>(loc, stVal.getType(),
stVal, arg, idx);
}
Expand Down
100 changes: 93 additions & 7 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
import ast
import hashlib
import graphlib
import sys, os
from typing import Callable
Expand All @@ -19,6 +20,8 @@
from ..mlir.dialects import builtin, func, arith, math, complex
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime, load_intrinsic, register_all_dialects

State = cudaq_runtime.State

# This file implements the CUDA-Q Python AST to MLIR conversion.
# It provides a `PyASTBridge` class that implements the `ast.NodeVisitor` type
# to walk the Python AST for a `cudaq.kernel` annotated function and generate
Expand Down Expand Up @@ -110,6 +113,8 @@ def __init__(self, **kwargs):
symbol table, which maps variable names to constructed `mlir.Values`.
"""
self.valueStack = deque()
self.cudaqStateHashes = kwargs[
'cudaqStateHashes'] if 'cudaqStateHashes' in kwargs else None
self.knownResultType = kwargs[
'knownResultType'] if 'knownResultType' in kwargs else None
if 'existingModule' in kwargs:
Expand Down Expand Up @@ -1842,19 +1847,27 @@ def bodyBuilder(iterVal):
return

if node.func.attr == 'qvector':
value = self.ifPointerThenLoad(self.popValue())
if (IntegerType.isinstance(value.type)):
valueOrPtr = self.popValue()
initializerTy = valueOrPtr.type

if cc.PointerType.isinstance(initializerTy):
initializerTy = cc.PointerType.getElementType(
initializerTy)

if (IntegerType.isinstance(initializerTy)):
# handle `cudaq.qvector(n)`
value = self.ifPointerThenLoad(valueOrPtr)
ty = self.getVeqType()
qubits = quake.AllocaOp(ty, size=value).result
self.pushValue(qubits)
return
if cc.StdvecType.isinstance(value.type):
if cc.StdvecType.isinstance(initializerTy):
# handle `cudaq.qvector(initState)`

# Validate the length in case of a constant initializer:
# `cudaq.qvector([1., 0., ...])`
# `cudaq.qvector(np.array([1., 0., ...]))`
value = self.ifPointerThenLoad(valueOrPtr)
listScalar = None
arrNode = node.args[0]
if isinstance(arrNode, ast.List):
Expand Down Expand Up @@ -1899,14 +1912,32 @@ def bodyBuilder(iterVal):
veqTy = quake.VeqType.get(self.ctx)

qubits = quake.AllocaOp(veqTy, size=numQubits).result
initials = cc.StdvecDataOp(ptrTy, value).result
data = cc.StdvecDataOp(ptrTy, value).result
init = quake.InitializeStateOp(veqTy, qubits,
initials).result
data).result
self.pushValue(init)
return

if cc.StateType.isinstance(initializerTy):
# handle `cudaq.qvector(state)`
statePtr = self.ifNotPointerThenStore(valueOrPtr)

symName = '__nvqpp_cudaq_state_numberOfQubits'
load_intrinsic(self.module, symName)
i64Ty = self.getIntegerType()
numQubits = func.CallOp([i64Ty], symName,
[statePtr]).result

veqTy = quake.VeqType.get(self.ctx)
qubits = quake.AllocaOp(veqTy, size=numQubits).result
init = quake.InitializeStateOp(veqTy, qubits,
statePtr).result

self.pushValue(init)
return

self.emitFatalError(
f"unsupported qvector argument type: {value.type} (unknown)",
f"unsupported qvector argument type: {value.type}",
node)
return

Expand Down Expand Up @@ -3330,6 +3361,54 @@ def visit_BinOp(self, node):
else:
self.emitFatalError(f"unhandled binary operator - {node.op}", node)

def __store_state(self, value: State):
# Compute a unique hash string for the state data
hashValue = hashlib.sha1(value).hexdigest(
annagrin marked this conversation as resolved.
Show resolved Hide resolved
)[:10] + self.name.removeprefix('__nvqppBuilderKernel_')
print(f'state has value for {self.name}: {hashValue}')
annagrin marked this conversation as resolved.
Show resolved Hide resolved

stateTy = cc.StateType.get(self.ctx)
statePtrTy = cc.PointerType.get(self.ctx, stateTy)

# Generate a function that stores the state value in a global
globalTy = statePtrTy
globalName = f'nvqpp.cudaq.state.{hashValue}'
setStateName = f'nvqpp.set.cudaq.state.{hashValue}'
with InsertionPoint.at_block_begin(self.module.body):
cc.GlobalOp(TypeAttr.get(globalTy), globalName, external=True)
setStateFunc = func.FuncOp(setStateName,
FunctionType.get(inputs=[statePtrTy],
results=[]),
loc=self.loc)
entry = setStateFunc.add_entry_block()
with InsertionPoint(entry):
zero = self.getConstantInt(0)
address = cc.AddressOfOp(cc.PointerType.get(self.ctx, globalTy),
FlatSymbolRefAttr.get(globalName))
ptr = cc.ComputePtrOp(
cc.PointerType.get(self.ctx, statePtrTy), address, [zero],
DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx))

cc.StoreOp(entry.arguments[0], ptr)
func.ReturnOp([])

# Record the unique hash value
if hashValue not in self.cudaqStateHashes:
self.cudaqStateHashes.append(hashValue)

# Store the state into a global variable
cudaq_runtime.storePointerToCudaqState(self.name, hashValue, value)

# Return the pointer to stored state
zero = self.getConstantInt(0)
address = cc.AddressOfOp(cc.PointerType.get(self.ctx, globalTy),
FlatSymbolRefAttr.get(globalName)).result
ptr = cc.ComputePtrOp(
cc.PointerType.get(self.ctx, statePtrTy), address, [zero],
DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)).result
statePtr = cc.LoadOp(ptr).result
return statePtr

def visit_Name(self, node):
"""
Visit `ast.Name` nodes and extract the correct value from the symbol table.
Expand Down Expand Up @@ -3378,6 +3457,11 @@ def visit_Name(self, node):
# Only support a small subset of types here
complexType = type(1j)
value = self.capturedVars[node.id]

if isinstance(value, State):
self.pushValue(self.__store_state(value))
return

if isinstance(value, (list, np.ndarray)) and isinstance(
value[0], (int, bool, float, np.float32, np.float64,
complexType, np.complex64, np.complex128)):
Expand Down Expand Up @@ -3445,7 +3529,7 @@ def visit_Name(self, node):
errorType = f"{errorType}[{type(value[0]).__name__}]"

self.emitFatalError(
f"Invalid type for variable ({node.id}) captured from parent scope (only int, bool, float, complex, and list/np.ndarray[int|bool|float|complex] accepted, type was {errorType}).",
f"Invalid type for variable ({node.id}) captured from parent scope (only int, bool, float, complex, cudaq.State, and list/np.ndarray[int|bool|float|complex] accepted, type was {errorType}).",
node)

# Throw an exception for the case that the name is not
Expand Down Expand Up @@ -3474,9 +3558,11 @@ def compile_to_mlir(astModule, metadata, **kwargs):
lineNumberOffset = kwargs['location'] if 'location' in kwargs else ('', 0)
parentVariables = kwargs[
'parentVariables'] if 'parentVariables' in kwargs else {}
cudaqStateHashes = kwargs['cudaqStateHashes']

# Create the AST Bridge
bridge = PyASTBridge(verbose=verbose,
cudaqStateHashes=cudaqStateHashes,
knownResultType=returnType,
returnTypeIsFromPython=True,
locationOffset=lineNumberOffset,
Expand Down
9 changes: 7 additions & 2 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class PyKernelDecorator(object):

def __init__(self, function, verbose=False, module=None, kernelName=None):
self.kernelFunction = function
self.cudaqStateHashes = []
self.module = None if module == None else module
self.verbose = verbose
self.name = kernelName if kernelName != None else self.kernelFunction.__name__
Expand Down Expand Up @@ -163,7 +164,8 @@ def compile(self):
verbose=self.verbose,
returnType=self.returnType,
location=self.location,
parentVariables=self.globalScopedVars)
parentVariables=self.globalScopedVars,
cudaqStateHashes=self.cudaqStateHashes)

# Grab the dependent capture variables, if any
self.dependentCaptures = extraMetadata[
Expand Down Expand Up @@ -294,13 +296,16 @@ def __call__(self, *args):
self.module,
*processedArgs,
callable_names=callableNames)
cudaq_runtime.deletePointersToCudaqState(self.cudaqStateHashes)
else:
return cudaq_runtime.pyAltLaunchKernelR(
result = cudaq_runtime.pyAltLaunchKernelR(
self.name,
self.module,
mlirTypeFromPyType(self.returnType, self.module.context),
*processedArgs,
callable_names=callableNames)
cudaq_runtime.deletePointersToCudaqState(self.cudaqStateHashes)
return result


def kernel(function=None, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions python/cudaq/kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable, List
import ast, sys, traceback

State = cudaq_runtime.State
qvector = cudaq_runtime.qvector
qubit = cudaq_runtime.qubit
pauli_word = cudaq_runtime.pauli_word
Expand Down Expand Up @@ -87,6 +88,8 @@ def emitFatalErrorOverride(msg):
if annotation.value.id == 'cudaq':
if annotation.attr in ['qview', 'qvector']:
return quake.VeqType.get(ctx)
if annotation.attr in ['State']:
return cc.PointerType.get(ctx, cc.StateType.get(ctx))
if annotation.attr == 'qubit':
return quake.RefType.get(ctx)
if annotation.attr == 'pauli_word':
Expand Down Expand Up @@ -193,6 +196,8 @@ def mlirTypeFromPyType(argType, ctx, **kwargs):
return ComplexType.get(mlirTypeFromPyType(np.float32, ctx))
if argType == pauli_word:
return cc.CharspanType.get(ctx)
if argType == State:
return cc.PointerType.get(ctx, cc.StateType.get(ctx))

if argType in [list, np.ndarray, List]:
if 'argInstance' not in kwargs:
Expand Down
18 changes: 14 additions & 4 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ void bindPyState(py::module &mod) {
dataTypeSize}, /* strides */
true /* readonly */
);

return py::buffer_info(dataPtr, dataTypeSize, /*itemsize */
desc, 1, /* ndim */
{shape[0]}, /* shape */
Expand Down Expand Up @@ -153,9 +152,20 @@ void bindPyState(py::module &mod) {
reinterpret_cast<std::complex<float> *>(info.ptr),
info.size));
}

return state::from_data(std::make_pair(
reinterpret_cast<std::complex<double> *>(info.ptr), info.size));
if (info.format ==
py::format_descriptor<std::complex<double>>::format()) {
return state::from_data(std::make_pair(
reinterpret_cast<std::complex<double> *>(info.ptr),
info.size));
}
throw std::runtime_error(
"A numpy array with only floating point elements passed to "
"state.from_data. input must be of complex float type, "
"please "
"add to your array creation `dtype=numpy.complex64` if "
"simulation is FP32 and `dtype=numpy.complex128` if "
"simulation if FP64, or dtype=cudaq.complex() for "
"precision-agnostic code");
},
"Return a state from data.")
.def_static(
Expand Down
Loading
Loading