Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions include/matx/core/capabilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace detail {
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level
UNIT_STRIDE_LAST, // Whether all leaf tensors have stride[RANK-1] == 1
// Add more capabilities as needed
};

Expand All @@ -89,17 +90,18 @@ namespace detail {


#if !defined(__CUDACC_RTC__)
template <ElementsPerThread EPT, bool JIT>
template <ElementsPerThread EPT, bool JIT, bool UNIT_STRIDE_LAST = false>
struct CapabilityParams {
static constexpr ElementsPerThread ept = EPT;
static constexpr bool jit = JIT;
static constexpr bool unit_stride_last = UNIT_STRIDE_LAST;
static constexpr int osize = 0;
static constexpr int block_size = 0;

// For JIT there will be other capabilties patched in with a string
};
};

using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false>;
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false, false>;

// Concept to detect scoped enums
template<typename T>
Expand Down Expand Up @@ -256,6 +258,18 @@ namespace detail {
static constexpr bool default_value = false; // Default: operators do their own bounds checking
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};

template <>
struct capability_attributes<OperatorCapability::UNIT_STRIDE_LAST> {
using type = bool;
using input_type = VoidCapabilityType;
// Non-tensor ops (scalars, generators) are trivially unit-stride, so default those to true
// since we will AND all unit stride capabilities in the expression tree. Tensor-like
// types should handle the unit stride query in their get_capability() method.
static constexpr bool default_value = true;
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};


Expand Down Expand Up @@ -324,6 +338,8 @@ namespace detail {
return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it.
case OperatorCapability::PASS_THROUGH_THREADS:
return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator()
case OperatorCapability::UNIT_STRIDE_LAST:
return CapabilityQueryType::AND_QUERY; // All leaf tensors must have stride[RANK-1] == 1
default:
// Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped.
return CapabilityQueryType::OR_QUERY;
Expand Down
2 changes: 2 additions & 0 deletions include/matx/core/nvrtc_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele
"struct CapabilityParams {\n"
" static constexpr ElementsPerThread ept = EPT;\n"
" static constexpr bool jit = JIT;\n"
// Note: no unit_stride_last here. JIT bakes strides as constexpr
// values, so the compiler already eliminates multiply-by-1.
" static constexpr int osize = " + std::to_string(osize) + ";\n"
" static constexpr int block_size = " + std::to_string(block_size) + ";\n"
" static constexpr bool pass_through_threads = " + pass_through_str + ";\n"
Expand Down
139 changes: 75 additions & 64 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,71 +154,71 @@ class tensor_impl_t {
" T *ldata_;\n" +
" constexpr static cuda::std::array<index_t, " + std::to_string(Rank()) + "> strides_ = { " + detail::array_to_string(desc_.Strides()) + " };\n" +
" constexpr static cuda::std::array<index_t, " + std::to_string(Rank()) + "> sizes_ = { " + detail::array_to_string(desc_.Shape()) + " };\n" +
" template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>\n" +
" template <typename CapType, int I = 0, typename ...Is>\n" +
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple<Is...> tup) {\n" +
" if constexpr (I < sizeof...(Is)) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
" return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT));\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
" return GetVal<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept));\n" +
" }\n" +
" else {\n" +
" return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
" return GetVal<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
" }\n" +
" }\n" +
" else {\n" +
" return 0;\n" +
" }\n" +
" }\n" +
" template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>\n" +
" template <typename CapType, int I = 0, typename ...Is>\n" +
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {\n" +
" if constexpr (I < sizeof...(Is)) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
" return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT));\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
" return GetValC<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept));\n" +
" }\n" +
" else {\n" +
" return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
" return GetValC<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
" }\n" +
" }\n" +
" else {\n" +
" return 0;\n" +
" }\n" +
" }\n" +
" template <detail::ElementsPerThread EPT, typename... Is>\n" +
" }\n" +
" template <typename CapType, typename... Is>\n" +
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {\n" +
" constexpr size_t rank = sizeof...(Is);\n" +
" constexpr int EPT_int = static_cast<int>(EPT);\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" const cuda::std::array<index_t, rank> idx{indices...};\n" +
" \n" +
" if constexpr (rank == 1) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
" return idx[0] * (strides_[0] * EPT_int);\n" +
" } else {\n" +
" return idx[0] * strides_[0];\n" +
" }\n" +
" }\n" +
" else if constexpr (rank == 2) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
" return idx[0] * strides_[0] + idx[1] * (strides_[1] * EPT_int);\n" +
" } else {\n" +
" return idx[0] * strides_[0] + idx[1] * strides_[1];\n" +
" }\n" +
" }\n" +
" else if constexpr (rank == 3) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * (strides_[2] * EPT_int);\n" +
" } else {\n" +
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2];\n" +
" }\n" +
" }\n" +
" else if constexpr (rank == 4) {\n" +
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * (strides_[3] * EPT_int);\n" +
" } else {\n" +
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * strides_[3];\n" +
" }\n" +
" }\n" +
" else {\n" +
" // For rank > 4, fall back to the recursive implementation\n" +
" return GetValC<EPT, 0, Is...>(cuda::std::make_tuple(indices...));\n" +
" return GetValC<CapType, 0, Is...>(cuda::std::make_tuple(indices...));\n" +
" }\n" +
" }\n" +
" template <typename CapType, int I = 0, typename... Is>\n" +
Expand Down Expand Up @@ -246,7 +246,7 @@ class tensor_impl_t {
" return ReturnType{};\n" +
" }\n" +
" }\n" +
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
" return ldata_[offset];\n" +
" } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" +
Expand All @@ -272,7 +272,7 @@ class tensor_impl_t {
" return dummy_;\n" +
" }\n" +
" }\n" +
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
" return ldata_[offset];\n" +
" } else {\n" +
Expand All @@ -296,7 +296,7 @@ class tensor_impl_t {
" template <typename CapType, int M = RANK, typename... Is>\n" +
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n" +
" {\n"
" //const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" //const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
" //return ldata_ + offset;\n" +
" return ldata_ + block_idx * ttl_threads * static_cast<index_t>(CapType::ept);\n" +
" }\n" +
Expand Down Expand Up @@ -1107,7 +1107,7 @@ MATX_IGNORE_WARNING_POP_GCC
template <typename... Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* GetPointer(Is... indices) const noexcept
{
return data_.ldata_ + GetOffsetOptimized<detail::ElementsPerThread::ONE>(indices...);
return data_.ldata_ + GetOffsetOptimized<detail::DefaultCapabilities>(indices...);
}

// Locates position of an element at given indices, or returns -1 when not
Expand Down Expand Up @@ -1204,76 +1204,80 @@ MATX_IGNORE_WARNING_POP_GCC
return desc_.IsContiguous();
}

template <typename detail::ElementsPerThread EPT, int I = 0, typename ...Is>
template <typename CapType, int I = 0, typename ...Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple<Is...> tup) {
if constexpr (I < sizeof...(Is)) {
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {
return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I) * static_cast<index_t>(EPT));
}
else {
return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I));
}
return GetVal<CapType, I+1, Is...>(tup) + DimStride<I, static_cast<int>(sizeof...(Is)), CapType>(cuda::std::get<I>(tup));
MATX_IGNORE_WARNING_POP_GCC
}
else {
return 0;
}
}

// Optimized offset calculation for ranks 1-4 with explicit stride multiplications
template <detail::ElementsPerThread EPT, typename... Is>
// Compute the stride contribution for a single dimension, eliding the
// load and multiply when the last dimension is known at dispatch time
// to have unit stride (via CapType::unit_stride_last).
template <int DIM, int RANK_VAL, typename CapType>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type
DimStride(index_t idx_val) const {
constexpr bool is_last = (DIM == RANK_VAL - 1);
constexpr bool is_unit = CapType::unit_stride_last && is_last;
constexpr bool has_ept = (CapType::ept != detail::ElementsPerThread::ONE) && is_last;

if constexpr (is_unit && has_ept) {
return idx_val * static_cast<index_t>(CapType::ept);
} else if constexpr (is_unit) {
return idx_val;
} else if constexpr (has_ept) {
return idx_val * (this->desc_.Stride(DIM) * static_cast<index_t>(CapType::ept));
} else {
return idx_val * this->desc_.Stride(DIM);
}
}
Comment thread
tbensonatl marked this conversation as resolved.

// Optimized offset calculation for ranks 1-4 with explicit stride multiplications.
// When CapType::unit_stride_last is true, the stride load (ULDC) and
// multiply (IMAD) for the last dimension are elided.
template <typename CapType, typename... Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
constexpr size_t rank = sizeof...(Is);
constexpr int EPT_int = static_cast<int>(EPT);
const cuda::std::array<index_t, rank> idx{indices...};


constexpr int R = static_cast<int>(rank);

if constexpr (rank == 1) {
if constexpr (EPT != detail::ElementsPerThread::ONE) {
return idx[0] * (this->desc_.Stride(0) * EPT_int);
} else {
return idx[0] * this->desc_.Stride(0);
}
return DimStride<0, R, CapType>(idx[0]);
}
else if constexpr (rank == 2) {
if constexpr (EPT != detail::ElementsPerThread::ONE) {
return idx[0] * this->desc_.Stride(0) + idx[1] * (this->desc_.Stride(1) * EPT_int);
} else {
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1);
}
return DimStride<0, R, CapType>(idx[0])
+ DimStride<1, R, CapType>(idx[1]);
}
else if constexpr (rank == 3) {
if constexpr (EPT != detail::ElementsPerThread::ONE) {
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * (this->desc_.Stride(2) * EPT_int);
} else {
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2);
}
return DimStride<0, R, CapType>(idx[0])
+ DimStride<1, R, CapType>(idx[1])
+ DimStride<2, R, CapType>(idx[2]);
}
else if constexpr (rank == 4) {
if constexpr (EPT != detail::ElementsPerThread::ONE) {
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * (this->desc_.Stride(3) * EPT_int);
} else {
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * this->desc_.Stride(3);
}
return DimStride<0, R, CapType>(idx[0])
+ DimStride<1, R, CapType>(idx[1])
+ DimStride<2, R, CapType>(idx[2])
+ DimStride<3, R, CapType>(idx[3]);
}
else {
// For rank > 4, fall back to the recursive implementation
return GetValC<EPT, 0, Is...>(cuda::std::make_tuple(indices...));
return GetValC<CapType, 0, Is...>(cuda::std::make_tuple(indices...));
}
MATX_IGNORE_WARNING_POP_GCC
MATX_IGNORE_WARNING_POP_GCC
}

template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>
template <typename CapType, int I = 0, typename ...Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {
if constexpr (I < sizeof...(Is)) {
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {
return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I) * static_cast<index_t>(EPT));
}
else {
return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I));
}
return GetValC<CapType, I+1, Is...>(tup) + DimStride<I, static_cast<int>(sizeof...(Is)), CapType>(cuda::std::get<I>(tup));
MATX_IGNORE_WARNING_POP_GCC
}
else {
Expand All @@ -1299,7 +1303,7 @@ MATX_IGNORE_WARNING_POP_GCC
assert(data_.ldata_ != nullptr);
#endif
constexpr int EPT_int = static_cast<int>(CapType::ept);
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
const index_t offset = GetOffsetOptimized<CapType>(indices...);

if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
return data_.ldata_[offset];
Expand Down Expand Up @@ -1329,7 +1333,7 @@ MATX_IGNORE_WARNING_POP_GCC
{
static_assert(sizeof...(Is) == M, "Number of indices of data_ptr must match rank of tensor");
if constexpr (!is_sparse_data_v<TensorData>) {
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
const index_t offset = GetOffsetOptimized<CapType>(indices...);
return data_.ldata_ + offset;
}
else {
Expand Down Expand Up @@ -1363,7 +1367,7 @@ MATX_IGNORE_WARNING_POP_GCC
assert(data_.ldata_ != nullptr);
#endif
constexpr int EPT_int = static_cast<int>(CapType::ept);
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
const index_t offset = GetOffsetOptimized<CapType>(indices...);

if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
return data_.ldata_[offset];
Expand Down Expand Up @@ -1563,6 +1567,13 @@ MATX_IGNORE_WARNING_POP_GCC
return overlaps;
}
}
else if constexpr (Cap == OperatorCapability::UNIT_STRIDE_LAST) {
if constexpr (Rank() == 0) {
return true;
} else {
return (Stride(Rank() - 1) == 1);
}
}
else {
return detail::capability_attributes<Cap>::default_value;
}
Comment thread
tbensonatl marked this conversation as resolved.
Expand Down
Loading