Skip to content

Commit

Permalink
[ARM] Fix int8 NCHWc compute and alter layout (#10839)
Browse files Browse the repository at this point in the history
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in #10310. The compute itself, not the schedule, is broken for the following reasons:

* We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375
* In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478
* The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108
* Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. 

Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect.

Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. 

@tkonolige I suggest doing perf benchmark again, since the numbers in #10310 are invalid.

cc @mbrookhart @Mousius  @junrushao1994 @vinx13
  • Loading branch information
masahi committed Apr 1, 2022
1 parent 63bb3b9 commit 912993f
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)

n_elems = 8
n_elems = 4

if cfg.is_fallback:
_get_default_config_int8(
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn

oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
else:
# data is nchw, implicitly treat it as nchw1c
Expand Down Expand Up @@ -103,8 +103,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
if len(data.shape) == 4:
data, kernel = _pack_data(cfg, data, kernel)

n_elems = int(kernel.shape[-1])

return nn.conv2d_NCHWc_int8(
data, kernel, strides, padding, dilation, layout, out_layout, out_dtype
data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems
)


Expand Down Expand Up @@ -149,7 +151,8 @@ def _callback(op):

args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
_, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape)
assert n_elems == 4
dtype = "uint" if data.dtype == "uint8" else "int"
if is_dotprod_available():
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
Expand Down
21 changes: 11 additions & 10 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,21 +614,22 @@ def _instr(index):
ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
return ib.get()

def pairwise_add_mul(idx):
# this broadcasts data to the vector size
a_int8 = ins[0].vload([0], "int8x4")
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
vec_ai32 = re_int32.astype("int32x2")
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
# this broadcasts data to the vector size
a_int8 = ins[0].vload([0], "int8x4")
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
vec_ai32 = re_int32.astype("int32x2")
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)

vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time
vec_b = ins[1].vload([0, 0], "int8x16")

def pairwise_add_mul(extract_half):
vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b)
multiply = tvm.tir.call_llvm_pure_intrin(
"int16x8",
"llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication
tvm.tir.const(2, "uint32"),
vec_a,
vec_b,
vec_b_half,
)
pairwise_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x4",
Expand All @@ -638,8 +639,8 @@ def pairwise_add_mul(idx):
)
return pairwise_reduction

pair_1 = pairwise_add_mul(0)
pair_2 = pairwise_add_mul(1)
pair_1 = pairwise_add_mul("tir.vectorlow")
pair_2 = pairwise_add_mul("tir.vectorhigh")
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x4",
"llvm.aarch64.neon.addp.v4i32",
Expand Down
1 change: 0 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def conv2d_NCHWc_int8(
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(
kernel.shape
)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group

dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel):
kernel = te.compute(
(oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[
occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w
occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w
],
name="kernel_vec",
)
Expand Down
14 changes: 7 additions & 7 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
from tvm import topi
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
Expand All @@ -34,6 +33,7 @@
from common import Int8Fallback
import tvm.testing
import pytest
import platform


def compile_conv2d_NHWC_gemm_int8_arm(
Expand Down Expand Up @@ -299,7 +299,6 @@ def get_ref_data():

a_np, w_np, b_np, c_np = get_ref_data()

print("Running on target: %s" % target)
with tvm.target.Target(target):
C = compute(
A,
Expand All @@ -311,8 +310,6 @@ def get_ref_data():
"NCHW",
out_dtype,
)
print(C.shape)
print(bias.shape)
if add_bias:
C = topi.add(C, bias)
if add_relu:
Expand Down Expand Up @@ -342,6 +339,8 @@ def get_ref_data():
if build_only:
return

print("Running on target: %s" % target)

func(*run_args)

tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
Expand All @@ -364,14 +363,15 @@ def get_ref_data():
# ),
]

# TODO(tvm-team): Properly run ARM code on CI aarch64 environment
build_only_aarch64 = platform.machine() != "aarch64"

targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
build_only_aarch64,
)
)

Expand All @@ -382,7 +382,7 @@ def get_ref_data():
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
build_only_aarch64,
)
)

Expand Down

0 comments on commit 912993f

Please sign in to comment.