From 5de799cb7632a4aa84440ffeb69284cfd713e55b Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Thu, 1 Jun 2023 14:23:00 +0000 Subject: [PATCH] Generic support for legalizing tosa.custom_op into another dialect operation. --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 11 ++++++++--- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 10 ++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a7750a7f7518c..105ee086db723 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -508,9 +508,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, // tosa::CustomOp if (auto customOp = dyn_cast(op)) { - return llvm::StringSwitch(customOp.getIdentifierAttr().str()) - .Case("atan2", rewriter.create(loc, resultTypes, args)) - .Default(nullptr); + // Only legalize tosa.custom_op's that are marked as implementable with + // 'linalg.generic' by looking at the 'implementation_attrs' attribute + auto implementationAttr = customOp.getImplementationAttrs(); + if (implementationAttr == "linalg.generic") { + OperationState state(loc, customOp.getIdentifierAttr(), args, + resultTypes); + return rewriter.create(state)->getResult(0); + } } (void)rewriter.notifyMatchFailure( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index b94867a9f7e51..6483e29e7a9c2 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1414,9 +1414,12 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, % // CHECK-LABEL: @test_custom_ops func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () { + // CHECK: linalg.generic + // CHECK: math.sin // CHECK: linalg.generic // CHECK: math.atan2 - %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.sin", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> return } @@ -1426,9 +1429,12 @@ func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () { // CHECK-LABEL: @test_custom_ops_dyn func.func @test_custom_ops_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: linalg.generic + // CHECK: math.cos // CHECK: linalg.generic // CHECK: math.atan2 - %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor, tensor) -> tensor + %2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.cos", implementation_attrs = "linalg.generic"}> : (tensor) -> tensor + %3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor, tensor) -> tensor return } \ No newline at end of file