Skip to content

Commit 9922aad

Browse files
authored
[OpenMPIRBuilder] Added if clause for teams (llvm#69139)
This patch adds support for the `if` clause on `teams` construct. The value of the argument must be an integer value. If the value evaluates to true (non-zero) integer, then the number of threads is determined by `num_threads` clause (or default and ICV if `num_threads` is absent). When the condition evaluates to false (zero), then the bounds are set to 1. ([OpenMP 5.2 Section 10.2](https://www.openmp.org/spec-html/5.2/openmpse58.html)) This essentially means that ``` upperbound = ifexpr ? upperbound : 1 lowerbound = ifexpr ? lowerbound : 1 ```
1 parent 122064a commit 9922aad

File tree

3 files changed

+165
-13
lines changed

3 files changed

+165
-13
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,11 +1923,12 @@ class OpenMPIRBuilder {
19231923
/// \param NumTeamsUpper Upper bound on the number of teams.
19241924
/// \param ThreadLimit on the number of threads that may participate in a
19251925
/// contention group created by each team.
1926-
InsertPointTy createTeams(const LocationDescription &Loc,
1927-
BodyGenCallbackTy BodyGenCB,
1928-
Value *NumTeamsLower = nullptr,
1929-
Value *NumTeamsUpper = nullptr,
1930-
Value *ThreadLimit = nullptr);
1926+
/// \param IfExpr is the integer argument value of the if condition on the
1927+
/// teams clause.
1928+
InsertPointTy
1929+
createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
1930+
Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr,
1931+
Value *ThreadLimit = nullptr, Value *IfExpr = nullptr);
19311932

19321933
/// Generate conditional branch and relevant BasicBlocks through which private
19331934
/// threads copy the 'copyin' variables from Master copy to threadprivate

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
57345734
OpenMPIRBuilder::InsertPointTy
57355735
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
57365736
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
5737-
Value *NumTeamsUpper, Value *ThreadLimit) {
5737+
Value *NumTeamsUpper, Value *ThreadLimit,
5738+
Value *IfExpr) {
57385739
if (!updateToLocation(Loc))
57395740
return InsertPointTy();
57405741

@@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
57735774
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
57745775

57755776
// Push num_teams
5776-
if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
5777+
if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
57775778
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
57785779
"if lowerbound is non-null, then upperbound must also be non-null "
57795780
"for bounds on num_teams");
@@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
57845785
if (NumTeamsLower == nullptr)
57855786
NumTeamsLower = NumTeamsUpper;
57865787

5788+
if (IfExpr) {
5789+
assert(IfExpr->getType()->isIntegerTy() &&
5790+
"argument to if clause must be an integer value");
5791+
5792+
// upper = ifexpr ? upper : 1
5793+
if (IfExpr->getType() != Int1)
5794+
IfExpr = Builder.CreateICmpNE(IfExpr,
5795+
ConstantInt::get(IfExpr->getType(), 0));
5796+
NumTeamsUpper = Builder.CreateSelect(
5797+
IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
5798+
5799+
// lower = ifexpr ? lower : 1
5800+
NumTeamsLower = Builder.CreateSelect(
5801+
IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
5802+
}
5803+
57875804
if (ThreadLimit == nullptr)
57885805
ThreadLimit = Builder.getInt32(0);
57895806

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
40334033
};
40344034

40354035
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4036-
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
4036+
Builder.restoreIP(OMPBuilder.createTeams(
4037+
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
4038+
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
40374039

40384040
OMPBuilder.finalize();
40394041
Builder.CreateRetVoid();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
40954097
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
40964098
/*NumTeamsLower=*/nullptr,
40974099
/*NumTeamsUpper=*/nullptr,
4098-
/*ThreadLimit=*/F->arg_begin()));
4100+
/*ThreadLimit=*/F->arg_begin(),
4101+
/*IfExpr=*/nullptr));
40994102

41004103
Builder.CreateRetVoid();
41014104
OMPBuilder.finalize();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
41444147
// `num_teams`
41454148
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
41464149
/*NumTeamsLower=*/nullptr,
4147-
/*NumTeamsUpper=*/F->arg_begin()));
4150+
/*NumTeamsUpper=*/F->arg_begin(),
4151+
/*ThreadLimit=*/nullptr,
4152+
/*IfExpr=*/nullptr));
41484153

41494154
Builder.CreateRetVoid();
41504155
OMPBuilder.finalize();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
41974202
// `F` already has an integer argument, so we use that as upper bound to
41984203
// `num_teams`
41994204
Builder.restoreIP(
4200-
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
4205+
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
4206+
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
42014207

42024208
Builder.CreateRetVoid();
42034209
OMPBuilder.finalize();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
42554261
};
42564262

42574263
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4258-
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
4259-
NumTeamsUpper, ThreadLimit));
4264+
Builder.restoreIP(OMPBuilder.createTeams(
4265+
Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));
42604266

42614267
Builder.CreateRetVoid();
42624268
OMPBuilder.finalize();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
42844290
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
42854291
}
42864292

4293+
TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
4294+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4295+
OpenMPIRBuilder OMPBuilder(*M);
4296+
OMPBuilder.initialize();
4297+
F->setName("func");
4298+
IRBuilder<> &Builder = OMPBuilder.Builder;
4299+
Builder.SetInsertPoint(BB);
4300+
4301+
Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
4302+
Builder.CreateAlloca(Builder.getInt1Ty()));
4303+
4304+
Function *FakeFunction =
4305+
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4306+
GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4307+
4308+
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4309+
Builder.restoreIP(CodeGenIP);
4310+
Builder.CreateCall(FakeFunction, {});
4311+
};
4312+
4313+
// `F` already has an integer argument, so we use that as upper bound to
4314+
// `num_teams`
4315+
Builder.restoreIP(OMPBuilder.createTeams(
4316+
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
4317+
/*ThreadLimit=*/nullptr, IfExpr));
4318+
4319+
Builder.CreateRetVoid();
4320+
OMPBuilder.finalize();
4321+
4322+
ASSERT_FALSE(verifyModule(*M));
4323+
4324+
CallInst *PushNumTeamsCallInst =
4325+
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4326+
ASSERT_NE(PushNumTeamsCallInst, nullptr);
4327+
Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
4328+
Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
4329+
Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);
4330+
4331+
// Check the lower_bound
4332+
ASSERT_NE(NumTeamsLower, nullptr);
4333+
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
4334+
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
4335+
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
4336+
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
4337+
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
4338+
4339+
// Check the upper_bound
4340+
ASSERT_NE(NumTeamsUpper, nullptr);
4341+
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
4342+
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
4343+
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
4344+
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
4345+
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
4346+
4347+
// Check thread_limit
4348+
EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
4349+
}
4350+
4351+
TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
4352+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4353+
OpenMPIRBuilder OMPBuilder(*M);
4354+
OMPBuilder.initialize();
4355+
F->setName("func");
4356+
IRBuilder<> &Builder = OMPBuilder.Builder;
4357+
Builder.SetInsertPoint(BB);
4358+
4359+
Value *IfExpr = Builder.CreateLoad(
4360+
Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
4361+
Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
4362+
Value *NumTeamsUpper =
4363+
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
4364+
Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));
4365+
4366+
Function *FakeFunction =
4367+
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4368+
GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4369+
4370+
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4371+
Builder.restoreIP(CodeGenIP);
4372+
Builder.CreateCall(FakeFunction, {});
4373+
};
4374+
4375+
// `F` already has an integer argument, so we use that as upper bound to
4376+
// `num_teams`
4377+
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
4378+
NumTeamsUpper, ThreadLimit, IfExpr));
4379+
4380+
Builder.CreateRetVoid();
4381+
OMPBuilder.finalize();
4382+
4383+
ASSERT_FALSE(verifyModule(*M));
4384+
4385+
CallInst *PushNumTeamsCallInst =
4386+
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4387+
ASSERT_NE(PushNumTeamsCallInst, nullptr);
4388+
Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
4389+
Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
4390+
Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);
4391+
4392+
// Get the boolean conversion of if expression
4393+
ASSERT_EQ(IfExpr->getNumUses(), 1U);
4394+
User *IfExprInst = IfExpr->user_back();
4395+
ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
4396+
ASSERT_NE(IfExprCmpInst, nullptr);
4397+
EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
4398+
EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
4399+
EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));
4400+
4401+
// Check the lower_bound
4402+
ASSERT_NE(NumTeamsLowerArg, nullptr);
4403+
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
4404+
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
4405+
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
4406+
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
4407+
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
4408+
4409+
// Check the upper_bound
4410+
ASSERT_NE(NumTeamsUpperArg, nullptr);
4411+
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
4412+
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
4413+
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
4414+
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
4415+
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
4416+
4417+
// Check thread_limit
4418+
EXPECT_EQ(ThreadLimitArg, ThreadLimit);
4419+
}
4420+
42874421
/// Returns the single instruction of InstTy type in BB that uses the value V.
42884422
/// If there is more than one such instruction, returns null.
42894423
template <typename InstTy>

0 commit comments

Comments
 (0)