diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp index 5c1cafc23d2f..68da99652d74 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp @@ -95,6 +95,11 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) { // If the generic op is "just" copy, then fuse always. Block &body = producerOp->getRegion(0).front(); if (std::begin(body)->hasTrait()) return true; + if (llvm::all_of(body.getArguments(), + [](BlockArgument arg) { return arg.use_empty(); })) { + // THe operands arent used, its just an `linalg.index` op. + return true; + } // If producer does not have a single user, dont fuse. if (!producerOp->hasOneUse()) return false; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir index a8d060ddf78e..45adb4c00eff 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir @@ -274,3 +274,59 @@ module { // CHECK-SAME: ins(%[[GENERIC1]], %[[GENERIC0]] : // CHECK-SAME: outs(%[[FILL]] : // CHECK: return %[[GENERIC2]] + +// ----- + +func.func @fuse_iota_ops(%arg0: tensor<10x20xi32>) -> (tensor<10x20xi32>, tensor<10x20xi32>) { + %c20 = arith.constant 20 : index + %0 = tensor.empty() : tensor<10x20xi32> + %1 = tensor.empty() : tensor<10x20xindex> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%1 : tensor<10x20xindex>) { + ^bb0(%b0 : index): + %3 = linalg.index 0 : index + %4 = linalg.index 1 : index + %5 = arith.muli %4, %c20 : index + %6 = arith.addi %3, %5 : index + linalg.yield %6 : index + } -> tensor<10x20xindex> + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %2: tensor<10x20xi32>, tensor<10x20xindex>) outs(%0 : tensor<10x20xi32>) { + ^bb0(%b0 : i32, %b1 : index, %b2 : i32): + %8 = arith.index_cast %b1 : index to i32 + %9 = arith.addi %8, %b0 : i32 + linalg.yield %9 : i32 + } -> tensor<10x20xi32> + %8 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %2: tensor<10x20xi32>, tensor<10x20xindex>) outs(%0 : tensor<10x20xi32>) { + ^bb0(%b0 : i32, %b1 : index, %b2 : i32): + %8 = arith.index_cast %b1 : index to i32 + %9 = arith.muli %8, %b0 : i32 + linalg.yield %9 : i32 + } -> tensor<10x20xi32> + return %7, %8 : tensor<10x20xi32>, tensor<10x20xi32> +} +// CHECK-LABEL: func @fuse_iota_ops( +// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32>) +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<10x20xi32> +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<10x20xi32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x20xi32>) +// CHECK: linalg.index +// CHECK: linalg.index +// CHECK: arith.addi +// CHECK: linalg.yield +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<10x20xi32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x20xi32>) +// CHECK: linalg.index +// CHECK: linalg.index +// CHECK: arith.muli +// CHECK: linalg.yield +// CHECK: return %[[GENERIC1]], %[[GENERIC2]]