diff --git a/include/infinicore/ops/linear.hpp b/include/infinicore/ops/linear.hpp new file mode 100644 index 000000000..81cb61986 --- /dev/null +++ b/include/infinicore/ops/linear.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "common/op.hpp" + +namespace infinicore::op { + +Tensor linear(Tensor input, Tensor weight, std::optional bias); + +void linear_(Tensor out, Tensor input, Tensor weight, std::optional bias); + +} // namespace infinicore::op diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 0bfdd4230..d35257e2f 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,13 +1,8 @@ from .causal_softmax import causal_softmax +from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm from .silu import silu from .swiglu import swiglu -__all__ = [ - "causal_softmax", - "random_sample", - "rms_norm", - "silu", - "swiglu", -] +__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu", "linear"] diff --git a/python/infinicore/nn/functional/linear.py b/python/infinicore/nn/functional/linear.py new file mode 100644 index 000000000..04c78b8e9 --- /dev/null +++ b/python/infinicore/nn/functional/linear.py @@ -0,0 +1,25 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["linear"] + + +def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor: + r"""Applies a linear transformation to the incoming data: y=xA^T+b.""" + + if out is None: + return Tensor( + _infinicore.linear( + input._underlying, + weight._underlying, + None if bias is None else bias._underlying, + ) + ) + + _infinicore.linear_( + out._underlying, + input._underlying, + weight._underlying, + None if bias is None else bias._underlying, + ) + return out diff --git a/src/infinicore/ops/linear/linear.cc b/src/infinicore/ops/linear/linear.cc new file mode 100644 index 000000000..cd766195e --- /dev/null +++ b/src/infinicore/ops/linear/linear.cc @@ -0,0 +1,57 @@ +#include "infinicore/ops/linear.hpp" +#include "infinicore/ops/add.hpp" +#include "infinicore/ops/matmul.hpp" + +namespace infinicore::op { + +Tensor linear(Tensor input, + Tensor weight, + std::optional bias) { + + Size ndim = input->ndim(); + Size out_features = weight->shape()[0]; + + // Assign memory to out variables + auto output_shape = input->shape(); + output_shape[ndim - 1] = out_features; + auto out = Tensor::empty(output_shape, input->dtype(), input->device()); + + // Inplace Calculate + linear_(out, input, weight, bias); + return out; +} + +void linear_(Tensor out, + Tensor input, + Tensor weight, + std::optional bias) { + + auto weight_shape = weight->shape(); + Size out_features = weight_shape[0]; + Size in_features = weight_shape[1]; + + Size ndim = input->ndim(); + assert(out->ndim() == ndim); + + // Calculate the number of features + Size N = 1; + auto input_shape = input->shape(); + for (size_t i = 0; i < ndim - 1; ++i) { + N *= input_shape[i]; + } + + // linear transformation + Tensor out_view = out->view({N, out_features}); + matmul_(out_view, + input->view({N, in_features}), + weight->permute({1, 0})); + + // Add bias + if (bias.has_value()) { + add_(out_view, + out_view, + bias.value()->as_strided({N, out_features}, {0, 1})); + } +} + +} // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 69450060f..98adb88dd 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -5,6 +5,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" #include "ops/random_sample.hpp" @@ -22,6 +23,7 @@ inline void bind(py::module &m) { bind_attention(m); bind_causal_softmax(m); bind_random_sample(m); + bind_linear(m); bind_matmul(m); bind_mul(m); bind_rearrange(m); diff --git a/src/infinicore/pybind11/ops/linear.hpp b/src/infinicore/pybind11/ops/linear.hpp new file mode 100644 index 000000000..8ddd2af98 --- /dev/null +++ b/src/infinicore/pybind11/ops/linear.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include "infinicore/ops/linear.hpp" + +#include + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_linear(Tensor input, + Tensor weight, + pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + return op::linear(input, weight, bias_tensor); +} + +void py_linear_(Tensor out, + Tensor input, + Tensor weight, + pybind11::object bias) { + + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + + op::linear_(out, input, weight, bias_tensor); +} + +inline void bind_linear(py::module &m) { + + m.def("linear", + &ops::py_linear, + py::arg("input"), + py::arg("weight"), + py::arg("bias") = py::none(), + R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc"); + + m.def("linear_", + &ops::py_linear_, + py::arg("out"), + py::arg("input"), + py::arg("weight"), + py::arg("bias") = py::none(), + R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/framework/tensor.py b/test/infinicore/framework/tensor.py index d8eb9c068..c098ea6b6 100644 --- a/test/infinicore/framework/tensor.py +++ b/test/infinicore/framework/tensor.py @@ -351,6 +351,6 @@ def __str__(self): else: strides_str = f", strides={self.strides}" if self.strides else "" dtype_str = ( - f", {str(self.dtype).replace("infinicore.", "")}" if self.dtype else "" + f", {str(self.dtype).replace('infinicore.', '')}" if self.dtype else "" ) return f"{name_str}tensor{self.shape}{strides_str}{dtype_str}" diff --git a/test/infinicore/ops/linear.py b/test/infinicore/ops/linear.py new file mode 100644 index 000000000..774dff64b --- /dev/null +++ b/test/infinicore/ops/linear.py @@ -0,0 +1,137 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import is_broadcast + +import infinicore + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== +_TEST_CASES_DATA = [ + # bs, n, in_features, out_features, bias + (1, 5, 2048, 5632, True, None, None, None), + (1, 1, 2048, 32000, False, None, None, None), + (2, 5, 2048, 5632, True, None, None, None), + (2, 5, 256, 2048, False, None, None, None), + (None, 5, 256, 2048, False, None, None, None), + (None, 1, 2048, 5632, True, None, None, None), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 1e-2}, + infinicore.float32: {"atol": 0, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for linear operation. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + bs = data[0] + n, in_features, out_features = data[1], data[2], data[3] + bias = data[4] + input_strides = data[5] if len(data) > 5 else None + weight_strides = data[6] if len(data) > 6 else None + out_strides = data[7] if len(data) > 7 else None + + # Determine shapes based on batch dimension + if bs is None: + input_shape = (n, in_features) + weight_shape = (out_features, in_features) + out_shape = (n, out_features) + else: + input_shape = (bs, n, in_features) + weight_shape = (out_features, in_features) + out_shape = (bs, n, out_features) + + if bias is True: + bias_shape = (out_features,) + else: + bias_shape = None + + # Check if tensors support in-place operations + c_supports_inplace = not is_broadcast(out_shape) + + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + # Create typed tensor specs + input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype) + weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype) + out_spec = TensorSpec.from_tensor(out_shape, out_strides, dtype) + + if bias_shape is not None: + bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype) + else: + bias_spec = None + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[input_spec, weight_spec, bias_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"Linear - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensor (Linear(a, b, out=c)) + if c_supports_inplace: + test_cases.append( + TestCase( + inputs=[input_spec, weight_spec, bias_spec], + kwargs=None, + output_spec=out_spec, # Specify the output tensor spec + comparison_target="out", + tolerance=tolerance, + description=f"Linear - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """Linear operator test with simplified implementation""" + + def __init__(self): + super().__init__("Linear") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + """PyTorch linear implementation""" + return torch.nn.functional.linear(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + """InfiniCore linear implementation""" + return infinicore.nn.functional.linear(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()