From eff11f2c0eba5fb29d56428aa5c12d39a7179a34 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 20 May 2026 09:59:36 +0800 Subject: [PATCH] feat(nvidia): add ntops rms norm backend --- CMakeLists.txt | 16 ++ scripts/generate_ninetoothed_ops.py | 15 ++ src/CMakeLists.txt | 51 ++++++ .../cuda/nvidia/ops/rms_norm/ninetoothed.h | 79 +++++++++ src/native/ninetoothed/codegen.py | 165 ++++++++++++++++++ src/native/ninetoothed/tensor.h | 62 +++++++ tests/test_generate_ninetoothed_ops.py | 114 ++++++++++++ 7 files changed, 502 insertions(+) create mode 100644 scripts/generate_ninetoothed_ops.py create mode 100644 src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h create mode 100644 src/native/ninetoothed/codegen.py create mode 100644 src/native/ninetoothed/tensor.h create mode 100644 tests/test_generate_ninetoothed_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac4bd400..748499762 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) option(WITH_TORCH "Enable PyTorch C++ backend" OFF) +option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -290,6 +292,20 @@ if(_gpu_backend_count GREATER 1) message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() +if(WITH_NINETOOTHED AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.") +endif() + +if(WITH_NINETOOTHED) + # NineToothed code generation configuration. + set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation") + set(NINETOOTHED_SOURCE_DIR "" CACHE PATH "Optional local ninetoothed source checkout; installed package is used when empty") + set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate") + set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate") + set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed") + set(INFINIOPS_NINETOOTHED_BLOCK_SIZE "256" CACHE STRING "Block size baked into simple NineToothed elementwise kernels") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_ninetoothed_ops.py b/scripts/generate_ninetoothed_ops.py new file mode 100644 index 000000000..612015ecd --- /dev/null +++ b/scripts/generate_ninetoothed_ops.py @@ -0,0 +1,15 @@ +import pathlib +import sys + + +def main(): + project_dir = pathlib.Path(__file__).resolve().parents[1] + sys.path.insert(0, str(project_dir / "src")) + + from native.ninetoothed.codegen import main as codegen_main + + codegen_main() + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4361ba38f..00894224b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,57 @@ if(WITH_NVIDIA) target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) + if(WITH_NINETOOTHED) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(NINETOOTHED_PYTHON_EXECUTABLE) + set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_ninetoothed_python "${_TORCH_PYTHON}") + else() + set(_ninetoothed_python "${Python_EXECUTABLE}") + endif() + message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}") + + string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}") + string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}") + string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}") + + set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed") + set(_ninetoothed_generator_args + "${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py" + --output-dir "${_ninetoothed_output_dir}" + --ops ${_ninetoothed_ops} + --dtypes ${_ninetoothed_dtypes} + --rms-norm-ndims ${_ninetoothed_rms_norm_ndims} + --block-size "${INFINIOPS_NINETOOTHED_BLOCK_SIZE}") + + if(NINETOOTHED_SOURCE_DIR) + list(APPEND _ninetoothed_generator_args + --ninetoothed-source-dir "${NINETOOTHED_SOURCE_DIR}") + endif() + + execute_process( + COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _ninetoothed_generation_result + ) + + if(NOT _ninetoothed_generation_result EQUAL 0) + message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `triton`, `sympy`, and CUDA dependencies installed.") + endif() + + include("${_ninetoothed_output_dir}/manifest.cmake") + set(_ninetoothed_compile_definitions + WITH_NINETOOTHED=1 + INFINIOPS_NINETOOTHED_BLOCK_SIZE=${INFINIOPS_NINETOOTHED_BLOCK_SIZE}) + target_compile_definitions(infiniops PUBLIC + ${_ninetoothed_compile_definitions}) + target_include_directories(infiniops PUBLIC + ${INFINIOPS_NINETOOTHED_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES}) + endif() + find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver) diff --git a/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h b/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h new file mode 100644 index 000000000..b881dd1c4 --- /dev/null +++ b/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h @@ -0,0 +1,79 @@ +#ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ +#define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ + +#ifdef WITH_NINETOOTHED + +#include +#include +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "native/ninetoothed/tensor.h" +#include "rms_norm/infiniops_ninetoothed_rms_norm.h" + +#ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE +#define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256 +#endif + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + using RmsNorm::RmsNorm; + using RmsNorm::operator(); + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() && + "operator `RmsNorm` requires all input and output tensors to have " + "the same dtype"); + assert(input.shape() == out.shape() && + "ninetoothed `RmsNorm` requires input and output tensors with the " + "same shape"); + assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) && + "ninetoothed `RmsNorm` requires a 1D weight matching the last " + "dimension"); + assert((out.ndim() == 2 || out.ndim() == 3) && + "ninetoothed `RmsNorm` currently supports rank-2 and rank-3 " + "tensors"); + + std::vector weight_sizes; + std::vector weight_strides; + double eps_value = static_cast(eps); + std::int64_t num_normalized_elements = + static_cast(out.size(-1)); + std::uint64_t empty_shape[1] = {}; + std::int64_t empty_strides[1] = {}; + + weight_sizes.assign(out.shape().begin(), out.shape().end()); + weight_strides.assign(out.ndim(), 0); + weight_strides.back() = + weight.strides().empty() ? 1 : weight.strides().back(); + + const int dtype_index = ninetoothed::DTypeIndex(out.dtype()); + assert( + dtype_index >= 0 && + "ninetoothed `RmsNorm` supports only float16, bfloat16, and float32"); + + auto result = launch_infiniops_ninetoothed_rms_norm( + static_cast(stream_), ninetoothed::Tensor(input), + ninetoothed::Tensor(const_cast(weight.data()), + weight_sizes.data(), weight_strides.data()), + ninetoothed::Tensor(eps_value, empty_shape, empty_strides), + ninetoothed::Tensor(out), + ninetoothed::Tensor(num_normalized_elements, empty_shape, + empty_strides), + static_cast(out.ndim()), 1, dtype_index, dtype_index, dtype_index, + INFINIOPS_NINETOOTHED_BLOCK_SIZE); + + assert(result == 0 && "ninetoothed `RmsNorm` launch failed"); + } +}; + +} // namespace infini::ops + +#endif // WITH_NINETOOTHED + +#endif diff --git a/src/native/ninetoothed/codegen.py b/src/native/ninetoothed/codegen.py new file mode 100644 index 000000000..0e1c1dad2 --- /dev/null +++ b/src/native/ninetoothed/codegen.py @@ -0,0 +1,165 @@ +import argparse +import pathlib +import shutil +import sys + + +_DEFAULT_DTYPES = ("float32", "float16", "bfloat16") +_DEFAULT_RMS_NORM_NDIMS = (2, 3) +_SUPPORTED_OPS = ("rms_norm",) + + +def _import_ninetoothed(source_dir): + if source_dir is not None: + sys.path.insert(0, str(pathlib.Path(source_dir) / "src")) + + import ninetoothed + + return ninetoothed + + +def _rms_norm_premake( + ndim, + num_normalized_dims, + input_dtype, + weight_dtype, + output_dtype, + block_size, +): + import ntops + + return ntops.kernels.rms_norm.premake( + ndim, + num_normalized_dims, + input_dtype=input_dtype, + weight_dtype=weight_dtype, + output_dtype=output_dtype, + block_size=block_size, + ) + + +def _normalize_ndims(values): + ndims = [] + + for value in values: + ndim = int(value) + + if ndim not in _DEFAULT_RMS_NORM_NDIMS: + raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}") + + if ndim not in ndims: + ndims.append(ndim) + + return tuple(ndims) + + +def _rms_norm_configs(ninetoothed, dtypes, ndims, block_size): + configs = [] + + for ndim in _normalize_ndims(ndims): + for dtype_name in dtypes: + dtype = getattr(ninetoothed, dtype_name) + configs.append( + ( + (), + { + "ndim": ndim, + "num_normalized_dims": 1, + "input_dtype": dtype, + "weight_dtype": dtype, + "output_dtype": dtype, + "block_size": block_size, + }, + {}, + ) + ) + + return tuple(configs) + + +def _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size): + variant_dir = output_dir / "rms_norm" + variant_dir.mkdir(parents=True, exist_ok=True) + ninetoothed.build( + _rms_norm_premake, + _rms_norm_configs(ninetoothed, dtypes, rms_norm_ndims, block_size), + meta_parameters=None, + caller="cuda", + kernel_name="infiniops_ninetoothed_rms_norm", + output_dir=variant_dir, + lazy=False, + ) + + +def _build_manifest(output_dir): + return sorted( + str(path) + for path in pathlib.Path(output_dir).rglob("*.cpp") + if not path.name.endswith(".tmp.cpp") + ) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def generate( + ops, + *, + output_dir, + dtypes=_DEFAULT_DTYPES, + rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS, + block_size=256, + ninetoothed_source_dir=None, +): + unknown_ops = tuple(op for op in ops if op not in _SUPPORTED_OPS) + + if unknown_ops: + raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + ninetoothed = _import_ninetoothed(ninetoothed_source_dir) + + if "rms_norm" in ops: + _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate ninetoothed operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=_SUPPORTED_OPS) + parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES) + parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--ninetoothed-source-dir") + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate( + args.ops, + output_dir=args.output_dir, + dtypes=tuple(args.dtypes), + rms_norm_ndims=tuple(args.rms_norm_ndims), + block_size=args.block_size, + ninetoothed_source_dir=args.ninetoothed_source_dir, + ) diff --git a/src/native/ninetoothed/tensor.h b/src/native/ninetoothed/tensor.h new file mode 100644 index 000000000..3139c0d0c --- /dev/null +++ b/src/native/ninetoothed/tensor.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_NATIVE_NINETOOTHED_TENSOR_H_ +#define INFINI_OPS_NATIVE_NINETOOTHED_TENSOR_H_ + +#include +#include + +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops::ninetoothed { + +inline int DTypeIndex(DataType dtype) { + switch (dtype) { + case DataType::kFloat16: + return 8; + case DataType::kBFloat16: + return 9; + case DataType::kFloat32: + return 10; + default: + return -1; + } +} + +class Tensor { + public: + explicit Tensor(const ::infini::ops::Tensor& tensor) + : Tensor(const_cast(tensor.data()), + reinterpret_cast( + const_cast<::infini::ops::Tensor::Size*>( + tensor.shape().data())), + reinterpret_cast( + const_cast<::infini::ops::Tensor::Stride*>( + tensor.strides().data()))) { + static_assert(sizeof(::infini::ops::Tensor::Size) == sizeof(std::uint64_t)); + static_assert(sizeof(::infini::ops::Tensor::Stride) == + sizeof(std::int64_t)); + static_assert(std::is_unsigned_v<::infini::ops::Tensor::Size>); + static_assert(std::is_signed_v<::infini::ops::Tensor::Stride>); + } + + Tensor(void* data, std::uint64_t* shape, std::int64_t* strides) + : data_(data), shape_(shape), strides_(strides) {} + + template + Tensor(T& value, std::uint64_t* shape, std::int64_t* strides) + : Tensor(static_cast(&value), shape, strides) {} + + template + operator NineToothedTensor() const { + return NineToothedTensor{data_, shape_, strides_}; + } + + private: + void* data_; + std::uint64_t* shape_; + std::int64_t* strides_; +}; + +} // namespace infini::ops::ninetoothed + +#endif diff --git a/tests/test_generate_ninetoothed_ops.py b/tests/test_generate_ninetoothed_ops.py new file mode 100644 index 000000000..519e2ce02 --- /dev/null +++ b/tests/test_generate_ninetoothed_ops.py @@ -0,0 +1,114 @@ +import importlib.util +import pathlib +import sys +import tempfile +import types + + +def _load_generator_module(): + path = ( + pathlib.Path(__file__).resolve().parents[1] + / "src" + / "native" + / "ninetoothed" + / "codegen.py" + ) + spec = importlib.util.spec_from_file_location( + "ninetoothed_codegen_under_test", path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def test_generate_rms_norm_uses_ntops_premake_with_rank_configs(monkeypatch): + module = _load_generator_module() + calls = [] + + fake_ninetoothed = types.SimpleNamespace( + float32="nt.float32", + ) + fake_ninetoothed.build = lambda *args, **kwargs: calls.append((args, kwargs)) + + fake_arrangement = object() + fake_application = object() + fake_tensors = object() + premake_calls = [] + + def fake_ntops_premake(*args, **kwargs): + premake_calls.append((args, kwargs)) + return fake_arrangement, fake_application, fake_tensors + + fake_ntops = types.SimpleNamespace( + kernels=types.SimpleNamespace( + rms_norm=types.SimpleNamespace(premake=fake_ntops_premake) + ) + ) + + monkeypatch.setattr( + module, "_import_ninetoothed", lambda source_dir: fake_ninetoothed + ) + monkeypatch.setattr(module, "_build_manifest", lambda output_dir: ["kernel.cpp"]) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = pathlib.Path(tmpdir) + manifest = module.generate( + ["rms_norm"], + output_dir=tmp_path, + dtypes=("float32",), + rms_norm_ndims=(2,), + block_size=256, + ) + + assert manifest == ["kernel.cpp"] + assert len(calls) == 1 + + args, kwargs = calls[0] + premake, configs = args + assert configs == ( + ( + (), + { + "ndim": 2, + "num_normalized_dims": 1, + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + "block_size": 256, + }, + {}, + ), + ) + assert kwargs["caller"] == "cuda" + assert kwargs["kernel_name"] == "infiniops_ninetoothed_rms_norm" + assert kwargs["output_dir"] == tmp_path / "rms_norm" + assert kwargs["lazy"] is False + assert kwargs["meta_parameters"] is None + + monkeypatch.setitem(sys.modules, "ntops", fake_ntops) + arrangement, application, tensors = premake( + ndim=2, + num_normalized_dims=1, + input_dtype="nt.float32", + weight_dtype="nt.float32", + output_dtype="nt.float32", + block_size=256, + ) + + assert arrangement is fake_arrangement + assert application is fake_application + assert tensors is fake_tensors + assert premake_calls == [ + ( + (2, 1), + { + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + "block_size": 256, + }, + ) + ]