Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF)

option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处不提及 NVIDIA,因为九齿的目标是跨平台,只是目前可能只暴露了 cuda caller,所以跟 PyTorch 的对齐即可。


# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
Expand All @@ -29,6 +31,14 @@ option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)

# NineToothed code generation configuration.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请在后面创建一个关于 WITH_NINETOOTHEDif 吧,把这些 set 放到这个分支里吧。

set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation")
set(NINETOOTHED_SOURCE_DIR "" CACHE PATH "Optional local ninetoothed source checkout; installed package is used when empty")
set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate")
set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate")
set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed")
set(INFINIOPS_NINETOOTHED_BLOCK_SIZE "256" CACHE STRING "Block size baked into simple NineToothed elementwise kernels")

if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detecting available devices...")

Expand Down Expand Up @@ -231,6 +241,10 @@ if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NINETOOTHED AND NOT WITH_NVIDIA)
message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.")
endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
enable_language(CUDA)
Expand Down
15 changes: 15 additions & 0 deletions scripts/generate_ninetoothed_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pathlib
import sys


def main():
project_dir = pathlib.Path(__file__).resolve().parents[1]
sys.path.insert(0, str(project_dir / "src"))

from native.ninetoothed.codegen import main as codegen_main

codegen_main()


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,57 @@ if(WITH_NVIDIA)
target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1)
target_sources(infiniops PRIVATE ${NVIDIA_SOURCES})

if(WITH_NINETOOTHED)
find_package(Python COMPONENTS Interpreter REQUIRED)

if(NINETOOTHED_PYTHON_EXECUTABLE)
set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}")
elseif(_TORCH_PYTHON)
set(_ninetoothed_python "${_TORCH_PYTHON}")
else()
set(_ninetoothed_python "${Python_EXECUTABLE}")
endif()
message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}")

string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}")
string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}")
string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}")

set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed")
set(_ninetoothed_generator_args
"${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py"
--output-dir "${_ninetoothed_output_dir}"
--ops ${_ninetoothed_ops}
--dtypes ${_ninetoothed_dtypes}
--rms-norm-ndims ${_ninetoothed_rms_norm_ndims}
--block-size "${INFINIOPS_NINETOOTHED_BLOCK_SIZE}")

if(NINETOOTHED_SOURCE_DIR)
list(APPEND _ninetoothed_generator_args
--ninetoothed-source-dir "${NINETOOTHED_SOURCE_DIR}")
endif()

execute_process(
COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args}
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
RESULT_VARIABLE _ninetoothed_generation_result
)

if(NOT _ninetoothed_generation_result EQUAL 0)
message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `triton`, `sympy`, and CUDA dependencies installed.")
endif()

include("${_ninetoothed_output_dir}/manifest.cmake")
set(_ninetoothed_compile_definitions
WITH_NINETOOTHED=1
INFINIOPS_NINETOOTHED_BLOCK_SIZE=${INFINIOPS_NINETOOTHED_BLOCK_SIZE})
target_compile_definitions(infiniops PUBLIC
${_ninetoothed_compile_definitions})
target_include_directories(infiniops PUBLIC
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
endif()

find_package(CUDAToolkit REQUIRED)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver)

Expand Down
101 changes: 101 additions & 0 deletions src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

九齿的定位在算子库中应该跟 PyTorch 差不多,都可以接入到后端里,所以在文件结构上应该跟 PyTorch 平行,而不是放在 cuda 下,现在的九齿可能只有 cuda 这个 caller,但是生成的接口是一致的,只要后期增多了支持,就可以跨平台,跟 PyTorch 一样。

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_
#define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_

#ifdef WITH_NINETOOTHED
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就像评论 https://github.com/InfiniTensor/InfiniOps/pull/616/changes#r3285560677 所说,九齿应该是与 PyTorch 等对应的后端,所以是跟 torch 差不多的文件架构,而咱们算子库都是靠 build system 和脚本来确定最终产物,所以不要在 src 里面的文件使用 WITH_NINETOOTHED 这种类似的宏。事实上,在 C++ 中,我们应当尽量少地使用宏。


