diff --git a/include/infinicore/ops/gemm.hpp b/include/infinicore/ops/gemm.hpp new file mode 100644 index 000000000..6562f087d --- /dev/null +++ b/include/infinicore/ops/gemm.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Gemm { +public: + using schema = void (*)(Tensor, Tensor, Tensor, float, float); + static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta); + static common::OpDispatcher &dispatcher(); +}; + +Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f); +void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/matmul.hpp b/include/infinicore/ops/matmul.hpp index 4e9f370ae..2a641aa39 100644 --- a/include/infinicore/ops/matmul.hpp +++ b/include/infinicore/ops/matmul.hpp @@ -4,13 +4,8 @@ #include "common/op.hpp" namespace infinicore::op { -class Matmul { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; Tensor matmul(Tensor a, Tensor b); void matmul_(Tensor c, Tensor a, Tensor b); + } // namespace infinicore::op diff --git a/python/infinicore/ops/gemm.py b/python/infinicore/ops/gemm.py new file mode 100644 index 000000000..6bd1f89c2 --- /dev/null +++ b/python/infinicore/ops/gemm.py @@ -0,0 +1,9 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def gemm(input, other, alpha=1.0, beta=0.0, *, out=None): + if out is None: + return Tensor(_infinicore.gemm(input._underlying, other._underlying, alpha, beta)) + + _infinicore.gemm_(out._underlying, input._underlying, other._underlying, alpha, beta) diff --git a/src/infinicore/ops/gemm/gemm.cc b/src/infinicore/ops/gemm/gemm.cc new file mode 100644 index 000000000..f45d1c6fa --- /dev/null +++ b/src/infinicore/ops/gemm/gemm.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/gemm.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Gemm::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { + dispatcher().lookup(context::getDevice().getType())(c, a, b, alpha, beta); +} + +Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { + Shape shape = a->shape(); + Size size = a->ndim(); + shape[size - 1] = b->size(size - 1); + auto c = Tensor::empty(shape, a->dtype(), a->device()); + gemm_(c, a, b, alpha, beta); + return c; +} + +void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) { + Gemm::execute(c, a, b, alpha, beta); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/matmul/matmul_infiniop.cc b/src/infinicore/ops/gemm/gemm_infiniop.cc similarity index 76% rename from src/infinicore/ops/matmul/matmul_infiniop.cc rename to src/infinicore/ops/gemm/gemm_infiniop.cc index 3bd69c3f8..f0414e5c1 100644 --- a/src/infinicore/ops/matmul/matmul_infiniop.cc +++ b/src/infinicore/ops/gemm/gemm_infiniop.cc @@ -1,10 +1,10 @@ #include "../../utils.hpp" #include "infinicore/common/hash.hpp" #include "infinicore/ops/common/cache.hpp" -#include "infinicore/ops/matmul.hpp" +#include "infinicore/ops/gemm.hpp" #include -namespace infinicore::op::matmul_impl::infiniop { +namespace infinicore::op::gemm_impl::infiniop { thread_local common::OpCache caches( 100, // capacity @@ -15,8 +15,8 @@ thread_local common::OpCache caches( } }); -void calculate(Tensor c, Tensor a, Tensor b) { - size_t seed = hash_combine(c, b, a); +void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) { + size_t seed = hash_combine(c, b, a, alpha, beta); auto device_type = context::getDevice().getType(); auto device_index = context::getDevice().getIndex(); @@ -41,12 +41,12 @@ void calculate(Tensor c, Tensor a, Tensor b) { INFINICORE_CHECK_ERROR(infiniopGemm( desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), 1.f, 0.f, context::getStream())); + c->data(), a->data(), b->data(), alpha, beta, context::getStream())); } static bool registered = []() { - Matmul::dispatcher().registerAll(&calculate, false); + Gemm::dispatcher().registerAll(&calculate, false); return true; }(); -} // namespace infinicore::op::matmul_impl::infiniop +} // namespace infinicore::op::gemm_impl::infiniop diff --git a/src/infinicore/ops/matmul/matmul.cc b/src/infinicore/ops/matmul/matmul.cc index d04e98268..571848935 100644 --- a/src/infinicore/ops/matmul/matmul.cc +++ b/src/infinicore/ops/matmul/matmul.cc @@ -1,26 +1,13 @@ #include "infinicore/ops/matmul.hpp" +#include "infinicore/ops/gemm.hpp" namespace infinicore::op { -common::OpDispatcher &Matmul::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void Matmul::execute(Tensor c, Tensor a, Tensor b) { - dispatcher().lookup(context::getDevice().getType())(c, a, b); -} - Tensor matmul(Tensor a, Tensor b) { - Shape shape = a->shape(); - Size size = a->ndim(); - shape[size - 1] = b->size(size - 1); - auto c = Tensor::empty(shape, a->dtype(), a->device()); - matmul_(c, a, b); - return c; + return gemm(a, b, 1.0f, 0.0f); } void matmul_(Tensor c, Tensor a, Tensor b) { - Matmul::execute(c, a, b); + Gemm::execute(c, a, b, 1.0f, 0.0f); } } // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 0036f49f6..fc2725487 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -5,6 +5,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/gemm.hpp" #include "ops/matmul.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" @@ -19,6 +20,7 @@ inline void bind(py::module &m) { bind_add(m); bind_attention(m); bind_causal_softmax(m); + bind_gemm(m); bind_matmul(m); bind_rearrange(m); bind_rms_norm(m); diff --git a/src/infinicore/pybind11/ops/gemm.hpp b/src/infinicore/pybind11/ops/gemm.hpp new file mode 100644 index 000000000..27597ecbe --- /dev/null +++ b/src/infinicore/pybind11/ops/gemm.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "infinicore/ops/gemm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_gemm(py::module &m) { + m.def("gemm", + &op::gemm, + py::arg("a"), + py::arg("b"), + py::arg("alpha") = 1.0f, + py::arg("beta") = 0.0f, + R"doc(General matrix multiplication: C = alpha * A @ B + beta * C.)doc"); + + m.def("gemm_", + &op::gemm_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + py::arg("alpha"), + py::arg("beta"), + R"doc(In-place general matrix multiplication.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/gemm.py b/test/infinicore/ops/gemm.py new file mode 100644 index 000000000..4cfe8d894 --- /dev/null +++ b/test/infinicore/ops/gemm.py @@ -0,0 +1,148 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from infinicore.ops.gemm import gemm as ic_gemm +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.tensor import TensorInitializer +from framework.runner import GenericTestRunner + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides) +# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n) +# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n) +# Aligned with test/infiniop/gemm.py shapes/strides and per-case alpha/beta +# Each item: (alpha, beta, operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides) +_TEST_CASES_DATA = [ + # (1) alpha=1.0, beta=0.0, a=(1,2048), b=(2048,2048), c=(1,2048) + (1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, None, None, None), + # (2) alpha=1.0, beta=0.0, a=(2,4,2048), b=(2,2048,2048), c=(2,4,2048) + (1.0, 0.0, TestCase.BOTH, 2, 4, 2048, 2048, None, None, None), + # (3) alpha=1.0, beta=0.0, strided (4096,1) + (1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)), + # (4) alpha=1.0, beta=1.0, only meaningful for IN_PLACE (needs existing C) + (1.0, 1.0, TestCase.IN_PLACE, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)), + # (5) alpha=1.0/8.0, beta=0.0, a=(4,48,64), b=(4,64,6), c=(4,48,6) + (1.0 / 8.0, 0.0, TestCase.BOTH, 4, 48, 6, 64, None, None, None), +] + + +def parse_test_cases(data): + """ + Parse gemm test case data according to format: + (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides) + """ + alpha = data[0] + beta = data[1] + operation_mode = data[2] + nbatch = data[3] + m, n, k = data[4], data[5], data[6] + a_strides = data[7] if len(data) > 7 else None + b_strides = data[8] if len(data) > 8 else None + c_strides = data[9] if len(data) > 9 else None + + # Determine shapes based on batch dimension + if nbatch is None: + a_shape = (m, k) + b_shape = (k, n) + c_shape = (m, n) + else: + a_shape = (nbatch, m, k) + b_shape = (nbatch, k, n) + c_shape = (nbatch, m, n) + + # Create input specifications + inputs = [] + + # Tensor a + if a_strides is not None: + inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides)) + else: + inputs.append(TensorSpec.from_tensor(a_shape)) + + # Tensor b + if b_strides is not None: + inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides)) + else: + inputs.append(TensorSpec.from_tensor(b_shape)) + + # Output tensor + if c_strides is not None: + output = TensorSpec.from_strided_tensor( + c_shape, + c_strides, + init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM, + ) + else: + output = TensorSpec.from_tensor( + c_shape, + init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM, + ) + + return TestCase(operation_mode, inputs, output, alpha=alpha, beta=beta) + + +# Parse test cases +_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA] + +# Data types +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + +# Tolerance +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 1e-2}, + infinicore.float32: {"atol": 0, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, +} + + +class OpTest(BaseOperatorTest): + """GEMM test with simplified test case parsing + + Note: We test default alpha=1.0 and beta=0.0 so it should match torch.matmul. + """ + + def __init__(self): + super().__init__("Gemm") + + def get_test_cases(self): + return _TEST_CASES + + def get_tensor_dtypes(self): + return _TENSOR_DTYPES + + def get_tolerance_map(self): + return _TOLERANCE_MAP + + def torch_operator(self, a, b, out=None, **kwargs): + alpha = kwargs.get("alpha", 1.0) + beta = kwargs.get("beta", 0.0) + mm = torch.matmul(a, b) + if out is None: + return mm.mul(alpha) + out.mul_(beta) + out.add_(mm, alpha=alpha) + return out + + def infinicore_operator(self, a, b, out=None, **kwargs): + alpha = kwargs.get("alpha", 1.0) + beta = kwargs.get("beta", 0.0) + if out is None: + return ic_gemm(a, b, alpha=alpha, beta=beta) + return ic_gemm(a, b, alpha=alpha, beta=beta, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()