Skip to content

Commit

Permalink
[mlir] use the new stateful LLVM type translator by default
Browse files Browse the repository at this point in the history
Previous type model in the LLVM dialect did not support identified structure
types properly and therefore could use stateless translations implemented as
free functions. The new model supports identified structs and must keep track
of the identified structure types present in the target context (LLVMContext or
MLIRContext) to avoid creating duplicate structs due to LLVM's type
auto-renaming. Expose the stateful type translation classes and use them during
translation, storing the state as part of ModuleTranslation.

Drop the test type translation mechanism that is no longer necessary and update
the tests to exercise type translation as part of the main translation flow.

Update the code in vector-to-LLVM dialect conversion that relied on stateless
translation to use the new class in a stateless manner.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85297
  • Loading branch information
ftynse committed Aug 5, 2020
1 parent e1de85f commit b2ab375
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 414 deletions.
6 changes: 0 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ struct LLVMTypeStorage;
struct LLVMDialectImpl;
} // namespace detail

/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function
/// exists exclusively for the purpose of gradual transition to the first-party
/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM
/// translation.
llvm::Type *convertLLVMType(LLVMType type);

///// Ops /////
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Value.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"

#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -127,6 +128,9 @@ class ModuleTranslation {
/// Mappings between llvm.mlir.global definitions and corresponding globals.
DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;

/// A stateful object used to translate types.
TypeToLLVMIRTranslator typeTranslator;

protected:
/// Mappings between original and translated values, used for lookups.
llvm::StringMap<llvm::Function *> functionMapping;
Expand Down
48 changes: 46 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H

#include <memory>

namespace llvm {
class DataLayout;
class LLVMContext;
class Type;
} // namespace llvm
Expand All @@ -27,8 +30,49 @@ namespace LLVM {

class LLVMType;

llvm::Type *translateTypeToLLVMIR(LLVMType type, llvm::LLVMContext &context);
LLVMType translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context);
namespace detail {
class TypeToLLVMIRTranslatorImpl;
class TypeFromLLVMIRTranslatorImpl;
} // namespace detail

/// Utility class to translate MLIR LLVM dialect types to LLVM IR. Stores the
/// translation state, in particular any identified structure types that can be
/// reused in further translation.
class TypeToLLVMIRTranslator {
public:
TypeToLLVMIRTranslator(llvm::LLVMContext &context);
~TypeToLLVMIRTranslator();

/// Returns the perferred alignment for the type given the data layout. Note
/// that this will perform type conversion and store its results for future
/// uses.
// TODO: this should be removed when MLIR has proper data layout.
unsigned getPreferredAlignment(LLVM::LLVMType type,
const llvm::DataLayout &layout);

/// Translates the given MLIR LLVM dialect type to LLVM IR.
llvm::Type *translateType(LLVM::LLVMType type);

private:
/// Private implementation.
std::unique_ptr<detail::TypeToLLVMIRTranslatorImpl> impl;
};

/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
/// the translation state, in particular any identified structure types that are
/// reused across translations.
class TypeFromLLVMIRTranslator {
public:
TypeFromLLVMIRTranslator(MLIRContext &context);
~TypeFromLLVMIRTranslator();

/// Translates the given LLVM IR type to the MLIR LLVM dialect.
LLVM::LLVMType translateType(llvm::Type *type);

private:
/// Private implementation.
std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
};

} // namespace LLVM
} // namespace mlir
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
if (!elementTy)
return failure();

auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
// TODO: this should be abstracted away to avoid depending on translation.
align = dataLayout.getPrefTypeAlignment(LLVM::translateTypeToLLVMIR(
elementTy.cast<LLVM::LLVMType>(),
typeConverter.getDialect()->getLLVMContext()));
// TODO: this should use the MLIR data layout when it becomes available and
// stop depending on translation.
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
dialect->getLLVMModule().getDataLayout());
return success();
}

Expand Down
87 changes: 14 additions & 73 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Translation.h"

#include "llvm/IR/Attributes.h"
Expand Down Expand Up @@ -48,7 +49,8 @@ class Importer {
public:
Importer(MLIRContext *context, ModuleOp module)
: b(context), context(context), module(module),
unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)),
typeTranslator(*context) {
b.setInsertionPointToStart(module.getBody());
dialect = context->getRegisteredDialect<LLVMDialect>();
}
Expand Down Expand Up @@ -129,6 +131,8 @@ class Importer {
Location unknownLoc;
/// Cached dialect.
LLVMDialect *dialect;
/// The stateful type translator (contains named structs).
LLVM::TypeFromLLVMIRTranslator typeTranslator;
};
} // namespace

Expand All @@ -149,79 +153,16 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
}