#include <cassert>
#include <cstdint>
#include <vector>

#include "base/rms_norm.h"
#include "data_type.h"
#include "native/ninetoothed/tensor.h"
#include "rms_norm/infiniops_ninetoothed_rms_norm.h"

#ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE
#define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ 中尽量不使用宏,尤其是这种可以被 constexpr 或者 const 替代的情况。

#endif

namespace infini::ops {

namespace detail {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看此处 detail 内部的函数比较少,且每个函数的内容也很少,可以考虑直接放在 Operator<RmsNorm, Device::Type::kNvidia, 9>::operator() 里面,暂时不单独抽成独立的 helper 了。


inline NineToothedTensor ExpandedRmsNormWeight(const Tensor& weight,
const Tensor::Shape& shape,
std::vector<std::uint64_t>& sizes,
std::vector<std::int64_t>& strides) {
sizes.assign(shape.begin(), shape.end());
strides.assign(shape.size(), 0);
strides.back() = weight.strides().empty() ? 1 : weight.strides().back();

return NineToothedTensor{const_cast<void*>(weight.data()), sizes.data(),
strides.data()};
}

inline NineToothedResult LaunchNineToothedRmsNorm(
const Tensor::Shape& shape, DataType dtype, NineToothedStream stream,
NineToothedTensor input, NineToothedTensor weight, NineToothedTensor eps,
NineToothedTensor out, NineToothedTensor num_normalized_elements) {
const int dtype_index = ninetoothed::DTypeIndex(dtype);

if (dtype_index < 0) {
return 1;
}

return launch_infiniops_ninetoothed_rms_norm(
stream, input, weight, eps, out, num_normalized_elements,
ninetoothed::SizeArg(shape.size()), 1, dtype_index, dtype_index,
dtype_index, INFINIOPS_NINETOOTHED_BLOCK_SIZE);
}

} // namespace detail

template <>
class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm {
public:
using RmsNorm::RmsNorm;
using RmsNorm::operator();

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() &&
"operator `RmsNorm` requires all input and output tensors to have "
"the same dtype");
assert(input.shape() == out.shape() &&
"ninetoothed `RmsNorm` requires input and output tensors with the "
"same shape");
assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) &&
"ninetoothed `RmsNorm` requires a 1D weight matching the last "
"dimension");
assert((out.ndim() == 2 || out.ndim() == 3) &&
"ninetoothed `RmsNorm` currently supports rank-2 and rank-3 "
"tensors");

std::vector<std::uint64_t> weight_sizes;
std::vector<std::int64_t> weight_strides;
double eps_value = static_cast<double>(eps);
std::int64_t num_normalized_elements =
static_cast<std::int64_t>(out.size(-1));
std::uint64_t empty_shape[1] = {};
std::int64_t empty_strides[1] = {};

auto result = detail::LaunchNineToothedRmsNorm(
out.shape(), out.dtype(), static_cast<NineToothedStream>(stream_),
ninetoothed::FromTensor<NineToothedTensor>(input),
detail::ExpandedRmsNormWeight(weight, out.shape(), weight_sizes,
weight_strides),
ninetoothed::FromScalar<NineToothedTensor>(eps_value, empty_shape,
empty_strides),
ninetoothed::FromTensor<NineToothedTensor>(out),
ninetoothed::FromScalar<NineToothedTensor>(
num_normalized_elements, empty_shape, empty_strides));

assert(result == 0 && "ninetoothed `RmsNorm` launch failed");
}
};

} // namespace infini::ops

#endif // WITH_NINETOOTHED

#endif
165 changes: 165 additions & 0 deletions src/native/ninetoothed/codegen.py
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件应当是一个通用工具脚本类的文件,换句话说应当是生成入口,里面不该包含具体的算子,比如 rms_norm,具体算子的构建脚本应当放到这个算子对应的文件夹内部。否则将来算子越来越多,难不成这个文件要放一堆算子相关的东西嘛?而且这个工具脚本本质上就可以直接是 scripts/generate_ninetoothed_ops.py,不用再倒一手,之前只是说把具体算子拆分出去。要分清哪些是必须在具体算子脚本里做的,而哪些是可以被统筹的。

Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import argparse
import pathlib
import shutil
import sys


