From f33594a5782be73d1a456358e72ee34d9ac034c6 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Thu, 3 Oct 2024 11:26:06 +0100 Subject: [PATCH] [MLIR][OpenMP] Introduce host_eval clause to omp.target This patch defines a map-like clause named `host_eval` used to capture host values for use inside of target regions on restricted cases: - As `num_teams` or `thread_limit` of a nested `omp.target` operation. - As `num_threads` of a nested `omp.parallel` operation or as bounds or steps of a nested `omp.loop_nest`, if it is a target SPMD kernel. This replaces the following `omp.target` arguments: `trip_count`, `num_threads`, `num_teams_lower`, `num_teams_upper` and `teams_thread_limit`. --- mlir/docs/Dialects/OpenMPDialect/_index.md | 58 +++++++++++- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 38 ++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 31 ++----- .../Dialect/OpenMP/OpenMPOpsInterfaces.td | 27 +++++- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 93 ++++++++++++------- mlir/test/Dialect/OpenMP/invalid.mlir | 71 +++++++++++++- mlir/test/Dialect/OpenMP/ops.mlir | 38 +++++++- 7 files changed, 296 insertions(+), 60 deletions(-) diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md index 3d28fe7819129..e0dd3f598e84b 100644 --- a/mlir/docs/Dialects/OpenMPDialect/_index.md +++ b/mlir/docs/Dialects/OpenMPDialect/_index.md @@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the introduction of private copies of the same underlying variable defined outside the MLIR operation the clause is attached to. Currently, clauses with this property can be classified into three main categories: - - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`. + - Map-like clauses: `host_eval`, `map`, `use_device_addr` and +`use_device_ptr`. - Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`. - Privatization clauses: `private`. @@ -522,3 +523,58 @@ omp.parallel ... { omp.terminator } {omp.composite} ``` + +## Host-Evaluated Clauses in Target Regions + +The `omp.target` operation, which represents the OpenMP `target` construct, is +marked with the `IsolatedFromAbove` trait. This means that, inside of its +region, no MLIR values defined outside of the op itself can be used. This is +consistent with the OpenMP specification of the `target` construct, which +mandates that all host device values used inside of the `target` region must +either be privatized (data-sharing) or mapped (data-mapping). + +Normally, clauses applied to a construct are evaluated before entering that +construct. Further, in some cases, the OpenMP specification stipulates that +clauses be evaluated _on the host device_ on entry to a parent `target` +construct. In particular, the `num_teams` and `thread_limit` clauses of the +`teams` construct must be evaluated on the host device if it's nested inside or +combined with a `target` construct. + +Additionally, the runtime library targeted by the MLIR to LLVM IR translation of +the OpenMP dialect supports the optimized launch of SPMD kernels (i.e. +`target teams distribute parallel {do,for}` in OpenMP), which requires +specifying in advance what the total trip count of the loop is. Consequently, it +is also beneficial to evaluate the trip count on the host device prior to the +kernel launch. + +These host-evaluated values in MLIR would need to be placed outside of the +`omp.target` region and also attached to the corresponding nested operations, +which is not possible because of the `IsolatedFromAbove` trait. The solution +implemented to address this problem has been to introduce the `host_eval` +argument to the `omp.target` operation. It works similarly to a `map` clause, +but its only intended use is to forward host-evaluated values to their +corresponding operation inside of the region. Any uses outside of the previously +described result in a verifier error. + +```mlir +// Initialize %0, %1, %2, %3... +omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) { + omp.teams num_teams(to %nt : i32) { + omp.parallel { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + // ... + omp.yield + } + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator +} +``` diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 886554f66afff..ddcde74a363d4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -419,6 +419,44 @@ class OpenMP_HintClauseSkip< def OpenMP_HintClause : OpenMP_HintClauseSkip<>; +//===----------------------------------------------------------------------===// +// Not in the spec: Clause-like structure to hold host-evaluated values. +//===----------------------------------------------------------------------===// + +class OpenMP_HostEvalClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let traits = [ + BlockArgOpenMPOpInterface + ]; + + let arguments = (ins + Variadic:$host_eval_vars + ); + + let extraClassDeclaration = [{ + unsigned numHostEvalBlockArgs() { + return getHostEvalVars().size(); + } + }]; + + let description = [{ + The optional `host_eval_vars` holds values defined outside of the region of + the `IsolatedFromAbove` operation for which a corresponding entry block + argument is defined. The only legal uses for these captured values are the + following: + - `num_teams` or `thread_limit` clause of an immediately nested + `omp.teams` operation. + - If the operation is the top-level `omp.target` of a target SPMD kernel: + - `num_threads` clause of the nested `omp.parallel` operation. + - Bounds and steps of the nested `omp.loop_nest` operation. + }]; +} + +def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [3.4] `if` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 05a02254b1027..b8dd8cbdd2a79 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1116,20 +1116,16 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [ // 2.14.5 target construct //===----------------------------------------------------------------------===// -// TODO: Remove num_threads, teams_thread_limit and trip_count and implement the -// passthrough approach described here: -// https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106. def TargetOp : OpenMP_Op<"target", traits = [ AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove, OutlineableOpenMPOpInterface ], clauses = [ // TODO: Complete clause list (defaultmap, uses_allocators). OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause, - OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause, - OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip, - OpenMP_NowaitClause, OpenMP_NumTeamsClauseSkip, - OpenMP_NumThreadsClauseSkip, OpenMP_PrivateClause, - OpenMP_ThreadLimitClause + OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause, + OpenMP_InReductionClause, OpenMP_IsDevicePtrClause, + OpenMP_MapClauseSkip, OpenMP_NowaitClause, + OpenMP_PrivateClause, OpenMP_ThreadLimitClause ], singleRegion = true> { let summary = "target construct"; let description = [{ @@ -1156,10 +1152,6 @@ def TargetOp : OpenMP_Op<"target", traits = [ an `omp.parallel`. }] # clausesDescription; - let arguments = !con(clausesArgs, - (ins Optional:$trip_count, - Optional:$teams_thread_limit)); - let builders = [ OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)> ]; @@ -1184,15 +1176,12 @@ def TargetOp : OpenMP_Op<"target", traits = [ bool isTargetSPMDLoop(); }] # clausesExtraClassDeclaration; - let assemblyFormat = clausesReqAssemblyFormat # - " oilist(" # clausesOptAssemblyFormat # [{ - | `trip_count` `(` $trip_count `:` type($trip_count) `)` - | `teams_thread_limit` `(` $teams_thread_limit `:` type($teams_thread_limit) `)` - }] # ")" # [{ - custom( - $region, $in_reduction_vars, type($in_reduction_vars), - $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars), - $private_vars, type($private_vars), $private_syms) attr-dict + let assemblyFormat = clausesAssemblyFormat # [{ + custom( + $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars, + type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, + $map_vars, type($map_vars), $private_vars, type($private_vars), + $private_syms) attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 8b72689dc3fd8..b119d097780ff 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let methods = [ // Default-implemented methods to be overriden by the corresponding clauses. + InterfaceMethod<"Get number of block arguments defined by `host_eval`.", + "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{ + return 0; + }]>, InterfaceMethod<"Get number of block arguments defined by `in_reduction`.", "unsigned", "numInReductionBlockArgs", (ins), [{}], [{ return 0; @@ -55,9 +59,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { }]>, // Unified access methods for clause-associated entry block arguments. + InterfaceMethod<"Get start index of block arguments defined by `host_eval`.", + "unsigned", "getHostEvalBlockArgsStart", (ins), [{ + return 0; + }]>, InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.", "unsigned", "getInReductionBlockArgsStart", (ins), [{ - return 0; + auto iface = ::llvm::cast(*$_op); + return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs(); }]>, InterfaceMethod<"Get start index of block arguments defined by `map`.", "unsigned", "getMapBlockArgsStart", (ins), [{ @@ -91,6 +100,13 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs(); }]>, + InterfaceMethod<"Get block arguments defined by `host_eval`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getHostEvalBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs()); + }]>, InterfaceMethod<"Get block arguments defined by `in_reduction`.", "::llvm::MutableArrayRef<::mlir::BlockArgument>", "getInReductionBlockArgs", (ins), [{ @@ -147,10 +163,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let verify = [{ auto iface = ::llvm::cast($_op); - unsigned expectedArgs = iface.numInReductionBlockArgs() + - iface.numMapBlockArgs() + iface.numPrivateBlockArgs() + - iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() + - iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs(); + unsigned expectedArgs = iface.numHostEvalBlockArgs() + + iface.numInReductionBlockArgs() + iface.numMapBlockArgs() + + iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() + + iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() + + iface.numUseDevicePtrBlockArgs(); if ($_op->getRegion(0).getNumArguments() < expectedArgs) return $_op->emitOpError() << "expected at least " << expectedArgs << " entry block argument(s)"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index bd28c44ce1e2d..74205238fbbe5 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -501,6 +501,7 @@ struct ReductionParseArgs { : vars(vars), types(types), byref(byref), syms(syms) {} }; struct AllRegionParseArgs { + std::optional hostEvalArgs; std::optional inReductionArgs; std::optional mapArgs; std::optional privateArgs; @@ -627,6 +628,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args) { llvm::SmallVector entryBlockArgs; + if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval", + args.hostEvalArgs))) + return parser.emitError(parser.getCurrentLocation()) + << "invalid `host_eval` format"; + if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction", args.inReductionArgs))) return parser.emitError(parser.getCurrentLocation()) @@ -665,8 +671,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, return parser.parseRegion(region, entryBlockArgs); } -static ParseResult parseInReductionMapPrivateRegion( +static ParseResult parseHostEvalInReductionMapPrivateRegion( OpAsmParser &parser, Region ®ion, + SmallVectorImpl &hostEvalVars, + SmallVectorImpl &hostEvalTypes, SmallVectorImpl &inReductionVars, SmallVectorImpl &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, @@ -675,6 +683,7 @@ static ParseResult parseInReductionMapPrivateRegion( llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { AllRegionParseArgs args; + args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); @@ -788,6 +797,7 @@ struct ReductionPrintArgs { : vars(vars), types(types), byref(byref), syms(syms) {} }; struct AllRegionPrintArgs { + std::optional hostEvalArgs; std::optional inReductionArgs; std::optional mapArgs; std::optional privateArgs; @@ -866,6 +876,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, auto iface = llvm::cast(op); MLIRContext *ctx = op->getContext(); + printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(), + args.hostEvalArgs); printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(), args.inReductionArgs); printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(), @@ -886,12 +898,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, p.printRegion(region, /*printEntryBlockArgs=*/false); } -static void printInReductionMapPrivateRegion( - OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, +static void printHostEvalInReductionMapPrivateRegion( + OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars, + TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) { AllRegionPrintArgs args; + args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); @@ -969,6 +983,7 @@ static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); printBlockArgRegion(p, op, region, args); } + /// Verifies Reduction Clause static LogicalResult verifyReductionVarList(Operation *op, std::optional reductionSyms, @@ -1654,14 +1669,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, // inReductionByref, inReductionSyms. TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, - clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr, + clauses.device, clauses.hasDeviceAddrVars, + clauses.hostEvalVars, clauses.ifExpr, /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, - clauses.mapVars, clauses.nowait, /*num_teams_lower=*/nullptr, - /*num_teams_upper=*/nullptr, /*num_threads_var=*/nullptr, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.threadLimit, - /*trip_count=*/nullptr, /*teams_thread_limit=*/nullptr); + clauses.mapVars, clauses.nowait, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit); } /// Only allow OpenMP terminators and non-OpenMP ops that have known memory @@ -1710,18 +1723,44 @@ LogicalResult TargetOp::verify() { if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) return emitError("target containing multiple teams constructs"); - if (!isTargetSPMDLoop() && getTripCount()) - return emitError("trip_count set on non-SPMD target region"); + // Check that host_eval values are only used in legal ways. + bool isTargetSPMD = isTargetSPMDLoop(); + for (Value hostEvalArg : + cast(getOperation()).getHostEvalBlockArgs()) { + for (Operation *user : hostEvalArg.getUsers()) { + if (auto teamsOp = dyn_cast(user)) { + if (llvm::is_contained({teamsOp.getNumTeamsLower(), + teamsOp.getNumTeamsUpper(), + teamsOp.getThreadLimit()}, + hostEvalArg)) + continue; + + return emitOpError() << "host_eval argument only legal as 'num_teams' " + "and 'thread_limit' in 'omp.teams'"; + } + if (auto parallelOp = dyn_cast(user)) { + if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads()) + continue; - if (teamsOps.empty()) { - if (getNumTeamsLower() || getNumTeamsUpper() || getTeamsThreadLimit()) - return emitError( - "num_teams and teams_thread_limit arguments only allowed if there is " - "an omp.teams child operation"); - } else { - if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(), - getNumTeamsUpper()))) - return failure(); + return emitOpError() + << "host_eval argument only legal as 'num_threads' in " + "'omp.parallel' when representing target SPMD"; + } + if (auto loopNestOp = dyn_cast(user)) { + if (isTargetSPMD && + (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || + llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || + llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) + continue; + + return emitOpError() + << "host_eval argument only legal as loop bounds and steps in " + "'omp.loop_nest' when representing target SPMD"; + } + + return emitOpError() << "host_eval argument illegal use in '" + << user->getName() << "' operation"; + } } LogicalResult verifyDependVars = @@ -1954,17 +1993,9 @@ LogicalResult TeamsOp::verify() { return emitError("expected to be nested inside of omp.target or not nested " "in any OpenMP dialect operations"); - auto offloadModOp = - llvm::cast(*(*this)->getParentOfType()); - if (targetOp && !offloadModOp.getIsTargetDevice()) { - if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit()) - return emitError("num_teams and thread_limit arguments expected to be " - "attached to parent omp.target operation"); - } else { - if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(), - getNumTeamsUpper()))) - return failure(); - } + if (failed( + verifyNumTeamsClause(*this, getNumTeamsLower(), getNumTeamsUpper()))) + return failure(); // Check for allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 389ca4769ab21..c7c275a1af27e 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2132,11 +2132,80 @@ func.func @omp_target_update_data_depend(%a: memref) { // ----- +func.func @omp_target_multiple_teams() { + // expected-error @below {{target containing multiple teams constructs}} + omp.target { + omp.teams { + omp.terminator + } + omp.teams { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval1(%x : !llvm.ptr) { + // expected-error @below {{op host_eval argument illegal use in 'llvm.load' operation}} + omp.target host_eval(%x -> %arg0 : !llvm.ptr) { + %0 = llvm.load %arg0 : !llvm.ptr -> f32 + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval2(%x : i1) { + // expected-error @below {{op host_eval argument only legal as 'num_teams' and 'thread_limit' in 'omp.teams'}} + omp.target host_eval(%x -> %arg0 : i1) { + omp.teams if(%arg0) { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval3(%x : i32) { + // expected-error @below {{op host_eval argument only legal as 'num_threads' in 'omp.parallel' when representing target SPMD}} + omp.target host_eval(%x -> %arg0 : i32) { + omp.parallel num_threads(%arg0 : i32) { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval3(%x : i32) { + // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}} + omp.target host_eval(%x -> %arg0 : i32) { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) { + omp.yield + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- + func.func @omp_target_depend(%data_var: memref) { // expected-error @below {{op expected as many depend values as depend variables}} "omp.target"(%data_var) ({ "omp.terminator"() : () -> () - }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () + }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () "func.return"() : () -> () } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 69deebe11d934..3c1c590431b79 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -770,7 +770,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic "omp.target"(%device, %if_cond, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () + }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () // Test with optional map clause. // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref, tensor) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} @@ -2749,3 +2749,39 @@ func.func @omp_target_private(%map1: memref, %map2: memref, %priv_ return } + +func.func @omp_target_host_eval(%x : i32) { + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { + // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32) + // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32) + omp.target host_eval(%x -> %arg0 : i32) { + omp.teams num_teams(to %arg0 : i32) thread_limit(%arg0 : i32) { + omp.terminator + } + omp.terminator + } + + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { + // CHECK: omp.teams + // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) { + // CHECK: omp.distribute { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) { + omp.target host_eval(%x -> %arg0 : i32) { + omp.teams { + omp.parallel num_threads(%arg0 : i32) { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + return +}