Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/infinicore/ops/mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from infinicore.ops.add import add
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.rearrange import rearrange
from infinicore.tensor import (
Tensor,
Expand Down Expand Up @@ -76,6 +77,7 @@
"add",
"attention",
"matmul",
"mul",
"rearrange",
"empty",
"empty_like",
Expand Down
9 changes: 9 additions & 0 deletions python/infinicore/ops/mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def mul(input, other, *, out=None):
if out is None:
return Tensor(_infinicore.mul(input._underlying, other._underlying))

_infinicore.mul_(out._underlying, input._underlying, other._underlying)
24 changes: 24 additions & 0 deletions src/infinicore/ops/mul/mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "infinicore/ops/mul.hpp"

namespace infinicore::op {

common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
static common::OpDispatcher<Mul::schema> dispatcher_;
return dispatcher_;
};

void Mul::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b);
}

Tensor mul(Tensor a, Tensor b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
mul_(c, a, b);
return c;
}

void mul_(Tensor c, Tensor a, Tensor b) {
Mul::execute(c, a, b);
}

} // namespace infinicore::op
52 changes: 52 additions & 0 deletions src/infinicore/ops/mul/mul_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/mul.hpp"
#include <infiniop.h>

namespace infinicore::op::mul_impl::infiniop {

thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
100, // capacity
[](infiniopMulDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopMulDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopMul(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
}

static bool registered = []() {
Mul::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::mul_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
Expand All @@ -20,6 +21,7 @@ inline void bind(py::module &m) {
bind_attention(m);
bind_causal_softmax(m);
bind_matmul(m);
bind_mul(m);
bind_rearrange(m);
bind_rms_norm(m);
bind_silu(m);
Expand Down
26 changes: 26 additions & 0 deletions src/infinicore/pybind11/ops/mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/mul.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_mul(py::module &m) {
m.def("mul",
&op::mul,
py::arg("a"),
py::arg("b"),
R"doc(Element-wise multiplication of two tensors.)doc");

m.def("mul_",
&op::mul_,
py::arg("c"),
py::arg("a"),
py::arg("b"),
R"doc(In-place element-wise tensor multiplication.)doc");
}

} // namespace infinicore::ops
142 changes: 142 additions & 0 deletions test/infinicore/ops/mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast

# ==============================================================================
# Operator-specific configuration
# ==============================================================================

# Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
]

# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]

# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}


def build_test_cases():
test_cases = []

for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None

a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)

for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})

a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)

# Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Mul - OUT_OF_PLACE (dtype={dtype})",
)
)

# With explicit output tensor (mul(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=c_spec,
comparison_target="out",
tolerance=tolerance,
description=f"Mul - INPLACE(out) (dtype={dtype})",
)
)

# In-place on first input (mul(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0},
output_spec=None,
comparison_target=0,
tolerance=tolerance,
description=f"Mul - INPLACE(a) (dtype={dtype})",
)
)

# In-place on second input (mul(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1},
output_spec=None,
comparison_target=1,
tolerance=tolerance,
description=f"Mul - INPLACE(b) (dtype={dtype})",
)
)

return test_cases


_TEST_CASES = build_test_cases()


class OpTest(BaseOperatorTest):
"""Mul test with simplified test case parsing"""

def __init__(self):
super().__init__("Mul")

def get_test_cases(self):
return _TEST_CASES

def torch_operator(self, a, b, out=None, **kwargs):
return torch.mul(a, b, out=out)

def infinicore_operator(self, a, b, out=None, **kwargs):
try:
return infinicore.mul(a, b, out=out)
except AttributeError as exc:
raise NotImplementedError("InfiniCore mul operator not available") from exc


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()