Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/infinicore/ops/gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class Gemm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);

} // namespace infinicore::op
7 changes: 1 addition & 6 deletions include/infinicore/ops/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
#include "common/op.hpp"

namespace infinicore::op {
class Matmul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor matmul(Tensor a, Tensor b);
void matmul_(Tensor c, Tensor a, Tensor b);

} // namespace infinicore::op
9 changes: 9 additions & 0 deletions python/infinicore/ops/gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def gemm(input, other, alpha=1.0, beta=0.0, *, out=None):
if out is None:
return Tensor(_infinicore.gemm(input._underlying, other._underlying, alpha, beta))

_infinicore.gemm_(out._underlying, input._underlying, other._underlying, alpha, beta)
27 changes: 27 additions & 0 deletions src/infinicore/ops/gemm/gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "infinicore/ops/gemm.hpp"

namespace infinicore::op {

common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
static common::OpDispatcher<Gemm::schema> dispatcher_;
return dispatcher_;
};

void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
dispatcher().lookup(context::getDevice().getType())(c, a, b, alpha, beta);
}

Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
Shape shape = a->shape();
Size size = a->ndim();
shape[size - 1] = b->size(size - 1);
auto c = Tensor::empty(shape, a->dtype(), a->device());
gemm_(c, a, b, alpha, beta);
return c;
}

void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
Gemm::execute(c, a, b, alpha, beta);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/matmul.hpp"
#include "infinicore/ops/gemm.hpp"
#include <infiniop.h>

namespace infinicore::op::matmul_impl::infiniop {
namespace infinicore::op::gemm_impl::infiniop {

thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
100, // capacity
Expand All @@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
}
});

void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, b, a, alpha, beta);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
Expand All @@ -41,12 +41,12 @@ void calculate(Tensor c, Tensor a, Tensor b) {

INFINICORE_CHECK_ERROR(infiniopGemm(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), 1.f, 0.f, context::getStream()));
c->data(), a->data(), b->data(), alpha, beta, context::getStream()));
}

static bool registered = []() {
Matmul::dispatcher().registerAll(&calculate, false);
Gemm::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::matmul_impl::infiniop
} // namespace infinicore::op::gemm_impl::infiniop
19 changes: 3 additions & 16 deletions src/infinicore/ops/matmul/matmul.cc
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
#include "infinicore/ops/matmul.hpp"
#include "infinicore/ops/gemm.hpp"

namespace infinicore::op {

common::OpDispatcher<Matmul::schema> &Matmul::dispatcher() {
static common::OpDispatcher<Matmul::schema> dispatcher_;
return dispatcher_;
};

void Matmul::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b);
}

Tensor matmul(Tensor a, Tensor b) {
Shape shape = a->shape();
Size size = a->ndim();
shape[size - 1] = b->size(size - 1);
auto c = Tensor::empty(shape, a->dtype(), a->device());
matmul_(c, a, b);
return c;
return gemm(a, b, 1.0f, 0.0f);
}

void matmul_(Tensor c, Tensor a, Tensor b) {
Matmul::execute(c, a, b);
Gemm::execute(c, a, b, 1.0f, 0.0f);
}
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/gemm.hpp"
#include "ops/matmul.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
Expand All @@ -19,6 +20,7 @@ inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_causal_softmax(m);
bind_gemm(m);
bind_matmul(m);
bind_rearrange(m);
bind_rms_norm(m);
Expand Down
30 changes: 30 additions & 0 deletions src/infinicore/pybind11/ops/gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/gemm.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_gemm(py::module &m) {
m.def("gemm",
&op::gemm,
py::arg("a"),
py::arg("b"),
py::arg("alpha") = 1.0f,
py::arg("beta") = 0.0f,
R"doc(General matrix multiplication: C = alpha * A @ B + beta * C.)doc");

m.def("gemm_",
&op::gemm_,
py::arg("c"),
py::arg("a"),
py::arg("b"),
py::arg("alpha"),
py::arg("beta"),
R"doc(In-place general matrix multiplication.)doc");
}

} // namespace infinicore::ops
148 changes: 148 additions & 0 deletions test/infinicore/ops/gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
from infinicore.ops.gemm import gemm as ic_gemm
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.tensor import TensorInitializer
from framework.runner import GenericTestRunner

# ==============================================================================
# Operator-specific configuration
# ==============================================================================

# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n)
# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n)
# Aligned with test/infiniop/gemm.py shapes/strides and per-case alpha/beta
# Each item: (alpha, beta, operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
# (1) alpha=1.0, beta=0.0, a=(1,2048), b=(2048,2048), c=(1,2048)
(1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, None, None, None),
# (2) alpha=1.0, beta=0.0, a=(2,4,2048), b=(2,2048,2048), c=(2,4,2048)
(1.0, 0.0, TestCase.BOTH, 2, 4, 2048, 2048, None, None, None),
# (3) alpha=1.0, beta=0.0, strided (4096,1)
(1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
# (4) alpha=1.0, beta=1.0, only meaningful for IN_PLACE (needs existing C)
(1.0, 1.0, TestCase.IN_PLACE, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
# (5) alpha=1.0/8.0, beta=0.0, a=(4,48,64), b=(4,64,6), c=(4,48,6)
(1.0 / 8.0, 0.0, TestCase.BOTH, 4, 48, 6, 64, None, None, None),
]


def parse_test_cases(data):
"""
Parse gemm test case data according to format:
(operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
"""
alpha = data[0]
beta = data[1]
operation_mode = data[2]
nbatch = data[3]
m, n, k = data[4], data[5], data[6]
a_strides = data[7] if len(data) > 7 else None
b_strides = data[8] if len(data) > 8 else None
c_strides = data[9] if len(data) > 9 else None

# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)

# Create input specifications
inputs = []

# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(a_shape))

# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(b_shape))

# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(
c_shape,
c_strides,
init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM,
)
else:
output = TensorSpec.from_tensor(
c_shape,
init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM,
)

return TestCase(operation_mode, inputs, output, alpha=alpha, beta=beta)


# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]

# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]

# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}


class OpTest(BaseOperatorTest):
"""GEMM test with simplified test case parsing

Note: We test default alpha=1.0 and beta=0.0 so it should match torch.matmul.
"""

def __init__(self):
super().__init__("Gemm")

def get_test_cases(self):
return _TEST_CASES

def get_tensor_dtypes(self):
return _TENSOR_DTYPES

def get_tolerance_map(self):
return _TOLERANCE_MAP

def torch_operator(self, a, b, out=None, **kwargs):
alpha = kwargs.get("alpha", 1.0)
beta = kwargs.get("beta", 0.0)
mm = torch.matmul(a, b)
if out is None:
return mm.mul(alpha)
out.mul_(beta)
out.add_(mm, alpha=alpha)
return out

def infinicore_operator(self, a, b, out=None, **kwargs):
alpha = kwargs.get("alpha", 1.0)
beta = kwargs.get("beta", 0.0)
if out is None:
return ic_gemm(a, b, alpha=alpha, beta=beta)
return ic_gemm(a, b, alpha=alpha, beta=beta, out=out)


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()