Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def build_non_residual_input_map(

# Generally if both the inputs have a backbone then both backbones are of the same type
if backbone1 and backbone2:
if backbone1 == backbone2 or backbone1.op != backbone2.op:
if backbone1 == backbone2:
non_residual_inputs[node.name] = None
continue

Expand Down
182 changes: 182 additions & 0 deletions tests/_test_utils/onnx_quantization/lib_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,185 @@ def build_conv_concat_model():
onnx.checker.check_model(model_inferred)

return model_inferred


def build_convtranspose_conv_residual_model():
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(2, 39, 96, 192)]
output_shapes = [(2, 32, 192, 384)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
helper.make_node(
op_type="ConvTranspose",
inputs=["input_0", "weights_1", "bias_1"],
outputs=["convtranspose1_convtranspose/ConvTranspose:0"],
name="convtranspose1_convtranspose/ConvTranspose",
dilations=[1, 1],
group=1,
kernel_shape=[2, 2],
pads=[0, 0, 0, 0],
strides=[2, 2],
),
helper.make_node(
op_type="Relu",
inputs=["convtranspose1_convtranspose/ConvTranspose:0"],
outputs=["relu1_relu/Relu:0"],
name="relu1_relu/Relu",
),
helper.make_node(
op_type="Conv",
inputs=["relu1_relu/Relu:0", "weights_2"],
outputs=["conv2_conv/Conv2D:0"],
name="conv2_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="BatchNormalization",
inputs=["conv2_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
outputs=["bn1_batchnorm/BatchNormalization:0"],
name="bn1_batchnorm/BatchNormalization",
),
helper.make_node(
op_type="Relu",
inputs=["bn1_batchnorm/BatchNormalization:0"],
outputs=["relu2_relu/Relu:0"],
name="relu2_relu/Relu",
),
helper.make_node(
op_type="Conv",
inputs=["relu2_relu/Relu:0", "weights_3"],
outputs=["conv3_conv/Conv2D:0"],
name="conv3_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="BatchNormalization",
inputs=["conv3_conv/Conv2D:0", "bn2_scale", "bn2_bias", "bn2_mean", "bn2_var"],
outputs=["bn2_batchnorm/BatchNormalization:0"],
name="bn2_batchnorm/BatchNormalization",
),
helper.make_node(
op_type="Add",
inputs=["relu1_relu/Relu:0", "bn2_batchnorm/BatchNormalization:0"],
outputs=["add1_add/Add:0"],
name="add1_add/Add",
),
helper.make_node(
op_type="Relu",
inputs=["add1_add/Add:0"],
outputs=["output_0"],
name="relu3_relu/Relu",
),
]

# Create the ONNX initializers
initializers = [
helper.make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(39, 32, 2, 2),
vals=np.random.uniform(low=0.5, high=1.0, size=39 * 32 * 2 * 2),
),
helper.make_tensor(
name="bias_1",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="weights_2",
data_type=onnx.TensorProto.FLOAT,
dims=(32, 32, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
),
helper.make_tensor(
name="bn1_scale",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn1_bias",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn1_mean",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn1_var",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="weights_3",
data_type=onnx.TensorProto.FLOAT,
dims=(32, 32, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
),
helper.make_tensor(
name="bn2_scale",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn2_bias",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn2_mean",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
helper.make_tensor(
name="bn2_var",
data_type=onnx.TensorProto.FLOAT,
dims=(32,),
vals=np.random.uniform(low=0.5, high=1.0, size=32),
),
]

# Create the ONNX graph with the nodes and initializers
graph = helper.make_graph(
nodes, "convtranspose_conv_residual", inputs, outputs, initializer=initializers
)

# Create the ONNX model
model = helper.make_model(graph)
model.opset_import[0].version = 13
model.ir_version = 10

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
38 changes: 36 additions & 2 deletions tests/unit/onnx/test_quantize_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@
import onnx_graphsurgeon as gs
import pytest
import torch
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx
from _test_utils.onnx_quantization.lib_test_models import (
SimpleMLP,
build_convtranspose_conv_residual_model,
export_as_onnx,
)

import modelopt.onnx.quantization as moq
from modelopt.onnx.utils import save_onnx


def _assert_nodes_are_quantized(nodes):
Expand Down Expand Up @@ -52,6 +57,35 @@ def test_int8(tmp_path, high_precision_dtype):
# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that all MatMul nodes are quantized
# Check that all MatMul nodes are quantized
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
assert _assert_nodes_are_quantized(mm_nodes)


def test_convtranspose_conv_residual_int8(tmp_path):
onnx_model = build_convtranspose_conv_residual_model()
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
save_onnx(onnx_model, onnx_path)

moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")

# Check that quantized explicit model is generated
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv and ConvTransposed are quantized
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
assert _assert_nodes_are_quantized(conv_nodes)

# Check that only 1 input of Add is quantized
add_nodes = [n for n in graph.nodes if n.op == "Add"]
for node in add_nodes:
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
assert len(quantized_inputs) == 1, (
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
)