Skip to content

Commit

Permalink
Simplify implementation after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
1tnguyen committed May 29, 2024
1 parent 3d064e6 commit 59d218e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 35 deletions.
35 changes: 0 additions & 35 deletions lib/Optimizer/CodeGen/ConvertToQIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ static LLVM::LLVMStructType lambdaAsPairOfPointers(MLIRContext *context) {
return LLVM::LLVMStructType::getLiteral(context, pairOfPointers);
}

static mlir::Type getStateType(mlir::MLIRContext *context) {
return mlir::LLVM::LLVMStructType::getOpaque("cudaq::state", context);
}
namespace {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -141,36 +138,6 @@ class QmemRAIIOpRewrite
// Get the CC Pointer for the state
auto ccState = adaptor.getInitState();

auto stateTy = getStateType(rewriter.getContext());
// If this is a `state` input
if (ccState.getType() == mlir::LLVM::LLVMPointerType::get(stateTy)) {
auto *ctx = rewriter.getContext();
auto ptrTy = cudaq::opt::factory::getPointerType(ctx);
FlatSymbolRefAttr getSimStateSymbolRef =
cudaq::opt::factory::createLLVMFunctionSymbol(
"__nvqpp_cudaq_state_getSimulationState", ptrTy, {ptrTy},
parentModule);

FlatSymbolRefAttr allocateWithStateSymbolRef =
cudaq::opt::factory::createLLVMFunctionSymbol(
"__quantum__rt__qubit_allocate_array_with_state_ptr",
array_qbit_type, {ptrTy}, parentModule);

// Call the allocation function
Value castedInitState =
rewriter.create<LLVM::BitcastOp>(loc, ptrTy, ccState);
// Get the underlying `SimulationState`
Value initSimState =
rewriter
.create<LLVM::CallOp>(loc, ptrTy, getSimStateSymbolRef,
ArrayRef<Value>{castedInitState})
.getResults()[0];
rewriter.replaceOpWithNewOp<LLVM::CallOp>(raii, array_qbit_type,
allocateWithStateSymbolRef,
ArrayRef<Value>{initSimState});
return success();
}
// This is state vector input.
// Inspect the element type of the complex data, need to
// know if its f32 or f64
StringRef functionName;
Expand Down Expand Up @@ -2065,8 +2032,6 @@ void cudaq::opt::initializeTypeConversions(LLVMTypeConverter &typeConverter) {
auto eleTy = type.getElementType();
if (isa<NoneType>(eleTy))
return factory::getPointerType(type.getContext());
if (isa<cc::StateType>(eleTy))
return factory::getPointerType(getStateType(type.getContext()));
eleTy = typeConverter.convertType(eleTy);
if (isa<NoneType>(eleTy))
return factory::getPointerType(type.getContext());
Expand Down
1 change: 1 addition & 0 deletions runtime/cudaq/builder/kernel_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ __nvqpp_getStateVectorData_fp32(StateVectorStorage &stateVectorStorage,
return getStateVectorData<float>(stateVectorStorage, index);
}

/// Runtime callback to get the state pointer of a captured `cudaq::state`.
cudaq::state *__nvqpp_getStatePtr(StateVectorStorage &stateVectorStorage,
std::intptr_t index) {
return std::get<cudaq::state *>(stateVectorStorage[index]);
Expand Down

0 comments on commit 59d218e

Please sign in to comment.