Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay] Add conv2d NHWC hybrid schedule for arm_cpu #16106

Merged
merged 6 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,14 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeig

/*! \brief Attributes used in gemm weight transformation operators */
struct ConvGemmWeightTransformAttrs : public tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
int tile_rows;
int tile_cols;
int tile_N;
int tile_K;

TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm.");
TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm.");
TVM_ATTR_FIELD(tile_N).describe(
"Tile size across N axis of the weight transformation for ConvGemm. (N = OC)");
TVM_ATTR_FIELD(tile_K).describe(
"Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC)");
}
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def mirror_pad_func(attrs, inputs, _):
@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv2d_gemm_weight_transform"""
out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, attrs.tile_cols)
out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_N, attrs.tile_K)
return [out]


Expand Down
12 changes: 6 additions & 6 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def contrib_conv2d_winograd_weight_transform(weight, tile_size):
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)


def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
def contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K):
r"""Weight Transformation part for 2D convolution with gemm algorithm.

We separate this as a single op to enable pre-compute for inference.
Expand All @@ -2751,17 +2751,17 @@ def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
----------
weights : tvm.relay.Expr
The weight expressions.
tile_rows: int
Tile rows of the weight transformation for ConvGemm.
tile_cols: int
Tile columns of the weight transformation for ConvGemm.
tile_N: int
Tile size across N axis of the weight transformation for ConvGemm. (N = OC)
tile_K: int
Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC)

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)
return _make.contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K)


def contrib_conv3d_winograd_weight_transform(weight, tile_size):
Expand Down
123 changes: 78 additions & 45 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,37 +211,50 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
elif kernel_layout == "HWIO":
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod
has_matmul_i8 = target.features.has_matmul_i8

if data.dtype in ["int8", "uint8"]:
if has_matmul_i8:
interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native
native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native
# Quantized cases
if is_aarch64 and data.dtype in ["int8", "uint8"]:
if has_matmul_i8 and has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d(
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
),
wrap_compute_conv2d(native_compute),
wrap_topi_schedule(native_schedule),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
elif has_matmul_i8:
strategy.add_implementation(
wrap_compute_conv2d(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if has_dot_prod:
elif has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
wrap_compute_conv2d(native_compute),
wrap_topi_schedule(native_schedule),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if is_aarch64 and has_asimd:
else:
strategy.add_implementation(
wrap_compute_conv2d(
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
),
wrap_compute_conv2d(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
# Non-quantized cases
if is_aarch64 and data.dtype in ["float32", "float16"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid),
name="conv2d_NHWC_hybrid.arm_cpu",
)
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
Expand All @@ -250,6 +263,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.arm_cpu",
plevel=5,
Anndrey24 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
raise RuntimeError(f"Unsupported kernel layout {kernel_layout} for conv2d NHWC")
Expand Down Expand Up @@ -485,40 +499,59 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
data = inputs[0]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod
has_matmul_i8 = target.features.has_matmul_i8

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
if has_matmul_i8:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
if has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
if is_aarch64 and has_asimd:
native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8", "float32", "float16"]:
# Non-AArch64 cases
if not is_aarch64:
raise RuntimeError("Unsupported non-AArch64 conv2d_NHWC_without_transform")
# AArch64 cases
if data.dtype in ["int8", "uint8"]:
# Quantized cases
if has_matmul_i8 and has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(native_schedule),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
elif has_matmul_i8:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
elif has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(native_schedule),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
elif data.dtype in ["float32", "float16"]:
# Non-quantized cases
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
)
else:
raise RuntimeError(
f"Unsupported conv2d_NHWC_quantized_without_transform layout {layout}"
f"Unsupported conv2d_NHWC_without_transform layout {layout}"
f"with datatype {data.dtype}"
)
return strategy
Expand Down
105 changes: 57 additions & 48 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from tvm.target import Target


def get_tiling_B_interleaved_t(interleave_A):
def get_tiling_B_transformed(interleave_A, in_dtype):
"""Compute the tiling information for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.
is the tiled, interleaved (and transposed) version of matrix B in C=A*B.

The tiling information is chosen to maximize register usage during the
tile computation.
Expand All @@ -36,59 +36,68 @@ def get_tiling_B_interleaved_t(interleave_A):

Parameters
----------
interleave_A: bool
determines if A is expected to be interleaved
interleave_A : bool
determines if A is expected to be interleaved
in_dtype : str
input datatype


Returns
----------
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
tile_N: the output tile size of B' on N axis (N = OC)
tile_K: the output tile size of B' on K axis (K = KW * KH * IC)
"""
target = Target.current(allow_none=False)

if target.features.has_matmul_i8:
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_rows_B = 12
tile_cols_B = 8
elif target.features.has_dotprod:
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
# 12 rows from B).
# * If we are not interleaving A, then we select 16 columns from B'(i.e.,
# 16 rows from B).
tile_rows_B = 12 if interleave_A else 16

# Dot product instruction groups 2 (u)int16x8 vectors in
# groups of 4 and compute the dot product among those groups
# This means that the number of columns in a tile of B' (i.e., the
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
if in_dtype in ["int8", "uint8"]:
if target.features.has_matmul_i8:
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_N = 12
tile_K = 8
elif target.features.has_dotprod:
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
# 12 rows from B).
# * If we are not interleaving A, then we select 16 columns from B'(i.e.,
# 16 rows from B).
tile_N = 12 if interleave_A else 16

# Dot product instruction groups 2 (u)int16x8 vectors in
# groups of 4 and compute the dot product among those groups
# This means that the number of columns in a tile of B' (i.e., the
# rows of the original matrix B) need to be 4.
tile_K = 4
else:
# If no acceleration is available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_N = 4
tile_K = 16
else:
# If no acceleration is available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16
# In non-quantized cases, A is not interleaved.
# Each load from B' contains 16 elements (i.e. 16 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 16
tile_K = 4

return tile_rows_B, tile_cols_B
return tile_N, tile_K


def get_conv2d_weights_padding(N, K, tile_rows, tile_cols):
def get_conv2d_weights_padding(N, K, tile_N, tile_K):
"""Compute the necessary padding for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.
is the transformed version of matrix B in C=A*B.

Parameters
----------
N : int
Number of rows in B' = OC
Number of columns in B = OC
K : int
Number of columns in B' = KW * KH * IC
tile_rows : int
tile rows of B'
tile_cols : int
tile columns of B'
Number of rows in B = KW * KH * IC
tile_N : int
tile size of B' on N axis
tile_K : int
tile size of B' on K axis

Returns
----------
Expand All @@ -98,16 +107,16 @@ def get_conv2d_weights_padding(N, K, tile_rows, tile_cols):
pad_N = 0
pad_K = 0

if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)
if N % tile_N != 0:
pad_N = tile_N - (N % tile_N)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols * column_multiplier
K_misalignment = K % tile_cols_multiplied
# Tensorize will later make use of 4 tiles at once across the K axis so make sure we pad such
# that K is multiple of 4
K_multiplier = 4
tile_K_multiplied = tile_K * K_multiplier
K_misalignment = K % tile_K_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment
pad_K = tile_K_multiplied - K_misalignment

return pad_N, pad_K