From e184c7e4f2198ed2136e68e7c174d00c76c417ac Mon Sep 17 00:00:00 2001 From: tianyuxbear Date: Fri, 25 Jul 2025 16:12:19 +0800 Subject: [PATCH] issue/456/feat: add silu operator --- include/infiniop.h | 1 + include/infiniop/ops/silu.h | 24 +++ src/infiniop-test/include/ops.hpp | 2 + src/infiniop-test/src/ops/silu.cpp | 101 +++++++++++ src/infiniop/ops/silu/cpu/silu_cpu.cc | 52 ++++++ src/infiniop/ops/silu/cpu/silu_cpu.h | 23 +++ src/infiniop/ops/silu/cuda/kernel.cuh | 37 ++++ src/infiniop/ops/silu/metax/silu_metax.h | 8 + src/infiniop/ops/silu/metax/silu_metax.maca | 60 +++++++ src/infiniop/ops/silu/nvidia/silu_nvidia.cu | 59 +++++++ src/infiniop/ops/silu/nvidia/silu_nvidia.cuh | 8 + src/infiniop/ops/silu/operator.cc | 142 +++++++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ test/infiniop/libinfiniop/utils.py | 5 +- test/infiniop/silu.py | 172 +++++++++++++++++++ 15 files changed, 725 insertions(+), 1 deletion(-) create mode 100644 include/infiniop/ops/silu.h create mode 100644 src/infiniop-test/src/ops/silu.cpp create mode 100644 src/infiniop/ops/silu/cpu/silu_cpu.cc create mode 100644 src/infiniop/ops/silu/cpu/silu_cpu.h create mode 100644 src/infiniop/ops/silu/cuda/kernel.cuh create mode 100644 src/infiniop/ops/silu/metax/silu_metax.h create mode 100644 src/infiniop/ops/silu/metax/silu_metax.maca create mode 100644 src/infiniop/ops/silu/nvidia/silu_nvidia.cu create mode 100644 src/infiniop/ops/silu/nvidia/silu_nvidia.cuh create mode 100644 src/infiniop/ops/silu/operator.cc create mode 100644 test/infiniop/silu.py diff --git a/include/infiniop.h b/include/infiniop.h index 476be24f5..7d3b3df18 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -16,6 +16,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/silu.h" #include "infiniop/ops/softplus.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/swiglu.h" diff --git a/include/infiniop/ops/silu.h b/include/infiniop/ops/silu.h new file mode 100644 index 000000000..037d6323f --- /dev/null +++ b/include/infiniop/ops/silu.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SILU_API_H__ +#define __INFINIOP_SILU_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSiluDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSiluDescriptor(infiniopHandle_t handle, + infiniopSiluDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t intput); + +__C __export infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSilu(infiniopSiluDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *intput, + void *stream); + +__C __export infiniStatus_t infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc); + +#endif diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index c40f420ec..12469d780 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -15,6 +15,7 @@ DECLARE_INFINIOP_TEST(swiglu) DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) +DECLARE_INFINIOP_TEST(silu) DECLARE_INFINIOP_TEST(sub) DECLARE_INFINIOP_TEST(zeros) DECLARE_INFINIOP_TEST(ones) @@ -53,6 +54,7 @@ DECLARE_INFINIOP_TEST(topksoftmax) REGISTER_INFINIOP_TEST(sigmoid) \ REGISTER_INFINIOP_TEST(topkrouter) \ REGISTER_INFINIOP_TEST(topksoftmax) \ + REGISTER_INFINIOP_TEST(silu) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/silu.cpp b/src/infiniop-test/src/ops/silu.cpp new file mode 100644 index 000000000..75684503c --- /dev/null +++ b/src/infiniop-test/src/ops/silu.cpp @@ -0,0 +1,101 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::silu { +struct Test::Attributes { + std::shared_ptr input; + std::shared_ptr output; + std::shared_ptr ans; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + if (tensors.find("input") == tensors.end() + || tensors.find("output") == tensors.end() + || tensors.find("ans") == tensors.end()) { + throw std::runtime_error("Invalid Test"); + } + + test->_attributes->input = tensors["input"]; + test->_attributes->output = tensors["output"]; + test->_attributes->ans = tensors["ans"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) { + infiniopSiluDescriptor_t op_desc; + auto input = _attributes->input->to(device, device_id); + auto output = _attributes->output->to(device, device_id); + CHECK_OR(infiniopCreateSiluDescriptor(handle, &op_desc, + output->desc(), + input->desc()), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor.")); + size_t workspace_size; + CHECK_OR(infiniopGetSiluWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size.")); + void *workspace; + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace.")); + CHECK_OR(infiniopSilu(op_desc, workspace, workspace_size, + output->data(), + input->data(), + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution.")); + + try { + allClose(output, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0.; + + elapsed_time = benchmark( + [=]() { + infiniopSilu( + op_desc, workspace, workspace_size, + output->data(), + input->data(), + nullptr); + }, + warm_ups, iterations); + + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {}; +} + +std::vector Test::tensor_names() { + return {"input", "output", "ans"}; +} + +std::vector Test::output_names() { + return {"output"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- input: " << _attributes->input->info() << std::endl; + oss << "- output: " << _attributes->output->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::silu diff --git a/src/infiniop/ops/silu/cpu/silu_cpu.cc b/src/infiniop/ops/silu/cpu/silu_cpu.cc new file mode 100644 index 000000000..c8466d227 --- /dev/null +++ b/src/infiniop/ops/silu/cpu/silu_cpu.cc @@ -0,0 +1,52 @@ +#include "silu_cpu.h" + +namespace op::silu::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::silu::cpu diff --git a/src/infiniop/ops/silu/cpu/silu_cpu.h b/src/infiniop/ops/silu/cpu/silu_cpu.h new file mode 100644 index 000000000..e1e9da4e3 --- /dev/null +++ b/src/infiniop/ops/silu/cpu/silu_cpu.h @@ -0,0 +1,23 @@ +#ifndef __SILU_CPU_H__ +#define __SILU_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(silu, cpu) + +#include + +namespace op::silu::cpu { +typedef struct SiluOp { +public: + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + return x / (static_cast(1) + std::exp(-x)); + } +} SiluOp; + +} // namespace op::silu::cpu + +#endif // __SILU_CPU_H__ diff --git a/src/infiniop/ops/silu/cuda/kernel.cuh b/src/infiniop/ops/silu/cuda/kernel.cuh new file mode 100644 index 000000000..5cb8616b0 --- /dev/null +++ b/src/infiniop/ops/silu/cuda/kernel.cuh @@ -0,0 +1,37 @@ +#ifndef __SILU_CUDA_H__ +#define __SILU_CUDA_H__ + +#include + +namespace op::silu::cuda { + +typedef struct SiluOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + if constexpr (std::is_same_v) { + // half2向量化优化 + return __hmul2(x, __h2div(__float2half2_rn(1.0f), + __hadd2(__float2half2_rn(1.0f), h2exp(__hneg2(x))))); + } else if constexpr (std::is_same_v) { + // BF16 + const float x_f = __bfloat162float(x); + return __float2bfloat16(x_f / (1.0f + __expf(-x_f))); + } else if constexpr (std::is_same_v) { + // FP16 + const float x_f = __half2float(x); + return __float2half(x_f / (1.0f + __expf(-x_f))); + } else if constexpr (std::is_same_v) { + // FP32 + return x * (1.0f / (1.0f + __expf(-x))); + } else if constexpr (std::is_same_v) { + // FP64 + return x / (1.0 + exp(-x)); + } + } +} SiluOp; + +} // namespace op::silu::cuda + +#endif // __SILU_CUDA_H__ diff --git a/src/infiniop/ops/silu/metax/silu_metax.h b/src/infiniop/ops/silu/metax/silu_metax.h new file mode 100644 index 000000000..a9717ccd0 --- /dev/null +++ b/src/infiniop/ops/silu/metax/silu_metax.h @@ -0,0 +1,8 @@ +#ifndef __SILU_METAX_API_H__ +#define __SILU_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(silu, metax) + +#endif // __SILU_METAX_API_H__ diff --git a/src/infiniop/ops/silu/metax/silu_metax.maca b/src/infiniop/ops/silu/metax/silu_metax.maca new file mode 100644 index 000000000..73408bfc6 --- /dev/null +++ b/src/infiniop/ops/silu/metax/silu_metax.maca @@ -0,0 +1,60 @@ +#include "silu_metax.h" + +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::silu::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create METAX elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::SiluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::SiluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::SiluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::silu::metax diff --git a/src/infiniop/ops/silu/nvidia/silu_nvidia.cu b/src/infiniop/ops/silu/nvidia/silu_nvidia.cu new file mode 100644 index 000000000..291b9835f --- /dev/null +++ b/src/infiniop/ops/silu/nvidia/silu_nvidia.cu @@ -0,0 +1,59 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "silu_nvidia.cuh" + +namespace op::silu::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::SiluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::SiluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::SiluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::silu::nvidia diff --git a/src/infiniop/ops/silu/nvidia/silu_nvidia.cuh b/src/infiniop/ops/silu/nvidia/silu_nvidia.cuh new file mode 100644 index 000000000..b13c7fd44 --- /dev/null +++ b/src/infiniop/ops/silu/nvidia/silu_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SILU_CUDA_API_H__ +#define __SILU_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(silu, nvidia) + +#endif // __SILU_CUDA_API_H__ diff --git a/src/infiniop/ops/silu/operator.cc b/src/infiniop/ops/silu/operator.cc new file mode 100644 index 000000000..5ae6ea4ff --- /dev/null +++ b/src/infiniop/ops/silu/operator.cc @@ -0,0 +1,142 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/silu.h" + +#ifdef ENABLE_CPU_API +#include "cpu/silu_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/silu_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/silu_metax.h" +#endif + +__C infiniStatus_t infiniopCreateSiluDescriptor( + infiniopHandle_t handle, + infiniopSiluDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::silu::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + {input_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopSilu( + infiniopSiluDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, output, {input}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 40d4155a4..41e2a75ca 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -736,3 +736,35 @@ def ones_(lib): lib.infiniopDestroyOnesDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def silu_(lib): + lib.infiniopCreateSiluDescriptor.restype = c_int32 + lib.infiniopCreateSiluDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSiluWorkspaceSize.restype = c_int32 + lib.infiniopGetSiluWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSilu.restype = c_int32 + lib.infiniopSilu.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySiluDescriptor.restype = c_int32 + lib.infiniopDestroySiluDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 93bc7c2b9..d85a77ec8 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -143,6 +143,9 @@ def from_torch(torch_tensor, dt: InfiniDtype, device: InfiniDeviceEnum): shape_, strides_, dt, device, mode="manual", set_tensor=torch_tensor ) + def update_torch_tensor(self, new_tensor: torch.Tensor): + self._torch_tensor = new_tensor + def to_torch_dtype(dt: InfiniDtype, compatability_mode=False): if dt == InfiniDtype.BOOL: @@ -607,7 +610,7 @@ def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS): # Timed execution elapsed = timed_op(lambda: func(), NUM_ITERATIONS, torch_device) - print(f" {desc} time: {elapsed * 1000 :6f} ms") + print(f" {desc} time: {elapsed * 1000:6f} ms") def test_operator(device, test_func, test_cases, tensor_dtypes): diff --git a/test/infiniop/silu.py b/test/infiniop/silu.py new file mode 100644 index 000000000..96eeb74f7 --- /dev/null +++ b/test/infiniop/silu.py @@ -0,0 +1,172 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, input_stride, output_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), None), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), None), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE = auto() + + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.F64] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F64: {"atol": 2.22e-15, "rtol": 2.22e-15}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def test( + handle, + device, + shape, + input_stride=None, + output_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, +): + input = TestTensor(shape, input_stride, dtype, device) + if inplace == Inplace.INPLACE: + if input_stride != output_stride: + return + output = input + else: + output = TestTensor(shape, output_stride, dtype, device, mode="ones") + + if output.is_broadcast(): + return + + print( + f"Testing Silu on {InfiniDeviceNames[device]} with shape:{shape} input_stride:{input_stride} output_stride:{output_stride}" + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + new_output = torch.nn.functional.silu(input.torch_tensor()) + output.update_torch_tensor(new_output) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSiluDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + input.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [input, output]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSiluWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, output.device) + + def lib_silu(): + check_error( + LIBINFINIOP.infiniopSilu( + descriptor, + workspace.data(), + workspace.size(), + output.data(), + input.data(), + None, + ) + ) + + lib_silu() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(output.actual_tensor(), output.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose( + output.actual_tensor(), output.torch_tensor(), atol=atol, rtol=rtol + ) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch.nn.functional.silu(input.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_silu(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroySiluDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m")