From a565b363c29627aa90209e82db6cea17bc587b48 Mon Sep 17 00:00:00 2001 From: zhuyue Date: Wed, 12 Nov 2025 15:45:38 +0800 Subject: [PATCH] Add mul python interface and tests. --- include/infinicore/ops/mul.hpp | 16 +++ python/infinicore/__init__.py | 2 + python/infinicore/ops/mul.py | 9 ++ src/infinicore/ops/mul/mul.cc | 24 +++++ src/infinicore/ops/mul/mul_infiniop.cc | 52 +++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/mul.hpp | 26 +++++ test/infinicore/ops/mul.py | 142 +++++++++++++++++++++++++ 8 files changed, 273 insertions(+) create mode 100644 include/infinicore/ops/mul.hpp create mode 100644 python/infinicore/ops/mul.py create mode 100644 src/infinicore/ops/mul/mul.cc create mode 100644 src/infinicore/ops/mul/mul_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/mul.hpp create mode 100644 test/infinicore/ops/mul.py diff --git a/include/infinicore/ops/mul.hpp b/include/infinicore/ops/mul.hpp new file mode 100644 index 000000000..83416bbd9 --- /dev/null +++ b/include/infinicore/ops/mul.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Mul { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor mul(Tensor a, Tensor b); +void mul_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 8b11ca8a7..fae1e9647 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -30,6 +30,7 @@ from infinicore.ops.add import add from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul +from infinicore.ops.mul import mul from infinicore.ops.rearrange import rearrange from infinicore.tensor import ( Tensor, @@ -76,6 +77,7 @@ "add", "attention", "matmul", + "mul", "rearrange", "empty", "empty_like", diff --git a/python/infinicore/ops/mul.py b/python/infinicore/ops/mul.py new file mode 100644 index 000000000..0371f9486 --- /dev/null +++ b/python/infinicore/ops/mul.py @@ -0,0 +1,9 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mul(input, other, *, out=None): + if out is None: + return Tensor(_infinicore.mul(input._underlying, other._underlying)) + + _infinicore.mul_(out._underlying, input._underlying, other._underlying) diff --git a/src/infinicore/ops/mul/mul.cc b/src/infinicore/ops/mul/mul.cc new file mode 100644 index 000000000..041bb37d4 --- /dev/null +++ b/src/infinicore/ops/mul/mul.cc @@ -0,0 +1,24 @@ +#include "infinicore/ops/mul.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Mul::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Mul::execute(Tensor c, Tensor a, Tensor b) { + dispatcher().lookup(context::getDevice().getType())(c, a, b); +} + +Tensor mul(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); + mul_(c, a, b); + return c; +} + +void mul_(Tensor c, Tensor a, Tensor b) { + Mul::execute(c, a, b); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mul/mul_infiniop.cc b/src/infinicore/ops/mul/mul_infiniop.cc new file mode 100644 index 000000000..57d605547 --- /dev/null +++ b/src/infinicore/ops/mul/mul_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/mul.hpp" +#include + +namespace infinicore::op::mul_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopMulDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a, Tensor b) { + size_t seed = hash_combine(c, b, a); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopMulDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( + context::getInfiniopHandle(), &desc, + c->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopMul( + desc, workspace->data(), workspace_size, + c->data(), a->data(), b->data(), context::getStream())); +} + +static bool registered = []() { + Mul::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::mul_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 0036f49f6..e41c4978c 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -6,6 +6,7 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/matmul.hpp" +#include "ops/mul.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" #include "ops/silu.hpp" @@ -20,6 +21,7 @@ inline void bind(py::module &m) { bind_attention(m); bind_causal_softmax(m); bind_matmul(m); + bind_mul(m); bind_rearrange(m); bind_rms_norm(m); bind_silu(m); diff --git a/src/infinicore/pybind11/ops/mul.hpp b/src/infinicore/pybind11/ops/mul.hpp new file mode 100644 index 000000000..fb8e4144b --- /dev/null +++ b/src/infinicore/pybind11/ops/mul.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "infinicore/ops/mul.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_mul(py::module &m) { + m.def("mul", + &op::mul, + py::arg("a"), + py::arg("b"), + R"doc(Element-wise multiplication of two tensors.)doc"); + + m.def("mul_", + &op::mul_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place element-wise tensor multiplication.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/mul.py b/test/infinicore/ops/mul.py new file mode 100644 index 000000000..c11958324 --- /dev/null +++ b/test/infinicore/ops/mul.py @@ -0,0 +1,142 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import is_broadcast + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (shape, a_strides, b_strides, c_strides) +_TEST_CASES_DATA = [ + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), +] + +# Data types +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + +# Tolerance +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 1e-2}, + infinicore.float32: {"atol": 0, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, +} + + +def build_test_cases(): + test_cases = [] + + for data in _TEST_CASES_DATA: + shape = data[0] + a_strides = data[1] if len(data) > 1 else None + b_strides = data[2] if len(data) > 2 else None + c_strides = data[3] if len(data) > 3 else None + + a_supports_inplace = not is_broadcast(a_strides) + b_supports_inplace = not is_broadcast(b_strides) + c_supports_inplace = not is_broadcast(c_strides) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + a_spec = TensorSpec.from_tensor(shape, a_strides, dtype) + b_spec = TensorSpec.from_tensor(shape, b_strides, dtype) + c_spec = TensorSpec.from_tensor(shape, c_strides, dtype) + + # Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[a_spec, b_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"Mul - OUT_OF_PLACE (dtype={dtype})", + ) + ) + + # With explicit output tensor (mul(a, b, out=c)) + if c_supports_inplace: + test_cases.append( + TestCase( + inputs=[a_spec, b_spec], + kwargs={}, + output_spec=c_spec, + comparison_target="out", + tolerance=tolerance, + description=f"Mul - INPLACE(out) (dtype={dtype})", + ) + ) + + # In-place on first input (mul(a, b, out=a)) + if a_supports_inplace: + test_cases.append( + TestCase( + inputs=[a_spec, b_spec], + kwargs={"out": 0}, + output_spec=None, + comparison_target=0, + tolerance=tolerance, + description=f"Mul - INPLACE(a) (dtype={dtype})", + ) + ) + + # In-place on second input (mul(a, b, out=b)) + if b_supports_inplace: + test_cases.append( + TestCase( + inputs=[a_spec, b_spec], + kwargs={"out": 1}, + output_spec=None, + comparison_target=1, + tolerance=tolerance, + description=f"Mul - INPLACE(b) (dtype={dtype})", + ) + ) + + return test_cases + + +_TEST_CASES = build_test_cases() + + +class OpTest(BaseOperatorTest): + """Mul test with simplified test case parsing""" + + def __init__(self): + super().__init__("Mul") + + def get_test_cases(self): + return _TEST_CASES + + def torch_operator(self, a, b, out=None, **kwargs): + return torch.mul(a, b, out=out) + + def infinicore_operator(self, a, b, out=None, **kwargs): + try: + return infinicore.mul(a, b, out=out) + except AttributeError as exc: + raise NotImplementedError("InfiniCore mul operator not available") from exc + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()