From 7105d13dd59e4db89a8058f2f36215bd3a2b5fa3 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 1 Apr 2025 12:23:27 +0800 Subject: [PATCH 01/14] issue/127: refactor elementwise infra to support Opaque input when calculate --- .../elementwise/cpu/elementwise_cpu.h | 109 ++++++++++++++ .../elementwise/cuda/elementwise_cuda.cuh | 81 ++++++++++ src/infiniop/elementwise/elementwise.h | 141 ++++++++++++++++++ 3 files changed, 331 insertions(+) create mode 100644 src/infiniop/elementwise/cpu/elementwise_cpu.h create mode 100644 src/infiniop/elementwise/cuda/elementwise_cuda.cuh create mode 100644 src/infiniop/elementwise/elementwise.h diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h new file mode 100644 index 000000000..518ee14e6 --- /dev/null +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -0,0 +1,109 @@ +#ifndef __INFINIOP_ELEMENTWISE_CPU_H__ +#define __INFINIOP_ELEMENTWISE_CPU_H__ + +#include "../../devices/cpu/common_cpu.h" +#include "../elementwise.h" +#include + +/** + * @brief Define the process for initializing a Descriptor of an elementwise operation + * for its CPU implementation + */ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \ + \ + op::elementwise::ElementwiseInfo elementwise_info; \ + CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + std::move(elementwise_info), \ + nullptr, \ + handle->device, \ + handle->device_id); + +DEVICE_IMPL(cpu) + +namespace op::elementwise::cpu { + +struct DeviceImpl::Opaque {}; + +template +infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) { + *device_info = new DeviceImpl(nullptr); + return INFINI_STATUS_SUCCESS; +} + +// Perform elementwise operation for different input types +template = 0> +void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, std::index_sequence, Args &&...args) { + Tout *out = reinterpret_cast(output); + std::tuple input_ptrs = {reinterpret_cast(inputs[Is])...}; + ptrdiff_t output_size = info.output_size; + +#pragma omp parallel for + for (ptrdiff_t i = 0; i < output_size; ++i) { + size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); + + auto get_input_idx = [&](size_t input_id) { + return info.input_contiguous[input_id] ? i + : (info.input_broadcasted[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) + : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); + }; + + out[out_idx] = utils::cast(Op{}(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); + } +} + +// Invoke elementwise operation for different input types +template = 0> +void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + Args &&...args) { + + static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); + calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); +} + +// Perform elementwise operation when all inputs have the same type +template +void calculate_impl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + Args &&...args) { + + Tdata *out = reinterpret_cast(output); + std::array ins = {reinterpret_cast(inputs[Is])...}; + const ptrdiff_t output_size = info.output_size; + +#pragma omp parallel for + for (ptrdiff_t i = 0; i < output_size; ++i) { + size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); + + auto get_input_idx = [&](size_t input_id) { + return info.input_contiguous[input_id] ? i + : (info.input_broadcasted[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) + : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); + }; + + if constexpr (std::is_same_v) { + out[out_idx] = utils::cast(Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., std::forward(args)...)); + } else { + out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward(args)...); + } + } +} + +// Invoke elementwise operation when all inputs have the same type +template +void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, Args &&...args) { + constexpr size_t N = Op::num_inputs; + calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); +} + +} // namespace op::elementwise::cpu + +#endif // __INFINIOP_ELEMENTWISE_CPU_H__ diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh new file mode 100644 index 000000000..cc7933df1 --- /dev/null +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -0,0 +1,81 @@ +// #ifndef __INFINIOP_ELEMENTWISE_CUDA_H__ +// #define __INFINIOP_ELEMENTWISE_CUDA_H__ + +// #include "../../devices/cuda/cuda_common.cuh" +// #include "../elementwise.h" + +// #define ELEMENTWISE_CUDA_OPAQUE(OP) \ +// \ +// namespace op::OP::cuda { \ +// struct Descriptor::Opaque { \ +// std::shared_ptr internal; \ +// }; \ +// \ +// Descriptor::~Descriptor() { \ +// delete _opaque; \ +// } \ +// } // namespace op::elementwise::cuda + +// namespace op::common_cuda::elementwise_op { + +// // Perform elementwise operation when all inputs have the same type +// template +// void _calculate_impl(const op::elementwise::ElementwiseInfo &info, +// void *output, +// const std::vector &inputs, +// std::index_sequence, +// Args &&...args) { + +// Tdata *out = reinterpret_cast(output); +// std::array ins = {reinterpret_cast(inputs[Is])...}; +// const ptrdiff_t output_size = info.output_size; + +// #pragma omp parallel for +// for (ptrdiff_t i = 0; i < output_size; ++i) { +// size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); + +// auto get_input_idx = [&](size_t input_id) { +// return info.input_contiguous[input_id] ? i +// : (info.input_broadcasted[input_id] +// ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) +// : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); +// }; + +// if constexpr (std::is_same_v) { +// out[out_idx] = utils::cast(Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., std::forward(args)...)); +// } else { +// out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward(args)...); +// } +// } +// } + +// template +// void calculate_impl(const op::elementwise::ElementwiseInfo &info, +// void *output, +// const std::vector &inputs, +// std::index_sequence, +// Args &&...args) { + +// if (info.output_size == 0) { +// return; +// } +// Tdata *out = reinterpret_cast(output); +// std::array inputs_vec = {reinterpret_cast(inputs[Is])...}; + +// dim3 blockDims = dim3(std::min(static_cast(BLOCK_SIZE), info.output_size)); +// dim3 gridDims = dim3(std::min(ROUND_UP_DIV(info.output_size, blockDims.x), desc->max_grid_size)); +// uint64_t step = gridDims.x * blockDims.x; + +// _calculate_impl(info, out, inputs_vec, Is, std::forward(args)...); +// } + +// // Invoke elementwise operation when all inputs have the same type +// template +// void calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, Args &&...args) { +// constexpr size_t N = Op::num_inputs; +// calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); +// } + +// } // namespace op::common_cuda::elementwise_op + +// #endif // __INFINIOP_ELEMENTWISE_CUDA_H__ \ No newline at end of file diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h new file mode 100644 index 000000000..ae3e258c5 --- /dev/null +++ b/src/infiniop/elementwise/elementwise.h @@ -0,0 +1,141 @@ +#ifndef __INFINIOP_ELEMENTWISE_H__ +#define __INFINIOP_ELEMENTWISE_H__ + +#include "../operator.h" +#include "../tensor.h" +#include +#include +#include +#include + +#define DEVICE_IMPL(NAMESPACE) \ + \ + namespace op::elementwise::NAMESPACE { \ + class DeviceImpl final { \ + struct Opaque; \ + std::unique_ptr _opaque; \ + \ + DeviceImpl(Opaque *opaque) : _opaque(opaque) {} \ + \ + public: \ + ~DeviceImpl() = default; \ + \ + template \ + static infiniStatus_t create( \ + DeviceImpl **device_info, \ + Args &&...args); \ + \ + /* Invoke elementwise operation when all inputs have the same type */ \ + template \ + void calculate( \ + const op::elementwise::ElementwiseInfo &info, \ + void *output, \ + const std::vector &inputs, \ + Args &&...args); \ + \ + /* Invoke elementwise operation for different input types */ \ + template = 0> \ + void calculate( \ + const op::elementwise::ElementwiseInfo &info, \ + void *output, \ + const std::vector &inputs, \ + Args &&...args); \ + }; \ + } + +#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ + \ + namespace op::OP::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + infiniDtype_t _dtype; \ + op::elementwise::ElementwiseInfo _info; \ + std::unique_ptr _device_info; \ + \ + Descriptor( \ + infiniDtype_t dtype, \ + op::elementwise::ElementwiseInfo info, \ + op::elementwise::NAMESPACE::DeviceImpl *device_info, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _dtype(dtype), \ + _info(info), \ + _device_info(device_info) {} \ + \ + public: \ + ~Descriptor(); \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + std::vector input_descs); \ + \ + infiniStatus_t calculate( \ + void *output, \ + std::vector inputs, \ + void *stream) const; \ + }; \ + } + +namespace op::elementwise { + +// struct that stores data needed for elementwise operation +struct ElementwiseInfo { + size_t output_size; + size_t ndim; + bool output_contiguous; + std::vector input_contiguous; + std::vector input_broadcasted; + std::vector output_shape; + std::vector> input_shapes; + std::vector output_strides; + std::vector> input_strides; +}; + +inline infiniStatus_t createElementwiseInfo( + ElementwiseInfo &info, + infiniopTensorDescriptor_t output_desc, + std::vector input_descs) { + + if (!output_desc || input_descs.empty()) { + return INFINI_STATUS_BAD_PARAM; + } + + // Destination cannot have broadcast setup + if (output_desc->hasBroadcastDim()) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + const size_t input_size = input_descs.size(); + const size_t out_ndim = output_desc->ndim(); + + // Intializing the ElementwiseInfo struct + info.output_size = output_desc->numel(); + info.ndim = out_ndim; + info.output_contiguous = output_desc->isContiguous(); + + for (const auto &desc : input_descs) { + info.input_contiguous.emplace_back(desc->isContiguous()); + } + + for (size_t i = 0; i < input_size; ++i) { + const auto &desc = input_descs[i]; + info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim())); + } + + info.output_shape = std::move(output_desc->shape()); + info.output_strides = std::move(output_desc->strides()); + for (const auto &desc : input_descs) { + info.input_shapes.emplace_back(desc->shape()); + info.input_strides.emplace_back(desc->strides()); + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::elementwise + +#endif // __INFINIOP_ELEMENTWISE_H__ From 6292da009b0808a3fae04f004ca0b86e9f3ff179 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 7 Apr 2025 20:46:02 +0800 Subject: [PATCH 02/14] issue/127: refactor elementwise framework, complete CUDA implementation, refactor swiglu using the generic elementwise framework --- .../elementwise/cpu/elementwise_cpu.h | 52 +- .../elementwise/cuda/elementwise_cuda.cuh | 612 +++++++++++++++--- .../elementwise/cuda/elementwise_cuda_api.cuh | 66 ++ src/infiniop/elementwise/elementwise.h | 172 +++-- src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc | 29 +- src/infiniop/ops/swiglu/cpu/swiglu_cpu.h | 11 +- src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu | 55 ++ src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh | 8 + .../ops/swiglu/cuda/swiglu_cuda_internal.cuh | 59 ++ src/infiniop/ops/swiglu/operator.cc | 26 +- src/utils.h | 2 + xmake/cuda.lua | 1 + 12 files changed, 904 insertions(+), 189 deletions(-) create mode 100644 src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh create mode 100644 src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu create mode 100644 src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh create mode 100644 src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 518ee14e6..23e5873c6 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -21,10 +21,43 @@ handle->device, \ handle->device_id); -DEVICE_IMPL(cpu) - namespace op::elementwise::cpu { +class DeviceImpl final { + struct Opaque; + std::shared_ptr _opaque; + + DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} + +public: + ~DeviceImpl() = default; + + template + static infiniStatus_t create( + DeviceImpl **device_info, + Args &&...args); + + /* Invoke elementwise operation when all inputs have the same type */ + template + void calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); + + /* Invoke elementwise operation for different input types */ + template = 0> + void calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); +}; + struct DeviceImpl::Opaque {}; template @@ -42,13 +75,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, #pragma omp parallel for for (ptrdiff_t i = 0; i < output_size; ++i) { - size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); + size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); auto get_input_idx = [&](size_t input_id) { return info.input_contiguous[input_id] ? i : (info.input_broadcasted[input_id] - ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) - : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); + ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) + : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); }; out[out_idx] = utils::cast(Op{}(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); @@ -60,6 +93,7 @@ template &inputs, + void *stream, Args &&...args) { static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); @@ -80,13 +114,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, #pragma omp parallel for for (ptrdiff_t i = 0; i < output_size; ++i) { - size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); + size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); auto get_input_idx = [&](size_t input_id) { return info.input_contiguous[input_id] ? i : (info.input_broadcasted[input_id] - ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) - : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); + ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) + : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); }; if constexpr (std::is_same_v) { @@ -99,7 +133,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, // Invoke elementwise operation when all inputs have the same type template -void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, Args &&...args) { +void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, void *stream, Args &&...args) { constexpr size_t N = Op::num_inputs; calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); } diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index cc7933df1..cfd0461ca 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -1,81 +1,531 @@ -// #ifndef __INFINIOP_ELEMENTWISE_CUDA_H__ -// #define __INFINIOP_ELEMENTWISE_CUDA_H__ - -// #include "../../devices/cuda/cuda_common.cuh" -// #include "../elementwise.h" - -// #define ELEMENTWISE_CUDA_OPAQUE(OP) \ -// \ -// namespace op::OP::cuda { \ -// struct Descriptor::Opaque { \ -// std::shared_ptr internal; \ -// }; \ -// \ -// Descriptor::~Descriptor() { \ -// delete _opaque; \ -// } \ -// } // namespace op::elementwise::cuda - -// namespace op::common_cuda::elementwise_op { - -// // Perform elementwise operation when all inputs have the same type -// template -// void _calculate_impl(const op::elementwise::ElementwiseInfo &info, -// void *output, -// const std::vector &inputs, -// std::index_sequence, -// Args &&...args) { - -// Tdata *out = reinterpret_cast(output); -// std::array ins = {reinterpret_cast(inputs[Is])...}; -// const ptrdiff_t output_size = info.output_size; - -// #pragma omp parallel for -// for (ptrdiff_t i = 0; i < output_size; ++i) { -// size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); - -// auto get_input_idx = [&](size_t input_id) { -// return info.input_contiguous[input_id] ? i -// : (info.input_broadcasted[input_id] -// ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) -// : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); -// }; - -// if constexpr (std::is_same_v) { -// out[out_idx] = utils::cast(Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., std::forward(args)...)); -// } else { -// out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward(args)...); -// } -// } -// } - -// template -// void calculate_impl(const op::elementwise::ElementwiseInfo &info, -// void *output, -// const std::vector &inputs, -// std::index_sequence, -// Args &&...args) { - -// if (info.output_size == 0) { -// return; -// } -// Tdata *out = reinterpret_cast(output); -// std::array inputs_vec = {reinterpret_cast(inputs[Is])...}; - -// dim3 blockDims = dim3(std::min(static_cast(BLOCK_SIZE), info.output_size)); -// dim3 gridDims = dim3(std::min(ROUND_UP_DIV(info.output_size, blockDims.x), desc->max_grid_size)); -// uint64_t step = gridDims.x * blockDims.x; - -// _calculate_impl(info, out, inputs_vec, Is, std::forward(args)...); -// } - -// // Invoke elementwise operation when all inputs have the same type -// template -// void calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, Args &&...args) { -// constexpr size_t N = Op::num_inputs; -// calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); -// } - -// } // namespace op::common_cuda::elementwise_op - -// #endif // __INFINIOP_ELEMENTWISE_CUDA_H__ \ No newline at end of file +#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__ +#define __INFINIOP_ELEMENTWISE_CUDA_H__ + +#include "../../../utils.h" +#include "../../devices/cuda/cuda_common.cuh" +#include "elementwise_cuda_api.cuh" +namespace op::elementwise::cuda { + +/** + * @brief Helper device function to expand a compile-time index sequence into individual constants + * and pass them to a lambda. + * + * @tparam Lambda Type of the lambda function to invoke. + * @tparam Is Index sequence values (automatically deduced). + * @param lambda Lambda to be called with std::integral_constant... as arguments. + */ +template +__device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence) { + lambda(std::integral_constant{}...); +} + +/** + * @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type. + * + * @tparam Op Operator type implementing operator()(Tdata...). + * @tparam Tdata Common data type for inputs and output. + * @tparam N Number of input tensors. + * @tparam Args Additional arguments to pass to the operator. + * + * @param output_size Total number of output elements. + * @param ndim Number of dimensions in tensors. + * @param output_contiguous Whether the output tensor is contiguous in memory. + * @param input_contiguous Array indicating if each input tensor is contiguous. + * @param input_broadcasted Array indicating if each input tensor is broadcasted. + * @param output_shape Shape of the output tensor. + * @param input_shapes Shapes of the input tensors. + * @param output_strides Strides for the output tensor. + * @param input_strides Strides for each input tensor. + * @param input_size Total number of input elements (optional, may be unused). + * @param output Output buffer. + * @param inputs Array of input pointers, all of type Tdata. + * @param offset Linear offset to support partitioned execution. + * @param args Additional arguments passed to the operator. + */ +template +INFINIOP_CUDA_KERNEL elementwise_kernel( + size_t output_size, + size_t ndim, + bool output_contiguous, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ output_shape, + const size_t *__restrict__ *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ *__restrict__ input_strides, + size_t input_size, + Tdata *output, + const Tdata *const *inputs, + size_t offset, + Args... args) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + if (idx < output_size) { + size_t out_idx = output_contiguous ? idx + : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); + + auto get_input_idx = [&] __device__(size_t input_id) { + return input_contiguous[input_id] ? idx + : (input_broadcasted[input_id] + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) + : device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); + }; + + // Use a helper to expand the index sequence into individual compile-time constants + auto expand_inputs = [&] __device__(auto... idxs) { + if constexpr (std::is_same_v) { + output[out_idx] = utils::cast( + Op{}(utils::cast(inputs[idxs.value][get_input_idx(idxs.value)])..., + std::forward(args)...)); + } else { + output[out_idx] = Op{}( + inputs[idxs.value][get_input_idx(idxs.value)]..., + std::forward(args)...); + } + }; + + call_expand(expand_inputs, std::make_index_sequence{}); + } +} + +/** + * @brief Casts an untyped device pointer to a typed pointer of type T. + * + * @tparam T Desired pointer type. + * @param ptr Untyped pointer. + * @return Pointer of type const T*. + */ +template +__device__ inline const T *typed_input_ptr(const void *ptr) { + return reinterpret_cast(ptr); +} + +/** + * @brief Launches a type-safe elementwise operation on a single output element. + * + * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). + * @tparam Tout Output data type. + * @tparam Tin Variadic input data types. + * @tparam Is Index sequence corresponding to each input. + * + * @param idx Linear index in the flattened output space. + * @param out_idx Actual output index (may be non-contiguous). + * @param ndim Number of dimensions in the tensors. + * @param input_contiguous Array indicating whether each input is contiguous. + * @param input_broadcasted Array indicating whether each input is broadcasted. + * @param input_shapes Shapes of the input tensors. + * @param input_strides Strides of the input tensors. + * @param inputs Raw pointers to input data. + * @param output Pointer to output data. + * @param ... Index sequence used for unpacking variadic inputs. + */ +template +__device__ void launch_op( + size_t idx, + size_t out_idx, + size_t ndim, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ const *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ const *__restrict__ input_strides, + const void *const *__restrict__ inputs, + Tout *output, + std::index_sequence) { + + auto get_input_idx = [&] __device__(size_t input_id) { + return input_contiguous[input_id] + ? idx + : (input_broadcasted[input_id] + ? device::cuda::indexToReducedOffset(idx, ndim, input_strides[0], input_strides[input_id]) + : device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); + }; + + output[out_idx] = Op{}.template operator()( + (typed_input_ptr(inputs[Is])[get_input_idx(Is)])...); +} + +/** + * @brief CUDA kernel for performing an elementwise operation on tensors with support + * for broadcasting and mixed data types. + * + * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). + * @tparam Tout Output data type. + * @tparam Tin Variadic input data types. + * + * @param output_size Total number of output elements. + * @param ndim Number of dimensions in the tensors. + * @param output_contiguous Whether the output tensor is contiguous. + * @param input_contiguous Array indicating whether each input is contiguous. + * @param input_broadcasted Array indicating whether each input is broadcasted. + * @param output_shape Shape of the output tensor. + * @param input_shapes Shapes of the input tensors. + * @param output_strides Strides of the output tensor. + * @param input_strides Strides of the input tensors. + * @param input_size Total number of input elements (unused here, but may be used for validation). + * @param output Pointer to the output buffer. + * @param inputs Array of untyped input pointers. + * @param offset Linear offset into the output for partitioned execution. + */ +template +INFINIOP_CUDA_KERNEL elementwise_kernel( + size_t output_size, + size_t ndim, + bool output_contiguous, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ output_shape, + const size_t *__restrict__ const *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ const *__restrict__ input_strides, + size_t input_size, + Tout *output, + const void *const *__restrict__ inputs, + size_t offset) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + if (idx >= output_size) { + return; + } + + size_t out_idx = output_contiguous + ? idx + : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); + + launch_op( + idx, + out_idx, + ndim, + input_contiguous, + input_broadcasted, + input_shapes, + input_strides, + inputs, + output, + std::index_sequence_for{}); +} + +struct DeviceImpl::Opaque { + std::shared_ptr internal; + + Opaque(const std::shared_ptr &internal) + : internal(internal) {} + + /** + * @brief Performs elementwise operations when all inputs and the output share the same data type. + * + * @tparam BLOCK_SIZE The block size for the kernel launch. + * @tparam N The number of input tensors. + * @tparam Op The operation to perform (e.g., addition, multiplication). + * @tparam Tdata The data type of the input and output tensors. + * @tparam Args Additional arguments to be passed to the operation. + * @param info Structure containing elementwise operation information (size, shape, etc.). + * @param output Pointer to the output memory where results will be stored. + * @param inputs Vector of pointers to input tensors. + * @param stream CUDA stream used for asynchronous execution. + * @param args Additional arguments for the operation. + */ + template + void calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + cudaStream_t stream, + Args &&...args) { + if (info.output_size == 0) { + return; + } + + // casting the output and the inputs to Tdata pointers + Tdata *out = reinterpret_cast(output); + const Tdata *inputs_arr[N]; + const Tdata **d_inputs_arr = nullptr; + for (size_t i = 0; i < N; ++i) { + inputs_arr[i] = reinterpret_cast(inputs[i]); + } + cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream); + cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream); + + // create and send the info to device + const bool *d_bools = nullptr; + const bool *d_input_contiguous = nullptr; + const bool *d_input_broadcasted = nullptr; + const int8_t *d_output_shape_strides = nullptr; + const size_t *d_output_shape = nullptr; + const ptrdiff_t *d_output_strides = nullptr; + const size_t **d_input_shapes = nullptr; + const ptrdiff_t **d_input_strides = nullptr; + std::vector tmp_device_ptrs(info.input_size); + std::vector tmp_device_ptrs_strides(info.input_size); + + infoToDevice(info, d_bools, d_input_contiguous, + d_input_broadcasted, d_output_shape_strides, d_output_shape, + d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, + d_input_strides, stream); + + dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); + dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); + size_t step = gridDims.x * blockDims.x; + + for (size_t i = 0; i < info.output_size; i += step) { + elementwise_kernel<<>>( + info.output_size, + info.ndim, + info.output_contiguous, + d_input_contiguous, + d_input_broadcasted, + d_output_shape, + d_input_shapes, + d_output_strides, + d_input_strides, + info.input_size, out, d_inputs_arr, i, std::forward(args)...); + } + + freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); + } + + /** + * @brief Performs elementwise operations when inputs and the outputs have mixed data types (i.e., different dtypes). + * + * @tparam BLOCK_SIZE The block size for the kernel launch. + * @tparam N The number of input tensors. + * @tparam Op The operation to perform (e.g., addition, multiplication). + * @tparam Tout The output data type. + * @tparam Tin The input data types. + * @tparam Args Additional arguments to be passed to the operation. + * @param info Structure containing elementwise operation information (size, shape, etc.). + * @param output Pointer to the output memory where results will be stored. + * @param inputs Vector of pointers to input tensors. + * @param stream CUDA stream used for asynchronous execution. + * @param args Additional arguments for the operation. + */ + template = 0> + void calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + cudaStream_t stream, + Args &&...args) { + if (info.output_size == 0) { + return; + } + + Tout *out = reinterpret_cast(output); + + // Store input pointers with the correct types + const std::tuple inputs_arr{reinterpret_cast(inputs[Is])...}; + const void **d_inputs_arr = nullptr; + + // Create array of input pointers on host (void*) to copy to device + const void *host_input_ptrs[] = {reinterpret_cast(std::get(inputs_arr))...}; + cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream); + cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream); + + // Device pointers + const bool *d_bools = nullptr; + const bool *d_input_contiguous = nullptr; + const bool *d_input_broadcasted = nullptr; + const int8_t *d_output_shape_strides = nullptr; + const size_t *d_output_shape = nullptr; + const ptrdiff_t *d_output_strides = nullptr; + const size_t **d_input_shapes = nullptr; + const ptrdiff_t **d_input_strides = nullptr; + std::vector tmp_device_ptrs(info.input_size); + std::vector tmp_device_ptrs_strides(info.input_size); + + infoToDevice(info, d_bools, d_input_contiguous, + d_input_broadcasted, d_output_shape_strides, d_output_shape, + d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, + d_input_strides, stream); + + dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); + dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); + size_t step = gridDims.x * blockDims.x; + + for (size_t i = 0; i < info.output_size; i += step) { + elementwise_kernel<<>>( + info.output_size, + info.ndim, + info.output_contiguous, + d_input_contiguous, + d_input_broadcasted, + d_output_shape, + d_input_shapes, + d_output_strides, + d_input_strides, + info.input_size, out, reinterpret_cast(d_inputs_arr), i); + } + + freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); + } + +private: + /** + * @brief Transfers elementwise kernel metadata (shapes, strides, flags) from host to device. + * + * @tparam N Number of inputs. + * @param info Structure containing input/output metadata. + * @param d_bools Device pointer for input_contiguous and input_broadcasted flags. + * @param d_input_contiguous Device pointer to input contiguity flags. + * @param d_input_broadcasted Device pointer to input broadcasting flags. + * @param d_output_shape_strides Device buffer containing both output shape and strides. + * @param d_output_shape Device pointer to output shape. + * @param d_output_strides Device pointer to output strides. + * @param tmp_device_ptrs Temporary device pointers for input shapes. + * @param d_input_shapes Device array of pointers to input shapes. + * @param tmp_device_ptrs_strides Temporary device pointers for input strides. + * @param d_input_strides Device array of pointers to input strides. + * @param stream CUDA stream for async allocation and transfers. + * @return infiniStatus_t Status indicating success or failure. + */ + template + infiniStatus_t infoToDevice( + const op::elementwise::ElementwiseInfo &info, + const bool *&d_bools, + const bool *&d_input_contiguous, + const bool *&d_input_broadcasted, + const int8_t *&d_output_shape_strides, + const size_t *&d_output_shape, + const ptrdiff_t *&d_output_strides, + std::vector &tmp_device_ptrs, + const size_t **&d_input_shapes, + std::vector &tmp_device_ptrs_strides, + const ptrdiff_t **&d_input_strides, + cudaStream_t stream) const { + + cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream); + cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); + + cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream); + cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream); + + cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream); + for (size_t i = 0; i < info.input_size; ++i) { + cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream); + cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], + info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream); + } + cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), + info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream); + + cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream); + for (size_t i = 0; i < info.input_size; ++i) { + cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream); + cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], + info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream); + } + cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), + info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream); + + d_input_contiguous = d_bools; + d_input_broadcasted = d_bools + info.input_size; + d_output_shape = reinterpret_cast(d_output_shape_strides); + d_output_strides = reinterpret_cast(d_output_shape_strides + info.ndim * sizeof(size_t)); + + return INFINI_STATUS_SUCCESS; + } + + /** + * @brief Frees all device-allocated memory used for metadata in elementwise kernel execution. + * + * @param d_inputs_arr Device array of input pointers. + * @param d_bools Device memory holding input flags. + * @param d_output_shape_strides Device buffer holding output shape and strides. + * @param input_size Number of input tensors. + * @param d_input_shapes Device array of input shape pointers. + * @param d_input_strides Device array of input stride pointers. + * @param stream CUDA stream for async deallocation. + * @return infiniStatus_t Status indicating success or failure. + */ + inline infiniStatus_t freeAllDevice(const void **d_inputs_arr, + const bool *d_bools, + const int8_t *d_output_shape_strides, + const size_t input_size, + const size_t **d_input_shapes, + const ptrdiff_t **d_input_strides, + cudaStream_t stream) const { + + cudaFreeAsync((void *)d_inputs_arr, stream); + cudaFreeAsync((void *)d_bools, stream); + cudaFreeAsync((void *)d_output_shape_strides, stream); + cudaFreeAsync((void *)d_input_shapes, stream); + cudaFreeAsync((void *)d_input_strides, stream); + return INFINI_STATUS_SUCCESS; + } +}; + +template +infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, + Args &&...args) { + auto opaque = std::make_shared(std::forward(args)...); + *device_info = new DeviceImpl(opaque); + return INFINI_STATUS_SUCCESS; +} + +/** + * @brief Launches elementwise operation where input types may differ. + * + * Dispatches to templated `calculateImpl` using specified output and input types. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tout Output data type. + * @tparam Tin... Input data types (must match Op::num_inputs). + * @tparam Args... Additional arguments passed to the operation. + * @param info Metadata describing tensor shapes, strides, etc. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args (UNUSED) Additional operation-specific arguments. + */ +template > +void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + constexpr size_t N = Op::num_inputs; + static_assert(sizeof...(Tin) == N, "Input type count mismatch"); + _opaque->calculateImpl( + info, output, inputs, + std::make_index_sequence{}, + reinterpret_cast(stream), + std::forward(args)...); +} + +/** + * @brief Launches elementwise operation where all input types are the same. + * + * Calls the corresponding templated `calculateImpl` with a unified input type. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tdata Data type for both input and output tensors. + * @tparam Args... Additional arguments passed to the operation. + * @param info Metadata describing tensor shapes, strides, etc. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args Additional operation-specific arguments. + */ +template +void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + constexpr size_t N = Op::num_inputs; + _opaque->calculateImpl( + info, output, inputs, + std::make_index_sequence{}, + reinterpret_cast(stream), + std::forward(args)...); + cudaStreamSynchronize(reinterpret_cast(stream)); +} + +} // namespace op::elementwise::cuda + +#endif // __INFINIOP_ELEMENTWISE_CUDA_H__ diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh new file mode 100644 index 000000000..6e80f5a03 --- /dev/null +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -0,0 +1,66 @@ +#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__ +#define __INFINIOP_ELEMENTWISE_CUDA_API_H__ + +#include "../elementwise.h" + +namespace op::elementwise::cuda { + +/** + * @brief Define the methods and info needed by CUDA to perform elementwise operation + */ +class DeviceImpl final { + struct Opaque; + std::shared_ptr _opaque; + + DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} + +public: + ~DeviceImpl() = default; + + template + static infiniStatus_t create( + DeviceImpl **device_info, + Args &&...args); + + /* Invoke elementwise operation when all inputs have the same dtype */ + template + void calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); + + /* Invoke elementwise operation for different input types */ + template = 0> + void calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); +}; +} // namespace op::elementwise::cuda + +/** + * @brief Define the process for initializing a Descriptor of an elementwise operation + * for its CUDA implementation + */ +#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \ + \ + op::elementwise::ElementwiseInfo elementwise_info; \ + CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ + \ + op::elementwise::cuda::DeviceImpl *device_impl; \ + CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + std::move(elementwise_info), \ + device_impl, \ + handle->device, \ + handle->device_id); + +#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index ae3e258c5..8be74326d 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -4,47 +4,12 @@ #include "../operator.h" #include "../tensor.h" #include +#include +#include #include #include #include -#define DEVICE_IMPL(NAMESPACE) \ - \ - namespace op::elementwise::NAMESPACE { \ - class DeviceImpl final { \ - struct Opaque; \ - std::unique_ptr _opaque; \ - \ - DeviceImpl(Opaque *opaque) : _opaque(opaque) {} \ - \ - public: \ - ~DeviceImpl() = default; \ - \ - template \ - static infiniStatus_t create( \ - DeviceImpl **device_info, \ - Args &&...args); \ - \ - /* Invoke elementwise operation when all inputs have the same type */ \ - template \ - void calculate( \ - const op::elementwise::ElementwiseInfo &info, \ - void *output, \ - const std::vector &inputs, \ - Args &&...args); \ - \ - /* Invoke elementwise operation for different input types */ \ - template = 0> \ - void calculate( \ - const op::elementwise::ElementwiseInfo &info, \ - void *output, \ - const std::vector &inputs, \ - Args &&...args); \ - }; \ - } - #define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ \ namespace op::OP::NAMESPACE { \ @@ -61,7 +26,7 @@ int device_id) \ : InfiniopDescriptor{device_type, device_id}, \ _dtype(dtype), \ - _info(info), \ + _info(std::move(info)), \ _device_info(device_info) {} \ \ public: \ @@ -87,12 +52,84 @@ struct ElementwiseInfo { size_t output_size; size_t ndim; bool output_contiguous; - std::vector input_contiguous; - std::vector input_broadcasted; - std::vector output_shape; - std::vector> input_shapes; - std::vector output_strides; - std::vector> input_strides; + bool *input_contiguous; + bool *input_broadcasted; + size_t *output_shape; + size_t **input_shapes; + ptrdiff_t *output_strides; + ptrdiff_t **input_strides; + size_t input_size; + + ElementwiseInfo() = default; + + // Destructor to free allocated memory + ~ElementwiseInfo() { + delete[] input_contiguous; + delete[] input_broadcasted; + delete[] output_shape; + delete[] output_strides; + + for (size_t i = 0; i < input_size; ++i) { + delete[] input_shapes[i]; + delete[] input_strides[i]; + } + delete[] input_shapes; + delete[] input_strides; + } + + ElementwiseInfo(const ElementwiseInfo &other) + : output_size(other.output_size), + ndim(other.ndim), + output_contiguous(other.output_contiguous), + input_size(other.input_size) { + + input_contiguous = new bool[input_size]; + std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous)); + + input_broadcasted = new bool[input_size]; + std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted)); + + output_shape = new size_t[ndim]; + std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape)); + + output_strides = new ptrdiff_t[ndim]; + std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides)); + + input_shapes = new size_t *[input_size]; + for (size_t i = 0; i < input_size; ++i) { + input_shapes[i] = new size_t[ndim]; + std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i])); + } + + input_strides = new ptrdiff_t *[input_size]; + for (size_t i = 0; i < input_size; ++i) { + input_strides[i] = new ptrdiff_t[ndim]; + std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i])); + } + } + + ElementwiseInfo(ElementwiseInfo &&other) noexcept + : output_size(other.output_size), + ndim(other.ndim), + output_contiguous(other.output_contiguous), + input_contiguous(other.input_contiguous), + input_broadcasted(other.input_broadcasted), + output_shape(other.output_shape), + input_shapes(other.input_shapes), + output_strides(other.output_strides), + input_strides(other.input_strides), + input_size(other.input_size) { + other.input_contiguous = nullptr; + other.input_broadcasted = nullptr; + other.output_shape = nullptr; + other.input_shapes = nullptr; + other.output_strides = nullptr; + other.input_strides = nullptr; + other.input_size = 0; + } + + ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete; + ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete; }; inline infiniStatus_t createElementwiseInfo( @@ -109,28 +146,37 @@ inline infiniStatus_t createElementwiseInfo( return INFINI_STATUS_BAD_TENSOR_STRIDES; } - const size_t input_size = input_descs.size(); - const size_t out_ndim = output_desc->ndim(); - - // Intializing the ElementwiseInfo struct + info.input_size = input_descs.size(); + info.ndim = output_desc->ndim(); info.output_size = output_desc->numel(); - info.ndim = out_ndim; info.output_contiguous = output_desc->isContiguous(); - for (const auto &desc : input_descs) { - info.input_contiguous.emplace_back(desc->isContiguous()); - } - - for (size_t i = 0; i < input_size; ++i) { - const auto &desc = input_descs[i]; - info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim())); - } - - info.output_shape = std::move(output_desc->shape()); - info.output_strides = std::move(output_desc->strides()); - for (const auto &desc : input_descs) { - info.input_shapes.emplace_back(desc->shape()); - info.input_strides.emplace_back(desc->strides()); + // Allocate memory for arrays + info.input_contiguous = new bool[info.input_size]; + info.input_broadcasted = new bool[info.input_size]; + info.output_shape = new size_t[info.ndim]; + info.output_strides = new ptrdiff_t[info.ndim]; + info.input_shapes = new size_t *[info.input_size]; + info.input_strides = new ptrdiff_t *[info.input_size]; + + // Fill arrays + const auto output_shape = output_desc->shape(); + const auto output_strides = output_desc->strides(); + std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape)); + std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides)); + + for (size_t i = 0; i < info.input_size; ++i) { + auto &desc = input_descs[i]; + info.input_contiguous[i] = desc->isContiguous(); + info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); + + info.input_shapes[i] = new size_t[desc->ndim()]; + const auto &in_shape = desc->shape(); + std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); + + info.input_strides[i] = new ptrdiff_t[desc->ndim()]; + const auto &in_strides = desc->strides(); + std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); } return INFINI_STATUS_SUCCESS; diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc index 9eb470aa7..eb64c75b4 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc @@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle_, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t up_desc, - infiniopTensorDescriptor_t gate_desc) { + std::vector input_desc) { auto handle = reinterpret_cast(handle_); auto dtype = out_desc->dtype(); + + const auto &up_desc = input_desc.at(0); + const auto &gate_desc = input_desc.at(1); const auto &out_shape = out_desc->shape(); const auto &up_shape = up_desc->shape(); const auto &gate_shape = gate_desc->shape(); @@ -21,35 +23,26 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); - op::binary::BinaryInfo info; - CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc)); - - // Create descriptor - *desc_ptr = new Descriptor( - dtype, - std::move(info), - nullptr, - handle->device, - handle->device_id); + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR; return INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate( - void *c, - const void *a, - const void *b, + void *output, + std::vector inputs, void *stream) const { switch (_dtype) { case INFINI_DTYPE_F16: - op::common_cpu::binary_op::calculate(_info, c, a, b); + _device_info->calculate(_info, output, inputs, stream); break; case INFINI_DTYPE_F32: - op::common_cpu::binary_op::calculate(_info, c, a, b); + _device_info->calculate(_info, output, inputs, stream); break; case INFINI_DTYPE_F64: - op::common_cpu::binary_op::calculate(_info, c, a, b); + _device_info->calculate(_info, output, inputs, stream); break; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h index ac1eba6f1..67e42d2c6 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h @@ -1,11 +1,12 @@ #ifndef __SWIGLU_CPU_H__ #define __SWIGLU_CPU_H__ -#include "../../../binary/cpu/binary_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" -BINARY_DESCRIPTOR(swiglu, cpu) +ELEMENTWISE_DESCRIPTOR(swiglu, cpu) -struct SwiGLUOp { +namespace op::swiglu::cpu { +typedef struct SwiGLUOp { private: template T sigmoid(const T &x) const { @@ -13,10 +14,12 @@ struct SwiGLUOp { } public: + static constexpr size_t num_inputs = 2; template T operator()(const T &up, const T &gate) const { return gate * sigmoid(gate) * up; } -}; +} SwiGLUOp; +} // namespace op::swiglu::cpu #endif // __SWIGLU_CPU_H__ diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu new file mode 100644 index 000000000..79f3572ff --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu @@ -0,0 +1,55 @@ +#include "swiglu_cuda.cuh" +#include "swiglu_cuda_internal.cuh" + +namespace op::swiglu::cuda { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &up_desc = input_desc.at(0); + const auto &gate_desc = input_desc.at(1); + const auto &out_shape = out_desc->shape(); + const auto &up_shape = up_desc->shape(); + const auto &gate_shape = gate_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + if (!SAME_VEC(out_shape, up_shape, gate_shape)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); + break; + case INFINI_DTYPE_F32: + _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); + break; + case INFINI_DTYPE_F64: + _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::swiglu::cuda diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh new file mode 100644 index 000000000..75e529ab1 --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SWIGLU_CUDA_API_H__ +#define __SWIGLU_CUDA_API_H__ + +#include "../../../elementwise/cuda/elementwise_cuda_api.cuh" + +ELEMENTWISE_DESCRIPTOR(swiglu, cuda) + +#endif // __SWIGLU_CUDA_API_H__ diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh new file mode 100644 index 000000000..5d44a4bb7 --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh @@ -0,0 +1,59 @@ +#ifndef __SWIGLU_CUDA_H__ +#define __SWIGLU_CUDA_H__ + +#include "../../../elementwise/cuda/elementwise_cuda.cuh" +#include + +namespace op::swiglu::cuda { +typedef struct SwiGLUOp { +private: + template + __device__ __forceinline__ T sigmoid(const T &x) const { + if constexpr (std::is_same_v) { + return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); + } else if constexpr (std::is_same_v) { + return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); + } else if constexpr (std::is_same_v) { + return __frcp_rd(__fadd_rd(1, __expf(-x))); + } else { + return 1 / (1 + std::exp(-x)); + } + } + +public: + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T &up, const T &gate) const { + if constexpr (std::is_same_v) { + return __hmul2(__hmul2(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + return __hmul(__hmul(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up); + } else { + return gate * sigmoid(gate) * up; + } + } + + template + __device__ __forceinline__ Tc operator()(const Ta &up, const Tb &gate) const { + if constexpr (std::is_same_v) { + return __hmul2(__hmul2(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (std::is_same_v) { + return __float2half(__fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up))); + } else { + return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up)); + } + } else if constexpr (std::is_same_v) { + return __hmul(__hmul(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up); + } else { + return gate * sigmoid(gate) * up; + } + } +} SwiGLUOp; +} // namespace op::swiglu::cuda + +#endif // __SWIGLU_CUDA_H__ diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 80be80bfd..de3ecb874 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -5,6 +5,9 @@ #ifdef ENABLE_CPU_API #include "cpu/swiglu_cpu.h" #endif +#ifdef ENABLE_CUDA_API +#include "cuda/swiglu_cuda.cuh" +#endif __C infiniStatus_t infiniopCreateSwiGLUDescriptor( infiniopHandle_t handle, @@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( handle, \ reinterpret_cast(desc_ptr), \ c_desc, \ - a_desc, \ - b_desc) + {a_desc, \ + b_desc}) switch (handle->device) { #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaCreateSwiGLUDescriptor((CudaHandle_t)handle, - (SwiGLUCudaDescriptor_t *)desc_ptr, - c_desc, a_desc, b_desc); +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -76,16 +76,15 @@ __C infiniStatus_t infiniopSwiGLU( #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(c, a, b, stream) + ->calculate(c, {a, b}, stream) switch (desc->device_type) { #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaSwiGLU((SwiGLUCudaDescriptor_t)desc, c, a, b, stream); +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -125,9 +124,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { #ifdef ENABLE_CPU_API DELETE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaDestroySwiGLUDescriptor((SwiGLUCudaDescriptor_t)desc); +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { diff --git a/src/utils.h b/src/utils.h index 13a8c78a1..fa5469584 100644 --- a/src/utils.h +++ b/src/utils.h @@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { } } +#define CEIL_DIV(x, y) ((x + y - 1) / y) + #endif diff --git a/xmake/cuda.lua b/xmake/cuda.lua index 7c89c64e3..0d7ccfdae 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -28,6 +28,7 @@ target("infiniop-cuda") else add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") add_cxxflags("-fPIC") end From 7a8f2bca3240816bcbd08cc764ff93897cbc2849 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 8 Apr 2025 14:17:54 +0800 Subject: [PATCH 03/14] issue/127: modify swiglu test to correctly handle broadcast scenarios, add two broadcast testcases, correct elementwise cpu mix-precision implementation --- .../elementwise/cpu/elementwise_cpu.h | 2 +- test/infiniop/swiglu.py | 44 +++++++++++++++---- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 23e5873c6..a7357d568 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -84,7 +84,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); }; - out[out_idx] = utils::cast(Op{}(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); + out[out_idx] = utils::cast(Op{}.template operator()(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); } } diff --git a/test/infiniop/swiglu.py b/test/infiniop/swiglu.py index 1e145692a..09649af87 100644 --- a/test/infiniop/swiglu.py +++ b/test/infiniop/swiglu.py @@ -25,8 +25,10 @@ # shape, a_stride, b_stride, c_stride ((13, 4), None, None, None), ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), ((13, 4, 4), None, None, None), ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), ((16, 5632), None, None, None), ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((4, 4, 5632), None, None, None), @@ -76,6 +78,38 @@ class SwiGLUDescriptor(Structure): def swiglu(a, b): return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) + + + +def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace): + """ + rearrange the tensors if needed and apply the inplace config. + if inplace is true and the output (i.e., c) is placed to the broadcasted input, + the inplace config is ignored and out-of-place is used + """ + original_c_strides = c_strides if c_strides else c.stride() + + def _rearrange(tensor, strides): + if strides and 0 in strides: + tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides) + return tensor + else: + return rearrange_if_needed(tensor, strides) + + a, b, c = [ + _rearrange(tensor, stride) + for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides]) + ] + c = ( + c + if inplace == Inplace.OUT_OF_PLACE + else (a if inplace == Inplace.INPLACE_A else b) + ) + # if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides + if 0 in c.stride(): + c.set_(c.untyped_storage(), 0, c.shape, original_c_strides) + + return a, b, c def test( @@ -98,18 +132,10 @@ def test( a = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device) + a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace) ans = swiglu(a, b) - a, b, c = [ - rearrange_if_needed(tensor, stride) - for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride]) - ] - c = ( - c - if inplace == Inplace.OUT_OF_PLACE - else (a if inplace == Inplace.INPLACE_A else b) - ) a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] c_tensor = ( to_tensor(c, lib) From a283a8fa5b6346ff9ac210c1568ecd08a97943c1 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 8 Apr 2025 16:04:38 +0800 Subject: [PATCH 04/14] issue/127: refactor ElementwiseInfo to use utils::Result, change elementwise calcualte and calculateImpl to return infiniStatus_t, add CHECK_CUDA to cuda function calls --- .../elementwise/cpu/elementwise_cpu.h | 78 ++++++--- .../elementwise/cuda/elementwise_cuda.cuh | 135 +++++++-------- .../elementwise/cuda/elementwise_cuda_api.cuh | 30 ++-- src/infiniop/elementwise/elementwise.h | 154 ++++++++---------- src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc | 9 +- src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu | 9 +- 6 files changed, 222 insertions(+), 193 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index a7357d568..2fcea5625 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -9,20 +9,27 @@ * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CPU implementation */ -#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \ - \ - op::elementwise::ElementwiseInfo elementwise_info; \ - CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ - \ - *desc_ptr = new Descriptor( \ - dtype, \ - std::move(elementwise_info), \ - nullptr, \ - handle->device, \ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ + CHECK_RESULT(info_result); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + std::move(info_result.take()), \ + nullptr, \ + handle->device, \ handle->device_id); namespace op::elementwise::cpu { +/** + * @brief CPU-specific device implementation for resource management and + * calculation implementations. + * + * This class encapsulates device-specific behavior and execution logic. + * Use the static create() method to instantiate a DeviceImpl. + */ class DeviceImpl final { struct Opaque; std::shared_ptr _opaque; @@ -37,20 +44,48 @@ class DeviceImpl final { DeviceImpl **device_info, Args &&...args); - /* Invoke elementwise operation when all inputs have the same type */ + /** + * @brief Dispatches an elementwise operation with uniform input types. + * + * @tparam Op The elementwise operation to perform. + * @tparam Tdata The common data type of all inputs and output. + * @tparam Args Additional backend-specific arguments. + * @param info Precomputed tensor metadata (shapes, strides, etc.). + * @param output Pointer to the output tensor buffer. + * @param inputs Vector of input tensor data pointers. + * @param stream Device execution stream. + * @param args Additional backend-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ template - void calculate( + infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, void *stream, Args &&...args); - /* Invoke elementwise operation for different input types */ + /** + * @brief Dispatches an elementwise operation with heterogeneous input types. + * + * Supports operations where each input may have a different type, as defined by Op. + * The number of input types must match the operation's expected input count. + * + * @tparam Op The elementwise operation to perform. + * @tparam Tout Output data type. + * @tparam Tin Variadic input data types. + * @tparam Args Additional backend-specific arguments. + * @param info Precomputed tensor metadata (shapes, strides, etc.). + * @param output Pointer to the output tensor buffer. + * @param inputs Vector of input tensor data pointers. + * @param stream Device execution stream. + * @param args Additional backend-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ template = 0> - void calculate( + infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, @@ -58,6 +93,7 @@ class DeviceImpl final { Args &&...args); }; +// Define the Opaque struct for CPU, which is empty struct DeviceImpl::Opaque {}; template @@ -90,14 +126,15 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, // Invoke elementwise operation for different input types template = 0> -void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args) { +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + return INFINI_STATUS_SUCCESS; } // Perform elementwise operation when all inputs have the same type @@ -133,9 +170,10 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, // Invoke elementwise operation when all inputs have the same type template -void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, void *stream, Args &&...args) { +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, void *stream, Args &&...args) { constexpr size_t N = Op::num_inputs; calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + return INFINI_STATUS_SUCCESS; } } // namespace op::elementwise::cpu diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index cfd0461ca..08a2c4fad 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -4,6 +4,7 @@ #include "../../../utils.h" #include "../../devices/cuda/cuda_common.cuh" #include "elementwise_cuda_api.cuh" + namespace op::elementwise::cuda { /** @@ -223,16 +224,17 @@ struct DeviceImpl::Opaque { * @param inputs Vector of pointers to input tensors. * @param stream CUDA stream used for asynchronous execution. * @param args Additional arguments for the operation. + * @return infiniStatus_t Status indicating success or failure. */ template - void calculateImpl(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - std::index_sequence, - cudaStream_t stream, - Args &&...args) { + infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + cudaStream_t stream, + Args &&...args) { if (info.output_size == 0) { - return; + return INFINI_STATUS_SUCCESS; } // casting the output and the inputs to Tdata pointers @@ -242,8 +244,8 @@ struct DeviceImpl::Opaque { for (size_t i = 0; i < N; ++i) { inputs_arr[i] = reinterpret_cast(inputs[i]); } - cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream); - cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream)); + CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream)); // create and send the info to device const bool *d_bools = nullptr; @@ -257,10 +259,10 @@ struct DeviceImpl::Opaque { std::vector tmp_device_ptrs(info.input_size); std::vector tmp_device_ptrs_strides(info.input_size); - infoToDevice(info, d_bools, d_input_contiguous, - d_input_broadcasted, d_output_shape_strides, d_output_shape, - d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, - d_input_strides, stream); + CHECK_STATUS(infoToDevice(info, d_bools, d_input_contiguous, + d_input_broadcasted, d_output_shape_strides, d_output_shape, + d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, + d_input_strides, stream)); dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); @@ -280,7 +282,9 @@ struct DeviceImpl::Opaque { info.input_size, out, d_inputs_arr, i, std::forward(args)...); } - freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); + CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, + info.input_size, d_input_shapes, d_input_strides, stream)); + return INFINI_STATUS_SUCCESS; } /** @@ -297,17 +301,18 @@ struct DeviceImpl::Opaque { * @param inputs Vector of pointers to input tensors. * @param stream CUDA stream used for asynchronous execution. * @param args Additional arguments for the operation. + * @return infiniStatus_t Status indicating success or failure. */ template = 0> - void calculateImpl(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - std::index_sequence, - cudaStream_t stream, - Args &&...args) { + infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + cudaStream_t stream, + Args &&...args) { if (info.output_size == 0) { - return; + return INFINI_STATUS_SUCCESS; } Tout *out = reinterpret_cast(output); @@ -318,8 +323,8 @@ struct DeviceImpl::Opaque { // Create array of input pointers on host (void*) to copy to device const void *host_input_ptrs[] = {reinterpret_cast(std::get(inputs_arr))...}; - cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream); - cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream)); + CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream)); // Device pointers const bool *d_bools = nullptr; @@ -333,10 +338,10 @@ struct DeviceImpl::Opaque { std::vector tmp_device_ptrs(info.input_size); std::vector tmp_device_ptrs_strides(info.input_size); - infoToDevice(info, d_bools, d_input_contiguous, - d_input_broadcasted, d_output_shape_strides, d_output_shape, - d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, - d_input_strides, stream); + CHECK_STATUS(infoToDevice(info, d_bools, d_input_contiguous, + d_input_broadcasted, d_output_shape_strides, d_output_shape, + d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, + d_input_strides, stream)); dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); @@ -356,7 +361,8 @@ struct DeviceImpl::Opaque { info.input_size, out, reinterpret_cast(d_inputs_arr), i); } - freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); + CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream)); + return INFINI_STATUS_SUCCESS; } private: @@ -393,31 +399,31 @@ private: const ptrdiff_t **&d_input_strides, cudaStream_t stream) const { - cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream); - cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream)); - cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream); - cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream)); - cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream); + CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream)); for (size_t i = 0; i < info.input_size; ++i) { - cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream); - cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], - info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], + info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream)); } - cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), - info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), + info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream)); - cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream); + CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream)); for (size_t i = 0; i < info.input_size; ++i) { - cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream); - cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], - info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], + info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream)); } - cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), - info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream); + CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), + info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream)); d_input_contiguous = d_bools; d_input_broadcasted = d_bools + info.input_size; @@ -447,11 +453,11 @@ private: const ptrdiff_t **d_input_strides, cudaStream_t stream) const { - cudaFreeAsync((void *)d_inputs_arr, stream); - cudaFreeAsync((void *)d_bools, stream); - cudaFreeAsync((void *)d_output_shape_strides, stream); - cudaFreeAsync((void *)d_input_shapes, stream); - cudaFreeAsync((void *)d_input_strides, stream); + CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream)); + CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream)); + CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream)); + CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream)); + CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream)); return INFINI_STATUS_SUCCESS; } }; @@ -479,17 +485,18 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, * @param inputs Vector of input pointers (device memory). * @param stream CUDA stream (opaque void*). * @param args (UNUSED) Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. */ template > -void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args) { +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { constexpr size_t N = Op::num_inputs; static_assert(sizeof...(Tin) == N, "Input type count mismatch"); - _opaque->calculateImpl( + return _opaque->calculateImpl( info, output, inputs, std::make_index_sequence{}, reinterpret_cast(stream), @@ -510,20 +517,20 @@ void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, * @param inputs Vector of input pointers (device memory). * @param stream CUDA stream (opaque void*). * @param args Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. */ template -void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args) { +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { constexpr size_t N = Op::num_inputs; - _opaque->calculateImpl( + return _opaque->calculateImpl( info, output, inputs, std::make_index_sequence{}, reinterpret_cast(stream), std::forward(args)...); - cudaStreamSynchronize(reinterpret_cast(stream)); } } // namespace op::elementwise::cuda diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index 6e80f5a03..cf7034f8c 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -24,7 +24,7 @@ public: /* Invoke elementwise operation when all inputs have the same dtype */ template - void calculate( + infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, @@ -35,7 +35,7 @@ public: template = 0> - void calculate( + infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, @@ -48,19 +48,19 @@ public: * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CUDA implementation */ -#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \ - \ - op::elementwise::ElementwiseInfo elementwise_info; \ - CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ - \ - op::elementwise::cuda::DeviceImpl *device_impl; \ - CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ - \ - *desc_ptr = new Descriptor( \ - dtype, \ - std::move(elementwise_info), \ - device_impl, \ - handle->device, \ +#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ + CHECK_RESULT(info_result); \ + \ + op::elementwise::cuda::DeviceImpl *device_impl; \ + CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + std::move(info_result.take()), \ + device_impl, \ + handle->device, \ handle->device_id); #endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index 8be74326d..5a6a7de1c 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -1,6 +1,7 @@ #ifndef __INFINIOP_ELEMENTWISE_H__ #define __INFINIOP_ELEMENTWISE_H__ +#include "../../utils.h" #include "../operator.h" #include "../tensor.h" #include @@ -47,8 +48,22 @@ namespace op::elementwise { -// struct that stores data needed for elementwise operation +/** + * @brief Stores the metadata required for performing an elementwise operation. + * + * This struct encapsulates shape, stride, and layout information for both + * output and multiple input tensors involved in an elementwise operation. + * + * Memory is manually managed and freed in the destructor. + * Supports move construction but disallows copy construction and copy/move assignment. + * + * Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors. + */ struct ElementwiseInfo { +private: + ElementwiseInfo() = default; + +public: size_t output_size; size_t ndim; bool output_contiguous; @@ -60,9 +75,6 @@ struct ElementwiseInfo { ptrdiff_t **input_strides; size_t input_size; - ElementwiseInfo() = default; - - // Destructor to free allocated memory ~ElementwiseInfo() { delete[] input_contiguous; delete[] input_broadcasted; @@ -77,37 +89,6 @@ struct ElementwiseInfo { delete[] input_strides; } - ElementwiseInfo(const ElementwiseInfo &other) - : output_size(other.output_size), - ndim(other.ndim), - output_contiguous(other.output_contiguous), - input_size(other.input_size) { - - input_contiguous = new bool[input_size]; - std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous)); - - input_broadcasted = new bool[input_size]; - std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted)); - - output_shape = new size_t[ndim]; - std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape)); - - output_strides = new ptrdiff_t[ndim]; - std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides)); - - input_shapes = new size_t *[input_size]; - for (size_t i = 0; i < input_size; ++i) { - input_shapes[i] = new size_t[ndim]; - std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i])); - } - - input_strides = new ptrdiff_t *[input_size]; - for (size_t i = 0; i < input_size; ++i) { - input_strides[i] = new ptrdiff_t[ndim]; - std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i])); - } - } - ElementwiseInfo(ElementwiseInfo &&other) noexcept : output_size(other.output_size), ndim(other.ndim), @@ -128,60 +109,69 @@ struct ElementwiseInfo { other.input_size = 0; } + ElementwiseInfo(const ElementwiseInfo &other) = delete; ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete; ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete; -}; -inline infiniStatus_t createElementwiseInfo( - ElementwiseInfo &info, - infiniopTensorDescriptor_t output_desc, - std::vector input_descs) { + using ResultType = utils::Result; + + /** + * @brief Construct ElementwiseInfo from output and input tensor descriptors. + * @param output_desc Descriptor of the output tensor. + * @param input_descs Descriptors of the input tensors. + * @return Result with the successfully constructed ElementwiseInfo, + * or the status code. + */ + static ResultType create( + infiniopTensorDescriptor_t output_desc, + std::vector input_descs) { + + if (!output_desc || input_descs.empty()) { + return INFINI_STATUS_BAD_PARAM; + } - if (!output_desc || input_descs.empty()) { - return INFINI_STATUS_BAD_PARAM; - } + // Destination cannot have broadcast setup + if (output_desc->hasBroadcastDim()) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } - // Destination cannot have broadcast setup - if (output_desc->hasBroadcastDim()) { - return INFINI_STATUS_BAD_TENSOR_STRIDES; - } + ElementwiseInfo info; + info.input_size = input_descs.size(); + info.ndim = output_desc->ndim(); + info.output_size = output_desc->numel(); + info.output_contiguous = output_desc->isContiguous(); + + // Allocate memory for arrays + info.input_contiguous = new bool[info.input_size]; + info.input_broadcasted = new bool[info.input_size]; + info.output_shape = new size_t[info.ndim]; + info.output_strides = new ptrdiff_t[info.ndim]; + info.input_shapes = new size_t *[info.input_size]; + info.input_strides = new ptrdiff_t *[info.input_size]; + + // Fill arrays + const auto output_shape = output_desc->shape(); + const auto output_strides = output_desc->strides(); + std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape)); + std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides)); + + for (size_t i = 0; i < info.input_size; ++i) { + auto &desc = input_descs[i]; + info.input_contiguous[i] = desc->isContiguous(); + info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); + + info.input_shapes[i] = new size_t[desc->ndim()]; + const auto &in_shape = desc->shape(); + std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); + + info.input_strides[i] = new ptrdiff_t[desc->ndim()]; + const auto &in_strides = desc->strides(); + std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); + } - info.input_size = input_descs.size(); - info.ndim = output_desc->ndim(); - info.output_size = output_desc->numel(); - info.output_contiguous = output_desc->isContiguous(); - - // Allocate memory for arrays - info.input_contiguous = new bool[info.input_size]; - info.input_broadcasted = new bool[info.input_size]; - info.output_shape = new size_t[info.ndim]; - info.output_strides = new ptrdiff_t[info.ndim]; - info.input_shapes = new size_t *[info.input_size]; - info.input_strides = new ptrdiff_t *[info.input_size]; - - // Fill arrays - const auto output_shape = output_desc->shape(); - const auto output_strides = output_desc->strides(); - std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape)); - std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides)); - - for (size_t i = 0; i < info.input_size; ++i) { - auto &desc = input_descs[i]; - info.input_contiguous[i] = desc->isContiguous(); - info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); - - info.input_shapes[i] = new size_t[desc->ndim()]; - const auto &in_shape = desc->shape(); - std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); - - info.input_strides[i] = new ptrdiff_t[desc->ndim()]; - const auto &in_strides = desc->strides(); - std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); + return ResultType(std::move(info)); } - - return INFINI_STATUS_SUCCESS; -} - +}; } // namespace op::elementwise #endif // __INFINIOP_ELEMENTWISE_H__ diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc index eb64c75b4..8413d295a 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc @@ -36,14 +36,11 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - _device_info->calculate(_info, output, inputs, stream); - break; + return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F32: - _device_info->calculate(_info, output, inputs, stream); - break; + return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F64: - _device_info->calculate(_info, output, inputs, stream); - break; + return _device_info->calculate(_info, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu index 79f3572ff..5b0e8cee6 100644 --- a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu @@ -38,14 +38,11 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); - break; + return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); case INFINI_DTYPE_F32: - _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); - break; + return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); case INFINI_DTYPE_F64: - _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); - break; + return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } From 6df3f51fd2d006b0ed9a1c32a3f003ec73ea7a01 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 8 Apr 2025 16:06:49 +0800 Subject: [PATCH 05/14] issue/127: remove the testing-purpose mix-precision implementation of swiglu in swiglu_cuda.cuh --- .../ops/swiglu/cuda/swiglu_cuda_internal.cuh | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh index 5d44a4bb7..d832f8110 100644 --- a/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh @@ -34,25 +34,6 @@ public: return gate * sigmoid(gate) * up; } } - - template - __device__ __forceinline__ Tc operator()(const Ta &up, const Tb &gate) const { - if constexpr (std::is_same_v) { - return __hmul2(__hmul2(gate, sigmoid(gate)), up); - } else if constexpr (std::is_same_v && std::is_same_v) { - if constexpr (std::is_same_v) { - return __float2half(__fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up))); - } else { - return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up)); - } - } else if constexpr (std::is_same_v) { - return __hmul(__hmul(gate, sigmoid(gate)), up); - } else if constexpr (std::is_same_v) { - return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up); - } else { - return gate * sigmoid(gate) * up; - } - } } SwiGLUOp; } // namespace op::swiglu::cuda From b0f75278088753cac01a2ca9c09956494158b18d Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 8 Apr 2025 16:19:06 +0800 Subject: [PATCH 06/14] issue/127: Fix CI format - remove redefinition of default 0 in enable_if, remove std::move() in elementwise_cpu.h, add inclusion --- src/infiniop/elementwise/cpu/elementwise_cpu.h | 4 ++-- src/infiniop/elementwise/elementwise.h | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 2fcea5625..ed71861cd 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -16,7 +16,7 @@ \ *desc_ptr = new Descriptor( \ dtype, \ - std::move(info_result.take()), \ + info_result.take(), \ nullptr, \ handle->device, \ handle->device_id); @@ -125,7 +125,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, } // Invoke elementwise operation for different input types -template = 0> +template > infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index 5a6a7de1c..80011417f 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -5,6 +5,7 @@ #include "../operator.h" #include "../tensor.h" #include +#include #include #include #include From 40fdded5e7e1430b982d82a54ae50018c562b03c Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Wed, 9 Apr 2025 11:28:33 +0800 Subject: [PATCH 07/14] issue/127: fix CUDA mix-precision broadcasting input mismatch issue, adjust comment structure and template variable order --- .../elementwise/cuda/elementwise_cuda.cuh | 59 +++++-------------- .../elementwise/cuda/elementwise_cuda_api.cuh | 40 +++++++++++-- 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index 08a2c4fad..a7be14bdb 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -43,7 +43,7 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence +template INFINIOP_CUDA_KERNEL elementwise_kernel( size_t output_size, size_t ndim, @@ -129,6 +129,7 @@ __device__ void launch_op( const bool *__restrict__ input_broadcasted, const size_t *__restrict__ const *__restrict__ input_shapes, const ptrdiff_t *__restrict__ const *__restrict__ input_strides, + const ptrdiff_t *__restrict__ output_strides, const void *const *__restrict__ inputs, Tout *output, std::index_sequence) { @@ -137,7 +138,7 @@ __device__ void launch_op( return input_contiguous[input_id] ? idx : (input_broadcasted[input_id] - ? device::cuda::indexToReducedOffset(idx, ndim, input_strides[0], input_strides[input_id]) + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) : device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); }; @@ -200,6 +201,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( input_broadcasted, input_shapes, input_strides, + output_strides, inputs, output, std::index_sequence_for{}); @@ -269,7 +271,7 @@ struct DeviceImpl::Opaque { size_t step = gridDims.x * blockDims.x; for (size_t i = 0; i < info.output_size; i += step) { - elementwise_kernel<<>>( + elementwise_kernel<<>>( info.output_size, info.ndim, info.output_contiguous, @@ -400,8 +402,8 @@ private: cudaStream_t stream) const { CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream)); CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream)); @@ -411,24 +413,24 @@ private: for (size_t i = 0; i < info.input_size; ++i) { CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream)); CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], - info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream)); + info.ndim * sizeof(*tmp_device_ptrs[i]), cudaMemcpyHostToDevice, stream)); } CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), - info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream)); + info.input_size * sizeof(*d_input_shapes), cudaMemcpyHostToDevice, stream)); CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream)); for (size_t i = 0; i < info.input_size; ++i) { - CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream)); + CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*tmp_device_ptrs_strides[i]), stream)); CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], - info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream)); + info.ndim * sizeof(*tmp_device_ptrs_strides[i]), cudaMemcpyHostToDevice, stream)); } CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), - info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream)); + info.input_size * sizeof(*d_input_strides), cudaMemcpyHostToDevice, stream)); d_input_contiguous = d_bools; d_input_broadcasted = d_bools + info.input_size; d_output_shape = reinterpret_cast(d_output_shape_strides); - d_output_strides = reinterpret_cast(d_output_shape_strides + info.ndim * sizeof(size_t)); + d_output_strides = reinterpret_cast(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)); return INFINI_STATUS_SUCCESS; } @@ -470,23 +472,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, return INFINI_STATUS_SUCCESS; } -/** - * @brief Launches elementwise operation where input types may differ. - * - * Dispatches to templated `calculateImpl` using specified output and input types. - * - * @tparam BLOCK_SIZE Number of threads per block. - * @tparam Op Operation functor defining the computation. - * @tparam Tout Output data type. - * @tparam Tin... Input data types (must match Op::num_inputs). - * @tparam Args... Additional arguments passed to the operation. - * @param info Metadata describing tensor shapes, strides, etc. - * @param output Pointer to output buffer on device. - * @param inputs Vector of input pointers (device memory). - * @param stream CUDA stream (opaque void*). - * @param args (UNUSED) Additional operation-specific arguments. - * @return infiniStatus_t Status indicating success or failure. - */ +/* Invoke elementwise operation for different input types */ template > infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, @@ -503,22 +489,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf std::forward(args)...); } -/** - * @brief Launches elementwise operation where all input types are the same. - * - * Calls the corresponding templated `calculateImpl` with a unified input type. - * - * @tparam BLOCK_SIZE Number of threads per block. - * @tparam Op Operation functor defining the computation. - * @tparam Tdata Data type for both input and output tensors. - * @tparam Args... Additional arguments passed to the operation. - * @param info Metadata describing tensor shapes, strides, etc. - * @param output Pointer to output buffer on device. - * @param inputs Vector of input pointers (device memory). - * @param stream CUDA stream (opaque void*). - * @param args Additional operation-specific arguments. - * @return infiniStatus_t Status indicating success or failure. - */ +/* Invoke elementwise operation when all inputs have the same dtype */ template infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index cf7034f8c..78b1ea881 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -18,11 +18,25 @@ public: ~DeviceImpl() = default; template - static infiniStatus_t create( - DeviceImpl **device_info, - Args &&...args); + static infiniStatus_t create(DeviceImpl **device_info, Args &&...args); - /* Invoke elementwise operation when all inputs have the same dtype */ + /** + * @brief Launches elementwise operation where all input types are the same. + * + * Calls the corresponding templated `calculateImpl` with a unified input type. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tdata Data type for both input and output tensors. + * @tparam Args... Additional arguments passed to the operation. + * + * @param info Metadata describing tensor shapes, strides, etc. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ template infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, @@ -31,7 +45,23 @@ public: void *stream, Args &&...args); - /* Invoke elementwise operation for different input types */ + /** + * @brief Launches elementwise operation where input types may differ. + * + * Dispatches to templated `calculateImpl` using specified output and input types. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tout Output data type. + * @tparam Tin... Input data types (must match Op::num_inputs). + * @tparam Args... Additional arguments passed to the operation. + * @param info Metadata describing tensor shapes, strides, etc. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args (UNUSED) Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ template = 0> From 9cc0c41644208da61430287e78b98ad6a9dec3f7 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 14 Apr 2025 15:47:28 +0800 Subject: [PATCH 08/14] issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for storing meta, fix misc. issues --- include/infiniop/ops/swiglu.h | 4 + .../devices/cuda/cuda_kernel_common.cuh | 5 + .../elementwise/cpu/elementwise_cpu.h | 50 ++- .../elementwise/cuda/elementwise_cuda.cuh | 342 +++++++----------- .../elementwise/cuda/elementwise_cuda_api.cuh | 9 +- src/infiniop/elementwise/elementwise.h | 179 +++++---- src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc | 2 + src/infiniop/ops/swiglu/cpu/swiglu_cpu.h | 2 +- src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu | 16 +- src/infiniop/ops/swiglu/operator.cc | 43 ++- test/infiniop/swiglu.py | 22 +- 11 files changed, 364 insertions(+), 310 deletions(-) diff --git a/include/infiniop/ops/swiglu.h b/include/infiniop/ops/swiglu.h index 7a74f6382..1d4d87e17 100644 --- a/include/infiniop/ops/swiglu.h +++ b/include/infiniop/ops/swiglu.h @@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc); +__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size); + __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, + void *workspace, + size_t workspace_size, void *c, void const *a, void const *b, diff --git a/src/infiniop/devices/cuda/cuda_kernel_common.cuh b/src/infiniop/devices/cuda/cuda_kernel_common.cuh index b3f52db01..68ef36c2a 100644 --- a/src/infiniop/devices/cuda/cuda_kernel_common.cuh +++ b/src/infiniop/devices/cuda/cuda_kernel_common.cuh @@ -9,6 +9,10 @@ #define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_512 512 +#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess) + +namespace device::cuda { + // return the memory offset of original tensor, given the flattened index of broadcasted tensor __forceinline__ __device__ __host__ size_t indexToReducedOffset( @@ -38,6 +42,7 @@ indexToOffset( } return res; } +} // namespace device::cuda #ifdef ENABLE_CUDA_API #include diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index ed71861cd..880ce027b 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -18,6 +18,7 @@ dtype, \ info_result.take(), \ nullptr, \ + 0, \ handle->device, \ handle->device_id); @@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) { } // Perform elementwise operation for different input types -template = 0> -void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, std::index_sequence, Args &&...args) { +template = 0> +void calculate_impl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + Args &&...args) { + Tout *out = reinterpret_cast(output); std::tuple input_ptrs = {reinterpret_cast(inputs[Is])...}; - ptrdiff_t output_size = info.output_size; + ptrdiff_t output_size = info.getOutputSize(); #pragma omp parallel for for (ptrdiff_t i = 0; i < output_size; ++i) { - size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); + size_t out_idx = info.isOutputContiguous() + ? i + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); auto get_input_idx = [&](size_t input_id) { - return info.input_contiguous[input_id] ? i - : (info.input_broadcasted[input_id] - ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) - : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); + return info.getInputContiguous()[input_id] + ? i + : (info.getInputBroadcasted()[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id)) + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id))); }; - out[out_idx] = utils::cast(Op{}.template operator()(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); + out[out_idx] = utils::cast( + Op{}.template operator()(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); } } @@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, Tdata *out = reinterpret_cast(output); std::array ins = {reinterpret_cast(inputs[Is])...}; - const ptrdiff_t output_size = info.output_size; + const ptrdiff_t output_size = info.getOutputSize(); #pragma omp parallel for for (ptrdiff_t i = 0; i < output_size; ++i) { - size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); + size_t out_idx = info.isOutputContiguous() + ? i + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); auto get_input_idx = [&](size_t input_id) { - return info.input_contiguous[input_id] ? i - : (info.input_broadcasted[input_id] - ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) - : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); + return info.getInputContiguous()[input_id] + ? i + : (info.getInputBroadcasted()[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id)) + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id))); }; if constexpr (std::is_same_v) { @@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, // Invoke elementwise operation when all inputs have the same type template -infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, void *stream, Args &&...args) { +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { constexpr size_t N = Op::num_inputs; calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); return INFINI_STATUS_SUCCESS; diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index a7be14bdb..de791d548 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -3,6 +3,7 @@ #include "../../../utils.h" #include "../../devices/cuda/cuda_common.cuh" +#include "../../devices/cuda/cuda_kernel_common.cuh" #include "elementwise_cuda_api.cuh" namespace op::elementwise::cuda { @@ -16,18 +17,17 @@ namespace op::elementwise::cuda { * @param lambda Lambda to be called with std::integral_constant... as arguments. */ template -__device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence) { +__device__ __forceinline__ void callExpand(Lambda lambda, std::index_sequence) { lambda(std::integral_constant{}...); } /** * @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type. * + * @tparam N Number of input tensors. * @tparam Op Operator type implementing operator()(Tdata...). * @tparam Tdata Common data type for inputs and output. - * @tparam N Number of input tensors. * @tparam Args Additional arguments to pass to the operator. - * * @param output_size Total number of output elements. * @param ndim Number of dimensions in tensors. * @param output_contiguous Whether the output tensor is contiguous in memory. @@ -37,24 +37,22 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence -INFINIOP_CUDA_KERNEL elementwise_kernel( +INFINIOP_CUDA_KERNEL elementwiseKernel( size_t output_size, size_t ndim, bool output_contiguous, const bool *__restrict__ input_contiguous, const bool *__restrict__ input_broadcasted, const size_t *__restrict__ output_shape, - const size_t *__restrict__ *__restrict__ input_shapes, + const size_t *__restrict__ input_shapes, const ptrdiff_t *__restrict__ output_strides, - const ptrdiff_t *__restrict__ *__restrict__ input_strides, - size_t input_size, + const ptrdiff_t *__restrict__ input_strides, Tdata *output, const Tdata *const *inputs, size_t offset, @@ -68,8 +66,8 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( auto get_input_idx = [&] __device__(size_t input_id) { return input_contiguous[input_id] ? idx : (input_broadcasted[input_id] - ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) - : device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) + : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); }; // Use a helper to expand the index sequence into individual compile-time constants @@ -85,7 +83,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( } }; - call_expand(expand_inputs, std::make_index_sequence{}); + callExpand(expand_inputs, std::make_index_sequence{}); } } @@ -97,38 +95,38 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( * @return Pointer of type const T*. */ template -__device__ inline const T *typed_input_ptr(const void *ptr) { +__device__ inline const T *typedInputPtr(const void *ptr) { return reinterpret_cast(ptr); } /** - * @brief Launches a type-safe elementwise operation on a single output element. - * - * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). - * @tparam Tout Output data type. - * @tparam Tin Variadic input data types. - * @tparam Is Index sequence corresponding to each input. + * @brief Launches elementwise operation at a specific output index. * - * @param idx Linear index in the flattened output space. - * @param out_idx Actual output index (may be non-contiguous). - * @param ndim Number of dimensions in the tensors. - * @param input_contiguous Array indicating whether each input is contiguous. - * @param input_broadcasted Array indicating whether each input is broadcasted. - * @param input_shapes Shapes of the input tensors. - * @param input_strides Strides of the input tensors. - * @param inputs Raw pointers to input data. - * @param output Pointer to output data. - * @param ... Index sequence used for unpacking variadic inputs. + * @tparam Op Functor representing the elementwise operation. + * @tparam Tout Output data type. + * @tparam Tin... Input data types. + * @tparam Is... Index sequence for unpacking variadic inputs. + * @param idx Global linear index into the output tensor. + * @param out_idx Offset into the output array. + * @param ndim Number of dimensions in the tensors. + * @param input_contiguous Flags indicating whether each input is contiguous. + * @param input_broadcasted Flags indicating whether each input is broadcasted. + * @param input_shapes Flattened input shapes (N * ndim). + * @param input_strides Flattened input strides (N * ndim). + * @param output_strides Output tensor strides. + * @param inputs Array of pointers to input tensors. + * @param output Pointer to output tensor. + * @param ...Is Index sequence for iterating over input tensors. */ template -__device__ void launch_op( +__device__ void launchOp( size_t idx, size_t out_idx, size_t ndim, const bool *__restrict__ input_contiguous, const bool *__restrict__ input_broadcasted, - const size_t *__restrict__ const *__restrict__ input_shapes, - const ptrdiff_t *__restrict__ const *__restrict__ input_strides, + const size_t *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ input_strides, const ptrdiff_t *__restrict__ output_strides, const void *const *__restrict__ inputs, Tout *output, @@ -138,12 +136,12 @@ __device__ void launch_op( return input_contiguous[input_id] ? idx : (input_broadcasted[input_id] - ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) - : device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) + : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); }; output[out_idx] = Op{}.template operator()( - (typed_input_ptr(inputs[Is])[get_input_idx(Is)])...); + (typedInputPtr(inputs[Is])[get_input_idx(Is)])...); } /** @@ -153,7 +151,6 @@ __device__ void launch_op( * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). * @tparam Tout Output data type. * @tparam Tin Variadic input data types. - * * @param output_size Total number of output elements. * @param ndim Number of dimensions in the tensors. * @param output_contiguous Whether the output tensor is contiguous. @@ -163,23 +160,21 @@ __device__ void launch_op( * @param input_shapes Shapes of the input tensors. * @param output_strides Strides of the output tensor. * @param input_strides Strides of the input tensors. - * @param input_size Total number of input elements (unused here, but may be used for validation). * @param output Pointer to the output buffer. * @param inputs Array of untyped input pointers. * @param offset Linear offset into the output for partitioned execution. */ template -INFINIOP_CUDA_KERNEL elementwise_kernel( +INFINIOP_CUDA_KERNEL elementwiseKernel( size_t output_size, size_t ndim, bool output_contiguous, const bool *__restrict__ input_contiguous, const bool *__restrict__ input_broadcasted, const size_t *__restrict__ output_shape, - const size_t *__restrict__ const *__restrict__ input_shapes, + const size_t *__restrict__ input_shapes, const ptrdiff_t *__restrict__ output_strides, - const ptrdiff_t *__restrict__ const *__restrict__ input_strides, - size_t input_size, + const ptrdiff_t *__restrict__ input_strides, Tout *output, const void *const *__restrict__ inputs, size_t offset) { @@ -193,7 +188,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( ? idx : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); - launch_op( + launchOp( idx, out_idx, ndim, @@ -214,252 +209,185 @@ struct DeviceImpl::Opaque { : internal(internal) {} /** - * @brief Performs elementwise operations when all inputs and the output share the same data type. + * @brief Executes an elementwise operation where all inputs and the output share the same data type. * - * @tparam BLOCK_SIZE The block size for the kernel launch. - * @tparam N The number of input tensors. - * @tparam Op The operation to perform (e.g., addition, multiplication). - * @tparam Tdata The data type of the input and output tensors. - * @tparam Args Additional arguments to be passed to the operation. - * @param info Structure containing elementwise operation information (size, shape, etc.). - * @param output Pointer to the output memory where results will be stored. - * @param inputs Vector of pointers to input tensors. - * @param stream CUDA stream used for asynchronous execution. - * @param args Additional arguments for the operation. - * @return infiniStatus_t Status indicating success or failure. + * @tparam BLOCK_SIZE CUDA block size used for kernel launch. + * @tparam N Number of input tensors. + * @tparam Op Functor representing the elementwise operation. + * @tparam Tdata Data type of both input and output tensors. + * @tparam Args Optional additional arguments passed to the operation. + * @param info Metadata about the operation including shape, size, and dimensionality. + * @param workspace Temporary workspace used for storing metadata on device. + * @param output Pointer to the output buffer. + * @param inputs Vector of pointers to input buffers. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments forwarded to the operation. + * @return infiniStatus_t Returns success or failure status. */ - template + template infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, - std::index_sequence, cudaStream_t stream, Args &&...args) { - if (info.output_size == 0) { + auto output_size = info.getOutputSize(); + if (output_size == 0) { return INFINI_STATUS_SUCCESS; } // casting the output and the inputs to Tdata pointers Tdata *out = reinterpret_cast(output); - const Tdata *inputs_arr[N]; - const Tdata **d_inputs_arr = nullptr; - for (size_t i = 0; i < N; ++i) { - inputs_arr[i] = reinterpret_cast(inputs[i]); - } - CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream)); - CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream)); + const void **d_inputs_arr = nullptr; // create and send the info to device - const bool *d_bools = nullptr; const bool *d_input_contiguous = nullptr; const bool *d_input_broadcasted = nullptr; - const int8_t *d_output_shape_strides = nullptr; const size_t *d_output_shape = nullptr; const ptrdiff_t *d_output_strides = nullptr; - const size_t **d_input_shapes = nullptr; - const ptrdiff_t **d_input_strides = nullptr; - std::vector tmp_device_ptrs(info.input_size); - std::vector tmp_device_ptrs_strides(info.input_size); + const size_t *d_input_shapes = nullptr; + const ptrdiff_t *d_input_strides = nullptr; - CHECK_STATUS(infoToDevice(info, d_bools, d_input_contiguous, - d_input_broadcasted, d_output_shape_strides, d_output_shape, - d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, - d_input_strides, stream)); + CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted, + d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream)); dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); - dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); + dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); size_t step = gridDims.x * blockDims.x; - for (size_t i = 0; i < info.output_size; i += step) { - elementwise_kernel<<>>( - info.output_size, - info.ndim, - info.output_contiguous, + for (size_t i = 0; i < output_size; i += step) { + elementwiseKernel<<>>( + output_size, + info.getNdim(), + info.isOutputContiguous(), d_input_contiguous, d_input_broadcasted, d_output_shape, d_input_shapes, d_output_strides, d_input_strides, - info.input_size, out, d_inputs_arr, i, std::forward(args)...); + out, reinterpret_cast(d_inputs_arr), i, std::forward(args)...); } - CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, - info.input_size, d_input_shapes, d_input_strides, stream)); return INFINI_STATUS_SUCCESS; } /** - * @brief Performs elementwise operations when inputs and the outputs have mixed data types (i.e., different dtypes). + * @brief Executes an elementwise operation with mixed input and output data types. * - * @tparam BLOCK_SIZE The block size for the kernel launch. - * @tparam N The number of input tensors. - * @tparam Op The operation to perform (e.g., addition, multiplication). - * @tparam Tout The output data type. - * @tparam Tin The input data types. - * @tparam Args Additional arguments to be passed to the operation. - * @param info Structure containing elementwise operation information (size, shape, etc.). - * @param output Pointer to the output memory where results will be stored. - * @param inputs Vector of pointers to input tensors. - * @param stream CUDA stream used for asynchronous execution. - * @param args Additional arguments for the operation. - * @return infiniStatus_t Status indicating success or failure. + * @tparam BLOCK_SIZE CUDA block size used for kernel launch. + * @tparam N Number of input tensors. + * @tparam Op Functor representing the elementwise operation. + * @tparam Tout Data type of the output tensor. + * @tparam Tin... Data types of the input tensors. + * @tparam Args Optional additional arguments passed to the operation.(UNUSED) + * @param info Metadata about the operation including shape, size, and dimensionality. + * @param workspace Temporary workspace used for storing metadata on device. + * @param output Pointer to the output buffer. + * @param inputs Vector of pointers to input buffers. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments forwarded to the operation. + * @return infiniStatus_t Returns success or failure status. */ - template = 0> infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, - std::index_sequence, cudaStream_t stream, Args &&...args) { - if (info.output_size == 0) { + auto output_size = info.getOutputSize(); + if (output_size == 0) { return INFINI_STATUS_SUCCESS; } Tout *out = reinterpret_cast(output); - - // Store input pointers with the correct types - const std::tuple inputs_arr{reinterpret_cast(inputs[Is])...}; const void **d_inputs_arr = nullptr; - // Create array of input pointers on host (void*) to copy to device - const void *host_input_ptrs[] = {reinterpret_cast(std::get(inputs_arr))...}; - CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream)); - CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream)); - // Device pointers - const bool *d_bools = nullptr; const bool *d_input_contiguous = nullptr; const bool *d_input_broadcasted = nullptr; - const int8_t *d_output_shape_strides = nullptr; const size_t *d_output_shape = nullptr; const ptrdiff_t *d_output_strides = nullptr; - const size_t **d_input_shapes = nullptr; - const ptrdiff_t **d_input_strides = nullptr; - std::vector tmp_device_ptrs(info.input_size); - std::vector tmp_device_ptrs_strides(info.input_size); + const size_t *d_input_shapes = nullptr; + const ptrdiff_t *d_input_strides = nullptr; - CHECK_STATUS(infoToDevice(info, d_bools, d_input_contiguous, - d_input_broadcasted, d_output_shape_strides, d_output_shape, - d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, - d_input_strides, stream)); + CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted, + d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream)); dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); - dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast(internal->gridSizeX()))); + dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); size_t step = gridDims.x * blockDims.x; - for (size_t i = 0; i < info.output_size; i += step) { - elementwise_kernel<<>>( - info.output_size, - info.ndim, - info.output_contiguous, + for (size_t i = 0; i < output_size; i += step) { + elementwiseKernel<<>>( + output_size, + info.getNdim(), + info.isOutputContiguous(), d_input_contiguous, d_input_broadcasted, d_output_shape, d_input_shapes, d_output_strides, d_input_strides, - info.input_size, out, reinterpret_cast(d_inputs_arr), i); + out, reinterpret_cast(d_inputs_arr), i); } - CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream)); return INFINI_STATUS_SUCCESS; } private: /** - * @brief Transfers elementwise kernel metadata (shapes, strides, flags) from host to device. + * @brief Transfers elementwise operation metadata and input pointers from host to device memory. * - * @tparam N Number of inputs. - * @param info Structure containing input/output metadata. - * @param d_bools Device pointer for input_contiguous and input_broadcasted flags. - * @param d_input_contiguous Device pointer to input contiguity flags. - * @param d_input_broadcasted Device pointer to input broadcasting flags. - * @param d_output_shape_strides Device buffer containing both output shape and strides. - * @param d_output_shape Device pointer to output shape. - * @param d_output_strides Device pointer to output strides. - * @param tmp_device_ptrs Temporary device pointers for input shapes. - * @param d_input_shapes Device array of pointers to input shapes. - * @param tmp_device_ptrs_strides Temporary device pointers for input strides. - * @param d_input_strides Device array of pointers to input strides. - * @param stream CUDA stream for async allocation and transfers. - * @return infiniStatus_t Status indicating success or failure. + * @tparam N Number of input tensors. + * @param info Elementwise operation metadata (shapes, strides, flags, etc.). + * @param workspace Pointer to device workspace memory for storing metadata and input pointers. + * @param h_inputs_arr Host array of input tensor pointers. + * @param d_inputs_arr Output reference to device array of input tensor pointers. + * @param d_input_contiguous Output reference to device array indicating whether each input is contiguous. + * @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted. + * @param d_output_shape Output reference to device array holding the output tensor shape. + * @param d_output_strides Output reference to device array holding output tensor strides. + * @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim). + * @param d_input_strides Output reference to flattened input tensor strides (N * ndim). + * @param stream CUDA stream used for asynchronous memory transfer. + * @return infiniStatus_t Status indicating success or failure of the memory transfer and setup. */ template infiniStatus_t infoToDevice( const op::elementwise::ElementwiseInfo &info, - const bool *&d_bools, + void *workspace, + const void *const *h_inputs_arr, + const void **&d_inputs_arr, const bool *&d_input_contiguous, const bool *&d_input_broadcasted, - const int8_t *&d_output_shape_strides, const size_t *&d_output_shape, const ptrdiff_t *&d_output_strides, - std::vector &tmp_device_ptrs, - const size_t **&d_input_shapes, - std::vector &tmp_device_ptrs_strides, - const ptrdiff_t **&d_input_strides, + const size_t *&d_input_shapes, + const ptrdiff_t *&d_input_strides, cudaStream_t stream) const { - CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); - - CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream)); - - CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream)); - for (size_t i = 0; i < info.input_size; ++i) { - CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], - info.ndim * sizeof(*tmp_device_ptrs[i]), cudaMemcpyHostToDevice, stream)); - } - CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), - info.input_size * sizeof(*d_input_shapes), cudaMemcpyHostToDevice, stream)); - - CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream)); - for (size_t i = 0; i < info.input_size; ++i) { - CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*tmp_device_ptrs_strides[i]), stream)); - CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], - info.ndim * sizeof(*tmp_device_ptrs_strides[i]), cudaMemcpyHostToDevice, stream)); - } - CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), - info.input_size * sizeof(*d_input_strides), cudaMemcpyHostToDevice, stream)); - - d_input_contiguous = d_bools; - d_input_broadcasted = d_bools + info.input_size; - d_output_shape = reinterpret_cast(d_output_shape_strides); - d_output_strides = reinterpret_cast(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)); - - return INFINI_STATUS_SUCCESS; - } + constexpr auto input_size = N; + const auto ndim = info.getNdim(); + constexpr auto input_arr_size = N * sizeof(*h_inputs_arr); + const int8_t *info_meta_start = info.getMetaStart(); + const int8_t *d_meta_start = reinterpret_cast(workspace) + input_arr_size; + + // copy the input pointer array and meta to device + CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream)); + + // offset/assign the pointers + d_inputs_arr = reinterpret_cast(workspace); + d_output_shape = reinterpret_cast(d_meta_start); + d_output_strides = reinterpret_cast(d_output_shape + ndim); + d_input_shapes = reinterpret_cast(d_output_strides + ndim); + d_input_strides = reinterpret_cast(d_input_shapes + input_size * ndim); + d_input_contiguous = reinterpret_cast(d_input_strides + input_size * ndim); + d_input_broadcasted = reinterpret_cast(d_input_contiguous + input_size); - /** - * @brief Frees all device-allocated memory used for metadata in elementwise kernel execution. - * - * @param d_inputs_arr Device array of input pointers. - * @param d_bools Device memory holding input flags. - * @param d_output_shape_strides Device buffer holding output shape and strides. - * @param input_size Number of input tensors. - * @param d_input_shapes Device array of input shape pointers. - * @param d_input_strides Device array of input stride pointers. - * @param stream CUDA stream for async deallocation. - * @return infiniStatus_t Status indicating success or failure. - */ - inline infiniStatus_t freeAllDevice(const void **d_inputs_arr, - const bool *d_bools, - const int8_t *d_output_shape_strides, - const size_t input_size, - const size_t **d_input_shapes, - const ptrdiff_t **d_input_strides, - cudaStream_t stream) const { - - CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream)); - CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream)); - CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream)); - CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream)); - CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream)); return INFINI_STATUS_SUCCESS; } }; @@ -476,6 +404,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, template > infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, void *stream, @@ -483,8 +412,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf constexpr size_t N = Op::num_inputs; static_assert(sizeof...(Tin) == N, "Input type count mismatch"); return _opaque->calculateImpl( - info, output, inputs, - std::make_index_sequence{}, + info, workspace, output, inputs, reinterpret_cast(stream), std::forward(args)...); } @@ -492,14 +420,14 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf /* Invoke elementwise operation when all inputs have the same dtype */ template infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, void *stream, Args &&...args) { constexpr size_t N = Op::num_inputs; return _opaque->calculateImpl( - info, output, inputs, - std::make_index_sequence{}, + info, workspace, output, inputs, reinterpret_cast(stream), std::forward(args)...); } diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index 78b1ea881..eec8b3511 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -31,6 +31,7 @@ public: * @tparam Args... Additional arguments passed to the operation. * * @param info Metadata describing tensor shapes, strides, etc. + * @param workspace Pointer to workspace buffer on device. * @param output Pointer to output buffer on device. * @param inputs Vector of input pointers (device memory). * @param stream CUDA stream (opaque void*). @@ -40,6 +41,7 @@ public: template infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, void *stream, @@ -56,6 +58,7 @@ public: * @tparam Tin... Input data types (must match Op::num_inputs). * @tparam Args... Additional arguments passed to the operation. * @param info Metadata describing tensor shapes, strides, etc. + * @param workspace Pointer to workspace buffer on device. * @param output Pointer to output buffer on device. * @param inputs Vector of input pointers (device memory). * @param stream CUDA stream (opaque void*). @@ -67,6 +70,7 @@ public: std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> infiniStatus_t calculate( const op::elementwise::ElementwiseInfo &info, + void *workspace, void *output, const std::vector &inputs, void *stream, @@ -82,14 +86,17 @@ public: \ auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ CHECK_RESULT(info_result); \ + auto info = info_result.take(); \ + auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ \ op::elementwise::cuda::DeviceImpl *device_impl; \ CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ \ *desc_ptr = new Descriptor( \ dtype, \ - std::move(info_result.take()), \ + std::move(info), \ device_impl, \ + workspace_size, \ handle->device, \ handle->device_id); diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index 80011417f..9feb8dd74 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -19,21 +19,26 @@ infiniDtype_t _dtype; \ op::elementwise::ElementwiseInfo _info; \ std::unique_ptr _device_info; \ + size_t _workspace_size; \ \ Descriptor( \ infiniDtype_t dtype, \ op::elementwise::ElementwiseInfo info, \ op::elementwise::NAMESPACE::DeviceImpl *device_info, \ + size_t workspace_size, \ infiniDevice_t device_type, \ int device_id) \ : InfiniopDescriptor{device_type, device_id}, \ _dtype(dtype), \ _info(std::move(info)), \ - _device_info(device_info) {} \ + _device_info(device_info), \ + _workspace_size(workspace_size) {} \ \ public: \ ~Descriptor(); \ \ + size_t workspaceSize() const { return _workspace_size; } \ + \ static infiniStatus_t create( \ infiniopHandle_t handle, \ Descriptor **desc_ptr, \ @@ -41,6 +46,7 @@ std::vector input_descs); \ \ infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ void *output, \ std::vector inputs, \ void *stream) const; \ @@ -62,57 +68,70 @@ namespace op::elementwise { */ struct ElementwiseInfo { private: - ElementwiseInfo() = default; + std::vector _meta; + size_t _output_size; + size_t _input_size; + size_t _ndim; + bool _output_contiguous; + + ElementwiseInfo(std::vector meta, + size_t output_size, + size_t input_size, + size_t ndim, + bool output_contiguous) + : _meta(std::move(meta)), _output_size(output_size), + _input_size(input_size), _ndim(ndim), + _output_contiguous(output_contiguous) {} public: - size_t output_size; - size_t ndim; - bool output_contiguous; - bool *input_contiguous; - bool *input_broadcasted; - size_t *output_shape; - size_t **input_shapes; - ptrdiff_t *output_strides; - ptrdiff_t **input_strides; - size_t input_size; - - ~ElementwiseInfo() { - delete[] input_contiguous; - delete[] input_broadcasted; - delete[] output_shape; - delete[] output_strides; - - for (size_t i = 0; i < input_size; ++i) { - delete[] input_shapes[i]; - delete[] input_strides[i]; + inline size_t getMetaMemSize() const { + return _meta.size(); + } + inline const int8_t *getMetaStart() const { + return _meta.data(); + } + inline size_t getOutputSize() const { + return _output_size; + } + inline size_t getInputSize() const { + return _input_size; + } + inline size_t getNdim() const { + return _ndim; + } + inline bool isOutputContiguous() const { + return _output_contiguous; + } + inline const size_t *getOutputShape() const { + return reinterpret_cast(_meta.data()); + } + inline const ptrdiff_t *getOutputStrides() const { + return reinterpret_cast(getOutputShape() + _ndim); + } + inline const size_t *getAllInputShapes() const { + return reinterpret_cast(getOutputStrides() + _ndim); + } + inline const size_t *getInputShape(const size_t &index) const { + if (index < _input_size) { + return reinterpret_cast(getAllInputShapes() + index * _ndim); + } + return nullptr; + } + inline const ptrdiff_t *getAllInputStrides() const { + return reinterpret_cast(getAllInputShapes() + _input_size * _ndim); + } + inline const ptrdiff_t *getInputStrides(const size_t &index) const { + if (index < _input_size) { + return reinterpret_cast(getAllInputStrides() + index * _ndim); } - delete[] input_shapes; - delete[] input_strides; - } - - ElementwiseInfo(ElementwiseInfo &&other) noexcept - : output_size(other.output_size), - ndim(other.ndim), - output_contiguous(other.output_contiguous), - input_contiguous(other.input_contiguous), - input_broadcasted(other.input_broadcasted), - output_shape(other.output_shape), - input_shapes(other.input_shapes), - output_strides(other.output_strides), - input_strides(other.input_strides), - input_size(other.input_size) { - other.input_contiguous = nullptr; - other.input_broadcasted = nullptr; - other.output_shape = nullptr; - other.input_shapes = nullptr; - other.output_strides = nullptr; - other.input_strides = nullptr; - other.input_size = 0; - } - - ElementwiseInfo(const ElementwiseInfo &other) = delete; - ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete; - ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete; + return nullptr; + } + inline const bool *getInputContiguous() const { + return reinterpret_cast(getAllInputStrides() + _input_size * _ndim); + } + inline const bool *getInputBroadcasted() const { + return reinterpret_cast(getInputContiguous() + _input_size); + } using ResultType = utils::Result; @@ -136,40 +155,48 @@ struct ElementwiseInfo { return INFINI_STATUS_BAD_TENSOR_STRIDES; } - ElementwiseInfo info; - info.input_size = input_descs.size(); - info.ndim = output_desc->ndim(); - info.output_size = output_desc->numel(); - info.output_contiguous = output_desc->isContiguous(); - - // Allocate memory for arrays - info.input_contiguous = new bool[info.input_size]; - info.input_broadcasted = new bool[info.input_size]; - info.output_shape = new size_t[info.ndim]; - info.output_strides = new ptrdiff_t[info.ndim]; - info.input_shapes = new size_t *[info.input_size]; - info.input_strides = new ptrdiff_t *[info.input_size]; - - // Fill arrays + auto input_size = input_descs.size(); + auto ndim = output_desc->ndim(); + auto output_size = output_desc->numel(); + auto output_contiguous = output_desc->isContiguous(); + + // Allocate memory for meta + auto shape_unit = output_desc->dim(0); + auto stride_unit = output_desc->stride(0); + size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit)) + + input_size * ndim * sizeof(shape_unit) + + input_size * ndim * sizeof(stride_unit) + + 2 * input_size * sizeof(bool); + std::vector meta(meta_mem_size); + int8_t *meta_ptr = meta.data(); + const auto output_shape = output_desc->shape(); const auto output_strides = output_desc->strides(); - std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape)); - std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides)); - for (size_t i = 0; i < info.input_size; ++i) { - auto &desc = input_descs[i]; - info.input_contiguous[i] = desc->isContiguous(); - info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); + // Pointers to the sections within _meta + size_t *output_shape_p = reinterpret_cast(meta_ptr); + ptrdiff_t *output_strides_p = reinterpret_cast(output_shape_p + ndim); + size_t *input_shapes = reinterpret_cast(output_strides_p + ndim); + ptrdiff_t *input_strides = reinterpret_cast(input_shapes + input_size * ndim); + bool *input_contiguous = reinterpret_cast(input_strides + input_size * ndim); + bool *input_broadcasted = input_contiguous + input_size; - info.input_shapes[i] = new size_t[desc->ndim()]; - const auto &in_shape = desc->shape(); - std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); + // Copy output shape and strides + std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p)); + std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p)); - info.input_strides[i] = new ptrdiff_t[desc->ndim()]; - const auto &in_strides = desc->strides(); - std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); + // Copy input shapes, strides, contiguous, and broadcasted flags + for (size_t i = 0; i < input_size; ++i) { + auto &desc = input_descs[i]; + const auto in_shape = desc->shape(); + const auto in_strides = desc->strides(); + std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes)); + std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides)); + input_contiguous[i] = desc->isContiguous(); + input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim()); } + ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous); return ResultType(std::move(info)); } }; diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc index 8413d295a..6c283c02c 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc @@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create( } infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, void *output, std::vector inputs, void *stream) const { diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h index 67e42d2c6..65c1c7c33 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h @@ -10,7 +10,7 @@ typedef struct SwiGLUOp { private: template T sigmoid(const T &x) const { - return 1 / (1 + std::exp(-x)); + return T(1) / (T(1) + std::exp(-x)); } public: diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu index 5b0e8cee6..b6ab533d8 100644 --- a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu @@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create( const auto &gate_shape = gate_desc->shape(); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - if (!SAME_VEC(out_shape, up_shape, gate_shape)) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } + CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); // create CUDA elementwise descriptor CREATE_ELEMENTWISE_CUDA_DESCRIPTOR @@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create( } infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, void *output, std::vector inputs, void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); + return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); + return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); + return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index de3ecb874..3f90882c1 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( #undef CREATE } +__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda) +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size); + } +#endif +#ifdef ENABLE_ASCEND_API + GET(INFINI_DEVICE_ASCEND, ascend) +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size); + } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size); + } +#endif + } + +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + __C infiniStatus_t infiniopSwiGLU( infiniopSwiGLUDescriptor_t desc, + void *workspace, + size_t workspace_size, void *c, const void *a, const void *b, @@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU( #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(c, {a, b}, stream) + ->calculate(workspace, workspace_size, c, {a, b}, stream) switch (desc->device_type) { diff --git a/test/infiniop/swiglu.py b/test/infiniop/swiglu.py index 09649af87..01d6f9612 100644 --- a/test/infiniop/swiglu.py +++ b/test/infiniop/swiglu.py @@ -1,6 +1,6 @@ import torch import ctypes -from ctypes import POINTER, Structure, c_int32, c_void_p +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 from libinfiniop import ( infiniopHandle_t, infiniopTensorDescriptor_t, @@ -14,6 +14,7 @@ debug, get_tolerance, profile_operation, + create_workspace ) from enum import Enum, auto @@ -160,10 +161,19 @@ def test( for tensor in [a_tensor, b_tensor, c_tensor]: tensor.destroyDesc(lib) + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + ) + workspace = create_workspace(workspace_size.value, c.device) + def lib_swiglu(): check_error( lib.infiniopSwiGLU( - descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + c_tensor.data, a_tensor.data, b_tensor.data, None ) ) @@ -196,10 +206,18 @@ def lib_swiglu(): infiniopTensorDescriptor_t, ] + lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32 + lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [ + infiniopSwiGLUDescriptor_t, + POINTER(c_uint64), + ] + lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.argtypes = [ infiniopSwiGLUDescriptor_t, c_void_p, + c_uint64, + c_void_p, c_void_p, c_void_p, c_void_p, From 1d182fba55ed50a544df8e4ef34af7eaf5bd4f33 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 14 Apr 2025 17:00:18 +0800 Subject: [PATCH 09/14] issue/127: Optimize elementwise CUDA code by removing redundancy, change/correct kernel logic when all inputs have the same dtype --- .../elementwise/cuda/elementwise_cuda.cuh | 350 ++++++++---------- 1 file changed, 164 insertions(+), 186 deletions(-) diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index de791d548..ce8224bd4 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -9,16 +9,74 @@ namespace op::elementwise::cuda { /** - * @brief Helper device function to expand a compile-time index sequence into individual constants - * and pass them to a lambda. + * @brief Casts an untyped device pointer to a typed pointer of type T. * - * @tparam Lambda Type of the lambda function to invoke. - * @tparam Is Index sequence values (automatically deduced). - * @param lambda Lambda to be called with std::integral_constant... as arguments. + * @tparam T Desired pointer type. + * @param ptr Untyped pointer. + * @return Pointer of type const T*. */ -template -__device__ __forceinline__ void callExpand(Lambda lambda, std::index_sequence) { - lambda(std::integral_constant{}...); +template +__device__ __forceinline__ const T *typedInputPtr(const void *ptr) { + return reinterpret_cast(ptr); +} + +/** + * @brief Computes the output index in memory, accounting for strides if non-contiguous. + * + * @param idx Linear index. + * @param is_contiguous Whether the output tensor is contiguous. + * @param ndim Number of dimensions. + * @param shape Shape of the output tensor. + * @param strides Strides of the output tensor. + * @return Memory offset index. + */ +__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, + const size_t *shape, const ptrdiff_t *strides) { + return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides); +} + +/** + * @brief Computes input element offset for broadcasting and strided access. + * + * Used to map a linear output index to the corresponding index in an input tensor, + * considering contiguity and broadcasting. + */ +struct InputIndexer { + size_t idx; + size_t ndim; + const bool *input_contiguous; + const bool *input_broadcasted; + const size_t *input_shapes; + const ptrdiff_t *input_strides; + const ptrdiff_t *output_strides; + + /** + * @brief Computes the memory offset for a given input tensor at current index. + * + * @param input_id ID of the input tensor. + * @return Offset into the input tensor. + */ + __device__ __forceinline__ size_t operator()(size_t input_id) const { + return input_contiguous[input_id] + ? idx + : (input_broadcasted[input_id] + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) + : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); + } +}; + +/** + * @brief Invokes a callable with compile-time index constants. + * + * Used to unpack index sequence for variadic template processing of inputs. + * + * @tparam F Callable type. + * @tparam Is Compile-time index sequence. + * @param f Callable to invoke with index constants. + */ +template +__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence) { + f(std::integral_constant{}...); } /** @@ -54,96 +112,25 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( const ptrdiff_t *__restrict__ output_strides, const ptrdiff_t *__restrict__ input_strides, Tdata *output, - const Tdata *const *inputs, + const void *const *inputs, size_t offset, Args... args) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + if (idx < output_size) { - size_t out_idx = output_contiguous ? idx - : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); - - auto get_input_idx = [&] __device__(size_t input_id) { - return input_contiguous[input_id] ? idx - : (input_broadcasted[input_id] - ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) - : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); - }; - - // Use a helper to expand the index sequence into individual compile-time constants - auto expand_inputs = [&] __device__(auto... idxs) { - if constexpr (std::is_same_v) { - output[out_idx] = utils::cast( - Op{}(utils::cast(inputs[idxs.value][get_input_idx(idxs.value)])..., - std::forward(args)...)); - } else { - output[out_idx] = Op{}( - inputs[idxs.value][get_input_idx(idxs.value)]..., - std::forward(args)...); - } - }; - - callExpand(expand_inputs, std::make_index_sequence{}); + const Tdata *const *typed_inputs = reinterpret_cast(inputs); + size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides); + InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides}; + + unpackInputsAndApply( + [&](auto... Is) { + output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward(args)...); + }, + std::make_index_sequence{}); } } -/** - * @brief Casts an untyped device pointer to a typed pointer of type T. - * - * @tparam T Desired pointer type. - * @param ptr Untyped pointer. - * @return Pointer of type const T*. - */ -template -__device__ inline const T *typedInputPtr(const void *ptr) { - return reinterpret_cast(ptr); -} - -/** - * @brief Launches elementwise operation at a specific output index. - * - * @tparam Op Functor representing the elementwise operation. - * @tparam Tout Output data type. - * @tparam Tin... Input data types. - * @tparam Is... Index sequence for unpacking variadic inputs. - * @param idx Global linear index into the output tensor. - * @param out_idx Offset into the output array. - * @param ndim Number of dimensions in the tensors. - * @param input_contiguous Flags indicating whether each input is contiguous. - * @param input_broadcasted Flags indicating whether each input is broadcasted. - * @param input_shapes Flattened input shapes (N * ndim). - * @param input_strides Flattened input strides (N * ndim). - * @param output_strides Output tensor strides. - * @param inputs Array of pointers to input tensors. - * @param output Pointer to output tensor. - * @param ...Is Index sequence for iterating over input tensors. - */ -template -__device__ void launchOp( - size_t idx, - size_t out_idx, - size_t ndim, - const bool *__restrict__ input_contiguous, - const bool *__restrict__ input_broadcasted, - const size_t *__restrict__ input_shapes, - const ptrdiff_t *__restrict__ input_strides, - const ptrdiff_t *__restrict__ output_strides, - const void *const *__restrict__ inputs, - Tout *output, - std::index_sequence) { - - auto get_input_idx = [&] __device__(size_t input_id) { - return input_contiguous[input_id] - ? idx - : (input_broadcasted[input_id] - ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) - : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); - }; - - output[out_idx] = Op{}.template operator()( - (typedInputPtr(inputs[Is])[get_input_idx(Is)])...); -} - /** * @brief CUDA kernel for performing an elementwise operation on tensors with support * for broadcasting and mixed data types. @@ -180,26 +167,18 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( size_t offset) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; - if (idx >= output_size) { - return; - } - size_t out_idx = output_contiguous - ? idx - : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); - - launchOp( - idx, - out_idx, - ndim, - input_contiguous, - input_broadcasted, - input_shapes, - input_strides, - output_strides, - inputs, - output, - std::index_sequence_for{}); + if (idx < output_size) { + size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides); + InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides}; + + unpackInputsAndApply( + [&](auto... Is) { + output[out_idx] = Op{}.template operator()( + (typedInputPtr(inputs[Is.value])[indexer(Is.value)])...); + }, + std::index_sequence_for{}); + } } struct DeviceImpl::Opaque { @@ -231,45 +210,12 @@ struct DeviceImpl::Opaque { const std::vector &inputs, cudaStream_t stream, Args &&...args) { - auto output_size = info.getOutputSize(); - if (output_size == 0) { - return INFINI_STATUS_SUCCESS; - } - - // casting the output and the inputs to Tdata pointers - Tdata *out = reinterpret_cast(output); - const void **d_inputs_arr = nullptr; - - // create and send the info to device - const bool *d_input_contiguous = nullptr; - const bool *d_input_broadcasted = nullptr; - const size_t *d_output_shape = nullptr; - const ptrdiff_t *d_output_strides = nullptr; - const size_t *d_input_shapes = nullptr; - const ptrdiff_t *d_input_strides = nullptr; - - CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted, - d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream)); - - dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); - dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); - size_t step = gridDims.x * blockDims.x; - - for (size_t i = 0; i < output_size; i += step) { - elementwiseKernel<<>>( - output_size, - info.getNdim(), - info.isOutputContiguous(), - d_input_contiguous, - d_input_broadcasted, - d_output_shape, - d_input_shapes, - d_output_strides, - d_input_strides, - out, reinterpret_cast(d_inputs_arr), i, std::forward(args)...); - } - - return INFINI_STATUS_SUCCESS; + return launchElementwiseKernel( + info, workspace, + reinterpret_cast(output), inputs, + elementwiseKernel, + stream, + std::forward(args)...); } /** @@ -297,44 +243,12 @@ struct DeviceImpl::Opaque { const std::vector &inputs, cudaStream_t stream, Args &&...args) { - auto output_size = info.getOutputSize(); - if (output_size == 0) { - return INFINI_STATUS_SUCCESS; - } - - Tout *out = reinterpret_cast(output); - const void **d_inputs_arr = nullptr; - - // Device pointers - const bool *d_input_contiguous = nullptr; - const bool *d_input_broadcasted = nullptr; - const size_t *d_output_shape = nullptr; - const ptrdiff_t *d_output_strides = nullptr; - const size_t *d_input_shapes = nullptr; - const ptrdiff_t *d_input_strides = nullptr; - - CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted, - d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream)); - - dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); - dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); - size_t step = gridDims.x * blockDims.x; - - for (size_t i = 0; i < output_size; i += step) { - elementwiseKernel<<>>( - output_size, - info.getNdim(), - info.isOutputContiguous(), - d_input_contiguous, - d_input_broadcasted, - d_output_shape, - d_input_shapes, - d_output_strides, - d_input_strides, - out, reinterpret_cast(d_inputs_arr), i); - } - - return INFINI_STATUS_SUCCESS; + return launchElementwiseKernel( + info, workspace, + reinterpret_cast(output), inputs, + elementwiseKernel, + stream, + std::forward(args)...); } private: @@ -390,6 +304,70 @@ private: return INFINI_STATUS_SUCCESS; } + + /** + * @brief Launches the elementwise kernel for the specified operation. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam N Number of input tensors. + * @tparam KernelFunc Type of the kernel function pointer. + * @tparam Tout Output data type. + * @tparam Args Additional arguments to be forwarded to the kernel. + * + * @param info Metadata about the elementwise operation (shapes, strides, etc.). + * @param workspace CUDA memory used for storing metadata. + * @param output Pointer to output buffer on device. + * @param inputs Vector of device pointers to input tensors. + * @param kernel_func Kernel function to launch. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments passed to the kernel. + * @return infiniStatus_t Status code indicating success or failure. + */ + template + infiniStatus_t launchElementwiseKernel( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + Tout *output, + const std::vector &inputs, + KernelFunc kernel_func, + cudaStream_t stream, + Args &&...args) { + + auto output_size = info.getOutputSize(); + if (output_size == 0) { + return INFINI_STATUS_SUCCESS; + } + + // Device pointers + const void **d_inputs_arr = nullptr; + const bool *d_input_contiguous = nullptr; + const bool *d_input_broadcasted = nullptr; + const size_t *d_output_shape = nullptr; + const ptrdiff_t *d_output_strides = nullptr; + const size_t *d_input_shapes = nullptr; + const ptrdiff_t *d_input_strides = nullptr; + + CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, + d_input_contiguous, d_input_broadcasted, + d_output_shape, d_output_strides, + d_input_shapes, d_input_strides, stream)); + + dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); + dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); + size_t step = gridDims.x * blockDims.x; + + for (size_t i = 0; i < output_size; i += step) { + kernel_func<<>>( + output_size, info.getNdim(), info.isOutputContiguous(), + d_input_contiguous, d_input_broadcasted, + d_output_shape, d_input_shapes, + d_output_strides, d_input_strides, + output, reinterpret_cast(d_inputs_arr), + i, std::forward(args)...); + } + + return INFINI_STATUS_SUCCESS; + } }; template From d54312d0bd70de678a6b20567ba1e0361b576f12 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 14 Apr 2025 17:22:00 +0800 Subject: [PATCH 10/14] issue/127: remove the args forward in the calculateImpl() that handles different dtypes --- src/infiniop/elementwise/cuda/elementwise_cuda.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index ce8224bd4..34ed8c0c2 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -247,8 +247,7 @@ struct DeviceImpl::Opaque { info, workspace, reinterpret_cast(output), inputs, elementwiseKernel, - stream, - std::forward(args)...); + stream); } private: From 82adff3d87c4983bc99422ba0a7bc1762c4bfecc Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 14 Apr 2025 17:39:06 +0800 Subject: [PATCH 11/14] issue/127: add a blank line between @tparam and @param for comments in cuda elementwise files --- src/infiniop/elementwise/cuda/elementwise_cuda.cuh | 7 +++++++ src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh | 1 + 2 files changed, 8 insertions(+) diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index 34ed8c0c2..8d66d74c0 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -12,6 +12,7 @@ namespace op::elementwise::cuda { * @brief Casts an untyped device pointer to a typed pointer of type T. * * @tparam T Desired pointer type. + * * @param ptr Untyped pointer. * @return Pointer of type const T*. */ @@ -72,6 +73,7 @@ struct InputIndexer { * * @tparam F Callable type. * @tparam Is Compile-time index sequence. + * * @param f Callable to invoke with index constants. */ template @@ -86,6 +88,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence< * @tparam Op Operator type implementing operator()(Tdata...). * @tparam Tdata Common data type for inputs and output. * @tparam Args Additional arguments to pass to the operator. + * * @param output_size Total number of output elements. * @param ndim Number of dimensions in tensors. * @param output_contiguous Whether the output tensor is contiguous in memory. @@ -138,6 +141,7 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). * @tparam Tout Output data type. * @tparam Tin Variadic input data types. + * * @param output_size Total number of output elements. * @param ndim Number of dimensions in the tensors. * @param output_contiguous Whether the output tensor is contiguous. @@ -195,6 +199,7 @@ struct DeviceImpl::Opaque { * @tparam Op Functor representing the elementwise operation. * @tparam Tdata Data type of both input and output tensors. * @tparam Args Optional additional arguments passed to the operation. + * * @param info Metadata about the operation including shape, size, and dimensionality. * @param workspace Temporary workspace used for storing metadata on device. * @param output Pointer to the output buffer. @@ -227,6 +232,7 @@ struct DeviceImpl::Opaque { * @tparam Tout Data type of the output tensor. * @tparam Tin... Data types of the input tensors. * @tparam Args Optional additional arguments passed to the operation.(UNUSED) + * * @param info Metadata about the operation including shape, size, and dimensionality. * @param workspace Temporary workspace used for storing metadata on device. * @param output Pointer to the output buffer. @@ -255,6 +261,7 @@ private: * @brief Transfers elementwise operation metadata and input pointers from host to device memory. * * @tparam N Number of input tensors. + * * @param info Elementwise operation metadata (shapes, strides, flags, etc.). * @param workspace Pointer to device workspace memory for storing metadata and input pointers. * @param h_inputs_arr Host array of input tensor pointers. diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index eec8b3511..70a90d308 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -57,6 +57,7 @@ public: * @tparam Tout Output data type. * @tparam Tin... Input data types (must match Op::num_inputs). * @tparam Args... Additional arguments passed to the operation. + * * @param info Metadata describing tensor shapes, strides, etc. * @param workspace Pointer to workspace buffer on device. * @param output Pointer to output buffer on device. From 256a5e3a260cd12eda0beacd10ee71bc887772d4 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 15 Apr 2025 14:54:41 +0800 Subject: [PATCH 12/14] issue/127: Add arguments to CREATE_ELEMENTWISE_PLATFORM_DESCRIPTOR macros for indirecting variable names, change DeviceImpl to use Result for the return type of the create function, change CEIL_DIV --- .../elementwise/cpu/elementwise_cpu.h | 36 +++++++++-------- .../elementwise/cuda/elementwise_cuda.cuh | 6 +-- .../elementwise/cuda/elementwise_cuda_api.cuh | 39 +++++++++++-------- src/infiniop/elementwise/elementwise.h | 2 +- src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc | 8 ++-- src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu | 8 ++-- src/utils.h | 2 +- 7 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 880ce027b..e8d71a35a 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -8,18 +8,23 @@ /** * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CPU implementation + * + * @param handle The device handle. + * @param dtype The output dtype. + * @param out_desc The output tensor descriptor. + * @param input_desc_vec A vector containing input tensor descriptors. */ -#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \ - \ - auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ - CHECK_RESULT(info_result); \ - \ - *desc_ptr = new Descriptor( \ - dtype, \ - info_result.take(), \ - nullptr, \ - 0, \ - handle->device, \ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc_vec); \ + CHECK_RESULT(info_result); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + info_result.take(), \ + nullptr, \ + 0, \ + handle->device, \ handle->device_id); namespace op::elementwise::cpu { @@ -41,9 +46,7 @@ class DeviceImpl final { ~DeviceImpl() = default; template - static infiniStatus_t create( - DeviceImpl **device_info, - Args &&...args); + static utils::Result create(Args &&...args); /** * @brief Dispatches an elementwise operation with uniform input types. @@ -98,9 +101,8 @@ class DeviceImpl final { struct DeviceImpl::Opaque {}; template -infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) { - *device_info = new DeviceImpl(nullptr); - return INFINI_STATUS_SUCCESS; +utils::Result DeviceImpl::create(Args &&...args) { + return utils::Result(nullptr); } // Perform elementwise operation for different input types diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh index 8d66d74c0..6f99200db 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -377,11 +377,9 @@ private: }; template -infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, - Args &&...args) { +utils::Result DeviceImpl::create(Args &&...args) { auto opaque = std::make_shared(std::forward(args)...); - *device_info = new DeviceImpl(opaque); - return INFINI_STATUS_SUCCESS; + return utils::Result(new DeviceImpl(opaque)); } /* Invoke elementwise operation for different input types */ diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index 70a90d308..8de1ab924 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -18,7 +18,7 @@ public: ~DeviceImpl() = default; template - static infiniStatus_t create(DeviceImpl **device_info, Args &&...args); + static utils::Result create(Args &&...args); /** * @brief Launches elementwise operation where all input types are the same. @@ -82,23 +82,28 @@ public: /** * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CUDA implementation + * + * @param handle The device handle. + * @param dtype The output dtype. + * @param out_desc The output tensor descriptor. + * @param input_desc_vec A vector containing input tensor descriptors. */ -#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \ - \ - auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ - CHECK_RESULT(info_result); \ - auto info = info_result.take(); \ - auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ - \ - op::elementwise::cuda::DeviceImpl *device_impl; \ - CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ - \ - *desc_ptr = new Descriptor( \ - dtype, \ - std::move(info), \ - device_impl, \ - workspace_size, \ - handle->device, \ +#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc_vec); \ + CHECK_RESULT(info_result); \ + auto info = info_result.take(); \ + auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ + \ + auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(handle->internal()); \ + CHECK_RESULT(device_impl_result); \ + \ + *desc_ptr = new Descriptor( \ + dtype, \ + std::move(info), \ + std::move(device_impl_result.take()), \ + workspace_size, \ + handle->device, \ handle->device_id); #endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index 9feb8dd74..df794d9c7 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -31,7 +31,7 @@ : InfiniopDescriptor{device_type, device_id}, \ _dtype(dtype), \ _info(std::move(info)), \ - _device_info(device_info), \ + _device_info(std::move(device_info)), \ _workspace_size(workspace_size) {} \ \ public: \ diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc index 6c283c02c..9b5b191b4 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc @@ -8,13 +8,13 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle_, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, - std::vector input_desc) { + std::vector input_desc_vec) { auto handle = reinterpret_cast(handle_); auto dtype = out_desc->dtype(); - const auto &up_desc = input_desc.at(0); - const auto &gate_desc = input_desc.at(1); + const auto &up_desc = input_desc_vec.at(0); + const auto &gate_desc = input_desc_vec.at(1); const auto &out_shape = out_desc->shape(); const auto &up_shape = up_desc->shape(); const auto &gate_shape = gate_desc->shape(); @@ -24,7 +24,7 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); // create CPU elementwise descriptor - CREATE_ELEMENTWISE_CPU_DESCRIPTOR; + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu index b6ab533d8..d1de22ed5 100644 --- a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu @@ -9,13 +9,13 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle_, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, - std::vector input_desc) { + std::vector input_desc_vec) { auto handle = reinterpret_cast(handle_); auto dtype = out_desc->dtype(); - const auto &up_desc = input_desc.at(0); - const auto &gate_desc = input_desc.at(1); + const auto &up_desc = input_desc_vec.at(0); + const auto &gate_desc = input_desc_vec.at(1); const auto &out_shape = out_desc->shape(); const auto &up_shape = up_desc->shape(); const auto &gate_shape = gate_desc->shape(); @@ -24,7 +24,7 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); // create CUDA elementwise descriptor - CREATE_ELEMENTWISE_CUDA_DESCRIPTOR + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) return INFINI_STATUS_SUCCESS; } diff --git a/src/utils.h b/src/utils.h index fa5469584..25ba3745f 100644 --- a/src/utils.h +++ b/src/utils.h @@ -98,6 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { } } -#define CEIL_DIV(x, y) ((x + y - 1) / y) +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) #endif From fe9c4aa5abcdc40c7505c084b70c248c3bcf229e Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 15 Apr 2025 16:16:13 +0800 Subject: [PATCH 13/14] issue/127: capitalize the arguments of the CREATE_ELEMENTWISE_PLATFORM_DESCRIPTOR marcos --- .../elementwise/cpu/elementwise_cpu.h | 18 ++++++++--------- .../elementwise/cuda/elementwise_cuda_api.cuh | 20 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index e8d71a35a..25357b689 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -9,23 +9,23 @@ * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CPU implementation * - * @param handle The device handle. - * @param dtype The output dtype. - * @param out_desc The output tensor descriptor. - * @param input_desc_vec A vector containing input tensor descriptors. + * @param HANDLE The device handle. + * @param DTYPE The output dtype. + * @param OUT_DESC The output tensor descriptor. + * @param INPUT_DESC_VEC A vector containing input tensor descriptors. */ -#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) \ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ \ - auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc_vec); \ + auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ CHECK_RESULT(info_result); \ \ *desc_ptr = new Descriptor( \ - dtype, \ + DTYPE, \ info_result.take(), \ nullptr, \ 0, \ - handle->device, \ - handle->device_id); + HANDLE->device, \ + HANDLE->device_id); namespace op::elementwise::cpu { diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index 8de1ab924..67223fd85 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -83,27 +83,27 @@ public: * @brief Define the process for initializing a Descriptor of an elementwise operation * for its CUDA implementation * - * @param handle The device handle. - * @param dtype The output dtype. - * @param out_desc The output tensor descriptor. - * @param input_desc_vec A vector containing input tensor descriptors. + * @param HANDLE The device handle. + * @param DTYPE The output dtype. + * @param OUT_DESC The output tensor descriptor. + * @param INPUT_DESC_VEC A vector containing input tensor descriptors. */ -#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) \ +#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ \ - auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc_vec); \ + auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ CHECK_RESULT(info_result); \ auto info = info_result.take(); \ auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ \ - auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(handle->internal()); \ + auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \ CHECK_RESULT(device_impl_result); \ \ *desc_ptr = new Descriptor( \ - dtype, \ + DTYPE, \ std::move(info), \ std::move(device_impl_result.take()), \ workspace_size, \ - handle->device, \ - handle->device_id); + HANDLE->device, \ + HANDLE->device_id); #endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ From da881f4d8bfb4edde7c3a77d7320ff305c2bffb4 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 15 Apr 2025 17:24:26 +0800 Subject: [PATCH 14/14] issue/127: change meta within ElementwiseInfo to std::vector for correct alignment and change the reference name of the Opaque struct to Opaque instead of struct Opaque --- src/infiniop/elementwise/cpu/elementwise_cpu.h | 2 +- src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh | 2 +- src/infiniop/elementwise/elementwise.h | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 25357b689..6e00bb998 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -38,7 +38,7 @@ namespace op::elementwise::cpu { */ class DeviceImpl final { struct Opaque; - std::shared_ptr _opaque; + std::shared_ptr _opaque; DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh index 67223fd85..2a9eaf25f 100644 --- a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -10,7 +10,7 @@ namespace op::elementwise::cuda { */ class DeviceImpl final { struct Opaque; - std::shared_ptr _opaque; + std::shared_ptr _opaque; DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h index df794d9c7..a43d30972 100644 --- a/src/infiniop/elementwise/elementwise.h +++ b/src/infiniop/elementwise/elementwise.h @@ -68,13 +68,13 @@ namespace op::elementwise { */ struct ElementwiseInfo { private: - std::vector _meta; + std::vector _meta; size_t _output_size; size_t _input_size; size_t _ndim; bool _output_contiguous; - ElementwiseInfo(std::vector meta, + ElementwiseInfo(std::vector meta, size_t output_size, size_t input_size, size_t ndim, @@ -88,7 +88,7 @@ struct ElementwiseInfo { return _meta.size(); } inline const int8_t *getMetaStart() const { - return _meta.data(); + return reinterpret_cast(_meta.data()); } inline size_t getOutputSize() const { return _output_size; @@ -167,8 +167,8 @@ struct ElementwiseInfo { + input_size * ndim * sizeof(shape_unit) + input_size * ndim * sizeof(stride_unit) + 2 * input_size * sizeof(bool); - std::vector meta(meta_mem_size); - int8_t *meta_ptr = meta.data(); + std::vector meta(meta_mem_size); + int8_t *meta_ptr = reinterpret_cast(meta.data()); const auto output_shape = output_desc->shape(); const auto output_strides = output_desc->strides();