From fe0017e1ad9ae846a0e4efb7985858b55568ffb9 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Fri, 17 May 2024 13:21:28 -0400 Subject: [PATCH] Add support for the onnx.SequenceConstruct op. (#3316) --- .../Conversion/TorchOnnxToTorch/Patterns.h | 12 ++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 14 +++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 24 +++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 3230cc8b46a..c00522a763f 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,18 @@ struct OpBinder { return success(); } + ParseResult tensorListResultType(Torch::ListType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto tt = dyn_cast(op->getResult(0).getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + type0 = tt; + return success(); + } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { for (auto result : op->getResults()) { auto t = toValidTensorType(result.getType()); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d9ecb930290..037633490d9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -518,6 +518,20 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cstStrReduction); return success(); }); + patterns.onOp( + "SequenceConstruct", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallVector operands; + Torch::ListType resultType; + + if (binder.tensorOperands(operands, binder.getNumOperands()) || + binder.tensorListResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operands); + return success(); + }); patterns.onOp( "Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index e52ccd6daf4..9432702b6b1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2075,6 +2075,30 @@ func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.v // ----- +// CHECK-LABEL: func.func @test_sequence_construct_3 +module { + func.func @test_sequence_construct_3(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_construct_1 +module { + func.func @test_sequence_construct_1(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + // CHECK-LABEL: func.func @test_sce_mean_3d func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none