From a311e9c8fda7dcc69387d21ae6f83613b5945584 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 13 Nov 2025 14:04:47 +0800 Subject: [PATCH 1/2] issue/591 infinicore.narrow --- python/infinicore/__init__.py | 2 + python/infinicore/ops/narrow.py | 5 ++ python/infinicore/tensor.py | 10 ++-- src/infinicore/pybind11/tensor.hpp | 2 +- test/infinicore/tensor/narrow.py | 95 ++++++++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 6 deletions(-) create mode 100644 python/infinicore/ops/narrow.py create mode 100644 test/infinicore/tensor/narrow.py diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index fae1e9647..fdbd65ecf 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -31,6 +31,7 @@ from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul +from infinicore.ops.narrow import narrow from infinicore.ops.rearrange import rearrange from infinicore.tensor import ( Tensor, @@ -78,6 +79,7 @@ "attention", "matmul", "mul", + "narrow", "rearrange", "empty", "empty_like", diff --git a/python/infinicore/ops/narrow.py b/python/infinicore/ops/narrow.py new file mode 100644 index 000000000..3d605a77b --- /dev/null +++ b/python/infinicore/ops/narrow.py @@ -0,0 +1,5 @@ +from infinicore.tensor import Tensor + + +def narrow(input: Tensor, dim: int, start: int, length: int) -> Tensor: + return Tensor(input._underlying.narrow(dim, start, length)) diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index b4eada011..5c0fdd798 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -52,8 +52,8 @@ def numel(self): def is_contiguous(self): return self._underlying.is_contiguous() - def is_is_pinned(self): - return self._underlying.is_is_pinned() + def is_pinned(self): + return self._underlying.is_pinned() def copy_(self, src): self._underlying.copy_(src._underlying) @@ -63,12 +63,12 @@ def to(self, *args, **kwargs): self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs) ) - def as_strided(self, size, stride): - return Tensor(self._underlying.as_strided(size, stride)) - def contiguous(self): return Tensor(self._underlying.contiguous()) + def as_strided(self, size, stride): + return Tensor(self._underlying.as_strided(size, stride)) + def permute(self, dims): return Tensor(self._underlying.permute(dims)) diff --git a/src/infinicore/pybind11/tensor.hpp b/src/infinicore/pybind11/tensor.hpp index 36aea199c..879cb5a78 100644 --- a/src/infinicore/pybind11/tensor.hpp +++ b/src/infinicore/pybind11/tensor.hpp @@ -32,7 +32,7 @@ inline void bind(py::module &m) { .def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); }) .def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); }) .def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); }) - + .def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); }) .def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); }) .def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); }); diff --git a/test/infinicore/tensor/narrow.py b/test/infinicore/tensor/narrow.py new file mode 100644 index 000000000..214c14333 --- /dev/null +++ b/test/infinicore/tensor/narrow.py @@ -0,0 +1,95 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import is_broadcast + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (shape, dim, start, length) +_TEST_CASES_DATA = [ + # Basic cases + ((2, 4), 0, 0, 1), + ((2, 4), 1, 1, 1), + ((5, 3, 2), 1, 0, 3), + ((5, 3, 2), 0, 1, 3), + ((4, 4, 1024, 32), 2, 1023, 1), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 0}, + infinicore.float32: {"atol": 0, "rtol": 0}, + infinicore.bfloat16: {"atol": 0, "rtol": 0}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for all operation types. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + shape = data[0] + dim = data[1] + start = data[2] + length = data[3] + + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0}) + + # Create typed tensor specs + a_spec = TensorSpec.from_tensor(shape, None, dtype) + test_cases.append( + TestCase( + inputs=[a_spec, dim, start, length], + kwargs={}, + output_spec=None, + comparison_target=None, # Compare output + tolerance=tolerance, + description=f"Narrow", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """Narrow operator test with simplified implementation""" + + def __init__(self): + super().__init__("Narrow") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + """PyTorch narrow implementation""" + return torch.narrow(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + """InfiniCore narrow implementation""" + return infinicore.narrow(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From 16854aed46ac50a3bb4adbeb377e4a1aaafb6656 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 14 Nov 2025 15:54:41 +0800 Subject: [PATCH 2/2] issue/591 - fix operator context mismatch --- include/infinicore/context/context.hpp | 2 +- src/infinicore/context/context_impl.cc | 8 +++++++- src/infinicore/ops/add/add_infiniop.cc | 2 +- src/infinicore/ops/attention/attention_infiniop.cc | 2 +- .../ops/causal_softmax/causal_softmax_infiniop.cc | 2 +- src/infinicore/ops/gemm/gemm_infiniop.cc | 2 +- src/infinicore/ops/mul/mul_infiniop.cc | 2 +- src/infinicore/ops/rearrange/rearrange_infiniop.cc | 2 +- src/infinicore/ops/rms_norm/rms_norm_infiniop.cc | 2 +- src/infinicore/ops/rope/rope_infiniop.cc | 2 +- src/infinicore/ops/silu/silu_infiniop.cc | 2 +- src/infinicore/ops/swiglu/swiglu_infiniop.cc | 2 +- 12 files changed, 18 insertions(+), 12 deletions(-) diff --git a/include/infinicore/context/context.hpp b/include/infinicore/context/context.hpp index d39df00bb..093004565 100644 --- a/include/infinicore/context/context.hpp +++ b/include/infinicore/context/context.hpp @@ -16,7 +16,7 @@ Device getDevice(); size_t getDeviceCount(Device::Type type); infinirtStream_t getStream(); -infiniopHandle_t getInfiniopHandle(); +infiniopHandle_t getInfiniopHandle(Device device); void syncStream(); void syncDevice(); diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index 93d3f9c04..c7a96d163 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -99,7 +99,13 @@ infinirtStream_t getStream() { return ContextImpl::singleton().getCurrentRuntime()->stream(); } -infiniopHandle_t getInfiniopHandle() { +infiniopHandle_t getInfiniopHandle(Device device) { + if (device.getType() == Device::Type::CPU) { + return ContextImpl::singleton().getCpuRuntime()->infiniopHandle(); + } + if (device != getDevice()) { + throw std::runtime_error("Requested device doesn't match current runtime."); + } return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle(); } diff --git a/src/infinicore/ops/add/add_infiniop.cc b/src/infinicore/ops/add/add_infiniop.cc index e034b94de..96dbc7bac 100644 --- a/src/infinicore/ops/add/add_infiniop.cc +++ b/src/infinicore/ops/add/add_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(c->device()), &desc, c->desc(), a->desc(), b->desc())); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/attention/attention_infiniop.cc b/src/infinicore/ops/attention/attention_infiniop.cc index 816cd884c..5e34cf490 100644 --- a/src/infinicore/ops/attention/attention_infiniop.cc +++ b/src/infinicore/ops/attention/attention_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(out->device()), &desc, out->desc(), q->desc(), k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), pos)); cache.put(seed, desc); diff --git a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc index 33d4ed287..89295747f 100644 --- a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc +++ b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(output->device()), &desc, output->desc(), input->desc())); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/gemm/gemm_infiniop.cc b/src/infinicore/ops/gemm/gemm_infiniop.cc index f0414e5c1..5e7308136 100644 --- a/src/infinicore/ops/gemm/gemm_infiniop.cc +++ b/src/infinicore/ops/gemm/gemm_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(c->device()), &desc, c->desc(), a->desc(), b->desc())); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/mul/mul_infiniop.cc b/src/infinicore/ops/mul/mul_infiniop.cc index 57d605547..7c5739bc8 100644 --- a/src/infinicore/ops/mul/mul_infiniop.cc +++ b/src/infinicore/ops/mul/mul_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(c->device()), &desc, c->desc(), a->desc(), b->desc())); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/rearrange/rearrange_infiniop.cc b/src/infinicore/ops/rearrange/rearrange_infiniop.cc index d0a02105b..a7d0717e4 100644 --- a/src/infinicore/ops/rearrange/rearrange_infiniop.cc +++ b/src/infinicore/ops/rearrange/rearrange_infiniop.cc @@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) { infiniopRearrangeDescriptor_t desc = nullptr; if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc())); + INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc())); cache.put(seed, desc); } else { desc = *desc_opt; diff --git a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc index 3a4cdbefa..4222c4877 100644 --- a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc +++ b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc(), weight->desc(), epsilon)); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/rope/rope_infiniop.cc b/src/infinicore/ops/rope/rope_infiniop.cc index c989d2c72..67ba15f43 100644 --- a/src/infinicore/ops/rope/rope_infiniop.cc +++ b/src/infinicore/ops/rope/rope_infiniop.cc @@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(x_out->device()), &desc, x_out->desc(), x->desc(), pos->desc(), sin_cache->desc(), cos_cache->desc(), infiniop_algo)); diff --git a/src/infinicore/ops/silu/silu_infiniop.cc b/src/infinicore/ops/silu/silu_infiniop.cc index 68838364f..edf508425 100644 --- a/src/infinicore/ops/silu/silu_infiniop.cc +++ b/src/infinicore/ops/silu/silu_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(output->device()), &desc, output->desc(), input->desc())); cache.put(seed, desc); } else { diff --git a/src/infinicore/ops/swiglu/swiglu_infiniop.cc b/src/infinicore/ops/swiglu/swiglu_infiniop.cc index 8eac367b2..eea54fda0 100644 --- a/src/infinicore/ops/swiglu/swiglu_infiniop.cc +++ b/src/infinicore/ops/swiglu/swiglu_infiniop.cc @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( - context::getInfiniopHandle(), &desc, + context::getInfiniopHandle(c->device()), &desc, c->desc(), a->desc(), b->desc())); cache.put(seed, desc); } else {