Skip to content

Commit

Permalink
[MLIR][XLA] Add HLO FuncOp to LHLO FuncOp legalization.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfki-ehna committed Jan 29, 2020
1 parent 874eaa9 commit a97d62e
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 65 deletions.
40 changes: 39 additions & 1 deletion tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
Expand Up @@ -11,6 +11,44 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return
}

// CHECK-LABEL: func @func_op
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[MAX_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
return %0 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
%2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.add"(%arg0, %[[MAX_RESULT]], %[[ADD_RESULT]])
%3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.min"(%arg0, %arg1, %[[MIN_RESULT]])
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.sub"(%arg1, %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
Expand Down Expand Up @@ -120,7 +158,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return
// CHECK: xla_lhlo.terminator
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
Expand Down
222 changes: 159 additions & 63 deletions tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
Expand Up @@ -39,54 +39,48 @@ namespace {

constexpr StringRef kTempBufferAttr = "temp";

Value GetTensorStoreOrReturnMemRef(Value value) {
/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
Operation* FindInsertionPointForCopy(Value value) {
for (const auto& user : value.getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
if (return_op.getOperand(0) == value) {
auto block = return_op.getOperation()->getBlock();
return *block->args_rbegin();
}
if (auto dealloc = dyn_cast<DeallocOp>(user)) {
return user;
}
}
return nullptr;
}

Operation* GetLastUse(Value value) {
Operation* last = value.getDefiningOp();
for (auto& user : value.getUses()) {
Operation* user_op = user.getOwner();
if (!user_op->isBeforeInBlock(last)) {
last = user_op;
Value GetTensorStore(Value value) {
for (const auto& user : value.getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
}
return last;
return nullptr;
}

Value InsertAllocAndDealloc(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
emitError(loc,
"tensor to buffer conversion expects statically shaped results");
result.getDefiningOp()->emitOpError()<<"tensor to buffer conversion expects statically shaped results";
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());

Operation* last = GetLastUse(result);

Operation* op = result.getDefiningOp();
auto block = op->getBlock();

OpBuilder allocBuilder(op);
allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);

alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));

allocBuilder.setInsertionPoint(op->getBlock(),
std::next(Block::iterator(last)));
allocBuilder.setInsertionPoint(block, std::prev(block->end()));
allocBuilder.create<DeallocOp>(loc, alloc);

return alloc;
}

Expand All @@ -95,7 +89,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
/// function to store that values held in the tensor.
Value GetBufferForResultValue(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
if (auto existing_memref = GetTensorStore(result)) {
return existing_memref;
}
return InsertAllocAndDealloc(loc, result, rewriter);
Expand All @@ -110,11 +104,6 @@ class HloToLhloOpConverter : public ConversionPattern {
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
if (op->getParentRegion()->getBlocks().size() != 1) {
emitError(op->getLoc(),
"tensor to buffer conversion expects a single block in the "
"region containing the operation");
}
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
Expand All @@ -129,7 +118,7 @@ class HloToLhloOpConverter : public ConversionPattern {
}
};

struct HloToLHloReduceConverter
struct HloToLHloReduceOpConverter
: public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -141,9 +130,9 @@ struct HloToLHloReduceConverter
// TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return matchFailure();
if (op.getParentRegion()->getBlocks().size() != 1) {
emitError(loc,
"tensor to buffer conversion expects a single block in the "
"region containing the operation");
op.emitOpError() << "tensor to buffer conversion expects a single block "
"in the region containing the operation";
return matchFailure();
}
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
Expand Down Expand Up @@ -185,11 +174,10 @@ struct HloToLHloReduceConverter
}
};

class HloToLhloTensorLoadConverter : public ConversionPattern {
class HloToLhloTensorLoadOpConverter : public ConversionPattern {
public:
explicit HloToLhloTensorLoadConverter(MLIRContext* context)
explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
: ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}

PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Expand All @@ -199,9 +187,9 @@ class HloToLhloTensorLoadConverter : public ConversionPattern {
};

// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreConverter : public ConversionPattern {
class HloToLhloTensorStoreOpConverter : public ConversionPattern {
public:
explicit HloToLhloTensorStoreConverter(MLIRContext* context)
explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
: ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}

PatternMatchResult matchAndRewrite(
Expand All @@ -212,19 +200,6 @@ class HloToLhloTensorStoreConverter : public ConversionPattern {
}
};

// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;

PatternMatchResult matchAndRewrite(
xla_hlo::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
}
};

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -265,26 +240,146 @@ class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
// return
// }
// }
struct HloLegalizeToLhlo : public FunctionPass<HloLegalizeToLhlo> {
void runOnFunction() override {
//
// FuncOp signature conversion example:
//
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32>
// return %1 : tensor<4xf32>
// }
//
// Transformed function with an extra argument for the result. The types have been
// converted from tensor to memref.
//
// func @func_op(%arg0: memref<4xf32>,
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() {temp = true} : memref<4xf32>
// %1 = alloc() {temp = true} : memref<4xf32>
// "xla_lhlo.max"(%arg0, %arg1, %1) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.add"(%arg0, %1, %0) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// dealloc %1 : memref<4xf32>
// "xla_lhlo.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// dealloc %0 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }

struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
void runOnModule() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();

auto func = getFunction();
populateHLOToLHLOConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::ReturnOp>();
target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return std::all_of(inputs.begin(), inputs.end(),
[](Type input) { return input.isa<MemRefType>(); });
});

auto module = getModule();
populateHLOToLHLOConversionPattern(module.getContext(), &patterns);

if (failed(applyFullConversion(module, target, patterns, nullptr))) {
signalPassFailure();
}
}
};

Type ConvertType(Type t) {
if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
}
return t;
}

} // namespace

/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
/// results are converted to memrefs and appended to the argument list.
class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;

PatternMatchResult matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
if (funcOp.getBody().getBlocks().size() > 1) {
funcOp.emitOpError() << "tensor to buffer conversion expects a single "
"block in the region containing the operation";
return matchFailure();
}

auto funcType = funcOp.getType();

TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs())) {
conversion.addInputs(argType.index(), ConvertType(argType.value()));
}
for (auto resType : funcType.getResults()) {
conversion.addInputs(ConvertType(resType));
}
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(
rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
});
return matchSuccess();
}
};

/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
/// result to the corresponding buffer argument.
class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;

PatternMatchResult matchAndRewrite(
mlir::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto numReturnValues = returnOp.getNumOperands();
auto funcOp = returnOp.getParentOfType<FuncOp>();
auto numFuncArgs = funcOp.getNumArguments();
auto loc = returnOp.getLoc();

for (auto operand : llvm::enumerate(operands)) {
auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
auto dstBuffer = funcOp.getArgument(returnArgNumber);
if (dstBuffer == operand.value()) {
continue;
}

auto dealloc = FindInsertionPointForCopy(operand.value());

if (dealloc == nullptr) {
returnOp.emitOpError() << "Missing dealloc for operand " << operand.index();
return matchFailure();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(dealloc);
rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
funcOp.getArgument(returnArgNumber));
}
rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
return matchSuccess();
}
};

void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
patterns->insert<
HloToLHloReduceOpConverter,
HloToLhloFuncOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
Expand All @@ -308,13 +403,14 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
HloToLHloReduceConverter, HloToLhloReturnConverter,
HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter,
StdToLhloReturnOpConverter
>(context);
// clang-format on
}

std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToLhloPass() {
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/xla/transforms/passes.h
Expand Up @@ -53,7 +53,7 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToStdPass();

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToLhloPass();
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();

} // namespace xla_hlo

Expand Down

0 comments on commit a97d62e

Please sign in to comment.