From 9190e1c0ef628b9cc432b507390dc6eec416f6ab Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Fri, 10 Jan 2025 09:23:50 +0800 Subject: [PATCH] [mlir][linalg] Handle reassociationIndices correctly for 0D tensor (#121683) This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043. --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 25 +++++++++++-------- .../TosaToLinalg/tosa-to-linalg.mlir | 23 +++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 88e544c4e4b5f..1d7ead16e8b63 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -600,30 +600,33 @@ static Value createLinalgBodyCalculationForElementwiseOp( static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank) { // No need to expand if we are already at the desired rank - auto shapedType = dyn_cast(tensor.getType()); - assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); - int64_t numExtraDims = rank - shapedType.getRank(); + auto tensorType = dyn_cast(tensor.getType()); + assert(tensorType && "expected a ranked tensor type"); + int64_t tensorRank = tensorType.getRank(); + int64_t numExtraDims = rank - tensorRank; assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank"); if (!numExtraDims) return tensor; // Compute reassociation indices - SmallVector> reassociationIndices( - shapedType.getRank()); + SmallVector reassociationIndices(tensorRank); int64_t index = 0; - for (index = 0; index <= numExtraDims; index++) - reassociationIndices[0].push_back(index); - for (size_t position = 1; position < reassociationIndices.size(); position++) - reassociationIndices[position].push_back(index++); + if (tensorRank != 0) { + for (index = 0; index <= numExtraDims; index++) + reassociationIndices[0].push_back(index); + for (size_t position = 1; position < reassociationIndices.size(); + position++) + reassociationIndices[position].push_back(index++); + } // Compute result type SmallVector resultShape; for (index = 0; index < numExtraDims; index++) resultShape.push_back(1); - for (auto size : shapedType.getShape()) + for (auto size : tensorType.getShape()) resultShape.push_back(size); auto resultType = - RankedTensorType::get(resultShape, shapedType.getElementType()); + RankedTensorType::get(resultShape, tensorType.getElementType()); // Emit 'tensor.expand_shape' op return rewriter.create(loc, resultType, tensor, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 265a75986c6c8..c840fb8648d7b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -100,6 +100,29 @@ func.func @test_add_0d(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: func.func @test_add_0d_broadcast( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor) -> tensor<2x1xf32> { +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor into tensor<1x1xf32> +// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32> +// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) { +// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<2x1xf32> +// CHECK: return %[[RESULT]] : tensor<2x1xf32> +// CHECK: } +func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor) -> tensor<2x1xf32> { + %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- + // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)> // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @test_add_1d_all_dynamic