Skip to content

Commit

Permalink
[Target][TOPI] Use LLVM for x86 CPU feature lookup (#15685)
Browse files Browse the repository at this point in the history
This PR leverage LLVM itself for CPU features lookup, replacing hard-coded lists.
In order to keep maintainability with X86 families & features we can rely on LLVM.

---
Changes:

* Introduce a single ```target_has_feature(XXX)``` replacing all ```target_has_XXX()```
* PY+FFI: expose new ```llvm_x86_get_archlist```, ```llvm_x86_get_features``` & ```llvm_x86_has_feature```
* PY:  expose new ```target_has_feature``` wrapper to ```_ffi.llvm_x86_has_feature```

---

There is a test unit for a comprehensive check with the old behaviour.
For better reliability, this way of feature checking can be implemented for other arches.
  • Loading branch information
cbalint13 committed Sep 14, 2023
1 parent e2e1d44 commit 67df20f
Show file tree
Hide file tree
Showing 15 changed files with 473 additions and 176 deletions.
4 changes: 2 additions & 2 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ class TargetKind : public ObjectRef {
*/
TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);

private:
/*! \brief Mutable access to the container class */
TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }

private:
TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer(
const String& attr_name);
friend class TargetKindRegEntry;
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm._ffi.base import TVMError
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op

from ....target.x86 import target_has_sse42
from ....target.x86 import target_has_features
from ....topi.utils import is_target
from .. import op as reg

Expand Down Expand Up @@ -457,8 +457,7 @@ def _shift(data, zero_point, out_dtype):

def is_fast_int8_on_intel():
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
target = tvm.target.Target.current(allow_none=False)
return target_has_sse42(target.mcpu)
return target_has_features("sse4.2")


# Helper function to align up given value.
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tvm.runtime import Object
from tvm.target import Target
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
from tvm.target.x86 import target_has_sse41
from tvm.target.x86 import target_has_features

from . import _make, _requantize

Expand All @@ -54,8 +54,9 @@ def _get_node_default_rounding():
@staticmethod
def _get_node_default_compute_dtype():
target = Target.current(True)
if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu):
return "float32"
if target and str(target.kind) == "llvm":
if target_has_features("sse4.1", target):
return "float32"

return "int64"

Expand Down
32 changes: 32 additions & 0 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,38 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str:
return _ffi_api.llvm_get_intrinsic_name(intrin_id)


def llvm_x86_get_archlist(only64bit=False):
"""Get X86 CPU name list.
Parameters
----------
only64bit : bool
Filter 64bit architectures.
Returns
-------
features : list[str]
String list of X86 architectures.
"""
return _ffi_api.llvm_x86_get_archlist(only64bit)


def llvm_x86_get_features(cpu_name):
"""Get X86 CPU features.
Parameters
----------
cpu_name : string
X86 CPU name (e.g. "skylake").
Returns
-------
features : list[str]
String list of X86 CPU features.
"""
return _ffi_api.llvm_x86_get_features(cpu_name)


