Skip to content

Commit a168ddc

Browse files
[MLIR][LLVM] Block address support (llvm#134335)
Add support for import and translate. MLIR does not support using basic block references outside a function (like LLVM does), This PR does not consider changes to MLIR to that respect. It instead introduces two new ops: `llvm.blockaddress` and `llvm.blocktag`. Here's an example: ``` llvm.func @ba() -> !llvm.ptr { %0 = llvm.blockaddress <function = @ba, tag = <id = 1>> : !llvm.ptr llvm.br ^bb1 ^bb1: // pred: ^bb0 llvm.blocktag <id = 1> llvm.return %0 : !llvm.ptr } ``` Value `%0` hold the address of block tagged as `id = 1` in function `@ba`. Block tags need to be unique within a function and use of `llvm.blockaddress` requires a matching tag in a `llvm.blocktag`.
1 parent 1a99284 commit a168ddc

File tree

16 files changed

+515
-31
lines changed

16 files changed

+515
-31
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,25 @@ def LLVM_DSOLocalEquivalentAttr : LLVM_Attr<"DSOLocalEquivalent",
12241224
let assemblyFormat = "$sym";
12251225
}
12261226

1227+
//===----------------------------------------------------------------------===//
1228+
// BlockAddressAttr
1229+
//===----------------------------------------------------------------------===//
1230+
1231+
def LLVM_BlockTagAttr : LLVM_Attr<"BlockTag", "blocktag"> {
1232+
let parameters = (ins "uint32_t":$id);
1233+
let assemblyFormat = "`<` struct(params) `>`";
1234+
}
1235+
1236+
/// Folded into from LLVM_BlockAddressAttr.
1237+
def LLVM_BlockAddressAttr : LLVM_Attr<"BlockAddress", "blockaddress"> {
1238+
let description = [{
1239+
Describes a block address identified by a pair of `$function` and `$tag`.
1240+
}];
1241+
let parameters = (ins "FlatSymbolRefAttr":$function,
1242+
"BlockTagAttr":$tag);
1243+
let assemblyFormat = "`<` struct(params) `>`";
1244+
}
1245+
12271246
//===----------------------------------------------------------------------===//
12281247
// VecTypeHintAttr
12291248
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,84 @@ def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
16251625
let hasFolder = 1;
16261626
}
16271627

