Skip to content

[Stablehlo] Unable to convert QLinearConv from ONNX #4189

Open
@anuragsingh-tt

Description

@anuragsingh-tt

A test case here has been added to convert ONNX QLinearConv to Torch.

After running --torch-fuse-quantized-ops on this MLIR, the following is produced:

// -----
module {
  func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8: !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
    %0 = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
    %1 = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
    %2 = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
    %3 = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
    %none = torch.constant.none
    %false = torch.constant.bool false
    %int6 = torch.constant.int 6
    %4 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32>
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %5 = torch.aten.sub.Scalar %4, %0, %float1.000000e00 : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32>
    %6 = torch.aten.mul.Scalar %5, %2 : !torch.vtensor<[?,3,224,224],f32>, !torch.float -> !torch.vtensor<[?,3,224,224],f32>
    %int1 = torch.constant.int 1
    %7 = torch.aten.unsqueeze %arg4, %int1 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32>
    %8 = torch.aten.unsqueeze %arg5, %int1 : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8>
    %int2 = torch.constant.int 2
    %9 = torch.aten.unsqueeze %7, %int2 : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32>
    %10 = torch.aten.unsqueeze %8, %int2 : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8>
    %int3 = torch.constant.int 3
    %11 = torch.aten.unsqueeze %9, %int3 : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32>
    %12 = torch.aten.unsqueeze %10, %int3 : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8>
    %13 = torch.aten.to.dtype %arg3, %int6, %false, %false, %none : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32>
    %14 = torch.aten.sub.Tensor %13, %12, %float1.000000e00 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32>
    %15 = torch.aten.mul.Tensor %14, %11 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],f32> -> !torch.vtensor<[64,3,7,7],f32>
    %16 = torch.aten.to.dtype %arg8, %int6, %false, %false, %none : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32>
    %int3_0 = torch.constant.int 3
    %int3_1 = torch.constant.int 3
    %int3_2 = torch.constant.int 3
    %int3_3 = torch.constant.int 3
    %int0 = torch.constant.int 0
    %17 = torch.prim.ListConstruct %int3_0, %int3_1 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_4 = torch.constant.int 1
    %int1_5 = torch.constant.int 1
    %int2_6 = torch.constant.int 2
    %int2_7 = torch.constant.int 2
    %18 = torch.prim.ListConstruct %int1_4, %int1_5 : (!torch.int, !torch.int) -> !torch.list<int>
    %19 = torch.prim.ListConstruct %int2_6, %int2_7 : (!torch.int, !torch.int) -> !torch.list<int>
    %20 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %false_8 = torch.constant.bool false
    %int1_9 = torch.constant.int 1
    %21 = torch.aten.convolution %6, %15, %16, %19, %17, %18, %false_8, %20, %int1_9 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,64,112,112],f32>
    %int13 = torch.constant.int 13
    %22 = torch.aten.quantize_per_tensor %21, %3, %1, %int13 : !torch.vtensor<[?,64,112,112],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8>
    %23 = torch.aten.int_repr %22 : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8>
    return %23 : !torch.vtensor<[?,64,112,112],ui8>
  }
}

This cannot be converted to stablehlo using --torch-backend-to-stablehlo-backend-pipeline:

<stdin>:34:11: error: failed to legalize operation 'torch.aten.quantize_per_tensor' that was explicitly marked illegal %22 = torch.aten.quantize_per_tensor %21, %3, %1, %int13 : !torch.vtensor<[?,64,112,112],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8>

as ConvertAtenQuantizePerTensorOp requires both scale and zero point to be defined by constant operations:

auto *zeroPoint = op.getZeroPoint().getDefiningOp();
if (!zeroPoint || !isa<ConstantIntOp>(zeroPoint)) {
  return failure();
}

auto scale = op.getScale().getDefiningOp();
if (!scale || !isa<ConstantFloatOp>(scale)) {
  return failure();
}

In the MLIR above, the scale and zero point are coming from %3 and %1, which are results of torch.aten.item operations, not constants.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions