diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 04899f1852b18..7e2b1f56ac6f7 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -3950,10 +3950,26 @@ let Visibility = [ClangOption, CC1Option, FC1Option, FlangOption] in { let Group = f_Group in { def fopenmp_target_debug_EQ : Joined<["-"], "fopenmp-target-debug=">; -def fopenmp_assume_teams_oversubscription : Flag<["-"], "fopenmp-assume-teams-oversubscription">; -def fopenmp_assume_threads_oversubscription : Flag<["-"], "fopenmp-assume-threads-oversubscription">; -def fno_openmp_assume_teams_oversubscription : Flag<["-"], "fno-openmp-assume-teams-oversubscription">; -def fno_openmp_assume_threads_oversubscription : Flag<["-"], "fno-openmp-assume-threads-oversubscription">; +def fopenmp_assume_teams_oversubscription : Flag<["-"], "fopenmp-assume-teams-oversubscription">, + HelpText<"Allow the optimizer to discretely increase the number of " + "teams. May cause ignore environment variables that set " + "the number of teams to be ignored. The combination of " + "-fopenmp-assume-teams-oversubscription " + "and -fopenmp-assume-threads-oversubscription " + "may allow the conversion of loops into sequential code by " + "ensuring that each team/thread executes at most one iteration.">; +def fopenmp_assume_threads_oversubscription : Flag<["-"], "fopenmp-assume-threads-oversubscription">, + HelpText<"Allow the optimizer to discretely increase the number of " + "threads. May cause ignore environment variables that set " + "the number of threads to be ignored. The combination of " + "-fopenmp-assume-teams-oversubscription " + "and -fopenmp-assume-threads-oversubscription " + "may allow the conversion of loops into sequential code by " + "ensuring that each team/thread executes at most one iteration.">; +def fno_openmp_assume_teams_oversubscription : Flag<["-"], "fno-openmp-assume-teams-oversubscription">, + HelpText<"Do not assume teams oversubscription.">; +def fno_openmp_assume_threads_oversubscription : Flag<["-"], "fno-openmp-assume-threads-oversubscription">, + HelpText<"Do not assume threads oversubscription.">; def fopenmp_assume_no_thread_state : Flag<["-"], "fopenmp-assume-no-thread-state">, HelpText<"Assert no thread in a parallel region modifies an ICV">, MarshallingInfoFlag>; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index f43ef932e965a..fbab7f3bb402c 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1085,11 +1085,13 @@ class OpenMPIRBuilder { /// preheader of the loop. /// \param LoopType Information about type of loop worksharing. /// It corresponds to type of loop workshare OpenMP pragma. + /// \param NoLoop If true, no-loop code is generated. /// /// \returns Point where to insert code after the workshare construct. InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP, - omp::WorksharingLoopType LoopType); + omp::WorksharingLoopType LoopType, + bool NoLoop); /// Modifies the canonical loop to be a statically-scheduled workshare loop. /// @@ -1209,6 +1211,7 @@ class OpenMPIRBuilder { /// present. /// \param LoopType Information about type of loop worksharing. /// It corresponds to type of loop workshare OpenMP pragma. + /// \param NoLoop If true, no-loop code is generated. /// /// \returns Point where to insert code after the workshare construct. LLVM_ABI InsertPointOrErrorTy applyWorkshareLoop( @@ -1219,7 +1222,8 @@ class OpenMPIRBuilder { bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false, bool HasOrderedClause = false, omp::WorksharingLoopType LoopType = - omp::WorksharingLoopType::ForStaticLoop); + omp::WorksharingLoopType::ForStaticLoop, + bool NoLoop = false); /// Tile a loop nest. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 9b67465faab0b..94c51531ee251 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4979,7 +4979,7 @@ static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType, BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg, Value *TripCount, - Function &LoopBodyFn) { + Function &LoopBodyFn, bool NoLoop) { Type *TripCountTy = TripCount->getType(); Module &M = OMPBuilder->M; IRBuilder<> &Builder = OMPBuilder->Builder; @@ -5007,8 +5007,10 @@ static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder, RealArgs.push_back(ConstantInt::get(TripCountTy, 0)); if (LoopType == WorksharingLoopType::DistributeForStaticLoop) { RealArgs.push_back(ConstantInt::get(TripCountTy, 0)); + RealArgs.push_back(ConstantInt::get(Builder.getInt8Ty(), NoLoop)); + } else { + RealArgs.push_back(ConstantInt::get(Builder.getInt8Ty(), 0)); } - RealArgs.push_back(ConstantInt::get(Builder.getInt8Ty(), 0)); Builder.CreateCall(RTLFn, RealArgs); } @@ -5016,7 +5018,7 @@ static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder, static void workshareLoopTargetCallback( OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident, Function &OutlinedFn, const SmallVector &ToBeDeleted, - WorksharingLoopType LoopType) { + WorksharingLoopType LoopType, bool NoLoop) { IRBuilder<> &Builder = OMPIRBuilder->Builder; BasicBlock *Preheader = CLI->getPreheader(); Value *TripCount = CLI->getTripCount(); @@ -5063,17 +5065,16 @@ static void workshareLoopTargetCallback( OutlinedFnCallInstruction->eraseFromParent(); createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident, - LoopBodyArg, TripCount, OutlinedFn); + LoopBodyArg, TripCount, OutlinedFn, NoLoop); for (auto &ToBeDeletedItem : ToBeDeleted) ToBeDeletedItem->eraseFromParent(); CLI->invalidate(); } -OpenMPIRBuilder::InsertPointTy -OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI, - InsertPointTy AllocaIP, - WorksharingLoopType LoopType) { +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget( + DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP, + WorksharingLoopType LoopType, bool NoLoop) { uint32_t SrcLocStrSize; Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize); Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); @@ -5156,7 +5157,7 @@ OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI, OI.PostOutlineCB = [=, ToBeDeletedVec = std::move(ToBeDeleted)](Function &OutlinedFn) { workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec, - LoopType); + LoopType, NoLoop); }; addOutlineInfo(std::move(OI)); return CLI->getAfterIP(); @@ -5167,9 +5168,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop( bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier, bool HasNonmonotonicModifier, bool HasOrderedClause, - WorksharingLoopType LoopType) { + WorksharingLoopType LoopType, bool NoLoop) { if (Config.isTargetDevice()) - return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType); + return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop); OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType( SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier, HasNonmonotonicModifier, HasOrderedClause); diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll index bf140c7fa216a..b1fe7b1b2b7ee 100644 --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -574,5 +574,164 @@ exit: ret void } +define void @test_ptr_aligned_by_2_and_4_via_assumption(ptr %start, ptr %end) { +; CHECK-LABEL: 'test_ptr_aligned_by_2_and_4_via_assumption' +; CHECK-NEXT: Classifying expressions for: @test_ptr_aligned_by_2_and_4_via_assumption +; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_ptr_aligned_by_2_and_4_via_assumption +; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 2) ] + call void @llvm.assume(i1 true) [ "align"(ptr %end, i64 4) ] + br label %loop + +loop: + %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] + store ptr %iv, ptr %iv + %iv.next = getelementptr i8, ptr %iv, i64 4 + %ec = icmp ne ptr %iv.next, %end + br i1 %ec, label %loop, label %exit + +exit: + ret void +} + +define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) { +; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption' +; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption +; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption +; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ] + call void @llvm.assume(i1 true) [ "align"(ptr %end, i64 4) ] + br label %loop + +loop: + %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] + store ptr %iv, ptr %iv + %iv.next = getelementptr i8, ptr %iv, i64 4 + %ec = icmp ne ptr %iv.next, %end + br i1 %ec, label %loop, label %exit + +exit: + ret void +} + +define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) { +; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption' +; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption +; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption +; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ] + call void @llvm.assume(i1 true) [ "align"(ptr %end, i64 8) ] + br label %loop + +loop: + %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] + store ptr %iv, ptr %iv + %iv.next = getelementptr i8, ptr %iv, i64 4 + %ec = icmp ne ptr %iv.next, %end + br i1 %ec, label %loop, label %exit + +exit: + ret void +} + +declare i1 @cond() + +define void @test_ptr_aligned_by_4_via_assumption_multiple_loop_predecessors(ptr %start, ptr %end) { +; CHECK-LABEL: 'test_ptr_aligned_by_4_via_assumption_multiple_loop_predecessors' +; CHECK-NEXT: Classifying expressions for: @test_ptr_aligned_by_4_via_assumption_multiple_loop_predecessors +; CHECK-NEXT: %c = call i1 @cond() +; CHECK-NEXT: --> %c U: full-set S: full-set +; CHECK-NEXT: %iv = phi ptr [ %start, %then ], [ %start, %else ], [ %iv.next, %loop ] +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_ptr_aligned_by_4_via_assumption_multiple_loop_predecessors +; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. +; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Predicates: +; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 2) ] + call void @llvm.assume(i1 true) [ "align"(ptr %end, i64 4) ] + %c = call i1 @cond() + br i1 %c, label %then, label %else + +then: + br label %loop + +else: + br label %loop + +loop: + %iv = phi ptr [ %start, %then] , [ %start, %else ], [ %iv.next, %loop ] + store ptr %iv, ptr %iv + %iv.next = getelementptr i8, ptr %iv, i64 4 + %ec = icmp ne ptr %iv.next, %end + br i1 %ec, label %loop, label %exit + +exit: + ret void +} + declare void @llvm.assume(i1) declare void @llvm.experimental.guard(i1, ...) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index 9dbe6897a3304..f693a0737e0fc 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -230,14 +230,24 @@ def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; +def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 3>; def TargetRegionFlags : OpenMP_BitEnumAttr< "TargetRegionFlags", - "target region property flags", [ + "These flags describe properties of the target kernel. " + "TargetRegionFlagsGeneric - denotes generic kernel. " + "TargetRegionFlagsSpmd - denotes SPMD kernel. " + "TargetRegionFlagsNoLoop - denotes kernel where " + "num_teams * num_threads >= loop_trip_count. It allows the conversion " + "of loops into sequential code by ensuring that each team/thread " + "executes at most one iteration. " + "TargetRegionFlagsTripCount - checks if the loop trip count should be " + "calculated.", [ TargetRegionFlagsNone, TargetRegionFlagsGeneric, TargetRegionFlagsSpmd, - TargetRegionFlagsTripCount + TargetRegionFlagsTripCount, + TargetRegionFlagsNoLoop ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 3d70e28ed23ab..f01ad05a778ec 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2111,6 +2111,31 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { }); } +/// Check if we can promote SPMD kernel to No-Loop kernel. +static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, + WsloopOp *wsLoopOp) { + // num_teams clause can break no-loop teams/threads assumption. + if (teamsOp.getNumTeamsUpper()) + return false; + + // Reduction kernels are slower in no-loop mode. + if (teamsOp.getNumReductionVars()) + return false; + if (wsLoopOp->getNumReductionVars()) + return false; + + // Check if the user allows the promotion of kernels to no-loop mode. + OffloadModuleInterface offloadMod = + capturedOp->getParentOfType(); + if (!offloadMod) + return false; + auto ompFlags = offloadMod.getFlags(); + if (!ompFlags) + return false; + return ompFlags.getAssumeTeamsOversubscription() && + ompFlags.getAssumeThreadsOversubscription(); +} + TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. @@ -2139,7 +2164,8 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { - if (!isa(innermostWrapper)) + WsloopOp *wsloopOp = dyn_cast(innermostWrapper); + if (!wsloopOp) return TargetRegionFlags::generic; innermostWrapper = std::next(innermostWrapper); @@ -2150,12 +2176,17 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { if (!isa_and_present(parallelOp)) return TargetRegionFlags::generic; - Operation *teamsOp = parallelOp->getParentOp(); - if (!isa_and_present(teamsOp)) + TeamsOp teamsOp = dyn_cast(parallelOp->getParentOp()); + if (!teamsOp) return TargetRegionFlags::generic; - if (teamsOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + if (teamsOp->getParentOp() == targetOp.getOperation()) { + TargetRegionFlags result = + TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp)) + result = result | TargetRegionFlags::no_loop; + return result; + } } // Detect target-teams-distribute[-simd] and target-teams-loop. else if (isa(innermostWrapper)) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 4921a1990b6e8..bb1f3d0d6c4ad 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2591,13 +2591,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, } builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + + // Check if we can generate no-loop kernel + bool noLoopMode = false; + omp::TargetOp targetOp = wsloopOp->getParentOfType(); + if (targetOp) { + Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp(); + // We need this check because, without it, noLoopMode would be set to true + // for every omp.wsloop nested inside a no-loop SPMD target region, even if + // that loop is not the top-level SPMD one. + if (loopOp == targetCapturedOp) { + omp::TargetRegionFlags kernelFlags = + targetOp.getKernelExecFlags(targetCapturedOp); + if (omp::bitEnumContainsAll(kernelFlags, + omp::TargetRegionFlags::spmd | + omp::TargetRegionFlags::no_loop) && + !omp::bitEnumContainsAny(kernelFlags, + omp::TargetRegionFlags::generic)) + noLoopMode = true; + } + } + llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = ompBuilder->applyWorkshareLoop( ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, convertToScheduleKind(schedule), chunk, isSimd, scheduleMod == omp::ScheduleModifier::monotonic, scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); + workshareLoopType, noLoopMode); if (failed(handleError(wsloopIP, opInst))) return failure(); @@ -5425,6 +5446,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + if (omp::bitEnumContainsAll(kernelFlags, + omp::TargetRegionFlags::spmd | + omp::TargetRegionFlags::no_loop) && + !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)) + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP; + attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; diff --git a/offload/test/offloading/fortran/target-no-loop.f90 b/offload/test/offloading/fortran/target-no-loop.f90 new file mode 100644 index 0000000000000..8e40e20e73e70 --- /dev/null +++ b/offload/test/offloading/fortran/target-no-loop.f90 @@ -0,0 +1,96 @@ +! REQUIRES: flang + +! RUN: %libomptarget-compile-fortran-generic -O3 -fopenmp-assume-threads-oversubscription -fopenmp-assume-teams-oversubscription +! RUN: env LIBOMPTARGET_INFO=16 OMP_NUM_TEAMS=16 OMP_TEAMS_THREAD_LIMIT=16 %libomptarget-run-generic 2>&1 | %fcheck-generic +function check_errors(array) result (errors) + integer, intent(in) :: array(1024) + integer :: errors + integer :: i + errors = 0 + do i = 1, 1024 + if ( array( i) .ne. (i) ) then + errors = errors + 1 + end if + end do +end function + +program main + use omp_lib + implicit none + integer :: i,j,red + integer :: array(1024), errors = 0 + array = 1 + + ! No-loop kernel + !$omp target teams distribute parallel do + do i = 1, 1024 + array(i) = i + end do + errors = errors + check_errors(array) + + ! SPMD kernel (num_teams clause blocks promotion to no-loop) + array = 1 + !$omp target teams distribute parallel do num_teams(3) + do i = 1, 1024 + array(i) = i + end do + + errors = errors + check_errors(array) + + ! No-loop kernel + array = 1 + !$omp target teams distribute parallel do num_threads(64) + do i = 1, 1024 + array(i) = i + end do + + errors = errors + check_errors(array) + + ! SPMD kernel + array = 1 + !$omp target parallel do + do i = 1, 1024 + array(i) = i + end do + + errors = errors + check_errors(array) + + ! Generic kernel + array = 1 + !$omp target teams distribute + do i = 1, 1024 + array(i) = i + end do + + errors = errors + check_errors(array) + + ! SPMD kernel (reduction clause blocks promotion to no-loop) + array = 1 + red =0 + !$omp target teams distribute parallel do reduction(+:red) + do i = 1, 1024 + red = red + array(i) + end do + + if (red .ne. 1024) then + errors = errors + 1 + end if + + print *,"number of errors: ", errors + +end program main + +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} SPMD-No-Loop mode +! CHECK: info: #Args: 3 Teams x Thrds: 64x 16 +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} SPMD mode +! CHECK: info: #Args: 3 Teams x Thrds: 3x 16 {{.*}} +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} SPMD-No-Loop mode +! CHECK: info: #Args: 3 Teams x Thrds: 64x 16 {{.*}} +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} SPMD mode +! CHECK: info: #Args: 3 Teams x Thrds: 1x 16 +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} Generic mode +! CHECK: info: #Args: 3 Teams x Thrds: 16x 16 {{.*}} +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} SPMD mode +! CHECK: info: #Args: 4 Teams x Thrds: 16x 16 {{.*}} +! CHECK: number of errors: 0 + diff --git a/openmp/device/src/Workshare.cpp b/openmp/device/src/Workshare.cpp index 59a2cc3f27aca..653104ce883d1 100644 --- a/openmp/device/src/Workshare.cpp +++ b/openmp/device/src/Workshare.cpp @@ -800,10 +800,6 @@ template class StaticLoopChunker { // If we know we have more threads than iterations we can indicate that to // avoid an outer loop. - if (config::getAssumeThreadsOversubscription()) { - OneIterationPerThread = true; - } - if (OneIterationPerThread) ASSERT(NumThreads >= NumIters, "Broken assumption"); @@ -851,10 +847,6 @@ template class StaticLoopChunker { // If we know we have more blocks than iterations we can indicate that to // avoid an outer loop. - if (config::getAssumeTeamsOversubscription()) { - OneIterationPerThread = true; - } - if (OneIterationPerThread) ASSERT(NumBlocks >= NumIters, "Broken assumption"); @@ -914,11 +906,6 @@ template class StaticLoopChunker { // If we know we have more threads (across all blocks) than iterations we // can indicate that to avoid an outer loop. - if (config::getAssumeTeamsOversubscription() & - config::getAssumeThreadsOversubscription()) { - OneIterationPerThread = true; - } - if (OneIterationPerThread) ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");