diff --git a/CMakeLists.txt b/CMakeLists.txt index 053a37b..36aa295 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,18 +41,33 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) +# NVIDIA and Iluvatar are parallel backends; only one GPU backend at a time. +if(WITH_NVIDIA AND WITH_ILUVATAR) + message(FATAL_ERROR "`WITH_NVIDIA` and `WITH_ILUVATAR` cannot both be `ON`. Build one GPU backend at a time.") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) endif() +# Iluvatar: CUDA-compatible device, uses `clang++` with `-x ivcore` (not `nvcc`). +# Reference: `InfiniCore` `xmake/iluvatar.lua`. if(WITH_ILUVATAR) add_compile_definitions(WITH_ILUVATAR=1) - if(NOT WITH_NVIDIA) - enable_language(CUDA) - find_package(CUDAToolkit REQUIRED) + set(ILUVATAR_ARCH "ivcore20" CACHE STRING "Iluvatar GPU architecture") + find_program(CLANGXX NAMES clang++) + if(CLANGXX) + set(CMAKE_CUDA_COMPILER "${CLANGXX}" CACHE STRING "Iluvatar CUDA compiler (clang++)") + else() + set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)") endif() + set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") + set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar") + message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}") + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) endif() if(WITH_METAX) diff --git a/pyproject.toml b/pyproject.toml index b5d2cdb..c27077c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,9 @@ version = "0.1.0" [project.optional-dependencies] dev = ["pytest", "pytest-cov", "pytest-xdist", "ruff", "torch"] +[tool.pytest.ini_options] +testpaths = ["tests"] + [tool.scikit-build.wheel] install-dir = "infini" diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 2a18752..dac6a29 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -73,26 +73,36 @@ class _OperatorExtractor: - def __call__(self, op_name): + def __call__(self, op_name, base_stem=None): def _get_system_include_flags(): + """Collect system include paths from g++ and clang++ so libclang can find STL (e.g. std::optional).""" + seen = set() system_include_flags = [] - for line in subprocess.getoutput( - "clang++ -E -x c++ -v /dev/null" - ).splitlines(): - if not line.startswith(" "): + for compiler in ("clang++", "g++"): + try: + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + path = line.strip() + if path and path not in seen: + seen.add(path) + system_include_flags.append("-isystem") + system_include_flags.append(path) + except Exception: continue - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) - return system_include_flags system_include_flags = _get_system_include_flags() index = clang.cindex.Index.create() args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name.lower()}.h", args=args) + header = f"src/base/{(base_stem or op_name.lower())}.h" + translation_unit = index.parse(header, args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -105,7 +115,8 @@ def _get_system_include_flags(): elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": calls.append(node) - return _Operator(op_name, constructors, calls) + header_name = base_stem if base_stem is not None else op_name.lower() + return _Operator(op_name, constructors, calls, header_name=header_name) @staticmethod def _find(node, op_name): @@ -117,13 +128,15 @@ def _find(node, op_name): class _Operator: - def __init__(self, name, constructors, calls): + def __init__(self, name, constructors, calls, header_name=None): self.name = name self.constructors = constructors self.calls = calls + self.header_name = header_name if header_name is not None else name.lower() + def _generate_pybind11(operator): def _generate_params(node): @@ -173,7 +186,8 @@ def _generate_call(op_name, call, method=True): ) calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) callers = "\n".join( - _generate_call(operator.name, call, method=False) for call in operator.calls + _generate_call(operator.header_name, call, method=False) + for call in operator.calls ) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -182,7 +196,7 @@ def _generate_call(op_name, call, method=True): #include #include -#include "base/{op_name.lower()}.h" +#include "base/{operator.header_name}.h" #include "utils.h" namespace py = pybind11; @@ -213,7 +227,7 @@ def _generate_source(operator): return f"""#include "../../handle.h" #include "../../tensor.h" -#include "infiniop/ops/{operator.name.lower()}.h" +#include "infiniop/ops/{operator.header_name}.h" {impl_includes} static infini::ops::DataType DataTypeFromInfiniDType( @@ -270,7 +284,7 @@ def _generate_header(operator): return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ #define __INFINIOP_{operator.name.upper()}_API_H__ -#include "base/{operator.name.lower()}.h" +#include "base/{operator.header_name}.h" typedef struct infini::ops::Operator *infiniop{operator.name}Descriptor_t; @@ -382,20 +396,21 @@ def _generate_tensor_caster(name, is_data=False): def _get_all_ops(devices): ops = {} - for file_path in _BASE_DIR.iterdir(): - if not file_path.is_file(): + for base_path in _BASE_DIR.iterdir(): + if not base_path.is_file(): continue - op_name = "".join(word.capitalize() for word in file_path.stem.split("_")) + op_name = "".join(word.capitalize() for word in base_path.stem.split("_")) + impl_paths = [] - ops[op_name] = [] - - for file_path in _SRC_DIR.rglob("*"): - if not file_path.is_file() or file_path.parent.parent.name not in devices: + for impl_path in _SRC_DIR.rglob("*"): + if not impl_path.is_file() or impl_path.parent.parent.name not in devices: continue - if f"class Operator<{op_name}" in file_path.read_text(): - ops[op_name].append(file_path) + if f"class Operator<{op_name}" in impl_path.read_text(): + impl_paths.append(impl_path) + + ops[op_name] = {"base_stem": base_path.stem, "impl_paths": impl_paths} return ops @@ -429,12 +444,27 @@ def _get_all_ops(devices): (_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT) - for op_name, impl_paths in ops.items(): + valid_ops = {} + for op_name, op_data in ops.items(): + base_stem = op_data.get("base_stem") if isinstance(op_data, dict) else None + impl_paths = ( + op_data.get("impl_paths", op_data) + if isinstance(op_data, dict) + else op_data + ) + extractor = _OperatorExtractor() - operator = extractor(op_name) + try: + operator = extractor(op_name, base_stem=base_stem) + except clang.cindex.TranslationUnitLoadError as e: + print( + f"Warning: Skipping {op_name} - failed to parse base header: {e}" + ) + continue + valid_ops[op_name] = impl_paths source_path = _GENERATED_SRC_DIR / op_name.lower() - header_name = f"{op_name.lower()}.h" + header_name = f"{operator.header_name}.h" bind_func_name = f"Bind{op_name}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) @@ -451,7 +481,7 @@ def _get_all_ops(devices): impl_includes = "\n".join( f'#include "{impl_path}"' - for impl_paths in ops.values() + for impl_paths in valid_ops.values() for impl_path in impl_paths ) op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ae2b70..97cc0e3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,10 @@ if(WITH_NVIDIA) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) list(APPEND DEVICE_LIST "nvidia") + set_target_properties(infiniops PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) endif() if(WITH_ILUVATAR) @@ -65,6 +69,11 @@ if(WITH_ILUVATAR) find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) + set_target_properties(infiniops PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + list(APPEND DEVICE_LIST "iluvatar") endif() @@ -112,7 +121,7 @@ if(GENERATE_PYTHON_BINDINGS) set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") # TODO: There might be a better solution. - if(WITH_NVIDIA) + if(WITH_NVIDIA OR WITH_ILUVATAR) set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) endif() diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h new file mode 100644 index 0000000..eaffbc8 --- /dev/null +++ b/src/base/rms_norm.h @@ -0,0 +1,59 @@ +#ifndef INFINI_OPS_BASE_RMS_NORM_H_ +#define INFINI_OPS_BASE_RMS_NORM_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class RmsNorm : public Operator { + public: + // Parameter order and naming follow PyTorch: input, weight, eps, out. + RmsNorm(const Tensor out, const Tensor input, const Tensor weight, float eps) + : eps_{eps}, + out_shape_{out.shape()}, + input_shape_{input.shape()}, + out_strides_{out.strides()}, + input_strides_{input.strides()}, + dim_{out.size(-1)}, + ndim_{out.ndim()}, + batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)}, + nhead_{ndim_ == 2 ? 1 : out.size(-2)} {} + + RmsNorm(const Tensor out, const Tensor input, const Tensor weight) + : RmsNorm{out, input, weight, 1e-6f} {} + + virtual void operator()(void* stream, Tensor out, const Tensor input, + const Tensor weight, float eps) const = 0; + + virtual void operator()(void* stream, Tensor out, const Tensor input, + const Tensor weight) const { + return operator()(stream, out, input, weight, eps_); + } + + protected: + float eps_{1e-6f}; + + Tensor::Shape out_shape_; + + Tensor::Shape input_shape_; + + Tensor::Strides out_strides_; + + Tensor::Strides input_strides_; + + Tensor::Size dim_{0}; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size nhead_{1}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 3db1275..0d142eb 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -8,9 +8,9 @@ namespace infini::ops { -template +template struct ConstexprMap { - constexpr ConstexprMap(std::array, N> data) + constexpr ConstexprMap(std::array, size> data) : data_(data) {} constexpr Value at(Key key) const { @@ -24,7 +24,7 @@ struct ConstexprMap { } private: - std::array, N> data_; + std::array, size> data_; }; } // namespace infini::ops diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index 4d92e00..98b9f48 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -3,7 +3,9 @@ #ifdef WITH_NVIDIA #include -#elif WITH_METAX +#elif defined(WITH_ILUVATAR) +#include +#elif WITH_METAX // TODO: Use `defined`. #include #endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 6c82f49..36df934 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -16,8 +16,8 @@ std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim, return res; } -template -constexpr auto CeilDiv(const Tx& x, const Ty& y) { +template +constexpr auto CeilDiv(const X& x, const Y& y) { return (x + y - 1) / y; } diff --git a/src/common/traits.h b/src/common/traits.h index 6f75e9f..10642ef 100644 --- a/src/common/traits.h +++ b/src/common/traits.h @@ -14,17 +14,17 @@ struct List {}; // List Queries // ----------------------------------------------------------------------------- -// Check at compile-time if a Value exists within a construct (e.g., List<>). +// Check at compile-time if a value exists within a construct (e.g., List<>). // Example: static_assert(ContainsValue); -template +template struct Contains; -template -struct Contains, Value> - : std::disjunction...> {}; +template +struct Contains, value> + : std::disjunction...> {}; -template -inline constexpr bool ContainsValue = Contains::value; +template +inline constexpr bool ContainsValue = Contains::value; // Check at compile-time if a type T is present in a variadic list of types Ts. // Example: static_assert(IsTypeInList); @@ -37,35 +37,35 @@ inline constexpr bool IsTypeInList = (std::is_same_v || ...); // Concatenates two List types into a single List. // Example: ConcatType, List<3, 4>> is List<1, 2, 3, 4>. -template +template struct Concat; -template -struct Concat, List> { - using type = List; +template +struct Concat, List> { + using type = List; }; -template -using ConcatType = typename Concat::type; +template +using ConcatType = typename Concat::type; // ----------------------------------------------------------------------------- // Invocability Detection (SFINAE) // ----------------------------------------------------------------------------- -// Checks if a Functor's template operator() can be called with Args. -template +// Checks if a Functor's template operator() can be called with Args. +template struct IsInvocable : std::false_type {}; -template +template struct IsInvocable< - Functor, Value, - std::void_t().template operator()( + Functor, value, + std::void_t().template operator()( std::declval()...))>, Args...> : std::true_type {}; -template +template inline constexpr bool IsInvocableValue = - IsInvocable::value; + IsInvocable::value; // ----------------------------------------------------------------------------- // Filtering Logic @@ -73,34 +73,34 @@ inline constexpr bool IsInvocableValue = // Recursive template to filter values based on Functor support at compile-time. template + auto... remaining> struct Filter; // Base case: All values processed. -template -struct Filter, List> { - using type = List; +template +struct Filter, List> { + using type = List; }; -// Recursive step: Test the 'Head' value and accumulate if supported. -template -struct Filter, List, Head, Tail...> { +// Recursive step: Test the head value and accumulate if supported. +template +struct Filter, List, head, tail...> { using type = typename std::conditional_t< - IsInvocableValue && - !ContainsValue, Head>, - Filter, List, Tail...>, - Filter, List, Tail...>>::type; + IsInvocableValue && + !ContainsValue, head>, + Filter, List, tail...>, + Filter, List, tail...>>::type; }; // Interface to filter a List type directly. template struct FilterList; -template -struct FilterList, List> { +template +struct FilterList, List> { using type = - typename Filter, List<>, Items...>::type; + typename Filter, List<>, items...>::type; }; } // namespace infini::ops diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h new file mode 100644 index 0000000..3ed53c9 --- /dev/null +++ b/src/cpu/rms_norm/rms_norm.h @@ -0,0 +1,66 @@ +#ifndef INFINI_OPS_CPU_RMS_NORM_H_ +#define INFINI_OPS_CPU_RMS_NORM_H_ + +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor out, const Tensor input, const Tensor weight, + float eps) + : RmsNorm{out, input, weight, eps} {} + + Operator(const Tensor out, const Tensor input, const Tensor weight) + : Operator{out, input, weight, 1e-6f} {} + + void operator()(void* stream, Tensor out, const Tensor input, + const Tensor weight, float /*eps*/ = 0) const override { + // CPU backend supports fp32 only; fp16/bf16 use GPU backends. + if (out.dtype() != DataType::kFloat32 || input.dtype() != DataType::kFloat32 || + weight.dtype() != DataType::kFloat32) { + abort(); + } + + auto* out_ptr = static_cast(out.data()); + const auto* input_ptr = static_cast(input.data()); + const auto* weight_ptr = static_cast(weight.data()); + + auto stride_input_batch = + input_strides_.size() > 1 ? input_strides_[0] : 0; + auto stride_input_nhead = + input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0]; + auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0; + auto stride_out_nhead = + out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0]; + + for (Tensor::Size bi = 0; bi < batch_size_; ++bi) { + for (Tensor::Size hi = 0; hi < nhead_; ++hi) { + const float* input_row = + input_ptr + bi * stride_input_batch + hi * stride_input_nhead; + float* out_row = + out_ptr + bi * stride_out_batch + hi * stride_out_nhead; + + float ss = 0; + for (Tensor::Size k = 0; k < dim_; ++k) { + float v = input_row[k]; + ss += v * v; + } + float rms = 1.f / std::sqrt(ss / static_cast(dim_) + eps_); + + for (Tensor::Size k = 0; k < dim_; ++k) { + out_row[k] = input_row[k] * weight_ptr[k] * rms; + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh new file mode 100644 index 0000000..1629516 --- /dev/null +++ b/src/cuda/rms_norm/kernel.cuh @@ -0,0 +1,59 @@ +#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ + +#include +#include + +#include +#include +#include + +namespace infini::ops { + +namespace { + +template +__device__ __forceinline__ Compute sumSquared(const Data* data_ptr, + size_t count) { + Compute ss = 0; + for (size_t i = threadIdx.x; i < count; i += block_size) { + ss += Compute(data_ptr[i]) * Compute(data_ptr[i]); + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Sum(ss); +} + +} // namespace + +template +__global__ void rmsnormKernel(Data* __restrict__ y, int64_t stride_y_batch, + int64_t stride_y_nhead, + const Data* __restrict__ x, + int64_t stride_x_batch, int64_t stride_x_nhead, + const Weight* __restrict__ w, size_t nhead, + size_t dim, float epsilon) { + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; + auto w_ptr = w; + + Compute ss = sumSquared(x_ptr, dim); + + __shared__ Compute rms; + if (threadIdx.x == 0) { + rms = Compute(rsqrtf(ss / Compute(dim) + epsilon)); + } + __syncthreads(); + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h new file mode 100644 index 0000000..622c312 --- /dev/null +++ b/src/cuda/rms_norm/kernel.h @@ -0,0 +1,70 @@ +#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "base/rms_norm.h" +#include "cuda/rms_norm/kernel.cuh" +#include "data_type.h" +#include "dispatcher.h" + +namespace infini::ops { + +namespace { + +constexpr unsigned int kBlockSize = 256; + +} // namespace + +template +class CudaRmsNorm : public RmsNorm { + public: + CudaRmsNorm(const Tensor out, const Tensor input, const Tensor weight, + float eps) + : RmsNorm{out, input, weight, eps} {} + + CudaRmsNorm(const Tensor out, const Tensor input, const Tensor weight) + : CudaRmsNorm{out, input, weight, 1e-6f} {} + + void operator()(void* stream, Tensor out, const Tensor input, + const Tensor weight, float /*eps*/) const override { + auto cuda_stream = + static_cast(stream ? stream : 0); + + auto stride_input_batch = + input_strides_.size() > 1 ? input_strides_[0] : 0; + auto stride_input_nhead = + input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0]; + auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0; + auto stride_out_nhead = + out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0]; + + uint32_t num_blocks = static_cast(batch_size_ * nhead_); + + if (out.dtype() != input.dtype() || out.dtype() != weight.dtype()) { + std::abort(); + } + + DispatchFunc( + out.dtype(), + [&]() { + rmsnormKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_nhead, reinterpret_cast(input.data()), + stride_input_batch, stride_input_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, + eps_); + }, + "CudaRmsNorm::operator()"); + + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/data_type.h b/src/data_type.h index 567a076..8de8518 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -7,7 +7,10 @@ #ifdef WITH_NVIDIA #include #include -#elif WITH_METAX +#elif defined(WITH_ILUVATAR) +#include +#include +#elif WITH_METAX //TODO: Use `defined`. #include #include #endif @@ -111,10 +114,10 @@ DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -#ifdef WITH_NVIDIA +#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) DEFINE_DATA_TYPE_MAPPING(kFloat16, half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) -#elif WITH_METAX +#elif WITH_METAX //TODO: Use `defined`. DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) #else diff --git a/src/dispatcher.h b/src/dispatcher.h index 6b70da5..31d2209 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -17,30 +17,30 @@ namespace infini::ops { // ----------------------------------------------------------------------------- // (Single Dispatch) Dispatches a runtime value to a compile-time functor. -template auto DispatchFunc(ValueType value, Functor&& func, std::string_view context_str = "", Args&&... args) { using FilteredPack = - typename Filter, List<>, AllValues...>::type; + typename Filter, List<>, all_values...>::type; - return [&](List) { + return [&](List) { using ReturnType = decltype(std::forward(func) - .template operator()(Head)>( + .template operator()(head)>( std::forward(args)...)); // Path for Void Functions if constexpr (std::is_void_v) { bool handled = - ((value == static_cast(Tail) - ? (std::forward(func).template operator()( + ((value == static_cast(tail) + ? (std::forward(func).template operator()( std::forward(args)...), true) : false) || ... || - (value == static_cast(Head) - ? (std::forward(func).template operator()( + (value == static_cast(head) + ? (std::forward(func).template operator()( std::forward(args)...), true) : false)); @@ -55,16 +55,16 @@ auto DispatchFunc(ValueType value, Functor&& func, else { std::optional result; bool handled = - ((value == static_cast(Tail) + ((value == static_cast(tail) ? (result.emplace( - std::forward(func).template operator()( + std::forward(func).template operator()( std::forward(args)...)), true) : false) || ... || - (value == static_cast(Head) + (value == static_cast(head) ? (result.emplace( - std::forward(func).template operator()( + std::forward(func).template operator()( std::forward(args)...)), true) : false)); @@ -85,31 +85,31 @@ auto DispatchFunc(ValueType value, Functor&& func, // (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time // functor. // Base Case: All dimensions resolved. -template +template auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, + Functor&& func, std::string_view context_str, List, Args&&... args) { - return std::forward(func).template operator()( + return std::forward(func).template operator()( std::forward(args)...); } // (Multi-Dispatch) Recursive Case template + typename... Args, auto... is> auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, + Functor&& func, std::string_view context_str, List, Args&&... args) { - return [&](List) { - static_assert(sizeof...(Allowed) > 0, + return [&](List) { + static_assert(sizeof...(allowed) > 0, "`DispatchFunc` dimension list is empty"); - using EnumType = std::common_type_t; + using EnumType = std::common_type_t; - return DispatchFunc( + return DispatchFunc( static_cast(values.at(index)), - [&](Args&&... inner_args) { + [&](Args&&... inner_args) { return DispatchFunc( values, index + 1, std::forward(func), context_str, - List{}, std::forward(inner_args)...); + List{}, std::forward(inner_args)...); }, context_str, std::forward(args)...); }(FirstList{}); @@ -126,8 +126,8 @@ auto DispatchFunc(DataType dtype, Functor&& func, std::string_view context_str = "", Args&&... args) { return DispatchFunc( dtype, - [&](Args&&... inner_args) { - using T = TypeMapType
; + [&](Args&&... inner_args) { + using T = TypeMapType
; return std::forward(func).template operator()( std::forward(inner_args)...); }, @@ -143,21 +143,21 @@ auto DispatchFunc(std::initializer_list dtypes, Functor&& func, return DispatchFunc( v, 0, - [&func](Args&&... inner_args) { + [&func](Args&&... inner_args) { return std::forward(func).template - operator()...>(std::forward(inner_args)...); + operator()...>(std::forward(inner_args)...); }, context_str, List<>{}, std::forward(args)...); } // Device Dispatch -template +template auto DispatchFunc(Device::Type device, Functor&& func, std::string_view context_str = "", Args&&... args) { - return DispatchFunc( + return DispatchFunc( device, - [&](Args&&... inner_args) { - return std::forward(func).template operator()( + [&](Args&&... inner_args) { + return std::forward(func).template operator()( std::forward(inner_args)...); }, context_str, std::forward(args)...); @@ -172,8 +172,8 @@ auto DispatchFunc(std::initializer_list devices, Functor&& func, return DispatchFunc( v, 0, - [&func](Args&&... inner_args) { - return std::forward(func).template operator()( + [&func](Args&&... inner_args) { + return std::forward(func).template operator()( std::forward(inner_args)...); }, context_str, List<>{}, std::forward(args)...); @@ -184,8 +184,8 @@ template auto DispatchFunc(ValueType value, Functor&& func, std::string_view context_str = "", Args&&... args) { - return [&](List) { - return DispatchFunc>(Is)...>( + return [&](List) { + return DispatchFunc>(is)...>( value, std::forward(func), context_str, std::forward(args)...); }(ListType{}); diff --git a/src/iluvatar/rms_norm/kernel.h b/src/iluvatar/rms_norm/kernel.h new file mode 100644 index 0000000..c1509ca --- /dev/null +++ b/src/iluvatar/rms_norm/kernel.h @@ -0,0 +1,32 @@ +#ifndef INFINI_OPS_ILUVATAR_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct IluvatarBackend { + using stream_t = cudaStream_t; + +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/rms_norm/kernel.h b/src/nvidia/rms_norm/kernel.h new file mode 100644 index 0000000..68e94d5 --- /dev/null +++ b/src/nvidia/rms_norm/kernel.h @@ -0,0 +1,32 @@ +#ifndef INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct NvidiaBackend { + using stream_t = cudaStream_t; + +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py new file mode 100644 index 0000000..592c525 --- /dev/null +++ b/tests/test_rms_norm.py @@ -0,0 +1,76 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +def _rms_norm(input_, weight, out, *, eps=1e-6): + infini.ops.rms_norm(out, input_, weight, eps) + + return out + + +def _torch_rms_norm(input_, weight, out, *, eps=1e-6): + rms = torch.sqrt(torch.mean(input_**2, dim=-1, keepdim=True) + eps) + result = input_ * weight / rms + out.copy_(result) + + return out + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "input_shape, weight_shape, input_strides, weight_strides, out_strides", + ( + ((1, 64), (64,), None, None, None), + ((2, 128), (128,), None, None, None), + ((4, 48, 64), (64,), None, None, None), + ((2, 4, 2048), (2048,), None, None, None), + ((1, 64), (64,), (64, 1), (1,), (64, 1)), + ((4, 48, 64), (64,), (3072, 64, 1), (1,), (3072, 64, 1)), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_rms_norm( + input_shape, + weight_shape, + input_strides, + weight_strides, + out_strides, + eps, + dtype, + device, + rtol, + atol, +): + if getattr(infini.ops, "rms_norm", None) is None: + pytest.skip("rms_norm not available (wrapper generation skipped)") + + if device == "cpu" and dtype in (torch.float16, torch.bfloat16): + pytest.skip("CPU backend does not support fp16/bf16") + + input_ = randn_strided( + input_shape, input_strides, dtype=dtype, device=device + ) + weight = randn_strided( + weight_shape, weight_strides, dtype=dtype, device=device + ) + out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) + + return Payload( + _rms_norm, + _torch_rms_norm, + (input_, weight, out), + {"eps": eps}, + rtol=rtol, + atol=atol, + )