Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][PaddlePaddle] PaddlePaddle model with NCHW data format that supports quantization #16651

Merged
merged 17 commits into from Mar 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 74 additions & 9 deletions python/tvm/relay/frontend/paddlepaddle.py
Expand Up @@ -31,6 +31,7 @@
from .. import function as _function
from .. import ty as _ty
from .. import op as _op
from .. import qnn as _qnn
from .common import (
autopad,
fold_constant,
Expand Down Expand Up @@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
strides = op.attr("strides")

kernel = g.get_node(op.input("Filter")[0])
kernel_layout = "OIHW"
input_x = g.get_node(op.input("Input")[0])
data_layout = op.attr("data_format")
kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
out_channels, _, k_h, k_w = infer_shape(kernel)
if padding_algorithm == "VALID":
paddings = [0, 0]
Expand All @@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."'
raise tvm.error.OpAttributeInvalid(msg)

if data_layout == "NHWC":
kernel_layout = "HWIO"
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
is_quantized = op.has_attr("quantization_type")
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
# There are two situations when converting the data format of weights:
# 1 Conv_2d is not a quantified OP, its weight information is the weights themselves.
# We directly convert the weight information when processing conv_2d.
# 2 Conv_2d is a quantified OP, and its weight information is the output of
# the quantize_linear operator. Therefore, the weight information needs to be
# transformed when processing the quantize_linear operator.
if (not is_quantized) and (data_layout == "NHWC"):
kernel_data = g.get_params(op.input("Filter")[0])
kernel_data = kernel_data.asnumpy()
kernel_data = kernel_data.transpose((2, 3, 1, 0))
Expand Down Expand Up @@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))

# handle with special case
# while kernel size less than input size
# while kernel size more than input size
# shrink kernel size to input size
if (
not isinstance(in_h, _op.Expr)
Expand Down Expand Up @@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_dequantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()

tvm_quantize_axis = op.attr("quant_axis")
if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

if len(infer_shape(data_node)) < 2:
tvm_quantize_axis = 0

out = _qnn.op.dequantize(
data=data_node,
input_scale=_op.const(tvm_quantize_scale, "float32"),
input_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_quantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_quantize_axis = op.attr("quant_axis")

if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

out = _qnn.op.quantize(
data=data_node,
output_scale=_op.const(tvm_quantize_scale, "float32"),
output_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_rnn(g, op, block):
"""Operator converter for rnn."""

Expand Down Expand Up @@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
def convert_softmax(g, op, block):
"""Operator converter for softmax."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
input_shape = block.var(op.input("X")[0]).shape
if axis < 0:
axis = len(input_shape) + axis
x = g.get_node(op.input("X")[0])
m = _op.max(x, axis, keepdims=True)
e = _op.exp(x - m)
out = e / _op.sum(e, axis, keepdims=True)
Expand Down Expand Up @@ -2905,6 +2965,9 @@ def convert_where_index(g, op, block):
"unstack": convert_unstack,
"where": convert_where,
"where_index": convert_where_index,
# Quantized
"dequantize_linear": convert_dequantize_linear,
"quantize_linear": convert_quantize_linear,
}


Expand Down Expand Up @@ -2938,7 +3001,7 @@ def get_params(self, name=None):

if name is None:
return self.params
assert name in self.params
assert name in self.params, f"The name({name}) is not in params"
return self.params[name]

def extract_parameters(self, program, scope=None):
Expand All @@ -2947,9 +3010,12 @@ def extract_parameters(self, program, scope=None):
self.params = {}
variables = program.global_block().vars
for name in variables:
var = program.global_block().var(name)
if name.endswith("feed") or name.endswith("fetch"):
continue
# This judgment will cause the PaddleInference model
# exported by PaddleSlim to skip some operators
# that need to be read in NHWC format.
var = program.global_block().var(name)
if not var.persistable:
continue
if isinstance(scope, dict):
Expand Down Expand Up @@ -3018,7 +3084,6 @@ def from_program(self, program, shape_dict, scope):
for op in block.ops:
if op.type == "fetch":
output_names.append(op.input("X")[0])

outputs = [self.nodes[name] for name in output_names]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)

Expand Down