diff --git a/flang/test/Lower/OpenMP/derived-type-map.f90 b/flang/test/Lower/OpenMP/derived-type-map.f90 index 0b08aacdc6b59..4c002126b91e4 100644 --- a/flang/test/Lower/OpenMP/derived-type-map.f90 +++ b/flang/test/Lower/OpenMP/derived-type-map.f90 @@ -1,5 +1,6 @@ !RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s +!CHECK: omp.declare_mapper @[[MAPPER1:_QQFmaptype_derived_implicit_allocatablescalar_and_array_omp_default_mapper]] : !fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> { !CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicitEscalar_arr"} !CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_implicitEscalar_arr"} : (!fir.ref,int:i32}>>) -> (!fir.ref,int:i32}>>, !fir.ref,int:i32}>>) @@ -18,6 +19,26 @@ subroutine mapType_derived_implicit !$omp end target end subroutine mapType_derived_implicit +!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box,int:i32}>>> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"} +!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"} : (!fir.ref,int:i32}>>>>) -> (!fir.ref,int:i32}>>>>, !fir.ref,int:i32}>>>>) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_offset %[[DECLARE]]#1 base_addr : (!fir.ref,int:i32}>>>>) -> !fir.llvm_ptr,int:i32}>>> +!CHECK: %[[BASE_MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref,int:i32}>>>>, !fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%[[BOX_ADDR]] : !fir.llvm_ptr,int:i32}>>>) mapper(@[[MAPPER1]]) -> !fir.llvm_ptr,int:i32}>>> {name = ""} +!CHECK: %[[DESC_MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref,int:i32}>>>>, !fir.box,int:i32}>>>) map_clauses(always, implicit, descriptor, to, attach) capture(ByRef) members(%[[BASE_MAP]] : [0] : !fir.llvm_ptr,int:i32}>>>) -> !fir.ref,int:i32}>>>> {name = "scalar_arr"} +!CHECK: omp.target map_entries(%[[DESC_MAP]] -> %[[ARG0:.*]], %[[BASE_MAP]] -> %[[ARG1:.*]] : !fir.ref,int:i32}>>>>, !fir.llvm_ptr,int:i32}>>>) { +subroutine mapType_derived_implicit_allocatable + type :: scalar_and_array + real(4) :: real + integer(4) :: array(10) + integer(4) :: int + end type scalar_and_array + type(scalar_and_array), allocatable :: scalar_arr + + allocate (scalar_arr) + !$omp target + scalar_arr%int = 1 + !$omp end target +end subroutine mapType_derived_implicit_allocatable + !CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_explicitEscalar_arr"} !CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_explicitEscalar_arr"} : (!fir.ref,int:i32}>>) -> (!fir.ref,int:i32}>>, !fir.ref,int:i32}>>) !CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref,int:i32}>>, !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref,int:i32}>> {name = "scalar_arr"} diff --git a/flang/test/Lower/OpenMP/map-character.f90 b/flang/test/Lower/OpenMP/map-character.f90 index 4e57c18cac10a..93bbfe14af872 100644 --- a/flang/test/Lower/OpenMP/map-character.f90 +++ b/flang/test/Lower/OpenMP/map-character.f90 @@ -63,4 +63,3 @@ end subroutine TestOfCharacter !CHECK: %[[UNBOXED_TGT_A0:.*]]:2 = fir.unboxchar %[[TGT_A0_BC_LD]] : (!fir.boxchar<1>) -> (!fir.ref>, index) !CHECK: %[[TGT_A0_DECL:.*]]:2 = hlfir.declare %[[TGT_A0]] typeparams %[[UNBOXED_TGT_A0]]#1 {{.*}} -> (!fir.boxchar<1>, !fir.ref>) !CHECK: %[[TGT_A1_DECL:.*]]:2 = hlfir.declare %[[TGT_A1]] typeparams %[[UNBOXED_TGT_A1]]#1 {{.*}} -> (!fir.boxchar<1>, !fir.ref>) - diff --git a/flang/test/Lower/OpenMP/map-neg-alloca-derived-type-array.f90 b/flang/test/Lower/OpenMP/map-neg-alloca-derived-type-array.f90 index 7ad8605144038..ef0bfbe157f17 100644 --- a/flang/test/Lower/OpenMP/map-neg-alloca-derived-type-array.f90 +++ b/flang/test/Lower/OpenMP/map-neg-alloca-derived-type-array.f90 @@ -24,4 +24,4 @@ subroutine map_negative_bounds_allocatable_dtype() ! CHECK: %[[VAL_11:.*]] = fir.coordinate_of %[[VAL_10]], data : (!fir.ref>>}>>) -> !fir.ref>>> ! CHECK: %[[VAL_12:.*]] = fir.box_offset %[[VAL_11]] base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> ! CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>>>, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%[[VAL_12]] : !fir.llvm_ptr>>) bounds({{.*}}) -> !fir.llvm_ptr>> {name = ""} -! CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>>>, !fir.box>>) map_clauses(to) capture(ByRef) -> !fir.ref>>> {name = {{.*}}} +! CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>>>, !fir.box>>) map_clauses(always, to) capture(ByRef) -> !fir.ref>>> {name = {{.*}}} diff --git a/flang/test/Lower/OpenMP/optional-argument-map-2.f90 b/flang/test/Lower/OpenMP/optional-argument-map-2.f90 index 79a3d9a5ea823..9c138b13567d6 100644 --- a/flang/test/Lower/OpenMP/optional-argument-map-2.f90 +++ b/flang/test/Lower/OpenMP/optional-argument-map-2.f90 @@ -105,7 +105,7 @@ end module mod ! CHECK-NO-FPRIV: %[[VAL_21:.*]] = omp.map.bounds lower_bound(%[[VAL_17]] : index) upper_bound(%[[VAL_20]] : index) extent(%[[VAL_19]]#1 : index) stride(%[[VAL_18]] : index) start_idx(%[[VAL_17]] : index) {stride_in_bytes = true} ! CHECK-NO-FPRIV: %[[VAL_22:.*]] = fir.box_offset %[[VAL_0]] base_addr : (!fir.ref>) -> !fir.llvm_ptr>> ! CHECK-NO-FPRIV: %[[VAL_23:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>, !fir.char<1,?>) map_clauses(implicit, to) capture(ByRef) var_ptr_ptr(%[[VAL_22]] : !fir.llvm_ptr>>) bounds(%14) -> !fir.llvm_ptr>> {name = ""} -! CHECK-NO-FPRIV: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>, !fir.boxchar<1>) map_clauses(implicit, to) capture(ByRef) members(%[[VAL_23]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = ""} +! CHECK-NO-FPRIV: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>, !fir.boxchar<1>) map_clauses(always, implicit, to) capture(ByRef) members(%[[VAL_23]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = ""} ! CHECK-NO-FPRIV: omp.target map_entries(%[[VAL_7]] -> %[[VAL_25:.*]], %[[VAL_16]] -> %[[VAL_26:.*]], %[[VAL_24]] -> %[[VAL_27:.*]], %[[VAL_23]] -> %[[VAL_28:.*]] : !fir.ref>, !fir.ref>, !fir.ref>, !fir.llvm_ptr>>) { ! CHECK-NO-FPRIV: %[[VAL_29:.*]] = fir.load %[[VAL_27]] : !fir.ref> ! CHECK-NO-FPRIV: %[[VAL_30:.*]]:2 = fir.unboxchar %[[VAL_29]] : (!fir.boxchar<1>) -> (!fir.ref>, index) diff --git a/flang/test/Lower/OpenMP/optional-argument-map-3.f90 b/flang/test/Lower/OpenMP/optional-argument-map-3.f90 index 1c3cc50d32d67..78947eb00fe9f 100644 --- a/flang/test/Lower/OpenMP/optional-argument-map-3.f90 +++ b/flang/test/Lower/OpenMP/optional-argument-map-3.f90 @@ -33,7 +33,7 @@ end subroutine foo ! CHECK: } ! CHECK: %[[VAL_3:.*]] = fir.box_offset %[[VAL_0]] base_addr : (!fir.ref>>) -> !fir.llvm_ptr>> ! CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, f32) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%[[VAL_3]] : !fir.llvm_ptr>>) bounds(%{{.*}}) -> !fir.llvm_ptr>> {name = ""} -! CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members(%[[VAL_4]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = "dt"} +! CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, !fir.box>) map_clauses(always, implicit, to) capture(ByRef) members(%[[VAL_4]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = "dt"} ! CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}%[[VAL_5]] -> {{.*}}, %[[VAL_4]] -> {{.*}} : {{.*}}) { ! CHECK: } else { ! CHECK: %[[VAL_6:.*]] = fir.is_present %[[VAL_1]]#1 : (!fir.box>) -> i1 @@ -42,5 +42,5 @@ end subroutine foo ! CHECK: } ! CHECK: %[[VAL_7:.*]] = fir.box_offset %[[VAL_0]] base_addr : (!fir.ref>>) -> !fir.llvm_ptr>> ! CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, f32) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%[[VAL_7]] : !fir.llvm_ptr>>) bounds(%{{.*}}) -> !fir.llvm_ptr>> {name = ""} -! CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members(%[[VAL_8]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = "dt"} +! CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref>>, !fir.box>) map_clauses(always, implicit, to) capture(ByRef) members(%[[VAL_8]] : [0] : !fir.llvm_ptr>>) -> !fir.ref> {name = "dt"} ! CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[VAL_9]] ->{{.*}}, %[[VAL_8]] -> {{.*}} : {{.*}}) { diff --git a/flang/test/Lower/OpenMP/target-enter-data-default-openmp52.f90 b/flang/test/Lower/OpenMP/target-enter-data-default-openmp52.f90 index d5311d7f1a6dc..e95d829e9768a 100644 --- a/flang/test/Lower/OpenMP/target-enter-data-default-openmp52.f90 +++ b/flang/test/Lower/OpenMP/target-enter-data-default-openmp52.f90 @@ -11,7 +11,7 @@ subroutine initialize() allocate(A) !$omp target enter data map(A) !CHECK-52: omp.map.info var_ptr(%2 : !fir.ref>>, f32) map_clauses(to) capture(ByRef) var_ptr_ptr(%5 : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} - !CHECK-52: omp.map.info var_ptr(%2 : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%6 : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} + !CHECK-52: omp.map.info var_ptr(%2 : !fir.ref>>, !fir.box>) map_clauses(always, to) capture(ByRef) members(%6 : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} !CHECK-51: to and alloc map types are permitted end subroutine initialize diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index c2951cce1c8fd..0b6fc30b8db70 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3712,6 +3712,17 @@ convertToCaptureClauseKind( llvm_unreachable("unhandled capture clause"); } +static Operation *getGlobalOpFromValue(Value value) { + Operation *op = value.getDefiningOp(); + if (auto addrCast = dyn_cast_if_present(op)) + op = addrCast->getOperand(0).getDefiningOp(); + if (auto addressOfOp = dyn_cast_if_present(op)) { + auto modOp = addressOfOp->getParentOfType(); + return modOp.lookupSymbol(addressOfOp.getGlobalName()); + } + return nullptr; +} + static llvm::SmallString<64> getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder) { @@ -3735,85 +3746,57 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, } static bool isDeclareTargetLink(Value value) { - Operation *op = value.getDefiningOp(); - if (auto addrCast = llvm::dyn_cast_if_present(op)) - op = addrCast->getOperand(0).getDefiningOp(); - - if (auto addressOfOp = llvm::dyn_cast_if_present(op)) { - auto modOp = addressOfOp->getParentOfType(); - Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); - if (auto declareTargetGlobal = - llvm::dyn_cast(gOp)) - if (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) - return true; - } + if (auto declareTargetGlobal = + dyn_cast_if_present( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) + return true; return false; } static bool isDeclareTargetTo(Value value) { - Operation *op = value.getDefiningOp(); - if (auto addrCast = llvm::dyn_cast_if_present(op)) - op = addrCast->getOperand(0).getDefiningOp(); - - if (auto addressOfOp = llvm::dyn_cast_if_present(op)) { - auto modOp = addressOfOp->getParentOfType(); - Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); - if (auto declareTargetGlobal = - llvm::dyn_cast(gOp)) { - if (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::to || - declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::enter) - return true; - } - } + if (auto declareTargetGlobal = + dyn_cast_if_present( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to || + declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::enter) + return true; return false; } -// Returns the reference pointer generated by the lowering of the declare target -// operation in cases where the link clause is used or the to clause is used in -// USM mode. +// Returns the reference pointer generated by the lowering of the declare +// target operation in cases where the link clause is used or the to clause is +// used in USM mode. static llvm::Value * -getRefPtrIfDeclareTarget(mlir::Value value, +getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - Operation *op = value.getDefiningOp(); - if (auto addrCast = llvm::dyn_cast_if_present(op)) - op = addrCast->getOperand(0).getDefiningOp(); - - // An easier way to do this may just be to keep track of any pointer - // references and their mapping to their respective operation - if (auto addressOfOp = llvm::dyn_cast_if_present(op)) { - if (auto gOp = llvm::dyn_cast_or_null( - addressOfOp->getParentOfType().lookupSymbol( - addressOfOp.getGlobalName()))) { - - if (auto declareTargetGlobal = - llvm::dyn_cast( - gOp.getOperation())) { - - // In this case, we must utilise the reference pointer generated by the - // declare target operation, similar to Clang - if ((declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) || - (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::to && - ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { - llvm::SmallString<64> suffix = - getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); - - if (gOp.getSymName().contains(suffix)) - return moduleTranslation.getLLVMModule()->getNamedValue( - gOp.getSymName()); + if (auto gOp = + dyn_cast_or_null(getGlobalOpFromValue(value))) { + if (auto declareTargetGlobal = + dyn_cast(gOp.getOperation())) { + // In this case, we must utilise the reference pointer generated by + // the declare target operation, similar to Clang + if ((declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) || + (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to && + ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { + llvm::SmallString<64> suffix = + getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); + if (gOp.getSymName().contains(suffix)) return moduleTranslation.getLLVMModule()->getNamedValue( - (gOp.getSymName().str() + suffix.str()).str()); - } + gOp.getSymName()); + + return moduleTranslation.getLLVMModule()->getNamedValue( + (gOp.getSymName().str() + suffix.str()).str()); } } } - return nullptr; } @@ -3857,7 +3840,7 @@ struct MapInfoData : MapInfosTy { } }; -enum class TargetDirective : uint32_t { +enum class TargetDirectiveEnumTy : uint32_t { None = 0, Target = 1, TargetData = 2, @@ -3866,18 +3849,20 @@ enum class TargetDirective : uint32_t { TargetUpdate = 5 }; -static TargetDirective getTargetDirectiveFromOp(Operation *op) { - return llvm::TypeSwitch(op) - .Case([](omp::TargetDataOp) { return TargetDirective::TargetData; }) +static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) { + return llvm::TypeSwitch(op) + .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; }) .Case([](omp::TargetEnterDataOp) { - return TargetDirective::TargetEnterData; + return TargetDirectiveEnumTy::TargetEnterData; }) .Case([&](omp::TargetExitDataOp) { - return TargetDirective::TargetExitData; + return TargetDirectiveEnumTy::TargetExitData; + }) + .Case([&](omp::TargetUpdateOp) { + return TargetDirectiveEnumTy::TargetUpdate; }) - .Case([&](omp::TargetUpdateOp) { return TargetDirective::TargetUpdate; }) - .Case([&](omp::TargetOp) { return TargetDirective::Target; }) - .Default([&](Operation *op) { return TargetDirective::None; }); + .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; }) + .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; }); } } // namespace @@ -3911,7 +3896,7 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, // This calculates the size to transfer based on bounds and the underlying // element type, provided bounds have been specified (Fortran // pointers/allocatables/target and arrays that have sections specified fall - // into this as well). + // into this as well) if (!memberClause.getBounds().empty()) { llvm::Value *elementCount = builder.getInt64(1); for (auto bounds : memberClause.getBounds()) { @@ -4235,9 +4220,9 @@ static void sortMapIndices(llvm::SmallVector &indices, }); } -static mlir::omp::MapInfoOp -getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) { - mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); +static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, + bool first) { + ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); // Only 1 member has been mapped, we can return it. if (indexAttr.size() == 1) if (auto mapOp = @@ -4439,7 +4424,7 @@ static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, - TargetDirective targetDirective, + TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam = true, int mapDataParentIdx = -1) { @@ -4453,7 +4438,7 @@ processIndividualMap(llvm::IRBuilderBase &builder, if (isPtrTy && mapData.IsDeclareTarget[mapDataIdx]) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; - if (isTargetParam && (targetDirective == TargetDirective::Target && + if (isTargetParam && (targetDirective == TargetDirectiveEnumTy::Target && !mapData.IsDeclareTarget[mapDataIdx])) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; @@ -4566,7 +4551,7 @@ void processAttachMap(LLVM::ModuleTranslation &moduleTranslation, uint64_t mapDataIndex, bool parentMap, llvm::SmallVectorImpl &immediateMapDataIdxs, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, - TargetDirective targetDirective) { + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); auto parentClause = @@ -4585,7 +4570,7 @@ void processAttachMap(LLVM::ModuleTranslation &moduleTranslation, // We only wish to apply this if this specific map will be the input // parameter to the kernel for the collection of maps that are linked // together. - if (parentMap && (targetDirective == TargetDirective::Target && + if (parentMap && (targetDirective == TargetDirectiveEnumTy::Target && !mapData.IsDeclareTarget[mapDataIndex])) descMapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; // We move attach onto the "binding" attach map, the initial map @@ -4682,7 +4667,7 @@ static void mapParentWithMembers( llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, - TargetDirective targetDirective) { + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); // Map the first segment of the parent. If a user-defined mapper is attached, @@ -4690,7 +4675,7 @@ static void mapParentWithMembers( // base entry so the mapper receives correct copy semantics via its 'type' // parameter. Also keep TARGET_PARAM when required for kernel arguments. llvm::omp::OpenMPOffloadMappingFlags baseFlag = - (targetDirective == TargetDirective::Target && + (targetDirective == TargetDirectiveEnumTy::Target && !mapData.IsDeclareTarget[mapDataIndex]) ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; @@ -4774,7 +4759,7 @@ static void mapParentWithMembers( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - if (targetDirective == TargetDirective::TargetUpdate || hasMapClose) { + if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) { combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); @@ -4853,7 +4838,7 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, - TargetDirective targetDirective) { + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -5013,10 +4998,10 @@ createAlteredByCaptureMap(MapInfoData &mapData, static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, - MapInfoData &mapData, TargetDirective targetDirective) { + MapInfoData &mapData, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); - // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can // involve generating new loads and stores, which changes the @@ -5058,11 +5043,13 @@ static void genMapInfos(llvm::IRBuilderBase &builder, static llvm::Expected emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName); + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective); static llvm::Expected getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { + LLVM::ModuleTranslation &moduleTranslation, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast(op); @@ -5074,13 +5061,14 @@ getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, return lookupFunc; return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, - mapperFuncName); + mapperFuncName, targetDirective); } static llvm::Expected emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName) { + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast(op); @@ -5109,10 +5097,10 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData, - TargetDirective::None); + TargetDirectiveEnumTy::None); - // Drop the mapping that is no longer necessary so that the same region can - // be processed multiple times. + // Drop the mapping that is no longer necessary so that the same region + // can be processed multiple times. moduleTranslation.forgetMapping(declMapperOp.getRegion()); return combinedInfo; }; @@ -5121,7 +5109,7 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, if (!combinedInfo.Mappers[i]) return nullptr; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Expected newFn = ompBuilder->emitUserDefinedMapper( @@ -5142,11 +5130,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, SmallVector useDeviceAddrVars; llvm::omp::RuntimeFunction RTLFn; DataLayout DL = DataLayout(op->getParentOfType()); - TargetDirective targetDirective = getTargetDirectiveFromOp(op); + TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true, - /*SeparateBeginEndCalls=*/true); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/true, + /*SeparateBeginEndCalls=*/true); bool isTargetDevice = ompBuilder->Config.isTargetDevice(); bool isOffloadEntry = isTargetDevice || !ompBuilder->Config.TargetTriples.empty(); @@ -5375,7 +5364,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); @@ -5652,10 +5641,9 @@ handleDeclareTargetMapVar(MapInfoData &mapData, moduleTranslation.getOpenMPBuilder() ->Config.hasRequiresUnifiedSharedMemory())) { builder.SetCurrentDebugLocation(insn->getDebugLoc()); - auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), + substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(), mapData.BasePointers[i]); - load->moveBefore(insn); - substitute = load; + cast(substitute)->moveBefore(insn->getIterator()); } user->replaceUsesOfWith(mapData.OriginalValue[i], substitute); } @@ -5982,8 +5970,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, int32_t minTeamsVal = 1, maxTeamsVal = -1; if (castOrGetParentOfType(capturedOp)) { - // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match - // clang and set min and max to the same value. + // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, + // match clang and set min and max to the same value. if (numTeamsUpper) { if (auto val = extractConstInteger(numTeamsUpper)) minTeamsVal = maxTeamsVal = *val; @@ -6174,9 +6162,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto parentFn = opInst.getParentOfType(); auto argIface = cast(opInst); auto &targetRegion = targetOp.getRegion(); - // Holds the private vars that have been mapped along with the block argument - // that corresponds to the MapInfoOp corresponding to the private var in - // question. So, for instance: + // Holds the private vars that have been mapped along with the block + // argument that corresponds to the MapInfoOp corresponding to the private + // var in question. So, for instance: // // %10 = omp.map.info var_ptr(%6#0 : !fir.ref>>, ..) // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1) @@ -6191,7 +6179,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, ArrayRef mapBlockArgs = argIface.getMapBlockArgs(); ArrayRef hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs(); llvm::Function *llvmOutlinedFn = nullptr; - TargetDirective targetDirective = getTargetDirectiveFromOp(&opInst); + TargetDirectiveEnumTy targetDirective = + getTargetDirectiveEnumTyFromOp(&opInst); // TODO: It can also be false if a compile-time constant `false` IF clause is // specified. @@ -6425,7 +6414,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Value *ifCond = nullptr; diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir new file mode 100644 index 0000000000000..93a417430ecee --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// This tests the replacement of operations for `declare target to` with the +// generated `declare target to` global variable inside of target op regions when +// lowering to IR for device. Unfortunately, as the host file is not passed as a +// module attribute, we miss out on the metadata and entry info. + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer + llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget} : !llvm.array<11 x f32> { + %0 = llvm.mlir.zero : !llvm.array<11 x f32> + llvm.return %0 : !llvm.array<11 x f32> + } + + // CHECK-LABEL: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %{{.*}}) {{.*}} { + // CHECK-DAG: omp.target: + // CHECK-DAG: store float 1.000000e+00, ptr @_QMtest_0Ezii, align 4 + // CHECK-DAG: br label %omp.region.cont + llvm.func @_QQmain() { + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(11 : index) : i64 + %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr + %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr + omp.target map_entries(%5 -> %arg0 : !llvm.ptr) { + %6 = llvm.mlir.constant(1.0 : f32) : f32 + %7 = llvm.mlir.constant(0 : i64) : i64 + %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %6, %8 : f32, !llvm.ptr + omp.terminator + } + llvm.return + } +} diff --git a/offload/test/offloading/fortran/declare-target-to-allocatable-vars-in-target-with-update.f90 b/offload/test/offloading/fortran/declare-target-to-allocatable-vars-in-target-with-update.f90 new file mode 100644 index 0000000000000..727a08b093400 --- /dev/null +++ b/offload/test/offloading/fortran/declare-target-to-allocatable-vars-in-target-with-update.f90 @@ -0,0 +1,41 @@ +! Test that checks an allocatable array can be marked implicit +! `declare target to` and functions without issue. +! REQUIRES: flang, amdgpu + +! RUN: %libomptarget-compile-fortran-run-and-check-generic +module test + implicit none + integer, allocatable, dimension(:) :: alloca_arr + !$omp declare target(alloca_arr) +end module test + +program main + use test + implicit none + integer :: cycle, i + + allocate(alloca_arr(10)) + + do i = 1, 10 + alloca_arr(i) = 0 + end do + + !$omp target data map(to:alloca_arr) + do cycle = 1, 2 + !$omp target + do i = 1, 10 + alloca_arr(i) = alloca_arr(i) + i + end do + !$omp end target + + ! NOTE: Technically doesn't affect the results, but there is a + ! regression case that'll cause a runtime crash if this is + ! invoked more than once, so this checks for that. + !$omp target update from(alloca_arr) + end do + !$omp end target data + + print *, alloca_arr +end program + +! CHECK: 2 4 6 8 10 12 14 16 18 20 diff --git a/offload/test/offloading/fortran/declare-target-to-vars-target-region-and-update.f90 b/offload/test/offloading/fortran/declare-target-to-vars-target-region-and-update.f90 new file mode 100644 index 0000000000000..16433af2c922c --- /dev/null +++ b/offload/test/offloading/fortran/declare-target-to-vars-target-region-and-update.f90 @@ -0,0 +1,40 @@ +! Test the implicit `declare target to` interaction with `target update from` +! REQUIRES: flang, amdgpu + +! RUN: %libomptarget-compile-fortran-run-and-check-generic +module test + implicit none + integer :: array(10) + !$omp declare target(array) +end module test + +PROGRAM main + use test + implicit none + integer :: i + + do i = 1, 10 + array(i) = 0 + end do + + !$omp target + do i = 1, 10 + array(i) = i + end do + !$omp end target + + !$omp target + do i = 1, 10 + array(i) = array(i) + i + end do + !$omp end target + + print *, array + + !$omp target update from(array) + + print *, array +END PROGRAM + +! CHECK: 0 0 0 0 0 0 0 0 0 0 +! CHECK: 2 4 6 8 10 12 14 16 18 20 diff --git a/offload/test/offloading/fortran/declare-target-to-zero-index-allocatable-target-map.f90 b/offload/test/offloading/fortran/declare-target-to-zero-index-allocatable-target-map.f90 new file mode 100644 index 0000000000000..0d650f6b44009 --- /dev/null +++ b/offload/test/offloading/fortran/declare-target-to-zero-index-allocatable-target-map.f90 @@ -0,0 +1,30 @@ +! Test `declare target to` interaction with an allocatable with a non-default +! range +! REQUIRES: flang, amdgpu + +! RUN: %libomptarget-compile-fortran-run-and-check-generic +module test_0 + real(4), allocatable :: zero_off(:) + !$omp declare target(zero_off) +end module test_0 + +program main + use test_0 + implicit none + + allocate(zero_off(0:10)) + + zero_off(0) = 30.0 + zero_off(1) = 40.0 + zero_off(10) = 25.0 + + !$omp target map(tofrom: zero_off) + zero_off(0) = zero_off(1) + !$omp end target + + print *, zero_off(0) + print *, zero_off(1) +end program + +! CHECK: 40. +! CHECK: 40.