From 65d40d55143ff7e711cd6de542c8edcddd97e854 Mon Sep 17 00:00:00 2001 From: KaiCHEN-HT Date: Thu, 7 Aug 2025 08:04:13 +0000 Subject: [PATCH 1/2] add ppfno-op source code --- .gitignore | 3 + source/ppfno_op/CMakeLists.txt | 312 ++++++++++++++++ source/ppfno_op/cmake/get_nvcc_flags.py | 23 ++ .../cmake/get_paddle_include_paths.py | 26 ++ .../cmake/get_paddle_library_paths.py | 26 ++ .../fused_segment_csr/fused_segment_csr.h | 21 ++ .../csrc/include/fused_segment_csr/select.h | 40 +++ source/ppfno_op/csrc/src/CMakeLists.txt | 10 + source/ppfno_op/csrc/src/module.cpp | 62 ++++ source/ppfno_op/csrc/src/select.cpp | 167 +++++++++ source/ppfno_op/csrc/src/select.cu | 340 ++++++++++++++++++ .../fused_segment_csr/_C/__init__.pyi | 11 + source/ppfno_op/fused_segment_csr/__init__.py | 4 + .../fused_segment_csr/select_segment_csr.py | 55 +++ source/ppfno_op/fused_segment_csr/version.py | 19 + source/ppfno_op/setup.py | 184 ++++++++++ .../tests/test_fused_segment_csr_select.py | 143 ++++++++ 17 files changed, 1446 insertions(+) create mode 100755 source/ppfno_op/CMakeLists.txt create mode 100755 source/ppfno_op/cmake/get_nvcc_flags.py create mode 100755 source/ppfno_op/cmake/get_paddle_include_paths.py create mode 100755 source/ppfno_op/cmake/get_paddle_library_paths.py create mode 100755 source/ppfno_op/csrc/include/fused_segment_csr/fused_segment_csr.h create mode 100755 source/ppfno_op/csrc/include/fused_segment_csr/select.h create mode 100755 source/ppfno_op/csrc/src/CMakeLists.txt create mode 100755 source/ppfno_op/csrc/src/module.cpp create mode 100755 source/ppfno_op/csrc/src/select.cpp create mode 100755 source/ppfno_op/csrc/src/select.cu create mode 100755 source/ppfno_op/fused_segment_csr/_C/__init__.pyi create mode 100755 source/ppfno_op/fused_segment_csr/__init__.py create mode 100755 source/ppfno_op/fused_segment_csr/select_segment_csr.py create mode 100755 source/ppfno_op/fused_segment_csr/version.py create mode 100755 source/ppfno_op/setup.py create mode 100755 source/ppfno_op/tests/test_fused_segment_csr_select.py diff --git a/.gitignore b/.gitignore index d125233..e2dc1b6 100755 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ logs log *.log *.mp3 +*.ruff_cache # Byte-compiled / optimized / DLL files __pycache__/ @@ -39,6 +40,8 @@ share/python-wheels/ *.egg-info/ .installed.cfg *.egg +*.tar.gz +*.whl MANIFEST # PyInstaller diff --git a/source/ppfno_op/CMakeLists.txt b/source/ppfno_op/CMakeLists.txt new file mode 100755 index 0000000..d277169 --- /dev/null +++ b/source/ppfno_op/CMakeLists.txt @@ -0,0 +1,312 @@ +# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cmake_minimum_required(VERSION 3.11) # for FetchContent +project(fused_segment_csr LANGUAGES CXX) + +include(FetchContent) + +set(THIRD_PARTY_DIR "${CMAKE_CURRENT_BINARY_DIR}/third_party") +if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "") + set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}") +endif() +if(NOT PYBIND11_VERSION) + set(PYBIND11_VERSION stable) +endif() + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) + +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) +add_compile_definitions(PADDLE_WITH_CUDA=1) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC +set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden + +if(MSVC) + string(APPEND CMAKE_CXX_FLAGS " /Wall") + string(APPEND CMAKE_CXX_FLAGS_DEBUG " /Zi") + string(APPEND CMAKE_CXX_FLAGS_RELEASE " /O2 /Ob2") +else() + string(APPEND CMAKE_CXX_FLAGS " -Wall") + string(APPEND CMAKE_CXX_FLAGS_DEBUG " -g -Og") + string(APPEND CMAKE_CXX_FLAGS_RELEASE " -O3") +endif() + +if(NOT DEFINED USE_FP16 AND NOT "$ENV{USE_FP16}" STREQUAL "") + set(USE_FP16 "$ENV{USE_FP16}") +endif() + +if(NOT DEFINED USE_FP16) + set(USE_FP16 OFF) + message(WARNING "FP16 support disabled, compiling without paddle.HalfTensor. Suppress this warning with -DUSE_FP16=ON or -DUSE_FP16=OFF.") +elseif(USE_FP16) + message(STATUS "FP16 support enabled, compiling with paddle.HalfTensor.") +else() + message(STATUS "FP16 support disabled, compiling without paddle.HalfTensor.") +endif() + +if(USE_FP16) + add_definitions(-DUSE_FP16) +endif() + + +function(system) + set(options STRIP) + set(oneValueArgs OUTPUT_VARIABLE ERROR_VARIABLE WORKING_DIRECTORY) + set(multiValueArgs COMMAND) + cmake_parse_arguments( + SYSTEM + "${options}" + "${oneValueArgs}" + "${multiValueArgs}" + "${ARGN}" + ) + + if(NOT DEFINED SYSTEM_WORKING_DIRECTORY) + set(SYSTEM_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") + endif() + + execute_process( + COMMAND ${SYSTEM_COMMAND} + OUTPUT_VARIABLE STDOUT + ERROR_VARIABLE STDERR + WORKING_DIRECTORY "${SYSTEM_WORKING_DIRECTORY}" + ) + + if("${SYSTEM_STRIP}") + string(STRIP "${STDOUT}" STDOUT) + string(STRIP "${STDERR}" STDERR) + endif() + + set("${SYSTEM_OUTPUT_VARIABLE}" "${STDOUT}" PARENT_SCOPE) + + if(DEFINED SYSTEM_ERROR_VARIABLE) + set("${SYSTEM_ERROR_VARIABLE}" "${STDERR}" PARENT_SCOPE) + endif() +endfunction() + +if(NOT DEFINED PYTHON_EXECUTABLE) + if(WIN32) + set(PYTHON_EXECUTABLE "python.exe") + else() + set(PYTHON_EXECUTABLE "python") + endif() +endif() + +if(UNIX) + system( + STRIP OUTPUT_VARIABLE PYTHON_EXECUTABLE + COMMAND bash -c "type -P '${PYTHON_EXECUTABLE}'" + ) +endif() + +system( + STRIP OUTPUT_VARIABLE PYTHON_VERSION + COMMAND "${PYTHON_EXECUTABLE}" -c "print('.'.join(map(str, __import__('sys').version_info[:3])))" +) + +message(STATUS "Use Python version: ${PYTHON_VERSION}") +message(STATUS "Use Python executable: \"${PYTHON_EXECUTABLE}\"") + +if(NOT DEFINED PYTHON_INCLUDE_DIR) + message(STATUS "Auto detecting Python include directory...") + system( + STRIP OUTPUT_VARIABLE PYTHON_INCLUDE_DIR + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('platinclude'))" + ) +endif() + +if("${PYTHON_INCLUDE_DIR}" STREQUAL "") + message(FATAL_ERROR "Python include directory not found") +else() + message(STATUS "Detected Python include directory: \"${PYTHON_INCLUDE_DIR}\"") + include_directories("${PYTHON_INCLUDE_DIR}") +endif() + +system( + STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES + COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('purelib'))" +) +message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"") + +find_package(CUDAToolkit REQUIRED) + +if(CUDAToolkit_FOUND AND NOT WIN32) + message(STATUS "Found CUDA Toolkit, potentially enabling CUDA support.") + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD "${CMAKE_CXX_STANDARD}") + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + add_definitions(-D__USE_CUDA__) + + string(APPEND CMAKE_CUDA_FLAGS " $ENV{PADDLE_NVCC_FLAGS}") + + # Execute Python code to get and process Paddle's supported CUDA architectures + system( + STRIP OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES + COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/get_nvcc_flags.py" + ) + + if(CMAKE_CUDA_ARCHITECTURES) + message(STATUS "Found Paddle CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + else() + message(WARNING "CUDA is not available in the detected Paddle, or Paddle is not installed. Building for all available CUDA architectures.") + # CMake will default to building for all if CMAKE_CUDA_ARCHITECTURES is not set + endif() + + set(CUDA_ARCH_FLAGS "") # No need for cuda_select_nvcc_arch_flags + message(STATUS "CMAKE_CUDA_ARCHITECTURES: \"${CMAKE_CUDA_ARCHITECTURES}\"") + + list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "7.5") + if (USE_FP16) + message(STATUS "Found CUDA Toolkit with FP16 support, compiling with paddle.cuda.HalfTensor.") + string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1" + " -D__CUDA_NO_HALF_OPERATORS__" + " -D__CUDA_NO_HALF_CONVERSIONS__" + " -D__CUDA_NO_HALF2_OPERATORS__" + " -D__CUDA_NO_BFLOAT16_CONVERSIONS__") + else() + message(STATUS "Found CUDA Toolkit with FP16 support, but it is suppressed by the compile options, compiling without paddle.cuda.HalfTensor.") + endif() + else() + message(STATUS "Could not find CUDA Toolkit with FP16 support (version < 7.5), compiling without paddle.cuda.HalfTensor.") + endif() + + foreach(FLAG ${CUDA_NVCC_FLAGS}) + string(FIND "${FLAG}" " " flag_space_position) + if(NOT flag_space_position EQUAL -1) + message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'") + endif() + string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}") + endforeach() + string(STRIP "${CMAKE_CUDA_FLAGS}" CMAKE_CUDA_FLAGS) + message(STATUS "CMAKE_CUDA_FLAGS: \"${CMAKE_CUDA_FLAGS}\"") + + if(MSVC) + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} /O2 /Ob2") + else() + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3") + endif() + + # Ensure the CUDA targets use the defined architectures + set(CUDA_SELECT_NVCC_ARCH_FLAGS_CALLED TRUE) # To avoid potential issues if other parts expect this +else() + if(NOT CUDAToolkit_FOUND) + message(STATUS "CUDA Toolkit not found, build for CPU-only.") + else() + message(STATUS "CUDA Toolkit found, but build for CPU-only on Windows.") + endif() +endif() + +set(PYBIND11_PYTHON_VERSION "${PYTHON_VERSION}") + +if(NOT DEFINED PYBIND11_CMAKE_DIR) + message(STATUS "Auto detecting pybind11 CMake directory...") + system( + STRIP OUTPUT_VARIABLE PYBIND11_CMAKE_DIR + COMMAND "${PYTHON_EXECUTABLE}" -m pybind11 --cmakedir + ) +endif() + +if("${PYBIND11_CMAKE_DIR}" STREQUAL "") + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG "${PYBIND11_VERSION}" + GIT_SHALLOW TRUE + SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11" + BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build" + STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp" + ) + FetchContent_GetProperties(pybind11) + + if(NOT pybind11_POPULATED) + message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third_party/pybind11...") + FetchContent_MakeAvailable(pybind11) + endif() +else() + message(STATUS "Detected Pybind11 CMake directory: \"${PYBIND11_CMAKE_DIR}\"") + find_package(pybind11 CONFIG PATHS "${PYBIND11_CMAKE_DIR}") +endif() + +if(NOT DEFINED PADDLE_INCLUDE_PATH) + message(STATUS "Auto detecting Paddle include directory...") + system( + STRIP OUTPUT_VARIABLE PADDLE_INCLUDE_PATH + COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/get_paddle_include_paths.py" + ) + + if("${PADDLE_INCLUDE_PATH}" STREQUAL "") + set(PADDLE_INCLUDE_PATH "${PYTHON_SITE_PACKAGES}/paddle/include") + endif() +endif() + +if("${PADDLE_INCLUDE_PATH}" STREQUAL "") + message(FATAL_ERROR "Paddle include directory not found. Got: \"${PADDLE_INCLUDE_PATH}\"") +else() + message(STATUS "Detected Paddle include directory: \"${PADDLE_INCLUDE_PATH}\"") + include_directories(${PADDLE_INCLUDE_PATH}) +endif() + +if(NOT DEFINED PADDLE_LIBRARY_PATH) + message(STATUS "Auto detecting Paddle library directory...") + system( + STRIP OUTPUT_VARIABLE PADDLE_LIBRARY_PATH + COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/get_paddle_library_paths.py" + ) + + if("${PADDLE_LIBRARY_PATH}" STREQUAL "") + set(PADDLE_LIBRARY_PATH "${PYTHON_SITE_PACKAGES}/paddle/lib") + endif() +endif() + +if("${PADDLE_LIBRARY_PATH}" STREQUAL "") + message(FATAL_ERROR "Paddle library directory not found. Got: \"${PADDLE_LIBRARY_PATH}\"") +else() + message(STATUS "Detected Paddle library directory: \"${PADDLE_LIBRARY_PATH}\"") +endif() + +set(PADDLE_LIBRARY "") + +foreach(VAR_PATH ${PADDLE_LIBRARY_PATH}) + file(GLOB ALL_FILES_AND_DIRS "${VAR_PATH}/*") + + list(FILTER ALL_FILES_AND_DIRS EXCLUDE REGEX "\\.py$") + list(FILTER ALL_FILES_AND_DIRS EXCLUDE REGEX "\\.h$") + list(FILTER ALL_FILES_AND_DIRS EXCLUDE REGEX "\\.cuh$") + list(FILTER ALL_FILES_AND_DIRS EXCLUDE REGEX "\\.pyi$") + list(FILTER ALL_FILES_AND_DIRS EXCLUDE REGEX "libflashattn\\.so$") + + set(FILTERED_FILES "") + foreach(ITEM ${ALL_FILES_AND_DIRS}) + if(NOT IS_DIRECTORY "${ITEM}") + list(APPEND FILTERED_FILES "${ITEM}") + endif() + endforeach() + + list(APPEND PADDLE_LIBRARY ${FILTERED_FILES}) +endforeach() + +message(STATUS "All Paddle libraries: \"${PADDLE_LIBRARY}\"") + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/csrc/include") +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/csrc/src") diff --git a/source/ppfno_op/cmake/get_nvcc_flags.py b/source/ppfno_op/cmake/get_nvcc_flags.py new file mode 100755 index 0000000..3155b72 --- /dev/null +++ b/source/ppfno_op/cmake/get_nvcc_flags.py @@ -0,0 +1,23 @@ +# import torch +# +# if hasattr(torch.cuda, 'is_available') and torch.cuda.is_available(): +# arch_list = torch.cuda.get_arch_list() +# filtered_arch_list = [arch.replace('sm_', '') for arch in arch_list if not arch.startswith('compute_')] +# else: +# filtered_arch_list = ['75', '86', '90'] +# +# unique_arch_set = set(filtered_arch_list) +# +# sorted_unique_arch_list = sorted(list(unique_arch_set)) +# +# result_string = ';'.join(sorted_unique_arch_list) +# +# print(result_string) + + +def run(): + return "70;75;86;90" + + +if __name__ == "__main__": + print(run()) diff --git a/source/ppfno_op/cmake/get_paddle_include_paths.py b/source/ppfno_op/cmake/get_paddle_include_paths.py new file mode 100755 index 0000000..0870659 --- /dev/null +++ b/source/ppfno_op/cmake/get_paddle_include_paths.py @@ -0,0 +1,26 @@ +import sys +import os + + +def run(): + original_stdout = sys.stdout + original_stderr = sys.stderr + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + + import paddle + + sys.stdout = original_stdout + sys.stderr = original_stderr + + paths = paddle.utils.cpp_extension.extension_utils.find_paddle_includes() + result = "" + for path in paths: + result += f";{path}" + + return result + + +if __name__ == "__main__": + print(run()) diff --git a/source/ppfno_op/cmake/get_paddle_library_paths.py b/source/ppfno_op/cmake/get_paddle_library_paths.py new file mode 100755 index 0000000..8e06427 --- /dev/null +++ b/source/ppfno_op/cmake/get_paddle_library_paths.py @@ -0,0 +1,26 @@ +import sys +import os + + +def run(): + original_stdout = sys.stdout + original_stderr = sys.stderr + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + + import paddle + + sys.stdout = original_stdout + sys.stderr = original_stderr + + paths = paddle.utils.cpp_extension.extension_utils.find_paddle_libraries() + result = "" + for path in paths: + result += f";{path}" + + return result + + +if __name__ == "__main__": + print(run()) diff --git a/source/ppfno_op/csrc/include/fused_segment_csr/fused_segment_csr.h b/source/ppfno_op/csrc/include/fused_segment_csr/fused_segment_csr.h new file mode 100755 index 0000000..46b22a5 --- /dev/null +++ b/source/ppfno_op/csrc/include/fused_segment_csr/fused_segment_csr.h @@ -0,0 +1,21 @@ +#pragma once +#include +#include + +namespace fused_segment_csr { +pybind11::object select_segment_csr_mean(pybind11::object src, + pybind11::object map, + pybind11::object indptr); + +pybind11::object select_segment_csr_sum(pybind11::object src, + pybind11::object map, + pybind11::object indptr); + +pybind11::object select_segment_csr_mean_bwd( + const std::vector& src_shape, pybind11::object grad_output, + pybind11::object map, pybind11::object indptr); + +pybind11::object select_segment_csr_sum_bwd( + const std::vector& src_shape, pybind11::object grad_output, + pybind11::object map, pybind11::object indptr); +} // namespace fused_segment_csr diff --git a/source/ppfno_op/csrc/include/fused_segment_csr/select.h b/source/ppfno_op/csrc/include/fused_segment_csr/select.h new file mode 100755 index 0000000..295898c --- /dev/null +++ b/source/ppfno_op/csrc/include/fused_segment_csr/select.h @@ -0,0 +1,40 @@ +#pragma once +#include + +namespace fused_segment_csr::impl { +void select_segment_csr_mean_1d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_sum_1d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_mean_bwd_1d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_sum_bwd_1d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_mean_2d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_sum_2d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_mean_bwd_2d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr); + +void select_segment_csr_sum_bwd_2d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr); +} // namespace fused_segment_csr::impl diff --git a/source/ppfno_op/csrc/src/CMakeLists.txt b/source/ppfno_op/csrc/src/CMakeLists.txt new file mode 100755 index 0000000..f8169a9 --- /dev/null +++ b/source/ppfno_op/csrc/src/CMakeLists.txt @@ -0,0 +1,10 @@ +set( + fused_segment_csr_csrc + module.cpp + select.cpp + select.cu +) + +pybind11_add_module(_C MODULE "${fused_segment_csr_csrc}") + +target_link_libraries(_C PRIVATE ${PADDLE_LIBRARY} CUDA::cudart) diff --git a/source/ppfno_op/csrc/src/module.cpp b/source/ppfno_op/csrc/src/module.cpp new file mode 100755 index 0000000..668e9ca --- /dev/null +++ b/source/ppfno_op/csrc/src/module.cpp @@ -0,0 +1,62 @@ +/* +Copyright 2022-2025 TheCoreTeam. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +================================================================================ +*/ + +#include +#include +#include + +namespace fused_segment_csr { +void BuildModule(pybind11::module_& mod) { + // NOLINT[runtime/references] + + mod.doc() = ""; + mod.attr("Py_TPFLAGS_BASETYPE") = pybind11::int_(Py_TPFLAGS_BASETYPE); +#ifdef _GLIBCXX_USE_CXX11_ABI + // NOLINTNEXTLINE[modernize-use-bool-literals] + mod.attr("GLIBCXX_USE_CXX11_ABI") = + pybind11::bool_(static_cast(_GLIBCXX_USE_CXX11_ABI)); +#else + mod.attr("GLIBCXX_USE_CXX11_ABI") = pybind11::bool_(false); +#endif + + mod.def("select_segment_csr_mean", &select_segment_csr_mean, "", + pybind11::arg("src"), pybind11::arg("idx_map"), + pybind11::arg("indptr")); + + mod.def("select_segment_csr_sum", &select_segment_csr_sum, "", + pybind11::arg("src"), pybind11::arg("idx_map"), + pybind11::arg("indptr")); + + mod.def("select_segment_csr_mean_bwd", &select_segment_csr_mean_bwd, "", + pybind11::arg("src_shape"), pybind11::arg("grad_output"), + pybind11::arg("idx_map"), pybind11::arg("indptr")); + + mod.def("select_segment_csr_sum_bwd", &select_segment_csr_sum_bwd, "", + pybind11::arg("src_shape"), pybind11::arg("grad_output"), + pybind11::arg("idx_map"), pybind11::arg("indptr")); +} +} // namespace fused_segment_csr + +#if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 +// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] +PYBIND11_MODULE(_C, mod, pybind11::mod_gil_not_used()) { + fused_segment_csr::BuildModule(mod); +} +#else +// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] +PYBIND11_MODULE(_C, mod) { fused_segement_csr::BuildModule(mod); } +#endif diff --git a/source/ppfno_op/csrc/src/select.cpp b/source/ppfno_op/csrc/src/select.cpp new file mode 100755 index 0000000..71e70c1 --- /dev/null +++ b/source/ppfno_op/csrc/src/select.cpp @@ -0,0 +1,167 @@ +#include "fused_segment_csr/select.h" + +#include +#include +#include +#include + +namespace fused_segment_csr { +pybind11::object select_segment_csr_sum(pybind11::object src_, + pybind11::object map_, + pybind11::object indptr_) { + const auto src = pybind11::cast(src_); + const auto map = pybind11::cast(map_); + const auto indptr = pybind11::cast(indptr_); + + PD_CHECK(map.shape().size() == 1); + PD_CHECK(indptr.shape().size() == 1); + + PD_CHECK(src.is_gpu()); + PD_CHECK(map.is_gpu()); + PD_CHECK(indptr.is_gpu()); + + PD_CHECK(src.is_contiguous()); + PD_CHECK(map.is_contiguous()); + PD_CHECK(indptr.is_contiguous()); + + paddle::Tensor out; + if (src.shape().size() == 1) { + out = paddle::empty({indptr.shape()[0] - 1}, src.dtype(), src.place()); + + impl::select_segment_csr_sum_1d(out, src, map, indptr); + } else if (src.shape().size() == 2) { + out = paddle::empty({indptr.shape()[0] - 1, src.shape()[1]}, src.dtype(), + src.place()); + + impl::select_segment_csr_sum_2d(out, src, map, indptr); + } else { + throw std::invalid_argument( + "select_segment_csr_sum only supports 1D or 2D tensors, but got " + + std::to_string(src.shape().size()) + "D tensor."); + } + + return pybind11::cast(out); +} + +pybind11::object select_segment_csr_mean(pybind11::object src_, + pybind11::object map_, + pybind11::object indptr_) { + const auto src = pybind11::cast(src_); + const auto map = pybind11::cast(map_); + const auto indptr = pybind11::cast(indptr_); + + PD_CHECK(map.shape().size() == 1); + PD_CHECK(indptr.shape().size() == 1); + + PD_CHECK(src.is_gpu()); + PD_CHECK(map.is_gpu()); + PD_CHECK(indptr.is_gpu()); + + PD_CHECK(src.is_contiguous()); + PD_CHECK(map.is_contiguous()); + PD_CHECK(indptr.is_contiguous()); + + paddle::Tensor out; + + if (src.shape().size() == 1) { + out = paddle::empty({indptr.shape()[0] - 1}, src.dtype(), src.place()); + + impl::select_segment_csr_mean_1d(out, src, map, indptr); + } else if (src.shape().size() == 2) { + out = paddle::empty({indptr.shape()[0] - 1, src.shape()[1]}, src.dtype(), + src.place()); + + impl::select_segment_csr_mean_2d(out, src, map, indptr); + } else { + throw std::invalid_argument( + "select_segment_csr_mean only supports 1D or 2D tensors, but got " + + std::to_string(src.shape().size()) + "D tensor."); + } + + return pybind11::cast(out); +} + +pybind11::object select_segment_csr_mean_bwd( + const std::vector& src_shape, pybind11::object grad_output_, + pybind11::object map_, pybind11::object indptr_) { + const auto grad_output = pybind11::cast(grad_output_); + const auto map = pybind11::cast(map_); + const auto indptr = pybind11::cast(indptr_); + + PD_CHECK(map.shape().size() == 1); + PD_CHECK(indptr.shape().size() == 1); + + PD_CHECK(map.is_gpu()); + PD_CHECK(indptr.is_gpu()); + + // For grad_output + PD_CHECK(grad_output.shape()[0] == indptr.shape()[0] - 1); + PD_CHECK(grad_output.is_gpu()); + + PD_CHECK(map.is_contiguous()); + PD_CHECK(indptr.is_contiguous()); + PD_CHECK(grad_output.is_contiguous()); + + paddle::Tensor grad_input = + paddle::zeros(src_shape, grad_output.dtype(), grad_output.place()); + + if (src_shape.size() == 1) { + PD_CHECK(grad_output.shape().size() == 1); + + impl::select_segment_csr_mean_bwd_1d(grad_input, grad_output, map, indptr); + } else if (src_shape.size() == 2) { + PD_CHECK(grad_output.shape().size() == 2); + PD_CHECK(grad_output.shape()[1] == src_shape[1]); + + impl::select_segment_csr_mean_bwd_2d(grad_input, grad_output, map, indptr); + } else { + throw std::invalid_argument( + "select_segment_csr_mean_bwd only supports 1D or 2D tensors, but got " + + std::to_string(src_shape.size()) + "D tensor."); + } + + return pybind11::cast(grad_input); +} + +pybind11::object select_segment_csr_sum_bwd( + const std::vector& src_shape, pybind11::object grad_output_, + pybind11::object map_, pybind11::object indptr_) { + const auto grad_output = pybind11::cast(grad_output_); + const auto map = pybind11::cast(map_); + const auto indptr = pybind11::cast(indptr_); + + PD_CHECK(map.shape().size() == 1); + PD_CHECK(indptr.shape().size() == 1); + + PD_CHECK(map.is_gpu()); + PD_CHECK(indptr.is_gpu()); + + // For grad_output + PD_CHECK(grad_output.shape()[0] == indptr.shape()[0] - 1); + PD_CHECK(grad_output.is_gpu()); + + PD_CHECK(map.is_contiguous()); + PD_CHECK(indptr.is_contiguous()); + PD_CHECK(grad_output.is_contiguous()); + + paddle::Tensor grad_input = + paddle::zeros(src_shape, grad_output.dtype(), grad_output.place()); + + if (src_shape.size() == 1) { + PD_CHECK(grad_output.shape().size() == 1); + + impl::select_segment_csr_sum_bwd_1d(grad_input, grad_output, map, indptr); + } else if (src_shape.size() == 2) { + PD_CHECK(grad_output.shape().size() == 2); + PD_CHECK(grad_output.shape()[1] == src_shape[1]); + + impl::select_segment_csr_sum_bwd_2d(grad_input, grad_output, map, indptr); + } else { + throw std::invalid_argument( + "select_segment_csr_sum_bwd only supports 1D or 2D tensors, but got " + + std::to_string(src_shape.size()) + "D tensor."); + } + + return pybind11::cast(grad_input); +} +} // namespace fused_segment_csr diff --git a/source/ppfno_op/csrc/src/select.cu b/source/ppfno_op/csrc/src/select.cu new file mode 100755 index 0000000..3246faa --- /dev/null +++ b/source/ppfno_op/csrc/src/select.cu @@ -0,0 +1,340 @@ +#include +#include +#include + +static constexpr int64_t BLOCK_SIZE = 128; + +template +auto ceil_div(const T1 a, const T2 b) { + return (a + b - 1) / b; +} + +namespace fused_segment_csr::impl { +namespace { +template +__global__ __launch_bounds__(BLOCK_SIZE) void select_segment_csr_sum_1d_kernel( + T* __restrict__ out, const size_t num_rows, const T* __restrict__ src, + const int64_t* __restrict__ map, const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid; + if (row >= num_rows) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + double sum = 0; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + sum += src[selected_row]; + } + out[row] = sum; +} + +template +__global__ __launch_bounds__(BLOCK_SIZE) void select_segment_csr_mean_1d_kernel( + T* __restrict__ out, const size_t num_rows, const T* __restrict__ src, + const int64_t* __restrict__ map, const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid; + if (row >= num_rows) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + double sum = 0; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + sum += src[selected_row]; + } + const auto row_length = row_end - row_start; + if (row_length == 0) { + out[row] = 0; + } else { + out[row] = sum / row_length; + } +} +} // namespace + +void select_segment_csr_sum_1d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = out.shape()[0]; + + const auto total_count = num_rows; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "select_segment_csr_sum", [&] { + select_segment_csr_sum_1d_kernel<<>>( + out.data(), num_rows, src.data(), map.data(), + indptr.data()); + }); +} + +void select_segment_csr_mean_1d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = out.shape()[0]; + + const auto total_count = num_rows; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "select_segment_csr_mean", [&] { + select_segment_csr_mean_1d_kernel<<>>( + out.data(), num_rows, src.data(), map.data(), + indptr.data()); + }); +} + +namespace { +template +__global__ +__launch_bounds__(BLOCK_SIZE) void select_segment_csr_mean_bwd_1d_kernel( + T* __restrict__ grad_input, const size_t num_rows, + const T* __restrict__ grad_output, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid; + if (row >= num_rows) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + const auto scale = row_end - row_start; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + const auto grad = grad_output[row]; + atomicAdd(&grad_input[selected_row], grad / scale); + } +} + +template +__global__ +__launch_bounds__(BLOCK_SIZE) void select_segment_csr_sum_bwd_1d_kernel( + T* __restrict__ grad_input, const size_t num_rows, + const T* __restrict__ grad_output, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid; + if (row >= num_rows) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + const auto grad = grad_output[row]; + atomicAdd(&grad_input[selected_row], grad); + } +} +} // namespace + +void select_segment_csr_mean_bwd_1d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = grad_output.shape()[0]; + + const auto total_count = num_rows; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES( + grad_output.type(), "select_segment_csr_mean_bwd", [&] { + select_segment_csr_mean_bwd_1d_kernel<<>>( + grad_input.data(), num_rows, grad_output.data(), + map.data(), indptr.data()); + }); +} + +void select_segment_csr_sum_bwd_1d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = grad_output.shape()[0]; + + const auto total_count = num_rows; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES( + grad_output.type(), "select_segment_csr_sum_bwd", [&] { + select_segment_csr_sum_bwd_1d_kernel<<>>( + grad_input.data(), num_rows, grad_output.data(), + map.data(), indptr.data()); + }); +} + +namespace { +template +__global__ __launch_bounds__(BLOCK_SIZE) void select_segment_csr_sum_2d_kernel( + T* __restrict__ out, const size_t num_rows, const size_t num_cols, + const T* __restrict__ src, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid / num_cols; + const auto col = tid % num_cols; + if (row >= num_rows || col >= num_cols) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + double sum = 0; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + sum += src[selected_row * num_cols + col]; + } + out[row * num_cols + col] = sum; +} + +template +__global__ __launch_bounds__(BLOCK_SIZE) void select_segment_csr_mean_2d_kernel( + T* __restrict__ out, const size_t num_rows, const size_t num_cols, + const T* __restrict__ src, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid / num_cols; + const auto col = tid % num_cols; + if (row >= num_rows || col >= num_cols) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + double sum = 0; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + sum += src[selected_row * num_cols + col]; + } + const auto row_length = row_end - row_start; + if (row_length == 0) { + out[row * num_cols + col] = 0; + } else { + out[row * num_cols + col] = sum / row_length; + } +} +} // namespace + +void select_segment_csr_sum_2d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = out.shape()[0]; + const auto num_cols = out.shape()[1]; + + const auto total_count = num_rows * num_cols; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "select_segment_csr_sum", [&] { + select_segment_csr_sum_2d_kernel<<>>( + out.data(), num_rows, num_cols, src.data(), + map.data(), indptr.data()); + }); +} + +void select_segment_csr_mean_2d(paddle::Tensor& out, const paddle::Tensor& src, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = out.shape()[0]; + const auto num_cols = out.shape()[1]; + + const auto total_count = num_rows * num_cols; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "select_segment_csr_mean", [&] { + select_segment_csr_mean_2d_kernel<<>>( + out.data(), num_rows, num_cols, src.data(), + map.data(), indptr.data()); + }); +} + +namespace { +template +__global__ +__launch_bounds__(BLOCK_SIZE) void select_segment_csr_mean_bwd_2d_kernel( + T* __restrict__ grad_input, const size_t num_rows, const size_t num_cols, + const T* __restrict__ grad_output, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid / num_cols; + const auto col = tid % num_cols; + if (row >= num_rows || col >= num_cols) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + const auto scale = row_end - row_start; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + const auto grad = grad_output[row * num_cols + col]; + atomicAdd(&grad_input[selected_row * num_cols + col], grad / scale); + } +} + +template +__global__ +__launch_bounds__(BLOCK_SIZE) void select_segment_csr_sum_bwd_2d_kernel( + T* __restrict__ grad_input, const size_t num_rows, const size_t num_cols, + const T* __restrict__ grad_output, const int64_t* __restrict__ map, + const int64_t* __restrict__ indptr) { + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto row = tid / num_cols; + const auto col = tid % num_cols; + if (row >= num_rows || col >= num_cols) { + return; + } + const auto row_start = indptr[row]; + const auto row_end = indptr[row + 1]; + for (auto i = row_start; i < row_end; ++i) { + const auto selected_row = map[i]; + const auto grad = grad_output[row * num_cols + col]; + atomicAdd(&grad_input[selected_row * num_cols + col], grad); + } +} +} // namespace + +void select_segment_csr_mean_bwd_2d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = grad_output.shape()[0]; + const auto num_cols = grad_output.shape()[1]; + + const auto total_count = num_rows * num_cols; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES( + grad_output.type(), "select_segment_csr_mean_bwd", [&] { + select_segment_csr_mean_bwd_2d_kernel<<>>( + grad_input.data(), num_rows, num_cols, + grad_output.data(), map.data(), + indptr.data()); + }); +} + +void select_segment_csr_sum_bwd_2d(paddle::Tensor& grad_input, + const paddle::Tensor& grad_output, + const paddle::Tensor& map, + const paddle::Tensor& indptr) { + const auto num_rows = grad_output.shape()[0]; + const auto num_cols = grad_output.shape()[1]; + + const auto total_count = num_rows * num_cols; + const dim3 block(min(BLOCK_SIZE, total_count)); + const dim3 grid(ceil_div(total_count, block.x)); + + PD_DISPATCH_FLOATING_TYPES( + grad_output.type(), "select_segment_csr_sum_bwd", [&] { + select_segment_csr_sum_bwd_2d_kernel<<>>( + grad_input.data(), num_rows, num_cols, + grad_output.data(), map.data(), + indptr.data()); + }); +} +} // namespace fused_segment_csr::impl diff --git a/source/ppfno_op/fused_segment_csr/_C/__init__.pyi b/source/ppfno_op/fused_segment_csr/_C/__init__.pyi new file mode 100755 index 0000000..2dda60b --- /dev/null +++ b/source/ppfno_op/fused_segment_csr/_C/__init__.pyi @@ -0,0 +1,11 @@ +def select_segment_csr_mean(src, idx_map, indptr): + pass + +def select_segment_csr_sum(src, idx_map, indptr): + pass + +def select_segment_csr_mean_bwd(src_shape, grad_output, idx_map, indptr): + pass + +def select_segment_csr_sum_bwd(src_shape, grad_output, idx_map, indptr): + pass diff --git a/source/ppfno_op/fused_segment_csr/__init__.py b/source/ppfno_op/fused_segment_csr/__init__.py new file mode 100755 index 0000000..d3f0280 --- /dev/null +++ b/source/ppfno_op/fused_segment_csr/__init__.py @@ -0,0 +1,4 @@ +from fused_segment_csr.select_segment_csr import select_segment_csr +from fused_segment_csr.version import __version__ + +__all__ = ["select_segment_csr", "__version__"] diff --git a/source/ppfno_op/fused_segment_csr/select_segment_csr.py b/source/ppfno_op/fused_segment_csr/select_segment_csr.py new file mode 100755 index 0000000..72633bb --- /dev/null +++ b/source/ppfno_op/fused_segment_csr/select_segment_csr.py @@ -0,0 +1,55 @@ +import paddle +import fused_segment_csr._C as _C + + +class SelectSegmentCsrMean(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, *args, **kwargs): + assert len(args) == 3 + assert len(kwargs) == 0 + src, idx_map, indptr = args + ctx.src_shape = src.shape + out= _C.select_segment_csr_mean(src, idx_map, indptr) + ctx.save_for_backward(idx_map, indptr) + return out + + @staticmethod + def backward(ctx, *args): + assert len(args) == 1 + grad_output = args[0] + idx_map, indptr = ctx.saved_tensor() + src_shape = ctx.src_shape + return _C.select_segment_csr_mean_bwd( + src_shape, grad_output, idx_map, indptr + ) + + +class SelectSegmentCsrSum(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, *args, **kwargs): + assert len(args) == 3 + assert len(kwargs) == 0 + src, idx_map, indptr = args + ctx.src_shape = src.shape + ctx.save_for_backward(idx_map, indptr) + return _C.select_segment_csr_sum(src, idx_map, indptr) + + @staticmethod + def backward(ctx, *args): + assert len(args) == 1 + grad_output = args[0] + idx_map, indptr = ctx.saved_tensor() + src_shape = ctx.src_shape + return _C.select_segment_csr_sum_bwd(src_shape, grad_output, idx_map, indptr) + + +def select_segment_csr(src, idx_map, indptr, reduce="sum"): + if reduce == "sum": + return SelectSegmentCsrSum.apply(src, idx_map, indptr) + elif reduce == "mean": + return SelectSegmentCsrMean.apply(src, idx_map, indptr) + else: + raise ValueError(f"Unsupported reduce: {reduce}. Use 'sum' or 'mean'.") + + +__all__ = ["select_segment_csr"] diff --git a/source/ppfno_op/fused_segment_csr/version.py b/source/ppfno_op/fused_segment_csr/version.py new file mode 100755 index 0000000..1754945 --- /dev/null +++ b/source/ppfno_op/fused_segment_csr/version.py @@ -0,0 +1,19 @@ +# Copyright 2022-2025 TheCoreTeam. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = "0.0.0" +__license__ = "Apache License, Version 2.0" +__author__ = "TheCoreTeam" +__release__ = False diff --git a/source/ppfno_op/setup.py b/source/ppfno_op/setup.py new file mode 100755 index 0000000..aa191e9 --- /dev/null +++ b/source/ppfno_op/setup.py @@ -0,0 +1,184 @@ +# Copyright 2022-2025 TheCoreTeam. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import contextlib +import os +import pathlib +import platform +import re +import shutil +import sys +import sysconfig +from importlib.util import module_from_spec, spec_from_file_location + +from setuptools import Extension, setup, find_packages +from setuptools.command.build_ext import build_ext + +HERE = pathlib.Path(__file__).absolute().parent + + +class CMakeExtension(Extension): + def __init__(self, name, source_dir=".", target=None, **kwargs): + super().__init__(name, sources=[], **kwargs) + self.source_dir = os.path.abspath(source_dir) + self.target = target if target is not None else name.rpartition(".")[-1] + + +class cmake_build_ext(build_ext): + def build_extension(self, ext): + if not isinstance(ext, CMakeExtension): + super().build_extension(ext) + return + + from cmake import get_paddle_include_paths, get_paddle_library_paths + + cmake = shutil.which("cmake") + if cmake is None: + raise RuntimeError("Cannot find CMake executable.") + + ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute() + build_temp = pathlib.Path(self.build_temp).absolute() + build_temp.mkdir(parents=True, exist_ok=True) + + config = "Debug" if self.debug else "Release" + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={config}", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}", + f"-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}', + f"-DPADDLE_INCLUDE_PATH={get_paddle_include_paths.run()}", + f"-DPADDLE_LIBRARY_PATH={get_paddle_library_paths.run()}", + ] + + if platform.system() == "Darwin": + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(archs)}') + + try: + import pybind11 + + cmake_args.append(f"-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}") + except ImportError: + pass + + build_args = ["--config", config] + if ( + "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ + and hasattr(self, "parallel") + and self.parallel + ): + build_args.extend(["--parallel", str(self.parallel)]) + else: + build_args.append("--parallel") + + build_args.extend(["--target", ext.target, "--"]) + + cwd = os.getcwd() + try: + os.chdir(build_temp) + + self.announce("Configuring with CMake...", level=3) + try: + self.spawn([cmake, ext.source_dir, *cmake_args]) + except Exception as e: + self.announce("CMake configuring failed: {}".format(str(e)), level=3) + raise RuntimeError(f"CMake configuring failed: {str(e)}") + + if not self.dry_run: + self.announce("Building with cmake...", level=3) + try: + self.spawn([cmake, "--build", ".", *build_args]) + except Exception as e: + self.announce("CMake building failed: {}".format(str(e)), level=3) + raise RuntimeError(f"CMake building failed: {str(e)}") + finally: + os.chdir(cwd) + + +@contextlib.contextmanager +def vcs_version(name, path): + path = pathlib.Path(path).absolute() + assert path.is_file() + module_spec = spec_from_file_location(name=name, location=path) + assert module_spec is not None + assert module_spec.loader is not None + module = sys.modules.get(name) + if module is None: + module = module_from_spec(module_spec) + sys.modules[name] = module + module_spec.loader.exec_module(module) + + if module.__release__: + yield module + return + + content = None + try: + try: + content = path.read_text(encoding="utf-8") + path.write_text( + data=re.sub( + r"""__version__\s*=\s*('[^']+'|"[^"]+")""", + f"__version__ = {module.__version__!r}", + string=content, + ), + encoding="utf-8", + ) + except OSError: + content = None + + yield module + finally: + if content is not None: + with path.open(mode="wt", encoding="utf-8", newline="") as file: + file.write(content) + + +CIBUILDWHEEL = os.getenv("CIBUILDWHEEL", "0") == "1" +LINUX = platform.system() == "Linux" +MACOS = platform.system() == "Darwin" +WINDOWS = platform.system() == "Windows" +ext_kwargs = { + "cmdclass": {"build_ext": cmake_build_ext}, + "ext_modules": [ + CMakeExtension( + "fused_segment_csr._C", + source_dir=HERE, + optional=not (LINUX and CIBUILDWHEEL), + ), + ], +} + +FUSED_SEGMENT_CSR_NO_EXTENSIONS = ( + bool(os.getenv("FUSED_SEGMENT_CSR_NO_EXTENSIONS", "")) or WINDOWS or MACOS +) +if FUSED_SEGMENT_CSR_NO_EXTENSIONS: + ext_kwargs.clear() + +with vcs_version( + name="fused_segment_csr.version", path=(HERE / "fused_segment_csr" / "version.py") +) as version: + setup( + name="fused_segment_csr", + version=version.__version__, + packages=find_packages(), + package_data={ + "fused_segment_csr": ["*.pyi", "**/*.pyi"], + }, + **ext_kwargs, + ) diff --git a/source/ppfno_op/tests/test_fused_segment_csr_select.py b/source/ppfno_op/tests/test_fused_segment_csr_select.py new file mode 100755 index 0000000..a0427e8 --- /dev/null +++ b/source/ppfno_op/tests/test_fused_segment_csr_select.py @@ -0,0 +1,143 @@ +import paddle +from typing import Literal +import fused_segment_csr + + +def segment_csr( + src: paddle.Tensor, indptr: paddle.Tensor, reduce: Literal["mean", "sum"] +): + """segment_csr reduces all entries of a CSR-formatted + matrix by summing or averaging over neighbors. + + Used to reduce features over neighborhoods + in neuralop.layers.IntegralTransform + + Parameters + ---------- + src : torch.Tensor + tensor of features for each point + indptr : torch.Tensor + splits representing start and end indices + of each neighborhood in src + reduce : Literal['mean', 'sum'] + how to reduce a neighborhood. if mean, + reduce by taking the average of all neighbors. + Otherwise take the sum. + """ + if reduce not in ["mean", "sum"]: + raise ValueError("reduce must be one of 'mean', 'sum'") + + n_nbrs = indptr[1:] - indptr[:-1] + output_shape = list(tuple(src.shape)) + output_shape[0] = tuple(indptr.shape)[0] - 1 + out = paddle.zeros(shape=output_shape) + for i, start in enumerate(indptr[:-1]): + if start == tuple(src.shape)[0]: + break + for j in range(n_nbrs[i]): + out[i] += src[start + j] + + if reduce == "mean": + out_result = paddle.empty_like(out) + for i, start in enumerate(indptr[:-1]): + if start == tuple(src.shape)[0]: + break + if n_nbrs[i] != 0: + out_result[i] = out[i] / n_nbrs[i] + return out_result + return out + + +if __name__ == "__main__": + sample_num = 10 + col_num = 20 + selected_num = 30 + + def test_mean_1d(): + src = paddle.rand([sample_num], dtype="float32") + src.stop_gradient = False + idx_map = paddle.randint(0, sample_num, [selected_num], dtype="int64") + selected_sample = src[idx_map] + csr_index = paddle.to_tensor([0, 3, 4, 7, 9, 11, 15, 20, 20, 21], dtype="int64") + + out_ref = segment_csr(selected_sample, csr_index, "mean") + + src_clone = src.clone().detach_() + src_clone.stop_gradient = False + out = fused_segment_csr.select_segment_csr( + src_clone, idx_map, csr_index, "mean" + ) + assert out.shape == out_ref.shape + assert paddle.allclose(out, out_ref) + + grad = paddle.rand(out.shape, out.dtype, out.place) + out.backward(grad) + out_ref.backward(grad) + assert paddle.allclose(src_clone.grad, src.grad) + + def test_sum_1d(): + src = paddle.rand([sample_num], dtype="float32") + src.stop_gradient = False + idx_map = paddle.randint(0, sample_num, [selected_num], dtype="int64") + selected_sample = src[idx_map] + csr_index = paddle.to_tensor([0, 3, 4, 7, 9, 11, 15, 20, 20, 21], dtype="int64") + + out_ref = segment_csr(selected_sample, csr_index, "sum") + + src_clone = src.clone().detach_() + src_clone.stop_gradient = False + out = fused_segment_csr.select_segment_csr(src_clone, idx_map, csr_index, "sum") + assert out.shape == out_ref.shape + assert paddle.allclose(out, out_ref) + + grad = paddle.rand(out.shape, out.dtype, out.place) + out.backward(grad) + out_ref.backward(grad) + assert paddle.allclose(src_clone.grad, src.grad) + + def test_mean_2d(): + src = paddle.rand([sample_num, col_num], dtype="float32") + src.stop_gradient = False + idx_map = paddle.randint(0, sample_num, [selected_num], dtype="int64") + selected_sample = src[idx_map] + csr_index = paddle.to_tensor([0, 3, 4, 7, 9, 11, 15, 20, 20, 21], dtype="int64") + + out_ref = segment_csr(selected_sample, csr_index, "mean") + + src_clone = src.clone().detach_() + src_clone.stop_gradient = False + out = fused_segment_csr.select_segment_csr( + src_clone, idx_map, csr_index, "mean" + ) + assert out.shape == out_ref.shape + assert paddle.allclose(out, out_ref) + + grad = paddle.rand(out.shape, out.dtype, out.place) + out.backward(grad) + out_ref.backward(grad) + assert paddle.allclose(src_clone.grad, src.grad) + + def test_sum_2d(): + src = paddle.rand([sample_num, col_num], dtype="float32") + src.stop_gradient = False + idx_map = paddle.randint(0, sample_num, [selected_num], dtype="int64") + selected_sample = src[idx_map] + csr_index = paddle.to_tensor([0, 3, 4, 7, 9, 11, 15, 20, 20, 21], dtype="int64") + + out_ref = segment_csr(selected_sample, csr_index, "sum") + + src_clone = src.clone().detach_() + src_clone.stop_gradient = False + out = fused_segment_csr.select_segment_csr(src_clone, idx_map, csr_index, "sum") + assert out.shape == out_ref.shape + assert paddle.allclose(out, out_ref) + + grad = paddle.rand(out.shape, out.dtype, out.place) + out.backward(grad) + out_ref.backward(grad) + assert paddle.allclose(src_clone.grad, src.grad) + + test_mean_1d() + test_sum_1d() + test_mean_2d() + test_sum_2d() From 6250e4d881da6d4d4bb28e18bd2ac14656dc21af Mon Sep 17 00:00:00 2001 From: KaiCHEN-HT Date: Thu, 7 Aug 2025 08:20:49 +0000 Subject: [PATCH 2/2] update readme about custom op compiling --- README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 88ead11..25d6f88 100755 --- a/README.md +++ b/README.md @@ -47,17 +47,16 @@ python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn wget https://paddle-org.bj.bcebos.com/paddlecfd/envs/open3d-0.18.0+da239b25-cp310-cp310-manylinux_2_31_x86_64.whl python -m pip install open3d-0.18.0+da239b25-cp310-cp310-manylinux_2_31_x86_64.whl -i https://pypi.tuna.tsinghua.edu.cn/simple -# Unzip compiled customed operator (fused_segment_csr) to conda env directory -wget https://paddle-org.bj.bcebos.com/paddlecfd/envs/fused_segment_csr.tar.gz -tar -xzvf fused_segment_csr.tar.gz -C /root/miniconda3/envs/ppcfd/ - -# Add environment variable -export LD_LIBRARY_PATH=/root/miniconda3/envs/ppcfd/lib/python3.10/site-packages/paddle/libs:$LD_LIBRARY_PATH -export LD_LIBRARY_PATH=/root/miniconda3/envs/ppcfd/lib/python3.10/site-packages/paddle/base:$LD_LIBRARY_PATH -export LD_LIBRARY_PATH=/root/miniconda3/envs/ppcfd/lib:$LD_LIBRARY_PATH +# Compile customed operator to conda environment +wget -nc https://paddle-org.bj.bcebos.com/paddlescience/cmake-3.23.0-linux-x86_64.tar.gz +tar -zxvf cmake-3.23.0-linux-x86_64.tar.gz +rm -f cmake-3.23.0-linux-x86_64.tar.gz +PATH=$PWD/cmake-3.23.0-linux-x86_64/bin:$PATH +cd source/ppfno_op +python -m pip install --no-build-isolation -v . ``` -##### PaddleCFD package installation +##### PaddleCFD package installation (Choose one of the following) ```bash # Install PaddleCFD from sourcecode at PaddleCFD root directory python -m pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple