Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/infinicore/ops/linear.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include "common/op.hpp"

namespace infinicore::op {

Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias);

void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias);

} // namespace infinicore::op
9 changes: 2 additions & 7 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from .causal_softmax import causal_softmax
from .linear import linear
from .random_sample import random_sample
from .rms_norm import rms_norm
from .silu import silu
from .swiglu import swiglu

__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"silu",
"swiglu",
]
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu", "linear"]
25 changes: 25 additions & 0 deletions python/infinicore/nn/functional/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["linear"]


def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""

if out is None:
return Tensor(
_infinicore.linear(
input._underlying,
weight._underlying,
None if bias is None else bias._underlying,
)
)

_infinicore.linear_(
out._underlying,
input._underlying,
weight._underlying,
None if bias is None else bias._underlying,
)
return out
57 changes: 57 additions & 0 deletions src/infinicore/ops/linear/linear.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "infinicore/ops/linear.hpp"
#include "infinicore/ops/add.hpp"
#include "infinicore/ops/matmul.hpp"

namespace infinicore::op {

Tensor linear(Tensor input,
Tensor weight,
std::optional<Tensor> bias) {

Size ndim = input->ndim();
Size out_features = weight->shape()[0];

// Assign memory to out variables
auto output_shape = input->shape();
output_shape[ndim - 1] = out_features;
auto out = Tensor::empty(output_shape, input->dtype(), input->device());

// Inplace Calculate
linear_(out, input, weight, bias);
return out;
}

void linear_(Tensor out,
Tensor input,
Tensor weight,
std::optional<Tensor> bias) {

auto weight_shape = weight->shape();
Size out_features = weight_shape[0];
Size in_features = weight_shape[1];

Size ndim = input->ndim();
assert(out->ndim() == ndim);

// Calculate the number of features
Size N = 1;
auto input_shape = input->shape();
for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i];
}

// linear transformation
Tensor out_view = out->view({N, out_features});
matmul_(out_view,
input->view({N, in_features}),
weight->permute({1, 0}));

// Add bias
if (bias.has_value()) {
add_(out_view,
out_view,
bias.value()->as_strided({N, out_features}, {0, 1}));
}
}

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/random_sample.hpp"
Expand All @@ -22,6 +23,7 @@ inline void bind(py::module &m) {
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_rearrange(m);
Expand Down
52 changes: 52 additions & 0 deletions src/infinicore/pybind11/ops/linear.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include "infinicore/ops/linear.hpp"

#include <pybind11/pybind11.h>

namespace py = pybind11;

namespace infinicore::ops {

Tensor py_linear(Tensor input,
Tensor weight,
pybind11::object bias) {
std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}
return op::linear(input, weight, bias_tensor);
}

void py_linear_(Tensor out,
Tensor input,
Tensor weight,
pybind11::object bias) {

std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}

op::linear_(out, input, weight, bias_tensor);
}

inline void bind_linear(py::module &m) {

m.def("linear",
&ops::py_linear,
py::arg("input"),
py::arg("weight"),
py::arg("bias") = py::none(),
R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc");

m.def("linear_",
&ops::py_linear_,
py::arg("out"),
py::arg("input"),
py::arg("weight"),
py::arg("bias") = py::none(),
R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc");
}

} // namespace infinicore::ops
2 changes: 1 addition & 1 deletion test/infinicore/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,6 @@ def __str__(self):
else:
strides_str = f", strides={self.strides}" if self.strides else ""
dtype_str = (
f", {str(self.dtype).replace("infinicore.", "")}" if self.dtype else ""
f", {str(self.dtype).replace('infinicore.', '')}" if self.dtype else ""
)
return f"{name_str}tensor{self.shape}{strides_str}{dtype_str}"
137 changes: 137 additions & 0 deletions test/infinicore/ops/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast

import infinicore

# ==============================================================================
# Operator-specific configuration
# ==============================================================================
_TEST_CASES_DATA = [
# bs, n, in_features, out_features, bias
(1, 5, 2048, 5632, True, None, None, None),
(1, 1, 2048, 32000, False, None, None, None),
(2, 5, 2048, 5632, True, None, None, None),
(2, 5, 256, 2048, False, None, None, None),
(None, 5, 256, 2048, False, None, None, None),
(None, 1, 2048, 5632, True, None, None, None),
]

# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}

# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]


def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for linear operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []

for data in _TEST_CASES_DATA:
bs = data[0]
n, in_features, out_features = data[1], data[2], data[3]
bias = data[4]
input_strides = data[5] if len(data) > 5 else None
weight_strides = data[6] if len(data) > 6 else None
out_strides = data[7] if len(data) > 7 else None

# Determine shapes based on batch dimension
if bs is None:
input_shape = (n, in_features)
weight_shape = (out_features, in_features)
out_shape = (n, out_features)
else:
input_shape = (bs, n, in_features)
weight_shape = (out_features, in_features)
out_shape = (bs, n, out_features)

if bias is True:
bias_shape = (out_features,)
else:
bias_shape = None

# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(out_shape)

# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})

# Create typed tensor specs
input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype)
weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype)
out_spec = TensorSpec.from_tensor(out_shape, out_strides, dtype)

if bias_shape is not None:
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype)
else:
bias_spec = None

# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[input_spec, weight_spec, bias_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Linear - OUT_OF_PLACE",
)
)

# Test Case 2: In-place with explicit output tensor (Linear(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec, weight_spec, bias_spec],
kwargs=None,
output_spec=out_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Linear - INPLACE(out)",
)
)

return test_cases


class OpTest(BaseOperatorTest):
"""Linear operator test with simplified implementation"""

def __init__(self):
super().__init__("Linear")

def get_test_cases(self):
return parse_test_cases()

def torch_operator(self, *args, **kwargs):
"""PyTorch linear implementation"""
return torch.nn.functional.linear(*args, **kwargs)

def infinicore_operator(self, *args, **kwargs):
"""InfiniCore linear implementation"""
return infinicore.nn.functional.linear(*args, **kwargs)


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()