diff --git a/include/infiniop/ops/swiglu.h b/include/infiniop/ops/swiglu.h index 7a74f6382..1d4d87e17 100644 --- a/include/infiniop/ops/swiglu.h +++ b/include/infiniop/ops/swiglu.h @@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc); +__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size); + __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, + void *workspace, + size_t workspace_size, void *c, void const *a, void const *b, diff --git a/src/infiniop/devices/cuda/cuda_kernel_common.cuh b/src/infiniop/devices/cuda/cuda_kernel_common.cuh index b3f52db01..68ef36c2a 100644 --- a/src/infiniop/devices/cuda/cuda_kernel_common.cuh +++ b/src/infiniop/devices/cuda/cuda_kernel_common.cuh @@ -9,6 +9,10 @@ #define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_512 512 +#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess) + +namespace device::cuda { + // return the memory offset of original tensor, given the flattened index of broadcasted tensor __forceinline__ __device__ __host__ size_t indexToReducedOffset( @@ -38,6 +42,7 @@ indexToOffset( } return res; } +} // namespace device::cuda #ifdef ENABLE_CUDA_API #include diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h new file mode 100644 index 000000000..6e00bb998 --- /dev/null +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -0,0 +1,201 @@ +#ifndef __INFINIOP_ELEMENTWISE_CPU_H__ +#define __INFINIOP_ELEMENTWISE_CPU_H__ + +#include "../../devices/cpu/common_cpu.h" +#include "../elementwise.h" +#include + +/** + * @brief Define the process for initializing a Descriptor of an elementwise operation + * for its CPU implementation + * + * @param HANDLE The device handle. + * @param DTYPE The output dtype. + * @param OUT_DESC The output tensor descriptor. + * @param INPUT_DESC_VEC A vector containing input tensor descriptors. + */ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ + CHECK_RESULT(info_result); \ + \ + *desc_ptr = new Descriptor( \ + DTYPE, \ + info_result.take(), \ + nullptr, \ + 0, \ + HANDLE->device, \ + HANDLE->device_id); + +namespace op::elementwise::cpu { + +/** + * @brief CPU-specific device implementation for resource management and + * calculation implementations. + * + * This class encapsulates device-specific behavior and execution logic. + * Use the static create() method to instantiate a DeviceImpl. + */ +class DeviceImpl final { + struct Opaque; + std::shared_ptr _opaque; + + DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} + +public: + ~DeviceImpl() = default; + + template + static utils::Result create(Args &&...args); + + /** + * @brief Dispatches an elementwise operation with uniform input types. + * + * @tparam Op The elementwise operation to perform. + * @tparam Tdata The common data type of all inputs and output. + * @tparam Args Additional backend-specific arguments. + * @param info Precomputed tensor metadata (shapes, strides, etc.). + * @param output Pointer to the output tensor buffer. + * @param inputs Vector of input tensor data pointers. + * @param stream Device execution stream. + * @param args Additional backend-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ + template + infiniStatus_t calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); + + /** + * @brief Dispatches an elementwise operation with heterogeneous input types. + * + * Supports operations where each input may have a different type, as defined by Op. + * The number of input types must match the operation's expected input count. + * + * @tparam Op The elementwise operation to perform. + * @tparam Tout Output data type. + * @tparam Tin Variadic input data types. + * @tparam Args Additional backend-specific arguments. + * @param info Precomputed tensor metadata (shapes, strides, etc.). + * @param output Pointer to the output tensor buffer. + * @param inputs Vector of input tensor data pointers. + * @param stream Device execution stream. + * @param args Additional backend-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ + template = 0> + infiniStatus_t calculate( + const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); +}; + +// Define the Opaque struct for CPU, which is empty +struct DeviceImpl::Opaque {}; + +template +utils::Result DeviceImpl::create(Args &&...args) { + return utils::Result(nullptr); +} + +// Perform elementwise operation for different input types +template = 0> +void calculate_impl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + Args &&...args) { + + Tout *out = reinterpret_cast(output); + std::tuple input_ptrs = {reinterpret_cast(inputs[Is])...}; + ptrdiff_t output_size = info.getOutputSize(); + +#pragma omp parallel for + for (ptrdiff_t i = 0; i < output_size; ++i) { + size_t out_idx = info.isOutputContiguous() + ? i + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); + + auto get_input_idx = [&](size_t input_id) { + return info.getInputContiguous()[input_id] + ? i + : (info.getInputBroadcasted()[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id)) + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id))); + }; + + out[out_idx] = utils::cast( + Op{}.template operator()(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); + } +} + +// Invoke elementwise operation for different input types +template > +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + + static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); + calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + return INFINI_STATUS_SUCCESS; +} + +// Perform elementwise operation when all inputs have the same type +template +void calculate_impl(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + std::index_sequence, + Args &&...args) { + + Tdata *out = reinterpret_cast(output); + std::array ins = {reinterpret_cast(inputs[Is])...}; + const ptrdiff_t output_size = info.getOutputSize(); + +#pragma omp parallel for + for (ptrdiff_t i = 0; i < output_size; ++i) { + size_t out_idx = info.isOutputContiguous() + ? i + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); + + auto get_input_idx = [&](size_t input_id) { + return info.getInputContiguous()[input_id] + ? i + : (info.getInputBroadcasted()[input_id] + ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id)) + : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id))); + }; + + if constexpr (std::is_same_v) { + out[out_idx] = utils::cast(Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., std::forward(args)...)); + } else { + out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward(args)...); + } + } +} + +// Invoke elementwise operation when all inputs have the same type +template +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + constexpr size_t N = Op::num_inputs; + calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::elementwise::cpu + +#endif // __INFINIOP_ELEMENTWISE_CPU_H__ diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh new file mode 100644 index 000000000..6f99200db --- /dev/null +++ b/src/infiniop/elementwise/cuda/elementwise_cuda.cuh @@ -0,0 +1,419 @@ +#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__ +#define __INFINIOP_ELEMENTWISE_CUDA_H__ + +#include "../../../utils.h" +#include "../../devices/cuda/cuda_common.cuh" +#include "../../devices/cuda/cuda_kernel_common.cuh" +#include "elementwise_cuda_api.cuh" + +namespace op::elementwise::cuda { + +/** + * @brief Casts an untyped device pointer to a typed pointer of type T. + * + * @tparam T Desired pointer type. + * + * @param ptr Untyped pointer. + * @return Pointer of type const T*. + */ +template +__device__ __forceinline__ const T *typedInputPtr(const void *ptr) { + return reinterpret_cast(ptr); +} + +/** + * @brief Computes the output index in memory, accounting for strides if non-contiguous. + * + * @param idx Linear index. + * @param is_contiguous Whether the output tensor is contiguous. + * @param ndim Number of dimensions. + * @param shape Shape of the output tensor. + * @param strides Strides of the output tensor. + * @return Memory offset index. + */ +__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, + const size_t *shape, const ptrdiff_t *strides) { + return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides); +} + +/** + * @brief Computes input element offset for broadcasting and strided access. + * + * Used to map a linear output index to the corresponding index in an input tensor, + * considering contiguity and broadcasting. + */ +struct InputIndexer { + size_t idx; + size_t ndim; + const bool *input_contiguous; + const bool *input_broadcasted; + const size_t *input_shapes; + const ptrdiff_t *input_strides; + const ptrdiff_t *output_strides; + + /** + * @brief Computes the memory offset for a given input tensor at current index. + * + * @param input_id ID of the input tensor. + * @return Offset into the input tensor. + */ + __device__ __forceinline__ size_t operator()(size_t input_id) const { + return input_contiguous[input_id] + ? idx + : (input_broadcasted[input_id] + ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) + : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); + } +}; + +/** + * @brief Invokes a callable with compile-time index constants. + * + * Used to unpack index sequence for variadic template processing of inputs. + * + * @tparam F Callable type. + * @tparam Is Compile-time index sequence. + * + * @param f Callable to invoke with index constants. + */ +template +__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence) { + f(std::integral_constant{}...); +} + +/** + * @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type. + * + * @tparam N Number of input tensors. + * @tparam Op Operator type implementing operator()(Tdata...). + * @tparam Tdata Common data type for inputs and output. + * @tparam Args Additional arguments to pass to the operator. + * + * @param output_size Total number of output elements. + * @param ndim Number of dimensions in tensors. + * @param output_contiguous Whether the output tensor is contiguous in memory. + * @param input_contiguous Array indicating if each input tensor is contiguous. + * @param input_broadcasted Array indicating if each input tensor is broadcasted. + * @param output_shape Shape of the output tensor. + * @param input_shapes Shapes of the input tensors. + * @param output_strides Strides for the output tensor. + * @param input_strides Strides for each input tensor. + * @param output Output buffer. + * @param inputs Array of input pointers, all of type Tdata. + * @param offset Linear offset to support partitioned execution. + * @param args Additional arguments passed to the operator. + */ +template +INFINIOP_CUDA_KERNEL elementwiseKernel( + size_t output_size, + size_t ndim, + bool output_contiguous, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ output_shape, + const size_t *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ input_strides, + Tdata *output, + const void *const *inputs, + size_t offset, + Args... args) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < output_size) { + const Tdata *const *typed_inputs = reinterpret_cast(inputs); + size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides); + InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides}; + + unpackInputsAndApply( + [&](auto... Is) { + output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward(args)...); + }, + std::make_index_sequence{}); + } +} + +/** + * @brief CUDA kernel for performing an elementwise operation on tensors with support + * for broadcasting and mixed data types. + * + * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). + * @tparam Tout Output data type. + * @tparam Tin Variadic input data types. + * + * @param output_size Total number of output elements. + * @param ndim Number of dimensions in the tensors. + * @param output_contiguous Whether the output tensor is contiguous. + * @param input_contiguous Array indicating whether each input is contiguous. + * @param input_broadcasted Array indicating whether each input is broadcasted. + * @param output_shape Shape of the output tensor. + * @param input_shapes Shapes of the input tensors. + * @param output_strides Strides of the output tensor. + * @param input_strides Strides of the input tensors. + * @param output Pointer to the output buffer. + * @param inputs Array of untyped input pointers. + * @param offset Linear offset into the output for partitioned execution. + */ +template +INFINIOP_CUDA_KERNEL elementwiseKernel( + size_t output_size, + size_t ndim, + bool output_contiguous, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ output_shape, + const size_t *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ input_strides, + Tout *output, + const void *const *__restrict__ inputs, + size_t offset) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < output_size) { + size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides); + InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides}; + + unpackInputsAndApply( + [&](auto... Is) { + output[out_idx] = Op{}.template operator()( + (typedInputPtr(inputs[Is.value])[indexer(Is.value)])...); + }, + std::index_sequence_for{}); + } +} + +struct DeviceImpl::Opaque { + std::shared_ptr internal; + + Opaque(const std::shared_ptr &internal) + : internal(internal) {} + + /** + * @brief Executes an elementwise operation where all inputs and the output share the same data type. + * + * @tparam BLOCK_SIZE CUDA block size used for kernel launch. + * @tparam N Number of input tensors. + * @tparam Op Functor representing the elementwise operation. + * @tparam Tdata Data type of both input and output tensors. + * @tparam Args Optional additional arguments passed to the operation. + * + * @param info Metadata about the operation including shape, size, and dimensionality. + * @param workspace Temporary workspace used for storing metadata on device. + * @param output Pointer to the output buffer. + * @param inputs Vector of pointers to input buffers. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments forwarded to the operation. + * @return infiniStatus_t Returns success or failure status. + */ + template + infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + cudaStream_t stream, + Args &&...args) { + return launchElementwiseKernel( + info, workspace, + reinterpret_cast(output), inputs, + elementwiseKernel, + stream, + std::forward(args)...); + } + + /** + * @brief Executes an elementwise operation with mixed input and output data types. + * + * @tparam BLOCK_SIZE CUDA block size used for kernel launch. + * @tparam N Number of input tensors. + * @tparam Op Functor representing the elementwise operation. + * @tparam Tout Data type of the output tensor. + * @tparam Tin... Data types of the input tensors. + * @tparam Args Optional additional arguments passed to the operation.(UNUSED) + * + * @param info Metadata about the operation including shape, size, and dimensionality. + * @param workspace Temporary workspace used for storing metadata on device. + * @param output Pointer to the output buffer. + * @param inputs Vector of pointers to input buffers. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments forwarded to the operation. + * @return infiniStatus_t Returns success or failure status. + */ + template = 0> + infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + cudaStream_t stream, + Args &&...args) { + return launchElementwiseKernel( + info, workspace, + reinterpret_cast(output), inputs, + elementwiseKernel, + stream); + } + +private: + /** + * @brief Transfers elementwise operation metadata and input pointers from host to device memory. + * + * @tparam N Number of input tensors. + * + * @param info Elementwise operation metadata (shapes, strides, flags, etc.). + * @param workspace Pointer to device workspace memory for storing metadata and input pointers. + * @param h_inputs_arr Host array of input tensor pointers. + * @param d_inputs_arr Output reference to device array of input tensor pointers. + * @param d_input_contiguous Output reference to device array indicating whether each input is contiguous. + * @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted. + * @param d_output_shape Output reference to device array holding the output tensor shape. + * @param d_output_strides Output reference to device array holding output tensor strides. + * @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim). + * @param d_input_strides Output reference to flattened input tensor strides (N * ndim). + * @param stream CUDA stream used for asynchronous memory transfer. + * @return infiniStatus_t Status indicating success or failure of the memory transfer and setup. + */ + template + infiniStatus_t infoToDevice( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + const void *const *h_inputs_arr, + const void **&d_inputs_arr, + const bool *&d_input_contiguous, + const bool *&d_input_broadcasted, + const size_t *&d_output_shape, + const ptrdiff_t *&d_output_strides, + const size_t *&d_input_shapes, + const ptrdiff_t *&d_input_strides, + cudaStream_t stream) const { + + constexpr auto input_size = N; + const auto ndim = info.getNdim(); + constexpr auto input_arr_size = N * sizeof(*h_inputs_arr); + const int8_t *info_meta_start = info.getMetaStart(); + const int8_t *d_meta_start = reinterpret_cast(workspace) + input_arr_size; + + // copy the input pointer array and meta to device + CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream)); + + // offset/assign the pointers + d_inputs_arr = reinterpret_cast(workspace); + d_output_shape = reinterpret_cast(d_meta_start); + d_output_strides = reinterpret_cast(d_output_shape + ndim); + d_input_shapes = reinterpret_cast(d_output_strides + ndim); + d_input_strides = reinterpret_cast(d_input_shapes + input_size * ndim); + d_input_contiguous = reinterpret_cast(d_input_strides + input_size * ndim); + d_input_broadcasted = reinterpret_cast(d_input_contiguous + input_size); + + return INFINI_STATUS_SUCCESS; + } + + /** + * @brief Launches the elementwise kernel for the specified operation. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam N Number of input tensors. + * @tparam KernelFunc Type of the kernel function pointer. + * @tparam Tout Output data type. + * @tparam Args Additional arguments to be forwarded to the kernel. + * + * @param info Metadata about the elementwise operation (shapes, strides, etc.). + * @param workspace CUDA memory used for storing metadata. + * @param output Pointer to output buffer on device. + * @param inputs Vector of device pointers to input tensors. + * @param kernel_func Kernel function to launch. + * @param stream CUDA stream for asynchronous execution. + * @param args Additional arguments passed to the kernel. + * @return infiniStatus_t Status code indicating success or failure. + */ + template + infiniStatus_t launchElementwiseKernel( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + Tout *output, + const std::vector &inputs, + KernelFunc kernel_func, + cudaStream_t stream, + Args &&...args) { + + auto output_size = info.getOutputSize(); + if (output_size == 0) { + return INFINI_STATUS_SUCCESS; + } + + // Device pointers + const void **d_inputs_arr = nullptr; + const bool *d_input_contiguous = nullptr; + const bool *d_input_broadcasted = nullptr; + const size_t *d_output_shape = nullptr; + const ptrdiff_t *d_output_strides = nullptr; + const size_t *d_input_shapes = nullptr; + const ptrdiff_t *d_input_strides = nullptr; + + CHECK_STATUS(infoToDevice(info, workspace, inputs.data(), d_inputs_arr, + d_input_contiguous, d_input_broadcasted, + d_output_shape, d_output_strides, + d_input_shapes, d_input_strides, stream)); + + dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); + dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast(internal->gridSizeX()))); + size_t step = gridDims.x * blockDims.x; + + for (size_t i = 0; i < output_size; i += step) { + kernel_func<<>>( + output_size, info.getNdim(), info.isOutputContiguous(), + d_input_contiguous, d_input_broadcasted, + d_output_shape, d_input_shapes, + d_output_strides, d_input_strides, + output, reinterpret_cast(d_inputs_arr), + i, std::forward(args)...); + } + + return INFINI_STATUS_SUCCESS; + } +}; + +template +utils::Result DeviceImpl::create(Args &&...args) { + auto opaque = std::make_shared(std::forward(args)...); + return utils::Result(new DeviceImpl(opaque)); +} + +/* Invoke elementwise operation for different input types */ +template > +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + constexpr size_t N = Op::num_inputs; + static_assert(sizeof...(Tin) == N, "Input type count mismatch"); + return _opaque->calculateImpl( + info, workspace, output, inputs, + reinterpret_cast(stream), + std::forward(args)...); +} + +/* Invoke elementwise operation when all inputs have the same dtype */ +template +infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args) { + constexpr size_t N = Op::num_inputs; + return _opaque->calculateImpl( + info, workspace, output, inputs, + reinterpret_cast(stream), + std::forward(args)...); +} + +} // namespace op::elementwise::cuda + +#endif // __INFINIOP_ELEMENTWISE_CUDA_H__ diff --git a/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh new file mode 100644 index 000000000..2a9eaf25f --- /dev/null +++ b/src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh @@ -0,0 +1,109 @@ +#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__ +#define __INFINIOP_ELEMENTWISE_CUDA_API_H__ + +#include "../elementwise.h" + +namespace op::elementwise::cuda { + +/** + * @brief Define the methods and info needed by CUDA to perform elementwise operation + */ +class DeviceImpl final { + struct Opaque; + std::shared_ptr _opaque; + + DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} + +public: + ~DeviceImpl() = default; + + template + static utils::Result create(Args &&...args); + + /** + * @brief Launches elementwise operation where all input types are the same. + * + * Calls the corresponding templated `calculateImpl` with a unified input type. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tdata Data type for both input and output tensors. + * @tparam Args... Additional arguments passed to the operation. + * + * @param info Metadata describing tensor shapes, strides, etc. + * @param workspace Pointer to workspace buffer on device. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ + template + infiniStatus_t calculate( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); + + /** + * @brief Launches elementwise operation where input types may differ. + * + * Dispatches to templated `calculateImpl` using specified output and input types. + * + * @tparam BLOCK_SIZE Number of threads per block. + * @tparam Op Operation functor defining the computation. + * @tparam Tout Output data type. + * @tparam Tin... Input data types (must match Op::num_inputs). + * @tparam Args... Additional arguments passed to the operation. + * + * @param info Metadata describing tensor shapes, strides, etc. + * @param workspace Pointer to workspace buffer on device. + * @param output Pointer to output buffer on device. + * @param inputs Vector of input pointers (device memory). + * @param stream CUDA stream (opaque void*). + * @param args (UNUSED) Additional operation-specific arguments. + * @return infiniStatus_t Status indicating success or failure. + */ + template = 0> + infiniStatus_t calculate( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + void *output, + const std::vector &inputs, + void *stream, + Args &&...args); +}; +} // namespace op::elementwise::cuda + +/** + * @brief Define the process for initializing a Descriptor of an elementwise operation + * for its CUDA implementation + * + * @param HANDLE The device handle. + * @param DTYPE The output dtype. + * @param OUT_DESC The output tensor descriptor. + * @param INPUT_DESC_VEC A vector containing input tensor descriptors. + */ +#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ + \ + auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ + CHECK_RESULT(info_result); \ + auto info = info_result.take(); \ + auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ + \ + auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \ + CHECK_RESULT(device_impl_result); \ + \ + *desc_ptr = new Descriptor( \ + DTYPE, \ + std::move(info), \ + std::move(device_impl_result.take()), \ + workspace_size, \ + HANDLE->device, \ + HANDLE->device_id); + +#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ diff --git a/src/infiniop/elementwise/elementwise.h b/src/infiniop/elementwise/elementwise.h new file mode 100644 index 000000000..a43d30972 --- /dev/null +++ b/src/infiniop/elementwise/elementwise.h @@ -0,0 +1,205 @@ +#ifndef __INFINIOP_ELEMENTWISE_H__ +#define __INFINIOP_ELEMENTWISE_H__ + +#include "../../utils.h" +#include "../operator.h" +#include "../tensor.h" +#include +#include +#include +#include +#include +#include +#include + +#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ + \ + namespace op::OP::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + infiniDtype_t _dtype; \ + op::elementwise::ElementwiseInfo _info; \ + std::unique_ptr _device_info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + infiniDtype_t dtype, \ + op::elementwise::ElementwiseInfo info, \ + op::elementwise::NAMESPACE::DeviceImpl *device_info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _dtype(dtype), \ + _info(std::move(info)), \ + _device_info(std::move(device_info)), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + std::vector input_descs); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *output, \ + std::vector inputs, \ + void *stream) const; \ + }; \ + } + +namespace op::elementwise { + +/** + * @brief Stores the metadata required for performing an elementwise operation. + * + * This struct encapsulates shape, stride, and layout information for both + * output and multiple input tensors involved in an elementwise operation. + * + * Memory is manually managed and freed in the destructor. + * Supports move construction but disallows copy construction and copy/move assignment. + * + * Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors. + */ +struct ElementwiseInfo { +private: + std::vector _meta; + size_t _output_size; + size_t _input_size; + size_t _ndim; + bool _output_contiguous; + + ElementwiseInfo(std::vector meta, + size_t output_size, + size_t input_size, + size_t ndim, + bool output_contiguous) + : _meta(std::move(meta)), _output_size(output_size), + _input_size(input_size), _ndim(ndim), + _output_contiguous(output_contiguous) {} + +public: + inline size_t getMetaMemSize() const { + return _meta.size(); + } + inline const int8_t *getMetaStart() const { + return reinterpret_cast(_meta.data()); + } + inline size_t getOutputSize() const { + return _output_size; + } + inline size_t getInputSize() const { + return _input_size; + } + inline size_t getNdim() const { + return _ndim; + } + inline bool isOutputContiguous() const { + return _output_contiguous; + } + inline const size_t *getOutputShape() const { + return reinterpret_cast(_meta.data()); + } + inline const ptrdiff_t *getOutputStrides() const { + return reinterpret_cast(getOutputShape() + _ndim); + } + inline const size_t *getAllInputShapes() const { + return reinterpret_cast(getOutputStrides() + _ndim); + } + inline const size_t *getInputShape(const size_t &index) const { + if (index < _input_size) { + return reinterpret_cast(getAllInputShapes() + index * _ndim); + } + return nullptr; + } + inline const ptrdiff_t *getAllInputStrides() const { + return reinterpret_cast(getAllInputShapes() + _input_size * _ndim); + } + inline const ptrdiff_t *getInputStrides(const size_t &index) const { + if (index < _input_size) { + return reinterpret_cast(getAllInputStrides() + index * _ndim); + } + return nullptr; + } + inline const bool *getInputContiguous() const { + return reinterpret_cast(getAllInputStrides() + _input_size * _ndim); + } + inline const bool *getInputBroadcasted() const { + return reinterpret_cast(getInputContiguous() + _input_size); + } + + using ResultType = utils::Result; + + /** + * @brief Construct ElementwiseInfo from output and input tensor descriptors. + * @param output_desc Descriptor of the output tensor. + * @param input_descs Descriptors of the input tensors. + * @return Result with the successfully constructed ElementwiseInfo, + * or the status code. + */ + static ResultType create( + infiniopTensorDescriptor_t output_desc, + std::vector input_descs) { + + if (!output_desc || input_descs.empty()) { + return INFINI_STATUS_BAD_PARAM; + } + + // Destination cannot have broadcast setup + if (output_desc->hasBroadcastDim()) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + auto input_size = input_descs.size(); + auto ndim = output_desc->ndim(); + auto output_size = output_desc->numel(); + auto output_contiguous = output_desc->isContiguous(); + + // Allocate memory for meta + auto shape_unit = output_desc->dim(0); + auto stride_unit = output_desc->stride(0); + size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit)) + + input_size * ndim * sizeof(shape_unit) + + input_size * ndim * sizeof(stride_unit) + + 2 * input_size * sizeof(bool); + std::vector meta(meta_mem_size); + int8_t *meta_ptr = reinterpret_cast(meta.data()); + + const auto output_shape = output_desc->shape(); + const auto output_strides = output_desc->strides(); + + // Pointers to the sections within _meta + size_t *output_shape_p = reinterpret_cast(meta_ptr); + ptrdiff_t *output_strides_p = reinterpret_cast(output_shape_p + ndim); + size_t *input_shapes = reinterpret_cast(output_strides_p + ndim); + ptrdiff_t *input_strides = reinterpret_cast(input_shapes + input_size * ndim); + bool *input_contiguous = reinterpret_cast(input_strides + input_size * ndim); + bool *input_broadcasted = input_contiguous + input_size; + + // Copy output shape and strides + std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p)); + std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p)); + + // Copy input shapes, strides, contiguous, and broadcasted flags + for (size_t i = 0; i < input_size; ++i) { + auto &desc = input_descs[i]; + const auto in_shape = desc->shape(); + const auto in_strides = desc->strides(); + std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes)); + std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides)); + input_contiguous[i] = desc->isContiguous(); + input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim()); + } + + ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous); + return ResultType(std::move(info)); + } +}; +} // namespace op::elementwise + +#endif // __INFINIOP_ELEMENTWISE_H__ diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc index 9eb470aa7..9b5b191b4 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc @@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle_, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t up_desc, - infiniopTensorDescriptor_t gate_desc) { + std::vector input_desc_vec) { auto handle = reinterpret_cast(handle_); auto dtype = out_desc->dtype(); + + const auto &up_desc = input_desc_vec.at(0); + const auto &gate_desc = input_desc_vec.at(1); const auto &out_shape = out_desc->shape(); const auto &up_shape = up_desc->shape(); const auto &gate_shape = gate_desc->shape(); @@ -21,36 +23,26 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); - op::binary::BinaryInfo info; - CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc)); - - // Create descriptor - *desc_ptr = new Descriptor( - dtype, - std::move(info), - nullptr, - handle->device, - handle->device_id); + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); return INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate( - void *c, - const void *a, - const void *b, + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, void *stream) const { switch (_dtype) { case INFINI_DTYPE_F16: - op::common_cpu::binary_op::calculate(_info, c, a, b); - break; + return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F32: - op::common_cpu::binary_op::calculate(_info, c, a, b); - break; + return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F64: - op::common_cpu::binary_op::calculate(_info, c, a, b); - break; + return _device_info->calculate(_info, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h index ac1eba6f1..65c1c7c33 100644 --- a/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h +++ b/src/infiniop/ops/swiglu/cpu/swiglu_cpu.h @@ -1,22 +1,25 @@ #ifndef __SWIGLU_CPU_H__ #define __SWIGLU_CPU_H__ -#include "../../../binary/cpu/binary_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" -BINARY_DESCRIPTOR(swiglu, cpu) +ELEMENTWISE_DESCRIPTOR(swiglu, cpu) -struct SwiGLUOp { +namespace op::swiglu::cpu { +typedef struct SwiGLUOp { private: template T sigmoid(const T &x) const { - return 1 / (1 + std::exp(-x)); + return T(1) / (T(1) + std::exp(-x)); } public: + static constexpr size_t num_inputs = 2; template T operator()(const T &up, const T &gate) const { return gate * sigmoid(gate) * up; } -}; +} SwiGLUOp; +} // namespace op::swiglu::cpu #endif // __SWIGLU_CPU_H__ diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu new file mode 100644 index 000000000..d1de22ed5 --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cu @@ -0,0 +1,56 @@ +#include "swiglu_cuda.cuh" +#include "swiglu_cuda_internal.cuh" + +namespace op::swiglu::cuda { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &up_desc = input_desc_vec.at(0); + const auto &gate_desc = input_desc_vec.at(1); + const auto &out_shape = out_desc->shape(); + const auto &up_shape = up_desc->shape(); + const auto &gate_shape = gate_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::swiglu::cuda diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh new file mode 100644 index 000000000..75e529ab1 --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SWIGLU_CUDA_API_H__ +#define __SWIGLU_CUDA_API_H__ + +#include "../../../elementwise/cuda/elementwise_cuda_api.cuh" + +ELEMENTWISE_DESCRIPTOR(swiglu, cuda) + +#endif // __SWIGLU_CUDA_API_H__ diff --git a/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh new file mode 100644 index 000000000..d832f8110 --- /dev/null +++ b/src/infiniop/ops/swiglu/cuda/swiglu_cuda_internal.cuh @@ -0,0 +1,40 @@ +#ifndef __SWIGLU_CUDA_H__ +#define __SWIGLU_CUDA_H__ + +#include "../../../elementwise/cuda/elementwise_cuda.cuh" +#include + +namespace op::swiglu::cuda { +typedef struct SwiGLUOp { +private: + template + __device__ __forceinline__ T sigmoid(const T &x) const { + if constexpr (std::is_same_v) { + return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); + } else if constexpr (std::is_same_v) { + return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); + } else if constexpr (std::is_same_v) { + return __frcp_rd(__fadd_rd(1, __expf(-x))); + } else { + return 1 / (1 + std::exp(-x)); + } + } + +public: + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T &up, const T &gate) const { + if constexpr (std::is_same_v) { + return __hmul2(__hmul2(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + return __hmul(__hmul(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up); + } else { + return gate * sigmoid(gate) * up; + } + } +} SwiGLUOp; +} // namespace op::swiglu::cuda + +#endif // __SWIGLU_CUDA_H__ diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 80be80bfd..3f90882c1 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -5,6 +5,9 @@ #ifdef ENABLE_CPU_API #include "cpu/swiglu_cpu.h" #endif +#ifdef ENABLE_CUDA_API +#include "cuda/swiglu_cuda.cuh" +#endif __C infiniStatus_t infiniopCreateSwiGLUDescriptor( infiniopHandle_t handle, @@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( handle, \ reinterpret_cast(desc_ptr), \ c_desc, \ - a_desc, \ - b_desc) + {a_desc, \ + b_desc}) switch (handle->device) { #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaCreateSwiGLUDescriptor((CudaHandle_t)handle, - (SwiGLUCudaDescriptor_t *)desc_ptr, - c_desc, a_desc, b_desc); +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( #undef CREATE } +__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda) +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size); + } +#endif +#ifdef ENABLE_ASCEND_API + GET(INFINI_DEVICE_ASCEND, ascend) +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size); + } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size); + } +#endif + } + +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + __C infiniStatus_t infiniopSwiGLU( infiniopSwiGLUDescriptor_t desc, + void *workspace, + size_t workspace_size, void *c, const void *a, const void *b, @@ -76,16 +117,15 @@ __C infiniStatus_t infiniopSwiGLU( #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(c, a, b, stream) + ->calculate(workspace, workspace_size, c, {a, b}, stream) switch (desc->device_type) { #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaSwiGLU((SwiGLUCudaDescriptor_t)desc, c, a, b, stream); +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -125,9 +165,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { #ifdef ENABLE_CPU_API DELETE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: - return cudaDestroySwiGLUDescriptor((SwiGLUCudaDescriptor_t)desc); +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { diff --git a/src/utils.h b/src/utils.h index 13a8c78a1..25ba3745f 100644 --- a/src/utils.h +++ b/src/utils.h @@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { } } +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) + #endif diff --git a/test/infiniop/swiglu.py b/test/infiniop/swiglu.py index 1e145692a..01d6f9612 100644 --- a/test/infiniop/swiglu.py +++ b/test/infiniop/swiglu.py @@ -1,6 +1,6 @@ import torch import ctypes -from ctypes import POINTER, Structure, c_int32, c_void_p +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 from libinfiniop import ( infiniopHandle_t, infiniopTensorDescriptor_t, @@ -14,6 +14,7 @@ debug, get_tolerance, profile_operation, + create_workspace ) from enum import Enum, auto @@ -25,8 +26,10 @@ # shape, a_stride, b_stride, c_stride ((13, 4), None, None, None), ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), ((13, 4, 4), None, None, None), ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), ((16, 5632), None, None, None), ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((4, 4, 5632), None, None, None), @@ -76,6 +79,38 @@ class SwiGLUDescriptor(Structure): def swiglu(a, b): return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) + + + +def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace): + """ + rearrange the tensors if needed and apply the inplace config. + if inplace is true and the output (i.e., c) is placed to the broadcasted input, + the inplace config is ignored and out-of-place is used + """ + original_c_strides = c_strides if c_strides else c.stride() + + def _rearrange(tensor, strides): + if strides and 0 in strides: + tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides) + return tensor + else: + return rearrange_if_needed(tensor, strides) + + a, b, c = [ + _rearrange(tensor, stride) + for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides]) + ] + c = ( + c + if inplace == Inplace.OUT_OF_PLACE + else (a if inplace == Inplace.INPLACE_A else b) + ) + # if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides + if 0 in c.stride(): + c.set_(c.untyped_storage(), 0, c.shape, original_c_strides) + + return a, b, c def test( @@ -98,18 +133,10 @@ def test( a = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device) + a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace) ans = swiglu(a, b) - a, b, c = [ - rearrange_if_needed(tensor, stride) - for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride]) - ] - c = ( - c - if inplace == Inplace.OUT_OF_PLACE - else (a if inplace == Inplace.INPLACE_A else b) - ) a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] c_tensor = ( to_tensor(c, lib) @@ -134,10 +161,19 @@ def test( for tensor in [a_tensor, b_tensor, c_tensor]: tensor.destroyDesc(lib) + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + ) + workspace = create_workspace(workspace_size.value, c.device) + def lib_swiglu(): check_error( lib.infiniopSwiGLU( - descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + c_tensor.data, a_tensor.data, b_tensor.data, None ) ) @@ -170,10 +206,18 @@ def lib_swiglu(): infiniopTensorDescriptor_t, ] + lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32 + lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [ + infiniopSwiGLUDescriptor_t, + POINTER(c_uint64), + ] + lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.argtypes = [ infiniopSwiGLUDescriptor_t, c_void_p, + c_uint64, + c_void_p, c_void_p, c_void_p, c_void_p, diff --git a/xmake/cuda.lua b/xmake/cuda.lua index 7c89c64e3..0d7ccfdae 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -28,6 +28,7 @@ target("infiniop-cuda") else add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") add_cxxflags("-fPIC") end