From 55d27e63b56efc060730eb859c7bfb1b358aa4a6 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 28 Apr 2025 09:20:17 -0600 Subject: [PATCH] Remove conv2d conversion to tosa.fully_connected This is a part of the following upstream changes: https://github.com/llvm/llvm-project/pull/126152/ [mlir][tosa] Remove FullyConnectedOp from TOSA Dialect Upstream commit is 4ec1990 --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 1 - .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 - .../Tosa/Transforms/TosaDecomposeConv2D.cpp | 163 ------------------ .../Transforms/TosaOptionalDecompositions.cpp | 1 - .../Dialect/Tosa/tosa-decompose-conv2d.mlir | 70 -------- 5 files changed, 236 deletions(-) delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp delete mode 100644 mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index c5a4c16604a30..483057f8ccb94 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -26,7 +26,6 @@ namespace tosa { // Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops. // The rewrites can be selectively added to a conversion pass. -void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 5b0f5ec4cd568..9c3345b617cc5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp - TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp TosaFolders.cpp TosaInferShapes.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp deleted file mode 100644 index 04a709c596779..0000000000000 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ /dev/null @@ -1,163 +0,0 @@ -//===- TosaDecomposeConv2D.cpp --------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically -// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - -using namespace mlir; -using namespace mlir::tosa; - -namespace { - -SmallVector convertFromMlirShape(ArrayRef shape) { - return to_vector(llvm::map_range(shape, [](int64_t dim) { - return ShapedType::isDynamic(dim) ? -1 : dim; - })); -} - -struct Conv2DIsFullyConnected : public OpRewritePattern { - explicit Conv2DIsFullyConnected(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(tosa::Conv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput(); - Value weight = op.getWeight(); - ShapedType inputType = cast(input.getType()); - ShapedType weightType = cast(weight.getType()); - ShapedType resultType = cast(op.getType()); - - auto numDynamic = - llvm::count_if(inputType.getShape(), ShapedType::isDynamic); - if (numDynamic > 1) - return rewriter.notifyMatchFailure( - op, "at most one dim in input may be dynamic"); - if (!weightType.hasRank()) - return rewriter.notifyMatchFailure(op, "unranked weight input"); - - if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) - return failure(); - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) - return failure(); - - llvm::ArrayRef padAttr = op.getPad(); - llvm::SmallVector pad(8, 0); - for (const auto &it : llvm::enumerate(padAttr)) - pad[it.index() + 2] = it.value(); - - if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { - Type inputETy = inputType.getElementType(); - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); - if (op.getQuantizationInfo()) { - auto quantizationInfo = op.getQuantizationInfo(); - int64_t iZp = quantizationInfo->getInputZp(); - - if (!validIntegerRange(cast(inputETy), iZp)) - return rewriter.notifyMatchFailure( - op, "tosa.conv op quantization has zp outside of input range"); - - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); - } - - llvm::SmallVector newShape(inputType.getShape()); - - for (int i = 0, s = newShape.size(); i < s; ++i) { - if (newShape[i] != ShapedType::kDynamic) { - newShape[i] += pad[i * 2] + pad[i * 2 + 1]; - } - } - - auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type()); - auto padSize = - DenseIntElementsAttr::get(padSizeTy, ArrayRef(pad)); - Value padSizeVal = - rewriter.create(op->getLoc(), padSizeTy, padSize); - - auto padTy = RankedTensorType::get({}, inputETy); - auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); - Value padVal = - rewriter.create(op->getLoc(), padTy, padAttr); - inputType = RankedTensorType::get(newShape, inputETy); - input = rewriter.create(op->getLoc(), inputType, input, - padSizeVal, padVal); - } - - // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. - ArrayRef inputShape = inputType.getShape(); - int64_t combined = ShapedType::kDynamic; - if (numDynamic == 0) - combined = inputShape[0] * inputShape[1] * inputShape[2]; - llvm::SmallVector revisedInputShape{combined, inputShape[3]}; - auto revisedInputShapeType = - RankedTensorType::get(revisedInputShape, inputType.getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getDenseI64ArrayAttr( - convertFromMlirShape(revisedInputShape))) - .getResult(); - - // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. - llvm::SmallVector revisedWeightShape{weightShape[0], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - dyn_cast(weight.getType()).getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getDenseI64ArrayAttr( - convertFromMlirShape(revisedWeightShape))) - .getResult(); - - // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{combined, weightShape[0]}; - auto fullyConnectedShapeType = - RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); - - Value fullyConnectedValue; - if (op.getQuantizationInfo()) { - fullyConnectedValue = - rewriter - .create( - op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.getBias(), *op.getQuantizationInfo()) - .getResult(); - } else { - fullyConnectedValue = rewriter - .create( - op.getLoc(), fullyConnectedShapeType, - reshapedInput, reshapedWeight, op.getBias()) - .getResult(); - } - - // Reshape output to [N, IH, IW, OC]. - llvm::SmallVector outputShape{inputShape[0], inputShape[1], - inputShape[2], weightShape[0]}; - rewriter.replaceOpWithNewOp( - op, resultType, fullyConnectedValue, - rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape))); - return success(); - } -}; - -} // namespace - -void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, - RewritePatternSet &patterns) { - patterns.add(ctx); -} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp index 603185e48aa94..ffa2ea3d0629f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -38,7 +38,6 @@ struct TosaOptionalDecompositions RewritePatternSet patterns(ctx); auto func = getOperation(); - mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns); mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir deleted file mode 100644 index 8df4630f9c17f..0000000000000 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected -func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { - // CHECK-NOT: tosa.conv2d - // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-SAME: -> tensor<400x2xf32> - // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-SAME: -> tensor<3x2xf32> - // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2 - // CHECK-SAME: -> tensor<400x3xf32> - // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} - // CHECK-SAME: -> tensor<4x10x10x3xf32> - // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected_quant -func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { - // CHECK-NOT: tosa.conv2d - // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-SAME: -> tensor<400x2xi8> - // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-SAME: -> tensor<3x2xi8> - // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2 - // CHECK-SAME: quantization_info = #tosa.conv_quant - // CHECK-SAME: -> tensor<400x3xi32> - // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} - // CHECK-SAME: -> tensor<4x10x10x3xi32> - // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> - return %0 : tensor<4x10x10x3xi32> -} - -// ----- - -// CHECK-LABEL: func.func @conv_with_dynamic_dim( -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor { -func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor { -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8> -// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant} : (tensor, tensor<384x64xi8>, tensor<384xi32>) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor -// CHECK: return %[[VAL_6]] : tensor -// CHECK: } - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected_padded -func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> { - // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>} - // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor} - // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor) -> tensor<4x12x12x2xi8> - // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array} - // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant} - // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> - return %0 : tensor<4x12x12x3xi32> -}