Skip to content

Commit b62afbc

Browse files
authored
[mlir][OpenMP] Add __atomic_store to AtomicInfo (llvm#121055)
This PR adds functionality for `__atomic_store` libcall in AtomicInfo. This allows for supporting complex types in `atomic write`. Fixes llvm#113479 Fixes llvm#115652
1 parent 6ffccea commit b62afbc

File tree

7 files changed

+108
-6
lines changed

7 files changed

+108
-6
lines changed

llvm/include/llvm/Frontend/Atomic/Atomic.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class AtomicInfo {
9797
bool IsVolatile, bool IsWeak);
9898

9999
std::pair<LoadInst *, AllocaInst *> EmitAtomicLoadLibcall(AtomicOrdering AO);
100+
101+
void EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source);
100102
};
101103
} // end namespace llvm
102104

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3285,11 +3285,12 @@ class OpenMPIRBuilder {
32853285
/// \param Expr The value to store.
32863286
/// \param AO Atomic ordering of the generated atomic
32873287
/// instructions.
3288+
/// \param AllocaIP Insert point for allocas
32883289
///
32893290
/// \return Insertion point after generated atomic Write IR.
32903291
InsertPointTy createAtomicWrite(const LocationDescription &Loc,
32913292
AtomicOpValue &X, Value *Expr,
3292-
AtomicOrdering AO);
3293+
AtomicOrdering AO, InsertPointTy AllocaIP);
32933294

32943295
/// Emit atomic update for constructs: X = X BinOp Expr ,or X = Expr BinOp X
32953296
/// For complex Operations: X = UpdateOp(X) => CmpExch X, old_X, UpdateOp(X)

llvm/lib/Frontend/Atomic/Atomic.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,42 @@ AtomicInfo::EmitAtomicLoadLibcall(AtomicOrdering AO) {
145145
AllocaResult);
146146
}
147147

148+
void AtomicInfo::EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source) {
149+
LLVMContext &Ctx = getLLVMContext();
150+
SmallVector<Value *, 6> Args;
151+
AttributeList Attr;
152+
Module *M = Builder->GetInsertBlock()->getModule();
153+
const DataLayout &DL = M->getDataLayout();
154+
Args.push_back(
155+
ConstantInt::get(DL.getIntPtrType(Ctx), this->getAtomicSizeInBits() / 8));
156+
157+
Value *PtrVal = getAtomicPointer();
158+
PtrVal = Builder->CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
159+
Args.push_back(PtrVal);
160+
161+
auto CurrentIP = Builder->saveIP();
162+
Builder->restoreIP(AllocaIP);
163+
Value *SourceAlloca = Builder->CreateAlloca(Source->getType());
164+
Builder->restoreIP(CurrentIP);
165+
Builder->CreateStore(Source, SourceAlloca);
166+
SourceAlloca = Builder->CreatePointerBitCastOrAddrSpaceCast(
167+
SourceAlloca, Builder->getPtrTy());
168+
Args.push_back(SourceAlloca);
169+
170+
Constant *OrderingVal =
171+
ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(AO));
172+
Args.push_back(OrderingVal);
173+
174+
SmallVector<Type *, 6> ArgTys;
175+
for (Value *Arg : Args)
176+
ArgTys.push_back(Arg->getType());
177+
FunctionType *FnType = FunctionType::get(Type::getVoidTy(Ctx), ArgTys, false);
178+
FunctionCallee LibcallFn =
179+
M->getOrInsertFunction("__atomic_store", FnType, Attr);
180+
CallInst *Call = Builder->CreateCall(LibcallFn, Args);
181+
Call->setAttributes(Attr);
182+
}
183+
148184
std::pair<Value *, Value *> AtomicInfo::EmitAtomicCompareExchange(
149185
Value *ExpectedVal, Value *DesiredVal, AtomicOrdering Success,
150186
AtomicOrdering Failure, bool IsVolatile, bool IsWeak) {

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8684,20 +8684,30 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
86848684
OpenMPIRBuilder::InsertPointTy
86858685
OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
86868686
AtomicOpValue &X, Value *Expr,
8687-
AtomicOrdering AO) {
8687+
AtomicOrdering AO, InsertPointTy AllocaIP) {
86888688
if (!updateToLocation(Loc))
86898689
return Loc.IP;
86908690

86918691
assert(X.Var->getType()->isPointerTy() &&
86928692
"OMP Atomic expects a pointer to target memory");
86938693
Type *XElemTy = X.ElemTy;
86948694
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8695-
XElemTy->isPointerTy()) &&
8695+
XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
86968696
"OMP atomic write expected a scalar type");
86978697

86988698
if (XElemTy->isIntegerTy()) {
86998699
StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
87008700
XSt->setAtomic(AO);
8701+
} else if (XElemTy->isStructTy()) {
8702+
LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
8703+
const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8704+
unsigned LoadSize =
8705+
LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
8706+
OpenMPIRBuilder::AtomicInfo atomicInfo(
8707+
&Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8708+
OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8709+
atomicInfo.EmitAtomicStoreLibcall(AO, Expr);
8710+
OldVal->eraseFromParent();
87018711
} else {
87028712
// We need to bitcast and perform atomic op as integers
87038713
IntegerType *IntCastTy =

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3875,6 +3875,9 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteFlt) {
38753875
IRBuilder<> Builder(BB);
38763876

38773877
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3878+
BasicBlock *EntryBB = BB;
3879+
OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
3880+
EntryBB->getFirstInsertionPt());
38783881

38793882
LLVMContext &Ctx = M->getContext();
38803883
Type *Float32 = Type::getFloatTy(Ctx);
@@ -3884,7 +3887,8 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteFlt) {
38843887
AtomicOrdering AO = AtomicOrdering::Monotonic;
38853888
Constant *ValToWrite = ConstantFP::get(Float32, 1.0);
38863889

3887-
Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
3890+
Builder.restoreIP(
3891+
OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO, AllocaIP));
38883892

