-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Python] Qvector init from state #1713
Changes from 29 commits
ebd7574
aa542b1
138174f
495ce33
a65fd11
b96ea19
45bab2a
8964dfa
06b405f
4abd9c2
5bf877f
38eb835
742e7ff
f84a810
d938e29
4444e2d
d95794f
6f06de7
b7a6eaa
03c9f8b
8dca23a
aa9f8e4
078dffa
c3c8580
9cd4c69
a6dee4a
94db953
6575bea
ab7eeaa
bd60230
138ea58
1e6a28f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,13 +143,17 @@ class QmemRAIIOpRewrite | |
StringRef functionName; | ||
if (Type eleTy = dyn_cast<LLVM::LLVMPointerType>(ccState.getType()) | ||
.getElementType()) { | ||
if (auto elePtrTy = dyn_cast<LLVM::LLVMPointerType>(eleTy)) | ||
eleTy = elePtrTy.getElementType(); | ||
if (auto arrayTy = dyn_cast<LLVM::LLVMArrayType>(eleTy)) | ||
eleTy = arrayTy.getElementType(); | ||
bool fromComplex = false; | ||
if (auto complexTy = dyn_cast<LLVM::LLVMStructType>(eleTy)) { | ||
fromComplex = true; | ||
eleTy = complexTy.getBody()[0]; | ||
} | ||
if (eleTy == rewriter.getI8Type()) | ||
functionName = cudaq::opt::QIRArrayQubitAllocateArrayWithCudaqStatePtr; | ||
if (eleTy == rewriter.getF64Type()) | ||
functionName = | ||
fromComplex | ||
|
@@ -187,6 +191,7 @@ class QmemRAIIOpRewrite | |
// Create QIR allocation with initializer function. | ||
auto *ctx = rewriter.getContext(); | ||
auto ptrTy = cudaq::opt::factory::getPointerType(ctx); | ||
|
||
FlatSymbolRefAttr raiiSymbolRef = | ||
cudaq::opt::factory::createLLVMFunctionSymbol( | ||
functionName, array_qbit_type, {i64Ty, ptrTy}, parentModule); | ||
|
@@ -2011,6 +2016,8 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) { | |
[](quake::VeqType type) { return getArrayType(type.getContext()); }); | ||
typeConverter.addConversion( | ||
[](quake::RefType type) { return getQubitType(type.getContext()); }); | ||
typeConverter.addConversion( | ||
[](cc::StateType type) { return factory::stateImplType(type); }); | ||
typeConverter.addConversion([](cc::CallableType type) { | ||
return lambdaAsPairOfPointers(type.getContext()); | ||
}); | ||
|
@@ -2026,6 +2033,9 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) { | |
if (isa<NoneType>(eleTy)) | ||
return factory::getPointerType(type.getContext()); | ||
eleTy = typeConverter.convertType(eleTy); | ||
if (isa<NoneType>(eleTy)) | ||
return factory::getPointerType(type.getContext()); | ||
|
||
Comment on lines
+2036
to
+2038
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why doesn't the recursion handle this? The recursive call on line 2035 ought to handle pointers to pointers. |
||
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy)) { | ||
// If array has a static size, it becomes an LLVMArrayType. | ||
assert(arrTy.isUnknownSize()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,8 +434,10 @@ class GenerateKernelExecution | |
hasTrailingData = true; | ||
continue; | ||
} | ||
if (isa<cudaq::cc::PointerType>(currEleTy)) | ||
if (isa<cudaq::cc::PointerType>(currEleTy) && | ||
!isStatePointerType(currEleTy)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't need braces. |
||
continue; | ||
} | ||
|
||
// cast to the struct element type, void* -> TYPE * | ||
argPtr = builder.create<cudaq::cc::CastOp>( | ||
|
@@ -933,6 +935,13 @@ class GenerateKernelExecution | |
builder.create<cudaq::cc::StoreOp>(loc, endPtr, sret2); | ||
} | ||
|
||
static bool isStatePointerType(mlir::Type ty) { | ||
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(ty)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: no braces |
||
return isa<cudaq::cc::StateType>(ptrTy.getElementType()); | ||
} | ||
return false; | ||
} | ||
|
||
static MutableArrayRef<BlockArgument> | ||
dropAnyHiddenArguments(MutableArrayRef<BlockArgument> args, | ||
FunctionType funcTy, bool hasThisPointer) { | ||
|
@@ -941,7 +950,8 @@ class GenerateKernelExecution | |
cudaq::cc::numberOfHiddenArgs(hasThisPointer, hiddenSRet); | ||
if (count > 0 && args.size() >= count && | ||
std::all_of(args.begin(), args.begin() + count, [](auto i) { | ||
return isa<cudaq::cc::PointerType>(i.getType()); | ||
return isa<cudaq::cc::PointerType>(i.getType()) && | ||
!isStatePointerType(i.getType()); | ||
Comment on lines
+953
to
+954
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get this one. We're dropping hidden arguments. Why would a |
||
})) | ||
return args.drop_front(count); | ||
return args; | ||
|
@@ -1207,8 +1217,9 @@ class GenerateKernelExecution | |
hasTrailingData = true; | ||
continue; | ||
} | ||
if (isa<cudaq::cc::PointerType>(inTy)) | ||
if (isa<cudaq::cc::PointerType>(inTy) && !isStatePointerType(inTy)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look correct. We cannot pass a |
||
continue; | ||
|
||
stVal = builder.create<cudaq::cc::InsertValueOp>(loc, stVal.getType(), | ||
stVal, arg, idx); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified to just:
return cudaq::opt::factory::getPointerType(eleTy.getContext());
On the other hand, we don't use the
eleTy
, so we could erase this function and just usegetPointerType
at the call site.