Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1cc2e35
feat(ascend): add Ascend framework layer — runtime, type mapping, bui…
Apr 8, 2026
44d681a
style(ascend): apply `clang-format` to framework headers
Apr 8, 2026
d091275
fix(ascend): adapt `Memcpy`/`Memset` arity, assert workspace alloc, r…
Apr 8, 2026
11602fa
feat(ascend): add GEMM kernel, NPU test infra, and example integration
Apr 8, 2026
ff03814
fix(ascend): move `aclrtMalloc` out of `assert()` in `WorkspacePool`
Apr 8, 2026
1ccadc0
fix(nvidia): restore `CUDA::cublasLt` link dependency
Apr 8, 2026
2fa9e13
feat(test): add `--devices` option to pytest for platform-name filtering
Apr 8, 2026
cb46b57
fix(nvidia): add missing include and work around NVCC `std::forward` bug
Apr 10, 2026
4462b6e
fix(ci): upgrade NVIDIA CI image to 25.12 and restore `std::forward`
zhangyue207 Apr 10, 2026
40b5858
fix: add explicit narrowing casts in `RotaryEmbedding` initializer list
Apr 10, 2026
c1816eb
style: fix lint issues from PR review
Apr 10, 2026
1f5e0ef
style: fix lint issues in `feat/ascend-framework`
Apr 10, 2026
9a5a5f1
fix: address PR #46 review feedback
Apr 13, 2026
2cb185b
docs(base): add vLLM interface references to `FlashAttention`, `Resha…
Apr 13, 2026
41d67c0
fix(base): rename matmul.h to mat_mul.h to match MatMul class name
Apr 13, 2026
85ef359
fix(nvidia): revert std::forward workaround for NVCC 12.6 bug
Apr 13, 2026
7398f9f
Revert "fix(nvidia): revert std::forward workaround for NVCC 12.6 bug"
Apr 13, 2026
978f948
style: remove extra blank line in `conftest.py`
voltjia Apr 14, 2026
2f5acb4
docs: add TODO for making `eps` optional in `AddRmsNorm`
voltjia Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF)
option(WITH_METAX "Enable MetaX backend" OFF)
option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
option(WITH_MOORE "Enable Moore backend" OFF)
option(WITH_ASCEND "Enable Ascend backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
Expand Down Expand Up @@ -71,20 +72,25 @@ if(AUTO_DETECT_DEVICES)
set(WITH_MOORE OFF)
set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE)
endif()

if(DEFINED ENV{ASCEND_HOME_PATH} OR EXISTS "/dev/davinci0")
set(WITH_ASCEND ON)
message(STATUS "Auto-detected Ascend environment.")
endif()
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

# Only one CUDA-like GPU backend can be enabled at a time.
set(_gpu_backend_count 0)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE WITH_ASCEND)
if(${_gpu_backend})
math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1")
endif()
endforeach()

if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.")
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NVIDIA)
Expand Down Expand Up @@ -178,8 +184,23 @@ if(WITH_CAMBRICON)
find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED)
endif()

if(WITH_ASCEND)
add_compile_definitions(WITH_ASCEND=1)
if(NOT DEFINED ASCEND_HOME)
if(DEFINED ENV{ASCEND_HOME_PATH} AND NOT "$ENV{ASCEND_HOME_PATH}" STREQUAL "")
set(ASCEND_HOME "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Ascend toolkit root")
else()
set(ASCEND_HOME "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Ascend toolkit root")
endif()
endif()
if(NOT EXISTS "${ASCEND_HOME}")
message(FATAL_ERROR "`WITH_ASCEND` is ON but `${ASCEND_HOME}` was not found. Set ASCEND_HOME_PATH.")
endif()
message(STATUS "Using Ascend from `${ASCEND_HOME}`.")
endif()

# If all other platforms are not enabled, CPU is enabled by default.
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON)
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND)
add_compile_definitions(WITH_CPU=1)
endif()

Expand Down
5 changes: 5 additions & 0 deletions examples/runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#elif WITH_MOORE
#include "moore/gemm/mublas.h"
#include "moore/runtime_.h"
#elif WITH_ASCEND
#include "ascend/gemm/kernel.h"
#include "ascend/runtime_.h"
#elif WITH_CPU
#include "cpu/gemm/gemm.h"
#include "cpu/runtime_.h"
Expand All @@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime<Device::Type::kMetax>;
using DefaultRuntimeUtils = Runtime<Device::Type::kCambricon>;
#elif WITH_MOORE
using DefaultRuntimeUtils = Runtime<Device::Type::kMoore>;
#elif WITH_ASCEND
using DefaultRuntimeUtils = Runtime<Device::Type::kAscend>;
#elif WITH_CPU
using DefaultRuntimeUtils = Runtime<Device::Type::kCpu>;
#endif
Expand Down
89 changes: 65 additions & 24 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import pathlib
import re
import shutil
import subprocess
import textwrap
Expand Down Expand Up @@ -91,26 +92,58 @@ def __init__(self, name, constructors, calls):
self.calls = calls


def _find_optional_tensor_params(op_name):
"""Return a set of parameter names declared as `std::optional<Tensor>` in
the base header. `libclang` resolves the type to `int` when the STL
headers are not fully available, so we fall back to a regex scan of the
source text.
"""
source = (_BASE_DIR / f"{op_name}.h").read_text()

return set(re.findall(r"std::optional<Tensor>\s+(\w+)", source))
Comment thread
voltjia marked this conversation as resolved.


def _generate_pybind11(operator):
optional_tensor_params = _find_optional_tensor_params(operator.name)

def _is_optional_tensor(arg):
if arg.spelling in optional_tensor_params:
return True

return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling
Comment thread
voltjia marked this conversation as resolved.

def _generate_params(node):
return (
", ".join(
f"{arg.type.spelling} {arg.spelling}"
for arg in node.get_arguments()
if arg.spelling != "stream"
)
.replace("const Tensor", "py::object")
.replace("Tensor", "py::object")
)
parts = []

for arg in node.get_arguments():
if arg.spelling == "stream":
continue

if _is_optional_tensor(arg):
Comment thread
voltjia marked this conversation as resolved.
parts.append(f"std::optional<py::object> {arg.spelling}")
else:
param = arg.type.spelling.replace("const Tensor", "py::object").replace(
"Tensor", "py::object"
)
parts.append(f"{param} {arg.spelling}")

return ", ".join(parts)
Comment thread
voltjia marked this conversation as resolved.

def _generate_arguments(node):
return ", ".join(
f"TensorFromPybind11Handle({arg.spelling})"
if "Tensor" in arg.type.spelling
else arg.spelling
for arg in node.get_arguments()
if arg.spelling != "stream"
)
args = []

for arg in node.get_arguments():
if arg.spelling == "stream":
continue

if _is_optional_tensor(arg):
Comment thread
voltjia marked this conversation as resolved.
args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})")
elif "Tensor" in arg.type.spelling:
args.append(f"TensorFromPybind11Handle({arg.spelling})")
else:
args.append(arg.spelling)

return ", ".join(args)

op_name = operator.name

Expand All @@ -134,18 +167,24 @@ def _generate_call(op_name, call, method=True):

if not method:
params = (
f"{call_params}, std::size_t implementation_index"
f"{call_params}, std::uintptr_t stream, std::size_t implementation_index"
if call_params
else "std::size_t implementation_index"
else "std::uintptr_t stream, std::size_t implementation_index"
)
py_args = _generate_py_args(call)
py_args_str = f"{py_args}, " if py_args else ""

return f""" m.def("{op_name}", []({params}) {{
Config config;
config.set_implementation_index(implementation_index);
return Self::call({{}}, config, {call_args});
}}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);"""
return (
Comment thread
voltjia marked this conversation as resolved.
f' m.def("{op_name}", []({params}) {{\n'
f" Handle handle;\n"
f" if (stream) {{\n"
f" handle.set_stream(reinterpret_cast<void*>(stream));\n"
f" }}\n"
f" Config config;\n"
f" config.set_implementation_index(implementation_index);\n"
f" return Self::call(handle, config, {call_args});\n"
f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);'
)

return f""" .def("__call__", [](const Self& self, {call_params}) {{
return static_cast<const Operator<Self>&>(self)({call_args});
Expand All @@ -169,6 +208,8 @@ def _generate_call(op_name, call, method=True):

#include "base/{op_name}.h"
#include "config.h"
#include "handle.h"
#include "operator.h"
#include "pybind11_utils.h"

namespace py = pybind11;
Expand Down Expand Up @@ -401,7 +442,7 @@ def _get_all_ops(devices):
nargs="+",
default="cpu",
type=str,
help="Devices to use. Please pick from cpu, nvidia, cambricon, ascend, metax, moore, iluvatar, kunlun, hygon, and qy. (default: cpu)",
help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)",
)

args = parser.parse_args()
Expand Down
50 changes: 50 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,60 @@ if(WITH_CAMBRICON)
list(APPEND DEVICE_LIST "cambricon")
endif()

if(WITH_ASCEND)
# ASCEND_HOME is set by the top-level CMakeLists.txt.
Comment thread
voltjia marked this conversation as resolved.
file(GLOB_RECURSE ASCEND_SOURCES CONFIGURE_DEPENDS
"ascend/*.cc"
"ascend/*.cpp"
)
# Exclude `kernel_impl.cpp` — AscendC device code, not compiled by the host C++ compiler.
list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$")

target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1)
target_sources(infiniops PRIVATE ${ASCEND_SOURCES})

# Resolve the driver lib dir two levels above the toolkit root.
get_filename_component(ASCEND_ROOT "${ASCEND_HOME}/../.." ABSOLUTE)

# Prefer the real driver HAL; fall back to the toolkit stub for build-only
# environments (e.g., Docker CI images without hardware drivers installed).
# CANN <= 8.0: stub at runtime/lib64/stub/; CANN >= 8.5: devlib/<arch>-linux/devlib/.
set(ASCEND_HAL_REAL "${ASCEND_ROOT}/driver/lib64/driver/libascend_hal.so")
set(ASCEND_HAL_STUB "${ASCEND_HOME}/runtime/lib64/stub/libascend_hal.so")
set(ASCEND_HAL_DEVLIB "${ASCEND_HOME}/${CMAKE_SYSTEM_PROCESSOR}-linux/devlib/libascend_hal.so")
if(EXISTS "${ASCEND_HAL_REAL}")
set(ASCEND_HAL_LIB "${ASCEND_HAL_REAL}")
elseif(EXISTS "${ASCEND_HAL_STUB}")
set(ASCEND_HAL_LIB "${ASCEND_HAL_STUB}")
message(STATUS "ascend_hal: driver not found, using stub for linking")
elseif(EXISTS "${ASCEND_HAL_DEVLIB}")
set(ASCEND_HAL_LIB "${ASCEND_HAL_DEVLIB}")
message(STATUS "ascend_hal: driver not found, using devlib for linking")
else()
message(FATAL_ERROR "libascend_hal.so not found (tried ${ASCEND_HAL_REAL}, ${ASCEND_HAL_STUB}, and ${ASCEND_HAL_DEVLIB})")
endif()

target_include_directories(infiniops PUBLIC
"${ASCEND_HOME}/include"
"${ASCEND_HOME}/include/aclnn"
"${ASCEND_HOME}/include/aclnnop")
target_link_libraries(infiniops PUBLIC
"${ASCEND_HOME}/lib64/libascendcl.so"
"${ASCEND_HOME}/lib64/libnnopbase.so"
"${ASCEND_HOME}/lib64/libopapi.so"
"${ASCEND_HAL_LIB}")

list(APPEND DEVICE_LIST "ascend")
endif()

target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

if(GENERATE_PYTHON_BINDINGS)
find_package(Python COMPONENTS Interpreter REQUIRED)
# Always regenerate bindings so the included kernel headers match the
# active device list. Stale generated files (e.g., committed for one
# platform) would omit specializations for other enabled backends,
# causing link-time or runtime failures.
execute_process(
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
Expand Down
56 changes: 56 additions & 0 deletions src/ascend/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef INFINI_OPS_ASCEND_COMMON_H_
#define INFINI_OPS_ASCEND_COMMON_H_

#include <cstdint>
#include <vector>

#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "ascend/data_type_.h"
#include "tensor.h"

namespace infini::ops::ascend {

// Build an `aclTensor` descriptor from an InfiniOps `Tensor`.
//
// When `transpose_last2` is true the last two dimensions are swapped in the
// descriptor (shape and strides) without copying data. This is used by `Gemm`
// and `MatMul` to express a transpose via the view.
inline aclTensor* buildAclTensor(const Tensor& t,
bool transpose_last2 = false) {
std::vector<int64_t> shape(t.shape().begin(), t.shape().end());
std::vector<int64_t> strides(t.strides().begin(), t.strides().end());

if (transpose_last2 && shape.size() >= 2) {
auto n = shape.size();
std::swap(shape[n - 2], shape[n - 1]);
std::swap(strides[n - 2], strides[n - 1]);
}

// Compute the minimum physical storage needed for this strided view.
// For contiguous tensors this equals `numel()`; for non-contiguous (gapped)
// tensors it may be larger; for broadcast (stride-0) tensors it may be
// smaller. Passing the view shape as the storage shape causes
// "ViewShape overlap" errors in ACLNN for non-contiguous inputs.
int64_t storage_elems = 1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
storage_elems = 0;
break;
}
if (strides[i] > 0 && shape[i] > 1) {
storage_elems += static_cast<int64_t>(shape[i] - 1) * strides[i];
}
}
std::vector<int64_t> storage_shape = {storage_elems};

return aclCreateTensor(
shape.data(), static_cast<int64_t>(shape.size()), ToAclDtype(t.dtype()),
strides.data(),
/*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(),
static_cast<int64_t>(storage_shape.size()), const_cast<void*>(t.data()));
}

} // namespace infini::ops::ascend

#endif
61 changes: 61 additions & 0 deletions src/ascend/data_type_.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#ifndef INFINI_OPS_ASCEND_DATA_TYPE__H_
#define INFINI_OPS_ASCEND_DATA_TYPE__H_

#include <cassert>

#include "acl/acl.h"
#include "ascend/device_.h"
#include "data_type.h"

namespace infini::ops::ascend {

inline aclDataType ToAclDtype(DataType dt) {
switch (dt) {
case DataType::kInt8:
return ACL_INT8;
case DataType::kInt16:
return ACL_INT16;
case DataType::kInt32:
return ACL_INT32;
case DataType::kInt64:
return ACL_INT64;
case DataType::kUInt8:
return ACL_UINT8;
case DataType::kUInt16:
return ACL_UINT16;
case DataType::kUInt32:
return ACL_UINT32;
case DataType::kUInt64:
return ACL_UINT64;
case DataType::kFloat16:
return ACL_FLOAT16;
case DataType::kBFloat16:
return ACL_BF16;
case DataType::kFloat32:
return ACL_FLOAT;
default:
assert(false && "Unsupported dtype for Ascend backend.");
return ACL_DT_UNDEFINED;
}
}

// Returns true for integer (signed or unsigned) `DataType` values.
inline bool IsIntegerDtype(DataType dt) {
switch (dt) {
case DataType::kInt8:
case DataType::kInt16:
case DataType::kInt32:
case DataType::kInt64:
case DataType::kUInt8:
case DataType::kUInt16:
case DataType::kUInt32:
case DataType::kUInt64:
return true;
default:
return false;
}
}

} // namespace infini::ops::ascend

#endif
Loading