Skip to content

[CIR] Streamline creation of mlir::IntegerAttrs using mlir::Builder #141830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Support/ErrorHandling.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
Expand Down Expand Up @@ -167,9 +168,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}

mlir::TypedAttr getConstPtrAttr(mlir::Type type, int64_t value) {
auto valueAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(type.getContext(), 64), value);
return cir::ConstPtrAttr::get(type, valueAttr);
return cir::ConstPtrAttr::get(type, getI64IntegerAttr(value));
}

mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
Expand Down Expand Up @@ -197,14 +196,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

mlir::Value createDummyValue(mlir::Location loc, mlir::Type type,
clang::CharUnits alignment) {
auto addr = createAlloca(loc, getPointerTo(type), type, {},
getSizeFromCharUnits(getContext(), alignment));
mlir::IntegerAttr alignAttr;
uint64_t align = alignment.getQuantity();
if (align)
alignAttr = getI64IntegerAttr(align);

return create<cir::LoadOp>(loc, addr, /*isDeref=*/false, alignAttr);
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
auto addr = createAlloca(loc, getPointerTo(type), type, {}, alignmentAttr);
return create<cir::LoadOp>(loc, addr, /*isDeref=*/false, alignmentAttr);
}

cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
Expand Down Expand Up @@ -428,13 +422,29 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return OpBuilder::InsertPoint(block, block->begin());
};

mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx,
clang::CharUnits size) {
// Note that mlir::IntegerType is used instead of cir::IntType here
// because we don't need sign information for this to be useful, so keep
// it simple.
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
size.getQuantity());
//
// Alignment and size helpers
//

// Note that mlir::IntegerType is used instead of cir::IntType here because we
// don't need sign information for these to be useful, so keep it simple.

// For 0 alignment, any overload of `getAlignmentAttr` returns an empty
// attribute.
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment) {
return getAlignmentAttr(alignment.getQuantity());
}

mlir::IntegerAttr getAlignmentAttr(llvm::Align alignment) {
return getAlignmentAttr(alignment.value());
}

mlir::IntegerAttr getAlignmentAttr(int64_t alignment) {
return alignment ? getI64IntegerAttr(alignment) : mlir::IntegerAttr();
}

mlir::IntegerAttr getSizeFromCharUnits(clang::CharUnits size) {
return getI64IntegerAttr(size.getQuantity());
}

/// Create a loop condition.
Expand Down
15 changes: 4 additions & 11 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,22 +282,15 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {

cir::LoadOp createLoad(mlir::Location loc, Address addr,
bool isVolatile = false) {
mlir::IntegerAttr align;
uint64_t alignment = addr.getAlignment().getQuantity();
if (alignment)
align = getI64IntegerAttr(alignment);
mlir::IntegerAttr align = getAlignmentAttr(addr.getAlignment());
return create<cir::LoadOp>(loc, addr.getPointer(), /*isDeref=*/false,
align);
}

cir::StoreOp createStore(mlir::Location loc, mlir::Value val, Address dst,
::mlir::IntegerAttr align = {}) {
if (!align) {
uint64_t alignment = dst.getAlignment().getQuantity();
if (alignment)
align = mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64),
alignment);
}
mlir::IntegerAttr align = {}) {
if (!align)
align = getAlignmentAttr(dst.getAlignment());
return CIRBaseBuilderTy::createStore(loc, val, dst.getPointer(), align);
}

Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class CIRGenModule : public CIRGenTypeCache {
const clang::FunctionDecl *funcDecl);

mlir::IntegerAttr getSize(CharUnits size) {
return builder.getSizeFromCharUnits(&getMLIRContext(), size);
return builder.getSizeFromCharUnits(size);
}

const llvm::Triple &getTriple() const { return target.getTriple(); }
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {

if (parser.parseOptionalKeyword("null").succeeded()) {
value = mlir::IntegerAttr::get(
mlir::IntegerType::get(parser.getContext(), 64), 0);
value = parser.getBuilder().getI64IntegerAttr(0);
return success();
}

Expand Down
40 changes: 16 additions & 24 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
case cir::CastKind::int_to_bool: {
mlir::Value llvmSrcVal = adaptor.getOperands().front();
mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>(
castOp.getLoc(), llvmSrcVal.getType(),
mlir::IntegerAttr::get(llvmSrcVal.getType(), 0));
castOp.getLoc(), llvmSrcVal.getType(), 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt);
break;
Expand Down Expand Up @@ -630,9 +629,8 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(), index.getType(),
rewriter.create<mlir::LLVM::ConstantOp>(
index.getLoc(), index.getType(),
mlir::IntegerAttr::get(index.getType(), 0)),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
index.getType(), 0),
index);
rewriter.eraseOp(sub);
}
Expand All @@ -648,8 +646,7 @@ mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
assert(!cir::MissingFeatures::opAllocaDynAllocSize());
mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), typeConverter->convertType(rewriter.getIndexType()),
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1);
mlir::Type elementTy =
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType());
mlir::Type resultTy = convertTypeForMemory(*getTypeConverter(), dataLayout,
Expand Down Expand Up @@ -1111,18 +1108,16 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
switch (op.getKind()) {
case cir::UnaryOpKind::Inc: {
assert(!isVector && "++ not allowed on vector types");
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
op, llvmType, adaptor.getInput(), one, maybeNSW);
return mlir::success();
}
case cir::UnaryOpKind::Dec: {
assert(!isVector && "-- not allowed on vector types");
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
op, llvmType, adaptor.getInput(), one, maybeNSW);
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(),
one, maybeNSW);
return mlir::success();
}
case cir::UnaryOpKind::Plus:
Expand All @@ -1133,10 +1128,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
if (isVector)
zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
else
zero = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
op, llvmType, zero, adaptor.getInput(), maybeNSW);
op, zero, adaptor.getInput(), maybeNSW);
return mlir::success();
}
case cir::UnaryOpKind::Not: {
Expand All @@ -1150,11 +1144,10 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
minusOne =
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
} else {
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
}
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
op, llvmType, adaptor.getInput(), minusOne);
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
minusOne);
return mlir::success();
}
}
Expand Down Expand Up @@ -1206,10 +1199,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
return op.emitError() << "Unsupported unary operation on boolean type";
case cir::UnaryOpKind::Not: {
assert(!isVector && "NYI: op! on vector mask");
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmType,
adaptor.getInput(), one);
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
one);
return mlir::success();
}
}
Expand Down
8 changes: 3 additions & 5 deletions clang/unittests/CIR/PointerLikeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ class CIROpenACCPointerLikeTest : public ::testing::Test {
llvm::StringMap<unsigned> recordNames;

mlir::IntegerAttr getAlignOne(mlir::MLIRContext *ctx) {
// Note that mlir::IntegerType is used instead of cir::IntType here
// because we don't need sign information for this to be useful, so keep
// it simple.
// Note that mlir::IntegerType is used instead of cir::IntType here because
// we don't need sign information for this to be useful, so keep it simple.
clang::CharUnits align = clang::CharUnits::One();
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
align.getQuantity());
return b.getI64IntegerAttr(align.getQuantity());
}

mlir::StringAttr getUniqueRecordName(const std::string &baseName) {
Expand Down
Loading