1628+
//===----------------------------------------------------------------------===//
1629+
// BlockAddressOp & BlockTagOp
1630+
//===----------------------------------------------------------------------===//
1631+
1632+
def LLVM_BlockAddressOp : LLVM_Op<"blockaddress",
1633+
[Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
1634+
let arguments = (ins LLVM_BlockAddressAttr:$block_addr);
1635+
let results = (outs LLVM_AnyPointer:$res);
1636+
1637+
let summary = "Creates a LLVM blockaddress ptr";
1638+
1639+
let description = [{
1640+
Creates an SSA value containing a pointer to a basic block. The block
1641+
address information (function and block) is given by the `BlockAddressAttr`
1642+
attribute. This operation assumes an existing `llvm.blocktag` operation
1643+
identifying an existing MLIR block within a function. Example:
1644+
1645+
```mlir
1646+
llvm.mlir.global private @g() : !llvm.ptr {
1647+
%0 = llvm.blockaddress <function = @fn, tag = <id = 0>> : !llvm.ptr
1648+
llvm.return %0 : !llvm.ptr
1649+
}
1650+
1651+
llvm.func @fn() {
1652+
llvm.br ^bb1
1653+
^bb1: // pred: ^bb0
1654+
llvm.blocktag <id = 0>
1655+
llvm.return
1656+
}
1657+
```
1658+
}];
1659+
1660+
let assemblyFormat = [{
1661+
$block_addr
1662+
attr-dict `:` qualified(type($res))
1663+
}];
1664+
1665+
let extraClassDeclaration = [{
1666+
/// Return the llvm.func operation that is referenced here.
1667+
LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);
1668+
1669+
/// Search for the matching `llvm.blocktag` operation. This is performed
1670+
/// by walking the function in `block_addr`.
1671+
BlockTagOp getBlockTagOp();
1672+
}];
1673+
1674+
let hasVerifier = 1;
1675+
let hasFolder = 1;
1676+
}
1677+
1678+
def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
1679+
let description = [{
1680+
This operation uses a `tag` to uniquely identify an MLIR block in a
1681+
function. The same tag is used by `llvm.blockaddress` in order to compute
1682+
the target address.
1683+
1684+
A given function should have at most one `llvm.blocktag` operation with a
1685+
given `tag`. This operation cannot be used as a terminator but can be
1686+
placed everywhere else in a block.
1687+
1688+
Example:
1689+
1690+
```mlir
1691+
llvm.func @f() -> !llvm.ptr {
1692+
%addr = llvm.blockaddress <function = @f, tag = <id = 1>> : !llvm.ptr
1693+
llvm.br ^bb1
1694+
^bb1:
1695+
llvm.blocktag <id = 1>
1696+
llvm.return %addr : !llvm.ptr
1697+
}
1698+
```
1699+
}];
1700+
let arguments = (ins LLVM_BlockTagAttr:$tag);
1701+
let assemblyFormat = [{ $tag attr-dict }];
1702+
// Covered as part of LLVMFuncOp verifier.
1703+
let hasVerifier = 0;
1704+
}
1705+
16281706
def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> {
16291707
let arguments = (ins
16301708
SymbolNameAttr:$sym_name,

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,29 @@ class ModuleTranslation {
136136
return callMapping.lookup(op);
137137
}
138138

139+
/// Maps a blockaddress operation to its corresponding placeholder LLVM
140+
/// value.
141+
void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst) {
142+
auto result = unresolvedBlockAddressMapping.try_emplace(op, cst);
143+
(void)result;
144+
assert(result.second &&
145+
"attempting to map a blockaddress that is already mapped");
146+
}
147+
148+
/// Maps a blockaddress operation to its corresponding placeholder LLVM
149+
/// value.
150+
void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) {
151+
// Attempts to map already mapped block labels which is fine if the given
152+
// labels are verified to be unique.
153+
blockTagMapping[attr] = blockTag;
154+
}
155+
156+
/// Finds an MLIR block that corresponds to the given MLIR call
157+
/// operation.
158+
BlockTagOp lookupBlockTag(BlockAddressAttr attr) const {
159+
return blockTagMapping.lookup(attr);
160+
}
161+
139162
/// Removes the mapping for blocks contained in the region and values defined
140163
/// in these blocks.
141164
void forgetMapping(Region &region);
@@ -338,6 +361,8 @@ class ModuleTranslation {
338361
LogicalResult convertFunctions();
339362
LogicalResult convertComdats();
340363

364+
LogicalResult convertUnresolvedBlockAddress();
365+
341366
/// Handle conversion for both globals and global aliases.
342367
///
343368
/// - Create named global variables that correspond to llvm.mlir.global
@@ -433,6 +458,16 @@ class ModuleTranslation {
433458
/// This map is populated on module entry.
434459
DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;
435460

461+
/// Mapping from llvm.blockaddress operations to their corresponding LLVM
462+
/// constant placeholders. After all basic blocks are translated, this
463+
/// mapping is used to replace the placeholders with the LLVM block addresses.
464+
DenseMap<BlockAddressOp, llvm::Value *> unresolvedBlockAddressMapping;
465+
466+
/// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This
467+
/// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in
468+
/// search for those.
469+
DenseMap<BlockAddressAttr, BlockTagOp> blockTagMapping;
470+
436471
/// Stack of user-specified state elements, useful when translating operations
437472
/// with regions.
438473
SmallVector<std::unique_ptr<StackFrame>> stack;

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,28 @@ static LogicalResult verifyComdat(Operation *op,
23052305
return success();
23062306
}
23072307

2308+
static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2309+
llvm::DenseSet<BlockTagAttr> blockTags;
2310+
BlockTagOp badBlockTagOp;
2311+
if (funcOp
2312+
.walk([&](BlockTagOp blockTagOp) {
2313+
if (blockTags.contains(blockTagOp.getTag())) {
2314+
badBlockTagOp = blockTagOp;
2315+
return WalkResult::interrupt();
2316+
}
2317+
blockTags.insert(blockTagOp.getTag());
2318+
return WalkResult::advance();
2319+
})
2320+
.wasInterrupted()) {
2321+
badBlockTagOp.emitError()
2322+
<< "duplicate block tag '" << badBlockTagOp.getTag().getId()
2323+
<< "' in the same function: ";
2324+
return failure();
2325+
}
2326+
2327+
return success();
2328+
}
2329+
23082330
/// Parse common attributes that might show up in the same order in both
23092331
/// GlobalOp and AliasOp.
23102332
template <typename OpType>
@@ -3060,6 +3082,9 @@ LogicalResult LLVMFuncOp::verify() {
30603082
return emitError(diagnosticMessage);
30613083
}
30623084

3085+
if (failed(verifyBlockTags(*this)))
3086+
return failure();
3087+
30633088
return success();
30643089
}
30653090

