-
Notifications
You must be signed in to change notification settings - Fork 6
feat(nvidia): add ntops rms norm backend #616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) | |
|
|
||
| option(WITH_TORCH "Enable PyTorch C++ backend" OFF) | ||
|
|
||
| option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF) | ||
|
|
||
| # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for | ||
| # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed | ||
| # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the | ||
|
|
@@ -29,6 +31,14 @@ option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) | |
| option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) | ||
| option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) | ||
|
|
||
| # NineToothed code generation configuration. | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请在后面创建一个关于 |
||
| set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation") | ||
| set(NINETOOTHED_SOURCE_DIR "" CACHE PATH "Optional local ninetoothed source checkout; installed package is used when empty") | ||
| set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate") | ||
| set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate") | ||
| set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed") | ||
| set(INFINIOPS_NINETOOTHED_BLOCK_SIZE "256" CACHE STRING "Block size baked into simple NineToothed elementwise kernels") | ||
|
|
||
| if(AUTO_DETECT_DEVICES) | ||
| message(STATUS "Auto-detecting available devices...") | ||
|
|
||
|
|
@@ -231,6 +241,10 @@ if(_gpu_backend_count GREATER 1) | |
| message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") | ||
| endif() | ||
|
|
||
| if(WITH_NINETOOTHED AND NOT WITH_NVIDIA) | ||
| message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.") | ||
| endif() | ||
|
|
||
| if(WITH_NVIDIA) | ||
| add_compile_definitions(WITH_NVIDIA=1) | ||
| enable_language(CUDA) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import pathlib | ||
| import sys | ||
|
|
||
|
|
||
| def main(): | ||
| project_dir = pathlib.Path(__file__).resolve().parents[1] | ||
| sys.path.insert(0, str(project_dir / "src")) | ||
|
|
||
| from native.ninetoothed.codegen import main as codegen_main | ||
|
|
||
| codegen_main() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 九齿的定位在算子库中应该跟 PyTorch 差不多,都可以接入到后端里,所以在文件结构上应该跟 PyTorch 平行,而不是放在 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| #ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ | ||
| #define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ | ||
|
|
||
| #ifdef WITH_NINETOOTHED | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 就像评论 https://github.com/InfiniTensor/InfiniOps/pull/616/changes#r3285560677 所说,九齿应该是与 PyTorch 等对应的后端,所以是跟 |
||
|
|
||
| #include <cassert> | ||
| #include <cstdint> | ||
| #include <vector> | ||
|
|
||
| #include "base/rms_norm.h" | ||
| #include "data_type.h" | ||
| #include "native/ninetoothed/tensor.h" | ||
| #include "rms_norm/infiniops_ninetoothed_rms_norm.h" | ||
|
|
||
| #ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE | ||
| #define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256 | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C++ 中尽量不使用宏,尤其是这种可以被 |
||
| #endif | ||
|
|
||
| namespace infini::ops { | ||
|
|
||
| namespace detail { | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我看此处 |
||
|
|
||
| inline NineToothedTensor ExpandedRmsNormWeight(const Tensor& weight, | ||
| const Tensor::Shape& shape, | ||
| std::vector<std::uint64_t>& sizes, | ||
| std::vector<std::int64_t>& strides) { | ||
| sizes.assign(shape.begin(), shape.end()); | ||
| strides.assign(shape.size(), 0); | ||
| strides.back() = weight.strides().empty() ? 1 : weight.strides().back(); | ||
|
|
||
| return NineToothedTensor{const_cast<void*>(weight.data()), sizes.data(), | ||
| strides.data()}; | ||
| } | ||
|
|
||
| inline NineToothedResult LaunchNineToothedRmsNorm( | ||
| const Tensor::Shape& shape, DataType dtype, NineToothedStream stream, | ||
| NineToothedTensor input, NineToothedTensor weight, NineToothedTensor eps, | ||
| NineToothedTensor out, NineToothedTensor num_normalized_elements) { | ||
| const int dtype_index = ninetoothed::DTypeIndex(dtype); | ||
|
|
||
| if (dtype_index < 0) { | ||
| return 1; | ||
| } | ||
|
|
||
| return launch_infiniops_ninetoothed_rms_norm( | ||
| stream, input, weight, eps, out, num_normalized_elements, | ||
| ninetoothed::SizeArg(shape.size()), 1, dtype_index, dtype_index, | ||
| dtype_index, INFINIOPS_NINETOOTHED_BLOCK_SIZE); | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| template <> | ||
| class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm { | ||
| public: | ||
| using RmsNorm::RmsNorm; | ||
| using RmsNorm::operator(); | ||
|
|
||
| void operator()(const Tensor input, const Tensor weight, float eps, | ||
| Tensor out) const override { | ||
| assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() && | ||
| "operator `RmsNorm` requires all input and output tensors to have " | ||
| "the same dtype"); | ||
| assert(input.shape() == out.shape() && | ||
| "ninetoothed `RmsNorm` requires input and output tensors with the " | ||
| "same shape"); | ||
| assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) && | ||
| "ninetoothed `RmsNorm` requires a 1D weight matching the last " | ||
| "dimension"); | ||
| assert((out.ndim() == 2 || out.ndim() == 3) && | ||
| "ninetoothed `RmsNorm` currently supports rank-2 and rank-3 " | ||
| "tensors"); | ||
|
|
||
| std::vector<std::uint64_t> weight_sizes; | ||
| std::vector<std::int64_t> weight_strides; | ||
| double eps_value = static_cast<double>(eps); | ||
| std::int64_t num_normalized_elements = | ||
| static_cast<std::int64_t>(out.size(-1)); | ||
| std::uint64_t empty_shape[1] = {}; | ||
| std::int64_t empty_strides[1] = {}; | ||
|
|
||
| auto result = detail::LaunchNineToothedRmsNorm( | ||
| out.shape(), out.dtype(), static_cast<NineToothedStream>(stream_), | ||
| ninetoothed::FromTensor<NineToothedTensor>(input), | ||
| detail::ExpandedRmsNormWeight(weight, out.shape(), weight_sizes, | ||
| weight_strides), | ||
| ninetoothed::FromScalar<NineToothedTensor>(eps_value, empty_shape, | ||
| empty_strides), | ||
| ninetoothed::FromTensor<NineToothedTensor>(out), | ||
| ninetoothed::FromScalar<NineToothedTensor>( | ||
| num_normalized_elements, empty_shape, empty_strides)); | ||
|
|
||
| assert(result == 0 && "ninetoothed `RmsNorm` launch failed"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace infini::ops | ||
|
|
||
| #endif // WITH_NINETOOTHED | ||
|
|
||
| #endif | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件应当是一个通用工具脚本类的文件,换句话说应当是生成入口,里面不该包含具体的算子,比如 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| import argparse | ||
| import pathlib | ||
| import shutil | ||
| import sys | ||
|
|
||
|
|
||
| _DEFAULT_DTYPES = ("float32", "float16", "bfloat16") | ||
| _DEFAULT_RMS_NORM_NDIMS = (2, 3) | ||
| _SUPPORTED_OPS = ("rms_norm",) | ||
|
|
||
|
|
||
| def _import_ninetoothed(source_dir): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 到底为啥需要这个 helper?去掉它,直接在 top-level |
||
| if source_dir is not None: | ||
| sys.path.insert(0, str(pathlib.Path(source_dir) / "src")) | ||
|
|
||
| import ninetoothed | ||
|
|
||
| return ninetoothed | ||
|
|
||
|
|
||
| def _rms_norm_premake( | ||
| ndim, | ||
| num_normalized_dims, | ||
| input_dtype, | ||
| weight_dtype, | ||
| output_dtype, | ||
| block_size, | ||
| ): | ||
| import ntops | ||
|
|
||
| return ntops.kernels.rms_norm.premake( | ||
| ndim, | ||
| num_normalized_dims, | ||
| input_dtype=input_dtype, | ||
| weight_dtype=weight_dtype, | ||
| output_dtype=output_dtype, | ||
| block_size=block_size, | ||
| ) | ||
|
|
||
|
|
||
| def _normalize_ndims(values): | ||
| ndims = [] | ||
|
|
||
| for value in values: | ||
| ndim = int(value) | ||
|
|
||
| if ndim not in _DEFAULT_RMS_NORM_NDIMS: | ||
| raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}") | ||
|
|
||
| if ndim not in ndims: | ||
| ndims.append(ndim) | ||
|
|
||
| return tuple(ndims) | ||
|
|
||
|
|
||
| def _rms_norm_configs(ninetoothed, dtypes, ndims, block_size): | ||
| configs = [] | ||
|
|
||
| for ndim in _normalize_ndims(ndims): | ||
| for dtype_name in dtypes: | ||
| dtype = getattr(ninetoothed, dtype_name) | ||
| configs.append( | ||
| ( | ||
| (), | ||
| { | ||
| "ndim": ndim, | ||
| "num_normalized_dims": 1, | ||
| "input_dtype": dtype, | ||
| "weight_dtype": dtype, | ||
| "output_dtype": dtype, | ||
| "block_size": block_size, | ||
| }, | ||
| {}, | ||
| ) | ||
| ) | ||
|
|
||
| return tuple(configs) | ||
|
|
||
|
|
||
| def _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size): | ||
| variant_dir = output_dir / "rms_norm" | ||
| variant_dir.mkdir(parents=True, exist_ok=True) | ||
| ninetoothed.build( | ||
| _rms_norm_premake, | ||
| _rms_norm_configs(ninetoothed, dtypes, rms_norm_ndims, block_size), | ||
| meta_parameters=None, | ||
| caller="cuda", | ||
| kernel_name="infiniops_ninetoothed_rms_norm", | ||
| output_dir=variant_dir, | ||
| lazy=False, | ||
| ) | ||
|
|
||
|
|
||
| def _build_manifest(output_dir): | ||
| return sorted( | ||
| str(path) | ||
| for path in pathlib.Path(output_dir).rglob("*.cpp") | ||
| if not path.name.endswith(".tmp.cpp") | ||
| ) | ||
|
|
||
|
|
||
| def _write_cmake_manifest(output_dir, sources): | ||
| manifest_path = pathlib.Path(output_dir) / "manifest.cmake" | ||
| lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"] | ||
| lines.extend(f' "{source}"' for source in sources) | ||
| lines.append(")") | ||
| lines.append("") | ||
| lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")') | ||
| lines.append("") | ||
| manifest_path.write_text("\n".join(lines) + "\n") | ||
|
|
||
|
|
||
| def generate( | ||
| ops, | ||
| *, | ||
| output_dir, | ||
| dtypes=_DEFAULT_DTYPES, | ||
| rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS, | ||
| block_size=256, | ||
| ninetoothed_source_dir=None, | ||
| ): | ||
| unknown_ops = tuple(op for op in ops if op not in _SUPPORTED_OPS) | ||
|
|
||
| if unknown_ops: | ||
| raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}") | ||
|
|
||
| output_dir = pathlib.Path(output_dir) | ||
| shutil.rmtree(output_dir, ignore_errors=True) | ||
| output_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| ninetoothed = _import_ninetoothed(ninetoothed_source_dir) | ||
|
|
||
| if "rms_norm" in ops: | ||
| _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size) | ||
|
|
||
| sources = _build_manifest(output_dir) | ||
| _write_cmake_manifest(output_dir, sources) | ||
|
|
||
| return sources | ||
|
|
||
|
|
||
| def _parse_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Generate ninetoothed operator sources for InfiniOps." | ||
| ) | ||
| parser.add_argument("--output-dir", required=True) | ||
| parser.add_argument("--ops", nargs="+", default=_SUPPORTED_OPS) | ||
| parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES) | ||
| parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS) | ||
| parser.add_argument("--block-size", type=int, default=256) | ||
| parser.add_argument("--ninetoothed-source-dir") | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = _parse_args() | ||
| generate( | ||
| args.ops, | ||
| output_dir=args.output_dir, | ||
| dtypes=tuple(args.dtypes), | ||
| rms_norm_ndims=tuple(args.rms_norm_ndims), | ||
| block_size=args.block_size, | ||
| ninetoothed_source_dir=args.ninetoothed_source_dir, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处不提及 NVIDIA,因为九齿的目标是跨平台,只是目前可能只暴露了
cudacaller,所以跟 PyTorch 的对齐即可。