Skip to content

Commit

Permalink
[Hexagon] F2qi avgpool bug fix (#15599)
Browse files Browse the repository at this point in the history
F2qi avgpool bug fix
  • Loading branch information
rasagna-quic committed Sep 15, 2023
1 parent 64ab31e commit 08a6ee5
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 34 deletions.
31 changes: 31 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,37 @@ def avgpool2d(expr, type_map):
t = type_map[arg]
out_t = type_map[expr]

# dq > nn.avg_pool2d > q
# Use the same input quantization parameters for output if the pattern is not the above.
# Type_map is a map of graphs and their Tensoraffinetypes
# Find the current "nn.avg_pool2d" op after checking for the "qnn.quantize" op in the graph.
# Structure for .. dq > op > q will be q [op [dq ..
def check(y, expr):
if isinstance(y, type(expr)):
if y.op.name != "nn.avg_pool2d":
return True
# check if this is the expr avg_pool
if y.attrs != expr.attrs:
return True
return False

for x in type_map.items():
if isinstance(x[0], type(expr)):
if x[0].op.name == "qnn.quantize":
prev = x[0]
y = prev.args[0]
while check(y, expr):
prev = y
y = prev.args[0]
if (
isinstance(y, type(expr))
and y.op.name == "nn.avg_pool2d"
and y.attrs == expr.attrs
):
if prev.op.name != "qnn.quantize":
out_t = t
break

out = relay.qnn.op.avg_pool2d(
arg,
t.scale,
Expand Down
22 changes: 10 additions & 12 deletions python/tvm/topi/hexagon/qnn/avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def saturate(x: te.Tensor, dtype: str):
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))


def get_temp_dtype(h, w, dtype):
temp_dtype = "int16" if h * w < 256 else "int32"
if dtype in ("uint8", "int8"):
return temp_dtype
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")


def qnn_avg_pool2d_NCHW(
data: te.Tensor,
kernel: list,
Expand All @@ -59,12 +67,7 @@ def qnn_avg_pool2d_NCHW(
rh = te.reduce_axis((0, kh), name="rh")
rw = te.reduce_axis((0, kw), name="rw")

if odtype == "uint8":
temp_dtype = "uint16"
elif odtype == "int8":
temp_dtype = "int16"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")
temp_dtype = get_temp_dtype(kh, kw, odtype)

sh, sw = stride
dh, dw = dilation
Expand Down Expand Up @@ -155,12 +158,7 @@ def qnn_avg_pool2d_NHWC(
rh = te.reduce_axis((0, kh), name="rh")
rw = te.reduce_axis((0, kw), name="rw")

if odtype == "uint8":
temp_dtype = "uint16"
elif odtype == "int8":
temp_dtype = "int16"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")
temp_dtype = get_temp_dtype(kh, kw, odtype)

sh, sw = stride
dh, dw = dilation
Expand Down
22 changes: 22 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy
import tvm
from tvm import te
from tvm.relay.backend import Executor


def ceildiv(o, d):
Expand Down Expand Up @@ -112,6 +113,27 @@ def build_and_run(inputs, func, target: str, target_host: str, *args, **kwargs):
return tensors[-1].asnumpy()


def build_module(relay_mod, target):
"""builds a relay module for a specified target"""
params = {}
executor = Executor("aot", {"link-params": True})
lowered = tvm.relay.build(
relay_mod,
tvm.target.Target(target, host=target),
executor=executor,
params=params,
)
return lowered


def run_module(mod, inputs):
"""invokes run function of specified module with inputs provided"""
mod.set_input(**inputs)
mod.run()
output = mod.get_output(0).numpy()
return output


def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, out_channels):
assert len(shape_nhwc) == 4
kernel = []
Expand Down
Loading

0 comments on commit 08a6ee5

Please sign in to comment.