diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 86f386abb827..130aea32f844 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -130,10 +130,10 @@ class TargetKind : public ObjectRef { */ TVM_DLL static Optional Get(const String& target_kind_name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); - - private: /*! \brief Mutable access to the container class */ TargetKindNode* operator->() { return static_cast(data_.get()); } + + private: TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer( const String& attr_name); friend class TargetKindRegEntry; diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 8dc54a19b998..ce45b2e75959 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -23,7 +23,7 @@ from tvm._ffi.base import TVMError from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op -from ....target.x86 import target_has_sse42 +from ....target.x86 import target_has_features from ....topi.utils import is_target from .. import op as reg @@ -457,8 +457,7 @@ def _shift(data, zero_point, out_dtype): def is_fast_int8_on_intel(): """Checks whether the hardware has support for fast Int8 arithmetic operations.""" - target = tvm.target.Target.current(allow_none=False) - return target_has_sse42(target.mcpu) + return target_has_features("sse4.2") # Helper function to align up given value. diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 968935f062b2..312fd18bf6b2 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -27,7 +27,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE -from tvm.target.x86 import target_has_sse41 +from tvm.target.x86 import target_has_features from . import _make, _requantize @@ -54,8 +54,9 @@ def _get_node_default_rounding(): @staticmethod def _get_node_default_compute_dtype(): target = Target.current(True) - if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu): - return "float32" + if target and str(target.kind) == "llvm": + if target_has_features("sse4.1", target): + return "float32" return "int64" diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index eacaf45fff44..5d43f4ae24ab 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -71,6 +71,38 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str: return _ffi_api.llvm_get_intrinsic_name(intrin_id) +def llvm_x86_get_archlist(only64bit=False): + """Get X86 CPU name list. + + Parameters + ---------- + only64bit : bool + Filter 64bit architectures. + + Returns + ------- + features : list[str] + String list of X86 architectures. + """ + return _ffi_api.llvm_x86_get_archlist(only64bit) + + +def llvm_x86_get_features(cpu_name): + """Get X86 CPU features. + + Parameters + ---------- + cpu_name : string + X86 CPU name (e.g. "skylake"). + + Returns + ------- + features : list[str] + String list of X86 CPU features. + """ + return _ffi_api.llvm_x86_get_features(cpu_name) + + def llvm_version_major(allow_none=False): """Get the major LLVM version. diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index b08f0be98c7f..a3dcb62e8aa7 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -16,127 +16,48 @@ # under the License. """Common x86 related utilities""" from .._ffi import register_func -from .target import Target - - -@register_func("tvm.target.x86.target_has_sse41") -def target_has_sse41(target): - return ( - target_has_sse42(target) - or target_has_avx(target) - or target_has_avx2(target) - or target_has_avx512(target) - or target_has_vnni(target) - or target - in { - "btver2", - "penryn", - } - ) - - -@register_func("tvm.target.x86.target_has_sse42") -def target_has_sse42(target): - return ( - target_has_avx(target) - or target_has_avx2(target) - or target_has_avx512(target) - or target_has_vnni(target) - or target - in { - "silvermont", - "slm", - "goldmont", - "goldmont-plus", - "tremont", - "nehalem", - "corei7", - "westmere", - "bdver1", - "bdver2", - "bdver3", - "x86-64-v2", - } - ) - - -@register_func("tvm.target.x86.target_has_avx") -def target_has_avx(target): - return ( - target_has_avx2(target) - or target_has_avx512(target) - or target_has_vnni(target) - or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"} - ) - - -@register_func("tvm.target.x86.target_has_avx2") -def target_has_avx2(target): - return ( - target_has_avx512(target) - or target_has_vnni(target) - or target - in { - "haswell", - "core-avx2", - "broadwell", - "skylake", - "bdver4", - "znver1", - "znver2", - "znver3", - "x86-64-v3", - } - ) - - -@register_func("tvm.target.x86.target_has_avx512") -def target_has_avx512(target): - return target in { - "skylake-avx512", - "skx", - "knl", - "knm", - "x86-64-v4", - "cannonlake", - # explicit enumeration of VNNI capable due to collision with alderlake - "cascadelake", - "icelake-client", - "icelake-server", - "rocketlake", - "tigerlake", - "cooperlake", - "sapphirerapids", - } - - -@register_func("tvm.target.x86.target_has_vnni") -def target_has_vnni(target): - return target in { - "cascadelake", - "icelake-client", - "icelake-server", - "rocketlake", - "tigerlake", - "cooperlake", - "sapphirerapids", - "alderlake", - } - - -@register_func("tvm.target.x86.target_has_amx") -def target_has_amx(target): - return target in { - "sapphirerapids", - } +from . import _ffi_api +from ..ir.container import Array + + +@register_func("tvm.target.x86.target_has_features") +def target_has_features(features, target=None): + """Check X86 CPU features. + Parameters + ---------- + features : str or Array + Feature(s) to check. + target : Target + Optional TVM target, default `None` use the global context target. + Returns + ------- + has_feats : bool + True if feature(s) are in the target arch. + """ + has_feats = True + assert isinstance(features, (Array, str)) + features = [features] if isinstance(features, str) else features + for feat in features: + has_feats &= _ffi_api.llvm_x86_has_feature(feat, target) + return has_feats @register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): - mcpu = Target.current().mcpu - fp32_vec_len = 4 - if target_has_avx512(mcpu): - fp32_vec_len = 16 - elif target_has_avx2(mcpu): - fp32_vec_len = 8 - return fp32_vec_len + """X86 SIMD optimal vector length lookup. + Parameters + ---------- + Returns + ------- + vec_len : int + The optimal vector length of CPU from the global context target. + """ + vec_len = 4 + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + if target_has_features(["avx512bw", "avx512f"]): + vec_len = 16 + elif target_has_features("avx2"): + vec_len = 8 + return vec_len diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 96b143326847..85accab87b2a 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -21,7 +21,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from tvm.target.x86 import target_has_amx, target_has_avx512 +from tvm.target.x86 import target_has_features from .. import generic, nn from ..transform import layout_transform @@ -38,8 +38,10 @@ def batch_matmul_int8_compute(cfg, x, y, *_): packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") - mcpu = tvm.target.Target.current().mcpu - if target_has_avx512(mcpu): + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + if target_has_features(["avx512bw", "avx512f"]): attrs_info = {"schedule_rule": "batch_matmul_int8"} else: attrs_info = None @@ -233,14 +235,16 @@ def _callback(op): def schedule_batch_matmul_int8(cfg, outs): """Schedule for batch_matmul_int8""" s = te.create_schedule([x.op for x in outs]) - mcpu = tvm.target.Target.current().mcpu def _callback(op): if "batch_matmul_int8" in op.tag: layout_trans = op.input_tensors[1] - if target_has_amx(mcpu): + if target_has_features("amx-int8"): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - elif target_has_avx512(mcpu): + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + elif target_has_features(["avx512bw", "avx512f"]): batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) traverse_inline(s, outs[0].op, _callback) diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index 9d325343529d..7c01967e87d3 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -20,7 +20,7 @@ import tvm from tvm import autotvm, te -from tvm.target.x86 import target_has_sse42 +from tvm.target.x86 import target_has_features from .. import nn, tag from ..generic import conv2d as conv2d_generic @@ -49,7 +49,10 @@ def _get_default_config_int8( """ if is_depthwise: # Fallback to FP32 default config until a VNNI schedule is defined. - wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_depthwise_conv2d_workload( + data, kernel, strides, padding, dilation, out_dtype, layout + ) + from .depthwise_conv2d import _fallback_schedule _fallback_schedule(cfg, wkl) @@ -81,8 +84,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype): is_llvm_support = llvm_version >= 8 # 3) Check target - mcpu = tvm.target.Target.current().mcpu - is_target_support = target_has_sse42(mcpu) + is_target_support = target_has_features("sse4.2") return is_dtype_support and is_llvm_support and is_target_support diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index a31127065eac..2437b1a69564 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -23,7 +23,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, dnnl, mkl -from tvm.target.x86 import get_simd_32bit_lanes, target_has_amx, target_has_avx512 +from tvm.target.x86 import get_simd_32bit_lanes, target_has_features from .. import generic, tag from ..utils import get_const_tuple, traverse_inline @@ -298,13 +298,15 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): def schedule_dense_int8(cfg, outs): """Create a schedule for dense__int8""" s = te.create_schedule([x.op for x in outs]) - mcpu = tvm.target.Target.current().mcpu def _callback(op): if "dense_int8" in op.tag: - if target_has_amx(mcpu): + if target_has_features("amx-int8"): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - elif target_has_avx512(mcpu): + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + elif target_has_features(["avx512bw", "avx512f"]): dense_int8_schedule(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) @@ -316,8 +318,10 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - mcpu = tvm.target.Target.current().mcpu - if target_has_avx512(mcpu): + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + if target_has_features(["avx512bw", "avx512f"]): target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: target_attr = None diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 973f94ecb9e5..ef6df7dd2c9b 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm, relay, te -from tvm.target.x86 import target_has_amx, target_has_avx512 +from tvm.target.x86 import target_has_features from .. import nn from ..nn import dense_alter_layout @@ -28,9 +28,12 @@ def check_int8_applicable(x, y, allow_padding=False): - mcpu = tvm.target.Target.current().mcpu - # TODO(vvchernov): may be also target_has_avx2 or lower? - simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu) + # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + # + llvm.x86.avx512.pmaddw.d.512" + simd_avai = target_has_features(["avx512bw", "avx512f"]) + simd_avai |= target_has_features("amx-int8") + # TODO(vvchernov): may be also target_has_features("avx2") or lower? return ( simd_avai and "int8" in x.dtype diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 8368f755e97e..f2e84a62ecbd 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -19,15 +19,15 @@ import tvm from tvm import te import tvm.target.codegen -from tvm.target.x86 import target_has_sse42, target_has_vnni, get_simd_32bit_lanes +from tvm.target.x86 import target_has_features, get_simd_32bit_lanes def dot_16x1x16_uint8_int8_int32(): """Dispatch the most optimized intrin depending on the target""" - mcpu = tvm.target.Target.current().mcpu - - assert target_has_sse42(mcpu), "An old Intel machine that does not have fast Int8 support." - if target_has_vnni(mcpu): + assert target_has_features( + "sse4.2" + ), "An old Intel machine that does not have fast Int8 support." + if target_has_features("avx512vnni") or target_has_features("avxvnni"): # VNNI capable platform return dot_16x1x16_uint8_int8_int32_cascadelake() # vpmaddubsw/vpmaddwd fallback diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index c96554e6a2d6..4657f962f32c 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,18 +24,21 @@ namespace meta_schedule { String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { - static const PackedFunc* f_check_vnni = - runtime::Registry::Get("tvm.target.x86.target_has_vnni"); - ICHECK(f_check_vnni != nullptr) << "The `target_has_vnni` func is not in tvm registry."; - if (target->GetAttr("mcpu") && - (*f_check_vnni)(target->GetAttr("mcpu").value())) { + static const PackedFunc* llvm_x86_has_feature_fn_ptr = + runtime::Registry::Get("target.llvm_x86_has_feature"); + ICHECK(llvm_x86_has_feature_fn_ptr != nullptr) + << "The `target.llvm_x86_has_feature` func is not in tvm registry."; + bool have_avx512vnni = (*llvm_x86_has_feature_fn_ptr)("avx512vnni", target); + bool have_avxvnni = (*llvm_x86_has_feature_fn_ptr)("avxvnni", target); + if (have_avx512vnni || have_avxvnni) { return "vnni"; } else { - static const PackedFunc* f_check_avx512 = - runtime::Registry::Get("tvm.target.x86.target_has_avx512"); - ICHECK(f_check_avx512 != nullptr) << "The `target_has_avx512` func is not in tvm registry."; - if (target->GetAttr("mcpu") && - (*f_check_avx512)(target->GetAttr("mcpu").value())) { + // avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) + // avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) + // + llvm.x86.avx512.pmaddw.d.512" + bool have_avx512f = (*llvm_x86_has_feature_fn_ptr)("avx512f", target); + bool have_avx512bw = (*llvm_x86_has_feature_fn_ptr)("avx512bw", target); + if (have_avx512bw && have_avx512f) { return "avx512"; } } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 1dd1eae9a89b..b57710b26686 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -121,12 +121,9 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, } bool has_current_target_sse41_support() { - auto target = Target::Current(true); - Optional mcpu = - target.defined() ? target->GetAttr("mcpu") : Optional(nullptr); - auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.target.x86.target_has_sse41"); - ICHECK(target_has_sse41_fn_ptr) << "Function tvm.target.x86.target_has_sse41 not found"; - return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value()); + auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); + ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; + return (*llvm_x86_has_feature_fn_ptr)("sse4.1", Target::Current(true)); } /* diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h index f83789592bec..956bc3533b81 100644 --- a/src/relay/qnn/op/requantize_config.h +++ b/src/relay/qnn/op/requantize_config.h @@ -61,14 +61,13 @@ class RequantizeConfigNode : public Object { // For the x86 architecture, the float32 computation is expected to give significant speedup, // with little loss in the accuracy of the requantize operation. auto target = Target::Current(true); - auto target_has_sse41 = tvm::runtime::Registry::Get("tvm.target.x86.target_has_sse41"); - ICHECK(target_has_sse41) << "Function tvm.target.x86.target_has_sse41 not found"; - if (target.defined() && target->kind->name == "llvm" && - (target->GetAttr("mcpu") && - (*target_has_sse41)(target->GetAttr("mcpu").value()))) { - return "float32"; + auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); + ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; + if (target.defined() && target->kind->name == "llvm") { + if ((*llvm_x86_has_feature_fn_ptr)("sse4.1", target)) { + return "float32"; + } } - return "int64"; } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 85750fbf146e..168163c416cf 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -45,6 +45,16 @@ #include #include #include +#if TVM_LLVM_VERSION < 110 +#include +#include +#else +#if TVM_LLVM_VERSION < 170 +#include +#else +#include +#endif +#endif #include #include #include @@ -77,6 +87,25 @@ #include "codegen_llvm.h" #include "llvm_instance.h" +#if TVM_LLVM_VERSION < 110 +namespace llvm { +// SubtargetSubTypeKV view +template MCSubtargetInfo::*Member> +struct ArchViewer { + friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; +ArrayRef& archViewer(MCSubtargetInfo); +// SubtargetFeatureKV view +template MCSubtargetInfo::*Member> +struct FeatViewer { + friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; +ArrayRef& featViewer(MCSubtargetInfo); +} // namespace llvm +#endif + namespace tvm { namespace codegen { @@ -485,6 +514,133 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); +#if TVM_LLVM_VERSION < 110 +static const llvm::MCSubtargetInfo* llvm_compat_get_subtargetinfo(const std::string triple, + const std::string cpu_name) { + std::string error; + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + // create a LLVM x86 instance + auto* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); + // create a target machine + llvm::TargetOptions target_options; + auto RM = llvm::Optional(); + auto* tm = llvm_instance->createTargetMachine(triple, cpu_name.c_str(), "", target_options, RM); + // create subtarget info module + const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); + + return MCInfo; +} + +static const Array llvm_compat_get_archlist(const std::string triple) { + // get the subtarget info module + const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, ""); + // get all X86 arches + llvm::ArrayRef x86_arches = + llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); + Array cpu_arches; + for (auto& arch : x86_arches) { + cpu_arches.push_back(arch.Key); + } + return cpu_arches; +} + +static const Array llvm_compat_get_features(const std::string triple, + const std::string cpu_name) { + // get the subtarget info module + const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, cpu_name.c_str()); + // get all features + llvm::ArrayRef x86_features = + llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); + // only targeted CPU features + Array cpu_features; + for (auto& feat : x86_features) { + if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { + cpu_features.push_back(feat.Key); + } + } + return cpu_features; +} +#endif + +TVM_REGISTER_GLOBAL("target.llvm_x86_get_archlist") + .set_body_typed([](bool only64bit) -> Array { + Array cpu_arches; +#if TVM_LLVM_VERSION < 110 + cpu_arches = llvm_compat_get_archlist("x86_64--"); +#else + llvm::SmallVector x86_arches; + llvm::X86::fillValidCPUArchList(x86_arches, only64bit); + for (auto& arch : x86_arches) { + cpu_arches.push_back(arch.str()); + } +#endif + return cpu_arches; + }); + +TVM_REGISTER_GLOBAL("target.llvm_x86_get_features") + .set_body_typed([](std::string cpu_name) -> Array { + Array cpu_features; +#if TVM_LLVM_VERSION < 110 + cpu_features = llvm_compat_get_features("x86_64--", cpu_name); +#else + llvm::SmallVector x86_features; + llvm::X86::getFeaturesForCPU(cpu_name, x86_features); + for (auto& feat : x86_features) { + cpu_features.push_back(feat.str()); + } +#endif + return cpu_features; + }); + +TVM_REGISTER_GLOBAL("target.llvm_x86_has_feature") + .set_body_typed([](String feature, const Target& target) -> bool { + // target argument is optional (nullptr or None) + // if not explicit then use the current context target + Optional mcpu = target.defined() ? target->GetAttr("mcpu") + : Target::Current(false)->GetAttr("mcpu"); + Optional> mattr = target.defined() + ? target->GetAttr>("mattr") + : Target::Current(false)->GetAttr>("mattr"); + String name = target.defined() ? target->kind->name : Target::Current(false)->kind->name; + // lookup only for `llvm` targets having -mcpu + if ((name != "llvm") || !mcpu) { + return false; + } + // lookup in -mattr flags + bool is_in_mattr = + !mattr ? false + : std::any_of(mattr.value().begin(), mattr.value().end(), + [&](const String& var) { return var == ("+" + feature); }); +#if TVM_LLVM_VERSION < 110 + auto x86_arches = llvm_compat_get_archlist("x86_64--"); + // decline on invalid arch (avoid llvm assertion) + if (!std::any_of(x86_arches.begin(), x86_arches.end(), + [&](const String& var) { return var == mcpu.value(); })) { + return false; + } + // lookup in -mcpu llvm architecture flags + auto cpu_features = llvm_compat_get_features("x86_64--", mcpu.value()); + bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(), + [&](const String& var) { return var == feature; }); +#else + llvm::SmallVector x86_arches; + llvm::X86::fillValidCPUArchList(x86_arches, false); + // decline on invalid arch (avoid llvm assertion) + if (!std::any_of(x86_arches.begin(), x86_arches.end(), + [&](const llvm::StringRef& var) { return var == mcpu.value().c_str(); })) { + return false; + } + // lookup in -mcpu llvm architecture flags + llvm::SmallVector x86_features; + llvm::X86::getFeaturesForCPU(mcpu.value().c_str(), x86_features); + bool has_feature = + std::any_of(x86_features.begin(), x86_features.end(), + [&](const llvm::StringRef& var) { return var == feature.c_str(); }); +#endif + return has_feature || is_in_mattr; + }); + TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { return TVM_LLVM_VERSION / 10; }); diff --git a/tests/python/target/test_x86_features.py b/tests/python/target/test_x86_features.py new file mode 100644 index 000000000000..31a823b504eb --- /dev/null +++ b/tests/python/target/test_x86_features.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.target import _ffi_api, codegen, Target +from tvm.target.x86 import target_has_features + +LLVM_VERSION = codegen.llvm_version_major() + +min_llvm_version, tvm_target, x86_feature, is_supported = tvm.testing.parameters( + # sse4.1 + (-1, "llvm -mcpu=btver2", "sse4a", True), + (-1, "llvm -mcpu=penryn", "sse4.1", True), + (-1, "llvm -mcpu=silvermont", "sse4.2", True), + (11, "llvm -mcpu=slm", "sse4.2", True), + (-1, "llvm -mcpu=goldmont", "sse4.2", True), + (-1, "llvm -mcpu=goldmont-plus", "sse4.2", True), + (-1, "llvm -mcpu=tremont", "sse4.2", True), + (-1, "llvm -mcpu=nehalem", "sse4.2", True), + (11, "llvm -mcpu=corei7", "sse4.2", True), + (-1, "llvm -mcpu=westmere", "sse4.2", True), + (-1, "llvm -mcpu=bdver1", "sse4.2", True), + (-1, "llvm -mcpu=bdver2", "sse4.2", True), + (-1, "llvm -mcpu=bdver3", "sse4.2", True), + (11, "llvm -mcpu=x86-64-v2", "sse4.2", True), + # avx + (-1, "llvm -mcpu=sandybridge", "avx", True), + (11, "llvm -mcpu=corei7-avx", "avx", True), + (-1, "llvm -mcpu=ivybridge", "avx", True), + (11, "llvm -mcpu=core-avx-i", "avx", True), + # avx2 + (-1, "llvm -mcpu=haswell", "avx2", True), + (11, "llvm -mcpu=core-avx2", "avx2", True), + (-1, "llvm -mcpu=broadwell", "avx2", True), + (-1, "llvm -mcpu=skylake", "avx2", True), + (-1, "llvm -mcpu=bdver4", "avx2", True), + (-1, "llvm -mcpu=znver1", "avx2", True), + (-1, "llvm -mcpu=znver2", "avx2", True), + (11, "llvm -mcpu=znver3", "avx2", True), + (11, "llvm -mcpu=x86-64-v3", "avx2", True), + # avx512bw + (-1, "llvm -mcpu=skylake-avx512", "avx512bw", True), + (11, "llvm -mcpu=skx", "avx512bw", True), + (11, "llvm -mcpu=knl", "avx512bw", False), + (-1, "llvm -mcpu=knl", "avx512f", True), + (11, "llvm -mcpu=knl", ["avx512bw", "avx512f"], False), + (11, "llvm -mcpu=knl", ("avx512bw", "avx512f"), False), + (-1, "llvm -mcpu=knl", "avx512cd", True), + (11, "llvm -mcpu=knl", ["avx512cd", "avx512f"], True), + (11, "llvm -mcpu=knl", ("avx512cd", "avx512f"), True), + (-1, "llvm -mcpu=knl", "avx512er", True), + (-1, "llvm -mcpu=knl", "avx512pf", True), + (11, "llvm -mcpu=knm", "avx512bw", False), + (-1, "llvm -mcpu=knm", "avx512f", True), + (-1, "llvm -mcpu=knm", "avx512cd", True), + (-1, "llvm -mcpu=knm", "avx512er", True), + (-1, "llvm -mcpu=knm", "avx512pf", True), + (11, "llvm -mcpu=x86-64-v4", "avx512bw", True), + (-1, "llvm -mcpu=cannonlake", "avx512bw", True), + # explicit enumeration of VNNI capable due to collision with alderlake + (11, "llvm -mcpu=alderlake", "avx512bw", False), + (-1, "llvm -mcpu=cascadelake", "avx512bw", True), + (-1, "llvm -mcpu=icelake-client", "avx512bw", True), + (-1, "llvm -mcpu=icelake-server", "avx512bw", True), + (11, "llvm -mcpu=rocketlake", "avx512bw", True), + (-1, "llvm -mcpu=tigerlake", "avx512bw", True), + (-1, "llvm -mcpu=cooperlake", "avx512bw", True), + (11, "llvm -mcpu=sapphirerapids", "avx512bw", True), + # avx512vnni + (11, "llvm -mcpu=alderlake", "avx512vnni", False), + (11, "llvm -mcpu=alderlake", "avxvnni", True), + (-1, "llvm -mcpu=cascadelake", "avx512vnni", True), + (-1, "llvm -mcpu=icelake-client", "avx512vnni", True), + (-1, "llvm -mcpu=icelake-server", "avx512vnni", True), + (11, "llvm -mcpu=rocketlake", "avx512vnni", True), + (-1, "llvm -mcpu=tigerlake", "avx512vnni", True), + (-1, "llvm -mcpu=cooperlake", "avx512vnni", True), + (11, "llvm -mcpu=sapphirerapids", "avx512vnni", True), + # amx-int8 + (11, "llvm -mcpu=sapphirerapids", "amx-int8", True), + # generic CPU (no features) but with extra -mattr + (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), + (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), + (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", False), +) + + +def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_supported): + """Test X86 features support for different targets. + + Parameters + ---------- + min_llvm_version : int + Minimal LLVM version. + tvm_target : str + TVM target. + x86_feature : str + X86 CPU feature. + is_supported : bool + Expected result. + """ + + ## + ## no context + ## + + # check for feature via the python api (no explicit target, no context target) + try: + assert target_has_features(x86_feature) == is_supported + assert False + except tvm.error.InternalError as e: + msg = str(e) + assert ( + msg.find( + "InternalError: Check failed: (allow_not_defined) is false: Target context required" + ) + != -1 + ) + + if isinstance(x86_feature, str): + # check for feature via the ffi llvm api (no explicit target, no context target) + try: + assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported + assert False + except tvm.error.InternalError as e: + msg = str(e) + assert ( + msg.find( + "InternalError: Check failed: (allow_not_defined) is false: Target context required" + ) + != -1 + ) + + # skip test on llvm_version + if LLVM_VERSION < min_llvm_version: + return + + # check for feature via the python api (with explicit target, no context target) + assert target_has_features(x86_feature, Target(tvm_target)) == is_supported + if isinstance(x86_feature, str): + # check for feature via the ffi llvm api (with explicit target, no context target) + assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported + + ## + ## with context + ## + + with Target(tvm_target): + mcpu = Target.current(False).mcpu + # check for feature via the python api (current context target) + assert target_has_features(x86_feature) == is_supported + # check for feature via the python api (with explicit target) + assert target_has_features(x86_feature, Target(tvm_target)) == is_supported + if isinstance(x86_feature, str): + # check for feature via the ffi llvm api (current context target) + assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported + # check for feature via the ffi llvm api (with explicit target) + assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported + # check for feature in target's llvm full x86 CPU feature list + if not Target(tvm_target).mattr: + assert (x86_feature in codegen.llvm_x86_get_features(mcpu)) == is_supported