@@ -3815,6 +3840,56 @@ void InlineAsmOp::getEffects(
38153840
}
38163841
}
38173842

3843+
//===----------------------------------------------------------------------===//
3844+
// BlockAddressOp
3845+
//===----------------------------------------------------------------------===//
3846+
3847+
LogicalResult
3848+
BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3849+
Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
3850+
getBlockAddr().getFunction());
3851+
auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
3852+
3853+
if (!function)
3854+
return emitOpError("must reference a function defined by 'llvm.func'");
3855+
3856+
return success();
3857+
}
3858+
3859+
LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
3860+
return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
3861+
parentLLVMModule(*this), getBlockAddr().getFunction()));
3862+
}
3863+
3864+
BlockTagOp BlockAddressOp::getBlockTagOp() {
3865+
auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
3866+
parentLLVMModule(*this), getBlockAddr().getFunction()));
3867+
if (!funcOp)
3868+
return nullptr;
3869+
3870+
BlockTagOp blockTagOp = nullptr;
3871+
funcOp.walk([&](LLVM::BlockTagOp labelOp) {
3872+
if (labelOp.getTag() == getBlockAddr().getTag()) {
3873+
blockTagOp = labelOp;
3874+
return WalkResult::interrupt();
3875+
}
3876+
return WalkResult::advance();
3877+
});
3878+
return blockTagOp;
3879+
}
3880+
3881+
LogicalResult BlockAddressOp::verify() {
3882+
if (!getBlockTagOp())
3883+
return emitOpError(
3884+
"expects an existing block label target in the referenced function");
3885+
3886+
return success();
3887+
}
3888+
3889+
/// Fold a blockaddress operation to a dedicated blockaddress
3890+
/// attribute.
3891+
OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
3892+
38183893
//===----------------------------------------------------------------------===//
38193894
// AssumeOp (intrinsic)
38203895
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,10 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
731731
}
732732

733733
bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
734-
// The inliner cannot handle variadic function arguments.
735-
return !isa<LLVM::VaStartOp>(op);
734+
// The inliner cannot handle variadic function arguments and blocktag
735+
// operations prevent inlining since they the blockaddress operations
736+
// reference them via the callee symbol.
737+
return !(isa<LLVM::VaStartOp>(op) || isa<LLVM::BlockTagOp>(op));
736738
}
737739

738740
/// Handle the given inlined return by replacing it with a branch. This

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,59 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
555555
return success();
556556
}
557557