38893893
IntegerType *IntCastTy =
38903894
IntegerType::get(M->getContext(), Float32->getScalarSizeInBits());
@@ -3918,8 +3922,11 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteInt) {
39183922
ConstantInt *ValToWrite = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
39193923

39203924
BasicBlock *EntryBB = BB;
3925+
OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
3926+
EntryBB->getFirstInsertionPt());
39213927

3922-
Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
3928+
Builder.restoreIP(
3929+
OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO, AllocaIP));
39233930

39243931
StoreInst *StoreofAtomic = nullptr;
39253932

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2808,6 +2808,8 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
28082808
return failure();
28092809

28102810
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2811+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2812+
findAllocaInsertPoint(builder, moduleTranslation);
28112813

28122814
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
28132815
llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
@@ -2816,7 +2818,8 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
28162818
llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
28172819
llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
28182820
/*isVolatile=*/false};
2819-
builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
2821+
builder.restoreIP(
2822+
ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
28202823
return success();
28212824
}
28222825

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,49 @@ llvm.func @omp_atomic_update(%x:!llvm.ptr, %expr: i32, %xbool: !llvm.ptr, %exprb
14811481

14821482
// -----
14831483

1484+
// CHECK-LABEL: @omp_atomic_write
1485+
llvm.func @omp_atomic_write() {
1486+
// CHECK: %[[ALLOCA0:.*]] = alloca { float, float }, align 8
1487+
// CHECK: %[[ALLOCA1:.*]] = alloca { float, float }, align 8
1488+
// CHECK: %[[X:.*]] = alloca float, i64 1, align 4
1489+
// CHECK: %[[R1:.*]] = alloca float, i64 1, align 4
1490+
// CHECK: %[[ALLOCA:.*]] = alloca { float, float }, i64 1, align 8
1491+
// CHECK: %[[LOAD:.*]] = load float, ptr %[[R1]], align 4
1492+
// CHECK: %[[IDX1:.*]] = insertvalue { float, float } undef, float %[[LOAD]], 0
1493+
// CHECK: %[[IDX2:.*]] = insertvalue { float, float } %[[IDX1]], float 0.000000e+00, 1
1494+
// CHECK: br label %entry
1495+
1496+
// CHECK: entry:
1497+
// CHECK: store { float, float } %[[IDX2]], ptr %[[ALLOCA1]], align 4
1498+
// CHECK: call void @__atomic_store(i64 8, ptr %[[ALLOCA]], ptr %[[ALLOCA1]], i32 0)
1499+
// CHECK: store { float, float } { float 1.000000e+00, float 1.000000e+00 }, ptr %[[ALLOCA0]], align 4
1500+
// CHECK: call void @__atomic_store(i64 8, ptr %[[ALLOCA]], ptr %[[ALLOCA0]], i32 0)
1501+
1502+
%0 = llvm.mlir.constant(1 : i64) : i64
1503+
%1 = llvm.alloca %0 x f32 {bindc_name = "x"} : (i64) -> !llvm.ptr
1504+
%2 = llvm.mlir.constant(1 : i64) : i64
1505+
%3 = llvm.alloca %2 x f32 {bindc_name = "r1"} : (i64) -> !llvm.ptr
1506+
%4 = llvm.mlir.constant(1 : i64) : i64
1507+
%5 = llvm.alloca %4 x !llvm.struct<(f32, f32)> {bindc_name = "c1"} : (i64) -> !llvm.ptr
1508+
%6 = llvm.mlir.constant(1.000000e+00 : f32) : f32
1509+
%7 = llvm.mlir.constant(0.000000e+00 : f32) : f32
1510+
%8 = llvm.mlir.constant(1 : i64) : i64
1511+
%9 = llvm.mlir.constant(1 : i64) : i64
1512+
%10 = llvm.mlir.constant(1 : i64) : i64
1513+
%11 = llvm.load %3 : !llvm.ptr -> f32
1514+
%12 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
1515+
%13 = llvm.insertvalue %11, %12[0] : !llvm.struct<(f32, f32)>
1516+
%14 = llvm.insertvalue %7, %13[1] : !llvm.struct<(f32, f32)>
1517+
omp.atomic.write %5 = %14 : !llvm.ptr, !llvm.struct<(f32, f32)>
1518+
%15 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
1519+
%16 = llvm.insertvalue %6, %15[0] : !llvm.struct<(f32, f32)>
1520+
%17 = llvm.insertvalue %6, %16[1] : !llvm.struct<(f32, f32)>
1521+
omp.atomic.write %5 = %17 : !llvm.ptr, !llvm.struct<(f32, f32)>
1522+
llvm.return
1523+
}
1524+
1525+
// -----
1526+
14841527
//CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
14851528
//CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
14861529
//CHECK: {{.*}} = alloca { float, float }, i64 1, align 8

0 commit comments

Comments
 (0)