LLVMType Importer::processType(llvm::Type *type) {
switch (type->getTypeID()) {
case llvm::Type::FloatTyID:
return LLVMType::getFloatTy(dialect);
case llvm::Type::DoubleTyID:
return LLVMType::getDoubleTy(dialect);
case llvm::Type::IntegerTyID:
return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
case llvm::Type::PointerTyID: {
LLVMType elementType = processType(type->getPointerElementType());
if (!elementType)
return nullptr;
return elementType.getPointerTo(type->getPointerAddressSpace());
}
case llvm::Type::ArrayTyID: {
LLVMType elementType = processType(type->getArrayElementType());
if (!elementType)
return nullptr;
return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
}
case llvm::Type::ScalableVectorTyID: {
emitError(unknownLoc) << "scalable vector types not supported";
return nullptr;
}
case llvm::Type::FixedVectorTyID: {
auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type);
LLVMType elementType = processType(typeVTy->getElementType());
if (!elementType)
return nullptr;
return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
}
case llvm::Type::VoidTyID:
return LLVMType::getVoidTy(dialect);
case llvm::Type::FP128TyID:
return LLVMType::getFP128Ty(dialect);
case llvm::Type::X86_FP80TyID:
return LLVMType::getX86_FP80Ty(dialect);
case llvm::Type::StructTyID: {
SmallVector<LLVMType, 4> elementTypes;
elementTypes.reserve(type->getStructNumElements());
for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
LLVMType ty = processType(type->getStructElementType(i));
if (!ty)
return nullptr;
elementTypes.push_back(ty);
}
return LLVMType::getStructTy(dialect, elementTypes,
cast<llvm::StructType>(type)->isPacked());
}
case llvm::Type::FunctionTyID: {
llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
SmallVector<LLVMType, 4> paramTypes;
for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
LLVMType ty = processType(fty->getParamType(i));
if (!ty)
return nullptr;
paramTypes.push_back(ty);
}
LLVMType result = processType(fty->getReturnType());
if (!result)
return nullptr;
if (LLVMType result = typeTranslator.translateType(type))
return result;

return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
}
default: {
// FIXME: Diagnostic should be able to natively handle types that have
// operator<<(raw_ostream&) defined.
std::string s;
llvm::raw_string_ostream os(s);
os << *type;
emitError(unknownLoc) << "unhandled type: " << os.str();
return nullptr;
}
}
// FIXME: Diagnostic should be able to natively handle types that have
// operator<<(raw_ostream&) defined.
std::string s;
llvm::raw_string_ostream os(s);
os << *type;
emitError(unknownLoc) << "unhandled type: " << os.str();
return nullptr;
}

// We only need integers, floats, doubles, and vectors and tensors thereof for
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
ompDialect(
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
llvmDialect(module->getContext()->getRegisteredDialect<LLVMDialect>()) {
llvmDialect(module->getContext()->getRegisteredDialect<LLVMDialect>()),
typeTranslator(this->llvmModule->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
Expand Down Expand Up @@ -935,7 +936,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
llvm::Type *ModuleTranslation::convertType(LLVMType type) {
// Lock the LLVM context as we create types in it.
llvm::sys::SmartScopedLock<true> lock(llvmDialect->getLLVMContextMutex());
return LLVM::translateTypeToLLVMIR(type, llvmDialect->getLLVMContext());
return typeTranslator.translateType(type);
}

/// A helper to look up remapped operands in the value remapping table.`
Expand Down
57 changes: 37 additions & 20 deletions mlir/lib/Target/LLVMIR/TypeTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
#include "mlir/IR/MLIRContext.h"

#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Type.h"

using namespace mlir;

namespace {
namespace mlir {
namespace LLVM {
namespace detail {
/// Support for translating MLIR LLVM dialect types to LLVM IR.
class TypeToLLVMIRTranslator {
class TypeToLLVMIRTranslatorImpl {
public:
/// Constructs a class creating types in the given LLVM context.
TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {}
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}

/// Translates a single type.
llvm::Type *translateType(LLVM::LLVMType type) {
Expand Down Expand Up @@ -160,22 +163,32 @@ class TypeToLLVMIRTranslator {
/// type instead of creating a new type.
llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
};
} // end namespace

/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain
/// the mapping for identified structs so new structs will be created with
/// auto-renaming on each call. This is intended exclusively for testing.
llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMType type,
llvm::LLVMContext &context) {
return TypeToLLVMIRTranslator(context).translateType(type);
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir

LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
: impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {}

LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}

llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
return impl->translateType(type);
}

namespace {
unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
LLVM::LLVMType type, const llvm::DataLayout &layout) {
return layout.getPrefTypeAlignment(translateType(type));
}

namespace mlir {
namespace LLVM {
namespace detail {
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
class TypeFromLLVMIRTranslator {
class TypeFromLLVMIRTranslatorImpl {
public:
/// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {}
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}

/// Translates the given type.
LLVM::LLVMType translateType(llvm::Type *type) {
Expand Down Expand Up @@ -299,11 +312,15 @@ class TypeFromLLVMIRTranslator {
/// The context in which MLIR types are created.
MLIRContext &context;
};
} // end namespace
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir

LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
: impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}

LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}

/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended
/// exclusively for testing.
LLVM::LLVMType mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type,
MLIRContext &context) {
return TypeFromLLVMIRTranslator(context).translateType(type);
LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
return impl->translateType(type);
}
2 changes: 1 addition & 1 deletion mlir/test/Target/import.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
%struct.t = type {}
%struct.s = type { %struct.t, i64 }

; CHECK: llvm.mlir.global external @g1() : !llvm.struct<(struct<()>, i64)>
; CHECK: llvm.mlir.global external @g1() : !llvm.struct<"struct.s", (struct<"struct.t", ()>, i64)>
@g1 = external global %struct.s, align 8
; CHECK: llvm.mlir.global external @g2() : !llvm.double
@g2 = external global double, align 8
Expand Down
Loading

0 comments on commit b2ab375

Please sign in to comment.