558+
// Emit blockaddress. We first need to find the LLVM block referenced by this
559+
// operation and then create a LLVM block address for it.
560+
if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
561+
// getBlockTagOp() walks a function to search for block labels. Check
562+
// whether it's in cache first.
563+
BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
564+
BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr);
565+
if (!blockTagOp) {
566+
blockTagOp = blockAddressOp.getBlockTagOp();
567+
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
568+
}
569+
570+
llvm::Value *llvmValue = nullptr;
571+
StringRef fnName = blockAddressAttr.getFunction().getValue();
572+
if (llvm::BasicBlock *llvmBlock =
573+
moduleTranslation.lookupBlock(blockTagOp->getBlock())) {
574+
llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
575+
llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
576+
} else {
577+
// The matching LLVM block is not yet emitted, a placeholder is created
578+
// in its place. When the LLVM block is emitted later in translation,
579+
// the llvmValue is replaced with the actual llvm::BlockAddress.
580+
// A GlobalVariable is chosen as placeholder because in general LLVM
581+
// constants are uniqued and are not proper for RAUW, since that could
582+
// harm unrelated uses of the constant.
583+
llvmValue = new llvm::GlobalVariable(
584+
*moduleTranslation.getLLVMModule(),
585+
llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
586+
/*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
587+
/*Initializer=*/nullptr,
588+
Twine("__mlir_block_address_")
589+
.concat(Twine(fnName))
590+
.concat(Twine((uint64_t)blockAddressOp.getOperation())));
591+
moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
592+
}
593+
594+
moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
595+
return success();
596+
}
597+
598+
// Emit block label. If this label is seen before BlockAddressOp is
599+
// translated, go ahead and already map it.
600+
if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
601+
auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
602+
BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
603+
&moduleTranslation.getContext(),
604+
FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
605+
funcOp.getName()),
606+
blockTagOp.getTag());
607+
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
608+
return success();
609+
}
610+
558611
return failure();
559612
}
560613

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,9 +1381,18 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
13811381
return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
13821382
}
13831383

1384+
if (auto *blockAddr = dyn_cast<llvm::BlockAddress>(constant)) {
1385+
auto fnSym =
1386+
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
1387+
auto blockTag =
1388+
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
1389+
return builder
1390+
.create<BlockAddressOp>(loc, convertType(blockAddr->getType()),
1391+
BlockAddressAttr::get(context, fnSym, blockTag))
1392+
.getRes();
1393+
}
1394+
13841395
StringRef error = "";
1385-
if (isa<llvm::BlockAddress>(constant))
1386-
error = " since blockaddress(...) is unsupported";
13871396

13881397
if (isa<llvm::ConstantPtrAuth>(constant))
13891398
error = " since ptrauth(...) is unsupported";
@@ -2448,8 +2457,13 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
24482457
SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
24492458
for (llvm::BasicBlock &basicBlock : *func) {
24502459
// Skip unreachable blocks.
2451-
if (!reachable.contains(&basicBlock))
2460+
if (!reachable.contains(&basicBlock)) {
2461+
if (basicBlock.hasAddressTaken())
2462+
return emitError(funcOp.getLoc())
2463+
<< "unreachable block '" << basicBlock.getName()
2464+
<< "' with address taken";
24522465
continue;
2466+
}
24532467
Region &body = funcOp.getBody();
24542468
Block *block = builder.createBlock(&body, body.end());
24552469
mapBlock(&basicBlock, block);
@@ -2606,6 +2620,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
26062620
}
26072621
}
26082622
}
2623+
2624+
if (bb->hasAddressTaken()) {
2625+
OpBuilder::InsertionGuard guard(builder);
2626+
builder.setInsertionPointToStart(block);
2627+
builder.create<BlockTagOp>(block->getParentOp()->getLoc(),
2628+
BlockTagAttr::get(context, bb->getNumber()));
2629+
}
26092630
return success();
26102631
}
26112632

0 commit comments

Comments
 (0)