def llvm_version_major(allow_none=False):
"""Get the major LLVM version.
Expand Down
161 changes: 41 additions & 120 deletions python/tvm/target/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,127 +16,48 @@
# under the License.
"""Common x86 related utilities"""
from .._ffi import register_func
from .target import Target


@register_func("tvm.target.x86.target_has_sse41")
def target_has_sse41(target):
return (
target_has_sse42(target)
or target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"btver2",
"penryn",
}
)


@register_func("tvm.target.x86.target_has_sse42")
def target_has_sse42(target):
return (
target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"silvermont",
"slm",
"goldmont",
"goldmont-plus",
"tremont",
"nehalem",
"corei7",
"westmere",
"bdver1",
"bdver2",
"bdver3",
"x86-64-v2",
}
)


@register_func("tvm.target.x86.target_has_avx")
def target_has_avx(target):
return (
target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"}
)


@register_func("tvm.target.x86.target_has_avx2")
def target_has_avx2(target):
return (
target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"haswell",
"core-avx2",
"broadwell",
"skylake",
"bdver4",
"znver1",
"znver2",
"znver3",
"x86-64-v3",
}
)


@register_func("tvm.target.x86.target_has_avx512")
def target_has_avx512(target):
return target in {
"skylake-avx512",
"skx",
"knl",
"knm",
"x86-64-v4",
"cannonlake",
# explicit enumeration of VNNI capable due to collision with alderlake
"cascadelake",
"icelake-client",
"icelake-server",
"rocketlake",
"tigerlake",
"cooperlake",
"sapphirerapids",
}


@register_func("tvm.target.x86.target_has_vnni")
def target_has_vnni(target):
return target in {
"cascadelake",
"icelake-client",
"icelake-server",
"rocketlake",
"tigerlake",
"cooperlake",
"sapphirerapids",
"alderlake",
}


@register_func("tvm.target.x86.target_has_amx")
def target_has_amx(target):
return target in {
"sapphirerapids",
}
from . import _ffi_api
from ..ir.container import Array


@register_func("tvm.target.x86.target_has_features")
def target_has_features(features, target=None):
"""Check X86 CPU features.
Parameters
----------
features : str or Array
Feature(s) to check.
target : Target
Optional TVM target, default `None` use the global context target.
Returns
-------
has_feats : bool
True if feature(s) are in the target arch.
"""
has_feats = True
assert isinstance(features, (Array, str))
features = [features] if isinstance(features, str) else features
for feat in features:
has_feats &= _ffi_api.llvm_x86_has_feature(feat, target)
return has_feats


@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
def get_simd_32bit_lanes():
mcpu = Target.current().mcpu
fp32_vec_len = 4
if target_has_avx512(mcpu):
fp32_vec_len = 16
elif target_has_avx2(mcpu):
fp32_vec_len = 8
return fp32_vec_len
"""X86 SIMD optimal vector length lookup.
Parameters
----------
Returns
-------
vec_len : int
The optimal vector length of CPU from the global context target.
"""
vec_len = 4
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
vec_len = 16
elif target_has_features("avx2"):
vec_len = 8
return vec_len
16 changes: 10 additions & 6 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, mkl
from tvm.target.x86 import target_has_amx, target_has_avx512
from tvm.target.x86 import target_has_features

from .. import generic, nn
from ..transform import layout_transform
Expand All @@ -38,8 +38,10 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
packed_y = layout_transform(y, "BNK", packed_y_layout)
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
attrs_info = None
Expand Down Expand Up @@ -233,14 +235,16 @@ def _callback(op):
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "batch_matmul_int8" in op.tag:
layout_trans = op.input_tensors[1]
if target_has_amx(mcpu):
if target_has_features("amx-int8"):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tvm
from tvm import autotvm, te
from tvm.target.x86 import target_has_sse42
from tvm.target.x86 import target_has_features

from .. import nn, tag
from ..generic import conv2d as conv2d_generic
Expand Down Expand Up @@ -49,7 +49,10 @@ def _get_default_config_int8(
"""
if is_depthwise:
# Fallback to FP32 default config until a VNNI schedule is defined.
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
wkl = _get_depthwise_conv2d_workload(
data, kernel, strides, padding, dilation, out_dtype, layout
)

from .depthwise_conv2d import _fallback_schedule

_fallback_schedule(cfg, wkl)
Expand Down Expand Up @@ -81,8 +84,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
is_llvm_support = llvm_version >= 8

# 3) Check target
mcpu = tvm.target.Target.current().mcpu
is_target_support = target_has_sse42(mcpu)
is_target_support = target_has_features("sse4.2")

return is_dtype_support and is_llvm_support and is_target_support

Expand Down
16 changes: 10 additions & 6 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, dnnl, mkl
from tvm.target.x86 import get_simd_32bit_lanes, target_has_amx, target_has_avx512
from tvm.target.x86 import get_simd_32bit_lanes, target_has_features

from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
Expand Down Expand Up @@ -298,13 +298,15 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
def schedule_dense_int8(cfg, outs):
"""Create a schedule for dense__int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
if target_has_features("amx-int8"):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
dense_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
Expand All @@ -316,8 +318,10 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
target_attr = None
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import tvm
from tvm import autotvm, relay, te
from tvm.target.x86 import target_has_amx, target_has_avx512
from tvm.target.x86 import target_has_features

from .. import nn
from ..nn import dense_alter_layout
Expand All @@ -28,9 +28,12 @@


def check_int8_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
# TODO(vvchernov): may be also target_has_avx2 or lower?
simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu)
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
simd_avai = target_has_features(["avx512bw", "avx512f"])
simd_avai |= target_has_features("amx-int8")
# TODO(vvchernov): may be also target_has_features("avx2") or lower?
return (
simd_avai
and "int8" in x.dtype
Expand Down
Loading

0 comments on commit 67df20f

Please sign in to comment.