Skip to content

[ONNX] Use torch.aten.quantize_per_channel + torch.aten.dequantize.self for QLinearConv weights #4190

Open
@anuragsingh-tt

Description

@anuragsingh-tt

The ONNX QLinearConv to Torch conversion subtracts the zero point and multiplies the scale manually:

// -----
module {
  func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8: !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
    %0 = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
    %1 = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
    %2 = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
    %3 = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
    %none = torch.constant.none
    %false = torch.constant.bool false
    %int6 = torch.constant.int 6
    %4 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32>
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %5 = torch.aten.sub.Scalar %4, %0, %float1.000000e00 : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32>
    %6 = torch.aten.mul.Scalar %5, %2 : !torch.vtensor<[?,3,224,224],f32>, !torch.float -> !torch.vtensor<[?,3,224,224],f32>
    %int1 = torch.constant.int 1
    %7 = torch.aten.unsqueeze %arg4, %int1 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32>
    %8 = torch.aten.unsqueeze %arg5, %int1 : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8>
    %int2 = torch.constant.int 2
    %9 = torch.aten.unsqueeze %7, %int2 : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32>
    %10 = torch.aten.unsqueeze %8, %int2 : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8>
    %int3 = torch.constant.int 3
    %11 = torch.aten.unsqueeze %9, %int3 : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32>
    %12 = torch.aten.unsqueeze %10, %int3 : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8>
    %13 = torch.aten.to.dtype %arg3, %int6, %false, %false, %none : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32>
    %14 = torch.aten.sub.Tensor %13, %12, %float1.000000e00 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32>
    %15 = torch.aten.mul.Tensor %14, %11 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],f32> -> !torch.vtensor<[64,3,7,7],f32>
    %16 = torch.aten.to.dtype %arg8, %int6, %false, %false, %none : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32>
    %int3_0 = torch.constant.int 3
    %int3_1 = torch.constant.int 3
    %int3_2 = torch.constant.int 3
    %int3_3 = torch.constant.int 3
    %int0 = torch.constant.int 0
    %17 = torch.prim.ListConstruct %int3_0, %int3_1 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_4 = torch.constant.int 1
    %int1_5 = torch.constant.int 1
    %int2_6 = torch.constant.int 2
    %int2_7 = torch.constant.int 2
    %18 = torch.prim.ListConstruct %int1_4, %int1_5 : (!torch.int, !torch.int) -> !torch.list<int>
    %19 = torch.prim.ListConstruct %int2_6, %int2_7 : (!torch.int, !torch.int) -> !torch.list<int>
    %20 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %false_8 = torch.constant.bool false
    %int1_9 = torch.constant.int 1
    %21 = torch.aten.convolution %6, %15, %16, %19, %17, %18, %false_8, %20, %int1_9 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,64,112,112],f32>
    %int13 = torch.constant.int 13
    %22 = torch.aten.quantize_per_tensor %21, %3, %1, %int13 : !torch.vtensor<[?,64,112,112],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8>
    %23 = torch.aten.int_repr %22 : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8>
    return %23 : !torch.vtensor<[?,64,112,112],ui8>
  }
}

but a Quantize -> Dequantize -> Convolution sequence in ONNX instead uses torch.aten.quantize_per_tensor -> torch.aten.int_repr -> torch.aten._make_per_tensor_quantized_tensor -> torch.aten.dequantize.self which would be the more appropriate sequence.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions