From 4f4d0be11519e4717be9d5bdb0e9cf0ed4323938 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 13 Apr 2026 17:37:19 +0000 Subject: [PATCH 1/3] Add a unit_stride_last capability Add a capability to indicate that a tensor has unit stride in the last dimension. When true, we can elide loading and multiplying by the last stride. The last dimension being unit-stride is the nominal case since that is what is created via make_tensor() without user-provided strides. This capability approach applies to the matxOpT*Kernel dispatch. It will not apply to custom kernels in MatX. Signed-off-by: Thomas Benson --- include/matx/core/capabilities.h | 22 ++++- include/matx/core/nvrtc_helper.h | 2 + include/matx/core/tensor_impl.h | 139 +++++++++++++++++-------------- include/matx/executors/cuda.h | 27 ++++-- 4 files changed, 114 insertions(+), 76 deletions(-) diff --git a/include/matx/core/capabilities.h b/include/matx/core/capabilities.h index 31fce76c8..43968467e 100644 --- a/include/matx/core/capabilities.h +++ b/include/matx/core/capabilities.h @@ -72,6 +72,7 @@ namespace detail { ALIASED_MEMORY, // Whether the operator's input and output pointers alias GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level + UNIT_STRIDE_LAST, // Whether all leaf tensors have stride[RANK-1] == 1 // Add more capabilities as needed }; @@ -89,17 +90,18 @@ namespace detail { #if !defined(__CUDACC_RTC__) - template + template struct CapabilityParams { static constexpr ElementsPerThread ept = EPT; static constexpr bool jit = JIT; + static constexpr bool unit_stride_last = UNIT_STRIDE_LAST; static constexpr int osize = 0; static constexpr int block_size = 0; // For JIT there will be other capabilties patched in with a string - }; + }; - using DefaultCapabilities = CapabilityParams; + using DefaultCapabilities = CapabilityParams; // Concept to detect scoped enums template @@ -256,6 +258,18 @@ namespace detail { static constexpr bool default_value = false; // Default: operators do their own bounds checking static constexpr bool or_identity = false; static constexpr bool and_identity = true; + }; + + template <> + struct capability_attributes { + using type = bool; + using input_type = VoidCapabilityType; + // Non-tensor ops (scalars, generators) are trivially unit-stride, so default those to true + // since we will AND all unit stride capabilities in the expression tree. Tensor-like + // types should handle the unit stride query in their get_capability() method. + static constexpr bool default_value = true; + static constexpr bool or_identity = false; + static constexpr bool and_identity = true; }; @@ -324,6 +338,8 @@ namespace detail { return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it. case OperatorCapability::PASS_THROUGH_THREADS: return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator() + case OperatorCapability::UNIT_STRIDE_LAST: + return CapabilityQueryType::AND_QUERY; // All leaf tensors must have stride[RANK-1] == 1 default: // Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped. return CapabilityQueryType::OR_QUERY; diff --git a/include/matx/core/nvrtc_helper.h b/include/matx/core/nvrtc_helper.h index 80eed3d08..70cc6647b 100644 --- a/include/matx/core/nvrtc_helper.h +++ b/include/matx/core/nvrtc_helper.h @@ -245,6 +245,8 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele "struct CapabilityParams {\n" " static constexpr ElementsPerThread ept = EPT;\n" " static constexpr bool jit = JIT;\n" + // Note: no unit_stride_last here. JIT bakes strides as constexpr + // values, so the compiler already eliminates multiply-by-1. " static constexpr int osize = " + std::to_string(osize) + ";\n" " static constexpr int block_size = " + std::to_string(block_size) + ";\n" " static constexpr bool pass_through_threads = " + pass_through_str + ";\n" diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index 32507262a..69b21eb6f 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -154,63 +154,63 @@ class tensor_impl_t { " T *ldata_;\n" + " constexpr static cuda::std::array strides_ = { " + detail::array_to_string(desc_.Strides()) + " };\n" + " constexpr static cuda::std::array sizes_ = { " + detail::array_to_string(desc_.Shape()) + " };\n" + - " template \n" + + " template \n" + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple tup) {\n" + " if constexpr (I < sizeof...(Is)) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" + - " return GetVal(tup) + cuda::std::get(tup)*(strides_[I] * static_cast(EPT));\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" + + " return GetVal(tup) + cuda::std::get(tup)*(strides_[I] * static_cast(CapType::ept));\n" + " }\n" + " else {\n" + - " return GetVal(tup) + cuda::std::get(tup)*(strides_[I]);\n" + + " return GetVal(tup) + cuda::std::get(tup)*(strides_[I]);\n" + " }\n" + " }\n" + " else {\n" + " return 0;\n" + " }\n" + " }\n" + - " template \n" + + " template \n" + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple tup) const {\n" + " if constexpr (I < sizeof...(Is)) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" + - " return GetValC(tup) + cuda::std::get(tup)*(strides_[I] * static_cast(EPT));\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" + + " return GetValC(tup) + cuda::std::get(tup)*(strides_[I] * static_cast(CapType::ept));\n" + " }\n" + " else {\n" + - " return GetValC(tup) + cuda::std::get(tup)*(strides_[I]);\n" + + " return GetValC(tup) + cuda::std::get(tup)*(strides_[I]);\n" + " }\n" + " }\n" + " else {\n" + " return 0;\n" + " }\n" + - " }\n" + - " template \n" + + " }\n" + + " template \n" + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {\n" + " constexpr size_t rank = sizeof...(Is);\n" + - " constexpr int EPT_int = static_cast(EPT);\n" + + " constexpr int EPT_int = static_cast(CapType::ept);\n" + " const cuda::std::array idx{indices...};\n" + " \n" + " if constexpr (rank == 1) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" + " return idx[0] * (strides_[0] * EPT_int);\n" + " } else {\n" + " return idx[0] * strides_[0];\n" + " }\n" + " }\n" + " else if constexpr (rank == 2) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" + " return idx[0] * strides_[0] + idx[1] * (strides_[1] * EPT_int);\n" + " } else {\n" + " return idx[0] * strides_[0] + idx[1] * strides_[1];\n" + " }\n" + " }\n" + " else if constexpr (rank == 3) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" + " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * (strides_[2] * EPT_int);\n" + " } else {\n" + " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2];\n" + " }\n" + " }\n" + " else if constexpr (rank == 4) {\n" + - " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" + + " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" + " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * (strides_[3] * EPT_int);\n" + " } else {\n" + " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * strides_[3];\n" + @@ -218,7 +218,7 @@ class tensor_impl_t { " }\n" + " else {\n" + " // For rank > 4, fall back to the recursive implementation\n" + - " return GetValC(cuda::std::make_tuple(indices...));\n" + + " return GetValC(cuda::std::make_tuple(indices...));\n" + " }\n" + " }\n" + " template \n" + @@ -246,7 +246,7 @@ class tensor_impl_t { " return ReturnType{};\n" + " }\n" + " }\n" + - " const index_t offset = GetOffsetOptimized(indices...);\n" + + " const index_t offset = GetOffsetOptimized(indices...);\n" + " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" + " return ldata_[offset];\n" + " } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" + @@ -272,7 +272,7 @@ class tensor_impl_t { " return dummy_;\n" + " }\n" + " }\n" + - " const index_t offset = GetOffsetOptimized(indices...);\n" + + " const index_t offset = GetOffsetOptimized(indices...);\n" + " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" + " return ldata_[offset];\n" + " } else {\n" + @@ -296,7 +296,7 @@ class tensor_impl_t { " template \n" + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n" + " {\n" - " //const index_t offset = GetOffsetOptimized(indices...);\n" + + " //const index_t offset = GetOffsetOptimized(indices...);\n" + " //return ldata_ + offset;\n" + " return ldata_ + block_idx * ttl_threads * static_cast(CapType::ept);\n" + " }\n" + @@ -1107,7 +1107,7 @@ MATX_IGNORE_WARNING_POP_GCC template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* GetPointer(Is... indices) const noexcept { - return data_.ldata_ + GetOffsetOptimized(indices...); + return data_.ldata_ + GetOffsetOptimized(indices...); } // Locates position of an element at given indices, or returns -1 when not @@ -1204,16 +1204,11 @@ MATX_IGNORE_WARNING_POP_GCC return desc_.IsContiguous(); } - template + template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple tup) { if constexpr (I < sizeof...(Is)) { MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") - if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) { - return GetVal(tup) + cuda::std::get(tup)*(this->desc_.Stride(I) * static_cast(EPT)); - } - else { - return GetVal(tup) + cuda::std::get(tup)*(this->desc_.Stride(I)); - } + return GetVal(tup) + DimStride(sizeof...(Is)), CapType>(cuda::std::get(tup)); MATX_IGNORE_WARNING_POP_GCC } else { @@ -1221,59 +1216,68 @@ MATX_IGNORE_WARNING_POP_GCC } } - // Optimized offset calculation for ranks 1-4 with explicit stride multiplications - template + // Compute the stride contribution for a single dimension, eliding the + // load and multiply when the last dimension is known at dispatch time + // to have unit stride (via CapType::unit_stride_last). + template + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type + DimStride(index_t idx_val) const { + constexpr bool is_unit = CapType::unit_stride_last && (DIM == RANK_VAL - 1); + constexpr bool is_last = (DIM == RANK_VAL - 1); + constexpr bool has_ept = (CapType::ept != detail::ElementsPerThread::ONE) && is_last; + + if constexpr (is_unit && has_ept) { + return idx_val * static_cast(CapType::ept); + } else if constexpr (is_unit) { + return idx_val; + } else if constexpr (has_ept) { + return idx_val * (this->desc_.Stride(DIM) * static_cast(CapType::ept)); + } else { + return idx_val * this->desc_.Stride(DIM); + } + } + + // Optimized offset calculation for ranks 1-4 with explicit stride multiplications. + // When CapType::unit_stride_last is true, the stride load (ULDC) and + // multiply (IMAD) for the last dimension are elided. + template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const { -MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") +MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") constexpr size_t rank = sizeof...(Is); - constexpr int EPT_int = static_cast(EPT); const cuda::std::array idx{indices...}; - + + constexpr int R = static_cast(rank); + if constexpr (rank == 1) { - if constexpr (EPT != detail::ElementsPerThread::ONE) { - return idx[0] * (this->desc_.Stride(0) * EPT_int); - } else { - return idx[0] * this->desc_.Stride(0); - } + return DimStride<0, R, CapType>(idx[0]); } else if constexpr (rank == 2) { - if constexpr (EPT != detail::ElementsPerThread::ONE) { - return idx[0] * this->desc_.Stride(0) + idx[1] * (this->desc_.Stride(1) * EPT_int); - } else { - return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1); - } + return DimStride<0, R, CapType>(idx[0]) + + DimStride<1, R, CapType>(idx[1]); } else if constexpr (rank == 3) { - if constexpr (EPT != detail::ElementsPerThread::ONE) { - return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * (this->desc_.Stride(2) * EPT_int); - } else { - return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2); - } + return DimStride<0, R, CapType>(idx[0]) + + DimStride<1, R, CapType>(idx[1]) + + DimStride<2, R, CapType>(idx[2]); } else if constexpr (rank == 4) { - if constexpr (EPT != detail::ElementsPerThread::ONE) { - return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * (this->desc_.Stride(3) * EPT_int); - } else { - return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * this->desc_.Stride(3); - } + return DimStride<0, R, CapType>(idx[0]) + + DimStride<1, R, CapType>(idx[1]) + + DimStride<2, R, CapType>(idx[2]) + + DimStride<3, R, CapType>(idx[3]); } else { // For rank > 4, fall back to the recursive implementation - return GetValC(cuda::std::make_tuple(indices...)); + return GetValC(cuda::std::make_tuple(indices...)); } -MATX_IGNORE_WARNING_POP_GCC +MATX_IGNORE_WARNING_POP_GCC } - template + template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple tup) const { if constexpr (I < sizeof...(Is)) { MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") - if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) { - return GetValC(tup) + cuda::std::get(tup)*(this->desc_.Stride(I) * static_cast(EPT)); - } - else { - return GetValC(tup) + cuda::std::get(tup)*(this->desc_.Stride(I)); - } + return GetValC(tup) + DimStride(sizeof...(Is)), CapType>(cuda::std::get(tup)); MATX_IGNORE_WARNING_POP_GCC } else { @@ -1299,7 +1303,7 @@ MATX_IGNORE_WARNING_POP_GCC assert(data_.ldata_ != nullptr); #endif constexpr int EPT_int = static_cast(CapType::ept); - const index_t offset = GetOffsetOptimized(indices...); + const index_t offset = GetOffsetOptimized(indices...); if constexpr (CapType::ept == detail::ElementsPerThread::ONE) { return data_.ldata_[offset]; @@ -1329,7 +1333,7 @@ MATX_IGNORE_WARNING_POP_GCC { static_assert(sizeof...(Is) == M, "Number of indices of data_ptr must match rank of tensor"); if constexpr (!is_sparse_data_v) { - const index_t offset = GetOffsetOptimized(indices...); + const index_t offset = GetOffsetOptimized(indices...); return data_.ldata_ + offset; } else { @@ -1363,7 +1367,7 @@ MATX_IGNORE_WARNING_POP_GCC assert(data_.ldata_ != nullptr); #endif constexpr int EPT_int = static_cast(CapType::ept); - const index_t offset = GetOffsetOptimized(indices...); + const index_t offset = GetOffsetOptimized(indices...); if constexpr (CapType::ept == detail::ElementsPerThread::ONE) { return data_.ldata_[offset]; @@ -1563,6 +1567,13 @@ MATX_IGNORE_WARNING_POP_GCC return overlaps; } } + else if constexpr (Cap == OperatorCapability::UNIT_STRIDE_LAST) { + if constexpr (Rank() == 0) { + return true; + } else { + return (Stride(Rank() - 1) == 1); + } + } else { return detail::capability_attributes::default_value; } diff --git a/include/matx/executors/cuda.h b/include/matx/executors/cuda.h index f0776288e..4156a6ab8 100644 --- a/include/matx/executors/cuda.h +++ b/include/matx/executors/cuda.h @@ -106,14 +106,18 @@ namespace matx // Find the best launch parameters auto [best_ept, shm_size, block_size, groups_per_block] = detail::find_best_launch_params(op, kernel_provider, 256, false); + // Check if all leaf tensors have unit stride in the last dimension. + // This allows eliding one ULDC + IMAD per tensor access. + const bool unit_stride_last = detail::get_operator_capability(op); + // Helper lambda to handle kernel dispatch. This is templated on the EPT - // type since that's what the kernels are templated on. - auto dispatch_kernel = [&](auto&& kernel_handler) { + // and unit-stride-last flag since those are what the kernels are templated on. + auto dispatch_kernel = [&](auto&& kernel_handler) { int max_tpb = 256; bool stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(EPT), max_tpb); - using CapType = detail::CapabilityParams; + using CapType = detail::CapabilityParams; if constexpr (Op::Rank() == 0) { kernel_handler([&]() { @@ -160,14 +164,19 @@ namespace matx } }; - // Helper lambda to launch kernel + // Helper lambda to launch kernel, branching on unit_stride_last auto launch_kernel = [&]() { - dispatch_kernel.template operator()([&](auto launch_func) { - MATX_LOG_DEBUG("Launching CUDA kernel: rank={}, blocks=({},{},{}), threads=({},{},{}), EPT={}, stream={}", - Op::Rank(), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, - static_cast(EPT), reinterpret_cast(stream_)); + auto handler = [&](auto launch_func) { + MATX_LOG_DEBUG("Launching CUDA kernel: rank={}, blocks=({},{},{}), threads=({},{},{}), EPT={}, USL={}, stream={}", + Op::Rank(), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, + static_cast(EPT), unit_stride_last, reinterpret_cast(stream_)); launch_func(); - }); + }; + if (unit_stride_last) { + dispatch_kernel.template operator()(handler); + } else { + dispatch_kernel.template operator()(handler); + } }; // Launch the correct kernel based on the best EPT found From 3bd71647fc4fbff66495862e5270cc62888fc61b Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 13 Apr 2026 12:49:31 -0700 Subject: [PATCH 2/3] Add USL support for 5D+ tensors Signed-off-by: Thomas Benson --- include/matx/core/tensor_impl.h | 2 +- include/matx/executors/cuda.h | 19 +++++++++++++++---- include/matx/executors/kernel.h | 6 +++--- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index 69b21eb6f..6ecf9ac0e 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -1222,8 +1222,8 @@ MATX_IGNORE_WARNING_POP_GCC template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type DimStride(index_t idx_val) const { - constexpr bool is_unit = CapType::unit_stride_last && (DIM == RANK_VAL - 1); constexpr bool is_last = (DIM == RANK_VAL - 1); + constexpr bool is_unit = CapType::unit_stride_last && is_last; constexpr bool has_ept = (CapType::ept != detail::ElementsPerThread::ONE) && is_last; if constexpr (is_unit && has_ept) { diff --git a/include/matx/executors/cuda.h b/include/matx/executors/cuda.h index 4156a6ab8..d7623e642 100644 --- a/include/matx/executors/cuda.h +++ b/include/matx/executors/cuda.h @@ -117,7 +117,8 @@ namespace matx bool stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(EPT), max_tpb); - using CapType = detail::CapabilityParams; + constexpr bool JIT = false; + using CapType = detail::CapabilityParams; if constexpr (Op::Rank() == 0) { kernel_handler([&]() { @@ -205,10 +206,20 @@ namespace matx } else { auto ept_type = detail::EPTQueryInput{false}; - const auto ept_bounds = detail::get_operator_capability(op, ept_type); - bool stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(ept_bounds[1]), 1024); + const auto ept_bounds = detail::get_operator_capability(op, ept_type); + bool stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(ept_bounds[1]), 1024); index_t dims = cuda::std::accumulate(cuda::std::begin(sizes) + 1, cuda::std::end(sizes), 1, cuda::std::multiplies()); - detail::matxOpTDKernel<<>>(op, sizes, dims); + constexpr bool JIT = false; + const bool usl = detail::get_operator_capability(op); + if (usl) { + constexpr bool USL = true; + using CapType = detail::CapabilityParams; + detail::matxOpTDKernel<<>>(op, sizes, dims); + } else { + constexpr bool USL = false; + using CapType = detail::CapabilityParams; + detail::matxOpTDKernel<<>>(op, sizes, dims); + } } #else MATX_ASSERT_STR(false, matxInvalidParameter, "Cannot call device executor using host compiler"); diff --git a/include/matx/executors/kernel.h b/include/matx/executors/kernel.h index 4eee1edbe..280f5aa04 100644 --- a/include/matx/executors/kernel.h +++ b/include/matx/executors/kernel.h @@ -186,7 +186,7 @@ __global__ void matxOpT4StrideKernel(Op op, index_t size0, index_t size1, index_ * @param sizes sizes of each dimension * @param mult Product of sizes of all but first dimension */ -template +template __global__ void matxOpTDKernel(Op op, const cuda::std::array sizes, index_t mult) { cuda::std::array indices; @@ -211,12 +211,12 @@ __global__ void matxOpTDKernel(Op op, const cuda::std::array) { cuda::std::apply([&](auto... args){ - (*op)(args...); + (*op).template operator()(args...); }, indices); } else { cuda::std::apply([&](auto... args){ - op(args...); + op.template operator()(args...); }, indices); } } From 5dec0ce5ffeeb0f8291bfcfbb652b02bbd7f64e5 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 13 Apr 2026 19:14:44 -0700 Subject: [PATCH 3/3] Propagate/combine capabilities through Interp1D object Signed-off-by: Thomas Benson --- include/matx/operators/interp.h | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/include/matx/operators/interp.h b/include/matx/operators/interp.h index 0ca38f054..32e3c10b7 100644 --- a/include/matx/operators/interp.h +++ b/include/matx/operators/interp.h @@ -166,13 +166,19 @@ namespace matx { } template - __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType&) const { + __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType& in) const { if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; return my_cap; } else { auto self_has_cap = detail::capability_attributes::default_value; - return self_has_cap; + return combine_capabilities(self_has_cap, + detail::get_operator_capability(dl_, in), + detail::get_operator_capability(d_, in), + detail::get_operator_capability(du_, in), + detail::get_operator_capability(b_, in), + detail::get_operator_capability(x_, in), + detail::get_operator_capability(v_, in)); } } @@ -508,7 +514,10 @@ namespace matx { } else { auto self_has_cap = detail::capability_attributes::default_value; // Note: m_ is a temporary internal tensor, not an input operator passed to constructor - return self_has_cap; + return combine_capabilities(self_has_cap, + detail::get_operator_capability(x_, in), + detail::get_operator_capability(v_, in), + detail::get_operator_capability(xq_, in)); } }