From ecb938c17761cceabf7c38288d1af19c85933923 Mon Sep 17 00:00:00 2001 From: zhuyue Date: Wed, 29 Oct 2025 15:17:32 +0800 Subject: [PATCH] Add causalSoftmax operator Python interface and tests. --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/causal_softmax.hpp | 16 +++ python/infinicore/__init__.py | 2 + python/infinicore/ops/causal_softmax.py | 9 ++ .../ops/causal_softmax/causal_softmax.cc | 32 +++++ .../causal_softmax/causal_softmax_infiniop.cc | 52 ++++++++ src/infinicore/pybind11/ops.hpp | 2 + .../pybind11/ops/causal_softmax.hpp | 24 ++++ test/infinicore/ops/causal_softmax.py | 114 ++++++++++++++++++ xmake/test.lua | 2 +- 10 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 include/infinicore/ops/causal_softmax.hpp create mode 100644 python/infinicore/ops/causal_softmax.py create mode 100644 src/infinicore/ops/causal_softmax/causal_softmax.cc create mode 100644 src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/causal_softmax.hpp create mode 100644 test/infinicore/ops/causal_softmax.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 23d2457a8..b83262c21 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -2,6 +2,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/causal_softmax.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/rearrange.hpp" diff --git a/include/infinicore/ops/causal_softmax.hpp b/include/infinicore/ops/causal_softmax.hpp new file mode 100644 index 000000000..ae40d521c --- /dev/null +++ b/include/infinicore/ops/causal_softmax.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class CausalSoftmax { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor causal_softmax(Tensor input); +void causal_softmax_(Tensor output, Tensor input); +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 4757d7f29..412362617 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -27,6 +27,7 @@ from infinicore.ntops import use_ntops from infinicore.ops.add import add from infinicore.ops.attention import attention +from infinicore.ops.causal_softmax import causal_softmax from infinicore.ops.matmul import matmul from infinicore.ops.rearrange import rearrange from infinicore.ops.rms_norm import rms_norm @@ -71,6 +72,7 @@ # Operations. "add", "attention", + "causal_softmax", "matmul", "rearrange", "rms_norm", diff --git a/python/infinicore/ops/causal_softmax.py b/python/infinicore/ops/causal_softmax.py new file mode 100644 index 000000000..0be90d357 --- /dev/null +++ b/python/infinicore/ops/causal_softmax.py @@ -0,0 +1,9 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def causal_softmax(input, *, out=None): + if out is None: + return Tensor(_infinicore.causal_softmax(input._underlying)) + + _infinicore.causal_softmax_(out._underlying, input._underlying) diff --git a/src/infinicore/ops/causal_softmax/causal_softmax.cc b/src/infinicore/ops/causal_softmax/causal_softmax.cc new file mode 100644 index 000000000..2a8d666c0 --- /dev/null +++ b/src/infinicore/ops/causal_softmax/causal_softmax.cc @@ -0,0 +1,32 @@ +#include "infinicore/ops/causal_softmax.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &CausalSoftmax::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void CausalSoftmax::execute(Tensor output, Tensor input) { + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input); +} + +Tensor causal_softmax(Tensor input) { + Shape shape = input->shape(); + auto output = Tensor::empty(shape, input->dtype(), input->device()); + causal_softmax_(output, input); + return output; +} + +void causal_softmax_(Tensor output, Tensor input) { + CausalSoftmax::execute(output, input); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc new file mode 100644 index 000000000..33d4ed287 --- /dev/null +++ b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/causal_softmax.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::causal_softmax_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopCausalSoftmaxDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output, Tensor input) { + size_t seed = hash_combine(output, input); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopCausalSoftmaxDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( + context::getInfiniopHandle(), &desc, + output->desc(), input->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopCausalSoftmax( + desc, workspace->data(), workspace_size, + output->data(), input->data(), context::getStream())); +} + +static bool registered = []() { + CausalSoftmax::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::causal_softmax_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3cfef16ac..c7d776cca 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/causal_softmax.hpp" #include "ops/matmul.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" @@ -15,6 +16,7 @@ namespace infinicore::ops { inline void bind(py::module &m) { bind_add(m); bind_attention(m); + bind_causal_softmax(m); bind_matmul(m); bind_rearrange(m); bind_rms_norm(m); diff --git a/src/infinicore/pybind11/ops/causal_softmax.hpp b/src/infinicore/pybind11/ops/causal_softmax.hpp new file mode 100644 index 000000000..926a96d90 --- /dev/null +++ b/src/infinicore/pybind11/ops/causal_softmax.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/causal_softmax.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_causal_softmax(py::module &m) { + m.def("causal_softmax", + &op::causal_softmax, + py::arg("input"), + R"doc(Causal softmax activation function.)doc"); + + m.def("causal_softmax_", + &op::causal_softmax_, + py::arg("output"), + py::arg("input"), + R"doc(In-place causal softmax activation function.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/causal_softmax.py b/test/infinicore/ops/causal_softmax.py new file mode 100644 index 000000000..e8817dd13 --- /dev/null +++ b/test/infinicore/ops/causal_softmax.py @@ -0,0 +1,114 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (operation_mode, shape, input_strides, output_strides) +# Causal softmax is a single-input function that applies causal masking before softmax +_TEST_CASES_DATA = [ + # Basic 2D causal softmax + (TestCase.BOTH, (3, 3), None, None), + (TestCase.BOTH, (32, 512), None, None), + # Strided tensors + (TestCase.BOTH, (32, 512), (1024, 1), (1024, 1)), + # 3D causal softmax + (TestCase.BOTH, (32, 5, 5), None, None), + (TestCase.BOTH, (32, 20, 512), None, None), + (TestCase.BOTH, (32, 20, 512), (20480, 512, 1), None), + (TestCase.BOTH, (28, 15, 15), None, None), +] + + +def parse_test_cases(data): + """ + Parse causal_softmax test case data according to format: + (operation_mode, shape, input_strides, output_strides) + """ + operation_mode = data[0] + shape = data[1] + input_strides = data[2] if len(data) > 2 else None + output_strides = data[3] if len(data) > 3 else None + + # Create input specifications + inputs = [] + + # Tensor input + if input_strides is not None: + inputs.append(TensorSpec.from_strided_tensor(shape, input_strides)) + else: + inputs.append(TensorSpec.from_tensor(shape)) + + # Output tensor + if output_strides is not None: + output = TensorSpec.from_strided_tensor(shape, output_strides) + else: + output = TensorSpec.from_tensor(shape) + + return TestCase(operation_mode, inputs, output) + + +# Parse test cases +_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA] + +# Data types +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + +# Tolerance +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, + infinicore.float32: {"atol": 3e-5, "rtol": 1e-5}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2}, +} + + +class OpTest(BaseOperatorTest): + """CausalSoftmax test with simplified test case parsing""" + + def __init__(self): + super().__init__("CausalSoftmax") + + def get_test_cases(self): + return _TEST_CASES + + def get_tensor_dtypes(self): + return _TENSOR_DTYPES + + def get_tolerance_map(self): + return _TOLERANCE_MAP + + def torch_operator(self, input, out=None, **kwargs): + # Causal softmax implementation: apply causal mask then softmax + dtype = input.dtype + + # Create causal mask + mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) + masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) + + result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype) + + if out is not None: + out.copy_(result) + return out + return result + + def infinicore_operator(self, input, out=None, **kwargs): + return infinicore.causal_softmax(input, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/xmake/test.lua b/xmake/test.lua index 550b9afe6..0e27dc572 100644 --- a/xmake/test.lua +++ b/xmake/test.lua @@ -85,7 +85,7 @@ target("infinicore-test") add_files(os.projectdir().."/src/infinicore/context/*.cc") add_files(os.projectdir().."/src/infinicore/context/*/*.cc") add_files(os.projectdir().."/src/infinicore/tensor/*.cc") - add_files(os.projectdir().."/src/infinicore/op/*/*.cc") + add_files(os.projectdir().."/src/infinicore/ops/*/*.cc") add_files(os.projectdir().."/src/infinicore-test/*.cc")