@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
4033
4033
};
4034
4034
4035
4035
OpenMPIRBuilder::LocationDescription Loc ({Builder.saveIP (), DL});
4036
- Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB));
4036
+ Builder.restoreIP (OMPBuilder.createTeams (
4037
+ Builder, BodyGenCB, /* NumTeamsLower=*/ nullptr , /* NumTeamsUpper=*/ nullptr ,
4038
+ /* ThreadLimit=*/ nullptr , /* IfExpr=*/ nullptr ));
4037
4039
4038
4040
OMPBuilder.finalize ();
4039
4041
Builder.CreateRetVoid ();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
4095
4097
Builder.restoreIP (OMPBuilder.createTeams (/* =*/ Builder, BodyGenCB,
4096
4098
/* NumTeamsLower=*/ nullptr ,
4097
4099
/* NumTeamsUpper=*/ nullptr ,
4098
- /* ThreadLimit=*/ F->arg_begin ()));
4100
+ /* ThreadLimit=*/ F->arg_begin (),
4101
+ /* IfExpr=*/ nullptr ));
4099
4102
4100
4103
Builder.CreateRetVoid ();
4101
4104
OMPBuilder.finalize ();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
4144
4147
// `num_teams`
4145
4148
Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB,
4146
4149
/* NumTeamsLower=*/ nullptr ,
4147
- /* NumTeamsUpper=*/ F->arg_begin ()));
4150
+ /* NumTeamsUpper=*/ F->arg_begin (),
4151
+ /* ThreadLimit=*/ nullptr ,
4152
+ /* IfExpr=*/ nullptr ));
4148
4153
4149
4154
Builder.CreateRetVoid ();
4150
4155
OMPBuilder.finalize ();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
4197
4202
// `F` already has an integer argument, so we use that as upper bound to
4198
4203
// `num_teams`
4199
4204
Builder.restoreIP (
4200
- OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
4205
+ OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
4206
+ /* ThreadLimit=*/ nullptr , /* IfExpr=*/ nullptr ));
4201
4207
4202
4208
Builder.CreateRetVoid ();
4203
4209
OMPBuilder.finalize ();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
4255
4261
};
4256
4262
4257
4263
OpenMPIRBuilder::LocationDescription Loc ({Builder.saveIP (), DL});
4258
- Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower,
4259
- NumTeamsUpper, ThreadLimit));
4264
+ Builder.restoreIP (OMPBuilder.createTeams (
4265
+ Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr ));
4260
4266
4261
4267
Builder.CreateRetVoid ();
4262
4268
OMPBuilder.finalize ();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
4284
4290
OMPBuilder.getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_teams));
4285
4291
}
4286
4292
4293
+ TEST_F (OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
4294
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4295
+ OpenMPIRBuilder OMPBuilder (*M);
4296
+ OMPBuilder.initialize ();
4297
+ F->setName (" func" );
4298
+ IRBuilder<> &Builder = OMPBuilder.Builder ;
4299
+ Builder.SetInsertPoint (BB);
4300
+
4301
+ Value *IfExpr = Builder.CreateLoad (Builder.getInt1Ty (),
4302
+ Builder.CreateAlloca (Builder.getInt1Ty ()));
4303
+
4304
+ Function *FakeFunction =
4305
+ Function::Create (FunctionType::get (Builder.getVoidTy (), false ),
4306
+ GlobalValue::ExternalLinkage, " fakeFunction" , M.get ());
4307
+
4308
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4309
+ Builder.restoreIP (CodeGenIP);
4310
+ Builder.CreateCall (FakeFunction, {});
4311
+ };
4312
+
4313
+ // `F` already has an integer argument, so we use that as upper bound to
4314
+ // `num_teams`
4315
+ Builder.restoreIP (OMPBuilder.createTeams (
4316
+ Builder, BodyGenCB, /* NumTeamsLower=*/ nullptr , /* NumTeamsUpper=*/ nullptr ,
4317
+ /* ThreadLimit=*/ nullptr , IfExpr));
4318
+
4319
+ Builder.CreateRetVoid ();
4320
+ OMPBuilder.finalize ();
4321
+
4322
+ ASSERT_FALSE (verifyModule (*M));
4323
+
4324
+ CallInst *PushNumTeamsCallInst =
4325
+ findSingleCall (F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4326
+ ASSERT_NE (PushNumTeamsCallInst, nullptr );
4327
+ Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand (2 );
4328
+ Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand (3 );
4329
+ Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand (4 );
4330
+
4331
+ // Check the lower_bound
4332
+ ASSERT_NE (NumTeamsLower, nullptr );
4333
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
4334
+ ASSERT_NE (NumTeamsLowerSelectInst, nullptr );
4335
+ EXPECT_EQ (NumTeamsLowerSelectInst->getCondition (), IfExpr);
4336
+ EXPECT_EQ (NumTeamsLowerSelectInst->getTrueValue (), Builder.getInt32 (0 ));
4337
+ EXPECT_EQ (NumTeamsLowerSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4338
+
4339
+ // Check the upper_bound
4340
+ ASSERT_NE (NumTeamsUpper, nullptr );
4341
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
4342
+ ASSERT_NE (NumTeamsUpperSelectInst, nullptr );
4343
+ EXPECT_EQ (NumTeamsUpperSelectInst->getCondition (), IfExpr);
4344
+ EXPECT_EQ (NumTeamsUpperSelectInst->getTrueValue (), Builder.getInt32 (0 ));
4345
+ EXPECT_EQ (NumTeamsUpperSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4346
+
4347
+ // Check thread_limit
4348
+ EXPECT_EQ (ThreadLimit, Builder.getInt32 (0 ));
4349
+ }
4350
+
4351
+ TEST_F (OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
4352
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4353
+ OpenMPIRBuilder OMPBuilder (*M);
4354
+ OMPBuilder.initialize ();
4355
+ F->setName (" func" );
4356
+ IRBuilder<> &Builder = OMPBuilder.Builder ;
4357
+ Builder.SetInsertPoint (BB);
4358
+
4359
+ Value *IfExpr = Builder.CreateLoad (
4360
+ Builder.getInt32Ty (), Builder.CreateAlloca (Builder.getInt32Ty ()));
4361
+ Value *NumTeamsLower = Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (5 ));
4362
+ Value *NumTeamsUpper =
4363
+ Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (10 ));
4364
+ Value *ThreadLimit = Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (20 ));
4365
+
4366
+ Function *FakeFunction =
4367
+ Function::Create (FunctionType::get (Builder.getVoidTy (), false ),
4368
+ GlobalValue::ExternalLinkage, " fakeFunction" , M.get ());
4369
+
4370
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4371
+ Builder.restoreIP (CodeGenIP);
4372
+ Builder.CreateCall (FakeFunction, {});
4373
+ };
4374
+
4375
+ // `F` already has an integer argument, so we use that as upper bound to
4376
+ // `num_teams`
4377
+ Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower,
4378
+ NumTeamsUpper, ThreadLimit, IfExpr));
4379
+
4380
+ Builder.CreateRetVoid ();
4381
+ OMPBuilder.finalize ();
4382
+
4383
+ ASSERT_FALSE (verifyModule (*M));
4384
+
4385
+ CallInst *PushNumTeamsCallInst =
4386
+ findSingleCall (F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4387
+ ASSERT_NE (PushNumTeamsCallInst, nullptr );
4388
+ Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand (2 );
4389
+ Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand (3 );
4390
+ Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand (4 );
4391
+
4392
+ // Get the boolean conversion of if expression
4393
+ ASSERT_EQ (IfExpr->getNumUses (), 1U );
4394
+ User *IfExprInst = IfExpr->user_back ();
4395
+ ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
4396
+ ASSERT_NE (IfExprCmpInst, nullptr );
4397
+ EXPECT_EQ (IfExprCmpInst->getPredicate (), ICmpInst::Predicate::ICMP_NE);
4398
+ EXPECT_EQ (IfExprCmpInst->getOperand (0 ), IfExpr);
4399
+ EXPECT_EQ (IfExprCmpInst->getOperand (1 ), Builder.getInt32 (0 ));
4400
+
4401
+ // Check the lower_bound
4402
+ ASSERT_NE (NumTeamsLowerArg, nullptr );
4403
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
4404
+ ASSERT_NE (NumTeamsLowerSelectInst, nullptr );
4405
+ EXPECT_EQ (NumTeamsLowerSelectInst->getCondition (), IfExprCmpInst);
4406
+ EXPECT_EQ (NumTeamsLowerSelectInst->getTrueValue (), NumTeamsLower);
4407
+ EXPECT_EQ (NumTeamsLowerSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4408
+
4409
+ // Check the upper_bound
4410
+ ASSERT_NE (NumTeamsUpperArg, nullptr );
4411
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
4412
+ ASSERT_NE (NumTeamsUpperSelectInst, nullptr );
4413
+ EXPECT_EQ (NumTeamsUpperSelectInst->getCondition (), IfExprCmpInst);
4414
+ EXPECT_EQ (NumTeamsUpperSelectInst->getTrueValue (), NumTeamsUpper);
4415
+ EXPECT_EQ (NumTeamsUpperSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4416
+
4417
+ // Check thread_limit
4418
+ EXPECT_EQ (ThreadLimitArg, ThreadLimit);
4419
+ }
4420
+
4287
4421
// / Returns the single instruction of InstTy type in BB that uses the value V.
4288
4422
// / If there is more than one such instruction, returns null.
4289
4423
template <typename InstTy>
0 commit comments