Skip to content

Commit

Permalink
[SVE] Implement scalable vectors in TVM
Browse files Browse the repository at this point in the history
This prototype is to accompany the open SVE RFC. It implements the
design outlined in the RFC.

The main changes to the stack include:

1. tir.split can accept an expression with vscale as a factor

2. LoopVectorizer can create Ramp and Broadcast nodes with scalable
lanes

3. BufferLoad and BufferStore nodes can accept an optional predicate
which is created in LoopVectorizer

4. LLVM codegen can lower the scalable predicated vectors into
llvm.masked.* intrinsics

The prototype is currently missing tir.tile and TVMScript parser support
for predicates.

Co-authored-by: Luke Hutton <luke.hutton@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>
  • Loading branch information
3 people committed Jan 4, 2024
1 parent 97f6e65 commit c5933da
Show file tree
Hide file tree
Showing 75 changed files with 1,871 additions and 316 deletions.
47 changes: 41 additions & 6 deletions include/tvm/runtime/data_type.h
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>

#include <cstring>
#include <string>
#include <type_traits>

Expand Down Expand Up @@ -71,11 +72,16 @@ class DataType {
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \param scalable Whether or not the data type is scalable.
*/
DataType(int code, int bits, int lanes) {
DataType(int code, int bits, int lanes, bool scalable = false) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (scalable) {
data_.lanes = static_cast<uint16_t>(-lanes);
} else {
data_.lanes = static_cast<uint16_t>(lanes);
}
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
Expand All @@ -90,7 +96,14 @@ class DataType {
/*! \return number of bytes to store each scalar. */
int bytes() const { return (bits() + 7) / 8; }
/*! \return number of lanes in the data. */
int lanes() const { return static_cast<int>(data_.lanes); }
int lanes() const {
int encoded_lanes = static_cast<int16_t>(data_.lanes);
if (is_scalable()) {
return -encoded_lanes;
} else {
return encoded_lanes;
}
}
/*! \return whether type is a scalar type. */
bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
Expand All @@ -114,17 +127,28 @@ class DataType {
/*! \return whether type is a handle type. */
bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
bool is_vector() const {
int encoded_lanes = static_cast<int16_t>(data_.lanes);
return encoded_lanes != 0 && encoded_lanes != 1;
}
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
/*! \return Whether the type is scalable. */
bool is_scalable() const { return static_cast<int16_t>(data_.lanes) < 0; }
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
/*!
* \brief Create a new scalable data type by changing the lanes to a specified value.
* \param lanes The target number of lanes.
* \return A copy of the old DataType with the number of scalable lanes.
*/
DataType with_scalable_lanes(int lanes) const { return DataType(data_.code, data_.bits, -lanes); }
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
Expand Down Expand Up @@ -247,6 +271,9 @@ class DataType {
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
if (dtype.is_scalable()) {
LOG(FATAL) << "Cannot get vector bytes of scalable vector";
}
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
Expand Down Expand Up @@ -357,8 +384,12 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);

int16_t lanes = static_cast<int16_t>(t.lanes);
if (lanes > 1) {
os << 'x' << lanes;
} else if (lanes < 0) {
os << 'x' << -lanes << "xvscale";
}
return os;
}
Expand Down Expand Up @@ -424,6 +455,10 @@ inline DLDataType String2DLDataType(std::string s) {
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
if (strncmp(endpt, "xvscale", 7) == 0) {
t.lanes = -t.lanes;
endpt = endpt + 7;
}
ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/builtin.h
Expand Up @@ -859,6 +859,15 @@ TVM_DLL const Op& start_profile_intrinsic();
*/
TVM_DLL const Op& end_profile_intrinsic();

TVM_DLL const Op& vscale();

/*!
* \brief Provide the predicate constructed of the currently active lanes
*
* Calculate the active lane masks given a bound and a current value
*/
TVM_DLL const Op& get_active_lane_mask();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
17 changes: 11 additions & 6 deletions include/tvm/tir/expr.h
Expand Up @@ -630,23 +630,27 @@ class BufferLoadNode : public PrimExprNode {
Buffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
/*! \brief The buffer predicate */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
equal(indices, other->indices);
equal(indices, other->indices) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferLoad";
Expand Down Expand Up @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode {
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
PrimExpr predicate = PrimExpr(nullptr), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
Expand Down Expand Up @@ -746,7 +751,7 @@ class RampNode : public PrimExprNode {
/*! \brief The stride of each step. */
PrimExpr stride;
/*! \brief Total number of lanes. */
int lanes;
PrimExpr lanes;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand Down Expand Up @@ -778,7 +783,7 @@ class RampNode : public PrimExprNode {
*/
class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};
Expand All @@ -789,7 +794,7 @@ class BroadcastNode : public PrimExprNode {
/*! \brief The base value. */
PrimExpr value;
/*! \brief The number of lanes. */
int lanes;
PrimExpr lanes;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand Down Expand Up @@ -818,7 +823,7 @@ class BroadcastNode : public PrimExprNode {
*/
class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Expand Up @@ -350,10 +350,15 @@ class ScheduleNode : public runtime::Object {
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \param disable_predication If enabled, don't create a predicate for guarding the
* loop. This can be useful when splitting with scalable factors that the schedule writer
* knows are divisible. Warning: enabling this feature may result in incorrect code generation
* if not used carefully.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters = true) = 0;
bool preserve_unit_iters = true,
bool disable_predication = false) = 0;
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
8 changes: 6 additions & 2 deletions include/tvm/tir/stmt.h
Expand Up @@ -231,23 +231,27 @@ class BufferStoreNode : public StmtNode {
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
/*! \brief The predicate for this store. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(value, other->value) &&
equal(indices, other->indices);
equal(indices, other->indices) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferStore";
Expand All @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());
PrimExpr predicate = PrimExpr(nullptr), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Expand Up @@ -135,7 +135,10 @@ def __init__(self, type_str):

arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
if len(arr) == 3 and arr[2] == "vscale":
self.lanes = ctypes.c_uint16(-int(arr[1]))
elif len(arr) > 1:
self.lanes = ctypes.c_uint16(int(arr[1]))
bits = 32

if head.startswith("int"):
Expand Down Expand Up @@ -188,8 +191,11 @@ def __repr__(self):

type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
if self.lanes != 1:
lanes_as_int = ctypes.c_int16(self.lanes).value
if lanes_as_int > 1:
x += "x%d" % self.lanes
elif lanes_as_int < 0:
x += "x%dxvscale" % -lanes_as_int
return x

def __eq__(self, other):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Expand Up @@ -1289,7 +1289,7 @@ def buffer_store(
if lanes == 1:
expr_indices.append(index.start)
else:
expr_indices.append(ramp(index.start, step, int(lanes)))
expr_indices.append(ramp(index.start, step, lanes))
else:
expr_indices.append(index)
if isinstance(value, bool) and buffer.dtype == "bool":
Expand Down Expand Up @@ -1853,6 +1853,7 @@ def wrapped(*args, **kwargs):
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
vscale = _op_wrapper(_tir_op.vscale)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
Expand Down Expand Up @@ -1886,7 +1887,6 @@ def wrapped(*args, **kwargs):
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)


broadcast = Broadcast
ramp = Ramp
fabs = abs
Expand Down Expand Up @@ -2189,4 +2189,5 @@ def wrapped(*args, **kwargs):
"IterVar",
"CommReducer",
"Range",
"vscale",
]
17 changes: 12 additions & 5 deletions python/tvm/testing/utils.py
Expand Up @@ -1003,7 +1003,7 @@ def _corstone300_compile_time_check():


# check cpu features
def _has_cpu_feat(features):
def has_cpu_feat(features):
cpu = codegen.llvm_get_system_cpu()
triple = codegen.llvm_get_system_triple()
target = "llvm -mtriple=%s -mcpu=%s" % (triple, cpu)
Expand All @@ -1015,21 +1015,28 @@ def _has_cpu_feat(features):
requires_arm_dot = Feature(
"arm_dot",
"ARM dot product",
run_time_check=lambda: _has_cpu_feat("dotprod"),
run_time_check=lambda: has_cpu_feat("dotprod"),
)


requires_aarch64_sve = Feature(
"arm_sve",
"AArch64 SVE",
run_time_check=lambda: has_cpu_feat("sve"),
)


requires_x86_vnni = Feature(
"x86_vnni",
"x86 VNNI Extensions",
run_time_check=lambda: (_has_cpu_feat("avx512vnni") or _has_cpu_feat("avxvnni")),
run_time_check=lambda: (has_cpu_feat("avx512vnni") or has_cpu_feat("avxvnni")),
)


requires_x86_avx512 = Feature(
"x86_avx512",
"x86 AVX512 Extensions",
run_time_check=lambda: _has_cpu_feat(
run_time_check=lambda: has_cpu_feat(
["avx512bw", "avx512cd", "avx512dq", "avx512vl", "avx512f"]
),
)
Expand All @@ -1038,7 +1045,7 @@ def _has_cpu_feat(features):
requires_x86_amx = Feature(
"x86_amx",
"x86 AMX Extensions",
run_time_check=lambda: _has_cpu_feat("amx-int8"),
run_time_check=lambda: has_cpu_feat("amx-int8"),
)


Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Expand Up @@ -88,6 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/tir/expr.py
Expand Up @@ -1095,6 +1095,9 @@ class BufferLoad(PrimExprWithOp):
indices : List[PrimExpr]
The buffer indices.
predicate : Optional[PrimExpr]
The buffer predicate
span : Optional[Span]
The location of this expression in the source code.
"""
Expand All @@ -1103,10 +1106,14 @@ class BufferLoad(PrimExprWithOp):
indices: List[PrimExpr]

def __init__(
self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None
self,
buffer: Buffer,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices, span # type: ignore
_ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore
)


Expand Down

0 comments on commit c5933da

Please sign in to comment.