_DEFAULT_DTYPES = ("float32", "float16", "bfloat16")
_DEFAULT_RMS_NORM_NDIMS = (2, 3)
_SUPPORTED_OPS = ("rms_norm",)


def _import_ninetoothed(source_dir):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

到底为啥需要这个 helper?去掉它,直接在 top-level import ninetoothed 就行。

if source_dir is not None:
sys.path.insert(0, str(pathlib.Path(source_dir) / "src"))

import ninetoothed

return ninetoothed


def _rms_norm_premake(
ndim,
num_normalized_dims,
input_dtype,
weight_dtype,
output_dtype,
block_size,
):
import ntops

return ntops.kernels.rms_norm.premake(
ndim,
num_normalized_dims,
input_dtype=input_dtype,
weight_dtype=weight_dtype,
output_dtype=output_dtype,
block_size=block_size,
)


def _normalize_ndims(values):
ndims = []

for value in values:
ndim = int(value)

if ndim not in _DEFAULT_RMS_NORM_NDIMS:
raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}")

if ndim not in ndims:
ndims.append(ndim)

return tuple(ndims)


def _rms_norm_configs(ninetoothed, dtypes, ndims, block_size):
configs = []

for ndim in _normalize_ndims(ndims):
for dtype_name in dtypes:
dtype = getattr(ninetoothed, dtype_name)
configs.append(
(
(),
{
"ndim": ndim,
"num_normalized_dims": 1,
"input_dtype": dtype,
"weight_dtype": dtype,
"output_dtype": dtype,
"block_size": block_size,
},
{},
)
)

return tuple(configs)


def _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size):
variant_dir = output_dir / "rms_norm"
variant_dir.mkdir(parents=True, exist_ok=True)
ninetoothed.build(
_rms_norm_premake,
_rms_norm_configs(ninetoothed, dtypes, rms_norm_ndims, block_size),
meta_parameters=None,
caller="cuda",
kernel_name="infiniops_ninetoothed_rms_norm",
output_dir=variant_dir,
lazy=False,
)


def _build_manifest(output_dir):
return sorted(
str(path)
for path in pathlib.Path(output_dir).rglob("*.cpp")
if not path.name.endswith(".tmp.cpp")
)


def _write_cmake_manifest(output_dir, sources):
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"]
lines.extend(f' "{source}"' for source in sources)
lines.append(")")
lines.append("")
lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")')
lines.append("")
manifest_path.write_text("\n".join(lines) + "\n")


def generate(
ops,
*,
output_dir,
dtypes=_DEFAULT_DTYPES,
rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS,
block_size=256,
ninetoothed_source_dir=None,
):
unknown_ops = tuple(op for op in ops if op not in _SUPPORTED_OPS)

if unknown_ops:
raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}")

output_dir = pathlib.Path(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

ninetoothed = _import_ninetoothed(ninetoothed_source_dir)

if "rms_norm" in ops:
_generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size)

sources = _build_manifest(output_dir)
_write_cmake_manifest(output_dir, sources)

return sources


def _parse_args():
parser = argparse.ArgumentParser(
description="Generate ninetoothed operator sources for InfiniOps."
)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--ops", nargs="+", default=_SUPPORTED_OPS)
parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES)
parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS)
parser.add_argument("--block-size", type=int, default=256)
parser.add_argument("--ninetoothed-source-dir")

return parser.parse_args()


def main():
args = _parse_args()
generate(
args.ops,
output_dir=args.output_dir,
dtypes=tuple(args.dtypes),
rms_norm_ndims=tuple(args.rms_norm_ndims),
block_size=args.block_size,
ninetoothed_source_dir=args.ninetoothed_source_dir,
)
Loading
Loading