Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/infinicore/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Device getDevice();
size_t getDeviceCount(Device::Type type);

infinirtStream_t getStream();
infiniopHandle_t getInfiniopHandle();
infiniopHandle_t getInfiniopHandle(Device device);

void syncStream();
void syncDevice();
Expand Down
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.rearrange import rearrange
from infinicore.tensor import (
Tensor,
Expand Down Expand Up @@ -78,6 +79,7 @@
"attention",
"matmul",
"mul",
"narrow",
"rearrange",
"empty",
"empty_like",
Expand Down
5 changes: 5 additions & 0 deletions python/infinicore/ops/narrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from infinicore.tensor import Tensor


def narrow(input: Tensor, dim: int, start: int, length: int) -> Tensor:
return Tensor(input._underlying.narrow(dim, start, length))
10 changes: 5 additions & 5 deletions python/infinicore/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def numel(self):
def is_contiguous(self):
return self._underlying.is_contiguous()

def is_is_pinned(self):
return self._underlying.is_is_pinned()
def is_pinned(self):
return self._underlying.is_pinned()

def copy_(self, src):
self._underlying.copy_(src._underlying)
Expand All @@ -63,12 +63,12 @@ def to(self, *args, **kwargs):
self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs)
)

def as_strided(self, size, stride):
return Tensor(self._underlying.as_strided(size, stride))

def contiguous(self):
return Tensor(self._underlying.contiguous())

def as_strided(self, size, stride):
return Tensor(self._underlying.as_strided(size, stride))

def permute(self, dims):
return Tensor(self._underlying.permute(dims))

Expand Down
8 changes: 7 additions & 1 deletion src/infinicore/context/context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ infinirtStream_t getStream() {
return ContextImpl::singleton().getCurrentRuntime()->stream();
}

infiniopHandle_t getInfiniopHandle() {
infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
}

Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/add/add_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/attention/attention_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(out->device()), &desc,
out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/gemm/gemm_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/mul/mul_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/rearrange/rearrange_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
infiniopRearrangeDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc()));
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/rms_norm/rms_norm_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(y->device()), &desc,
y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/rope/rope_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(x_out->device()), &desc,
x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo));
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/silu/silu_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/ops/swiglu/swiglu_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/pybind11/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ inline void bind(py::module &m) {
.def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); })
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
.def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); })

.def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); })
.def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); })
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); });

Expand Down
95 changes: 95 additions & 0 deletions test/infinicore/tensor/narrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast

# ==============================================================================
# Operator-specific configuration
# ==============================================================================

# Test cases format: (shape, dim, start, length)
_TEST_CASES_DATA = [
# Basic cases
((2, 4), 0, 0, 1),
((2, 4), 1, 1, 1),
((5, 3, 2), 1, 0, 3),
((5, 3, 2), 0, 1, 3),
((4, 4, 1024, 32), 2, 1023, 1),
]

# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.float32: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
}

# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]


def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []

for data in _TEST_CASES_DATA:
shape = data[0]
dim = data[1]
start = data[2]
length = data[3]

# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})

# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, None, dtype)
test_cases.append(
TestCase(
inputs=[a_spec, dim, start, length],
kwargs={},
output_spec=None,
comparison_target=None, # Compare output
tolerance=tolerance,
description=f"Narrow",
)
)

return test_cases


class OpTest(BaseOperatorTest):
"""Narrow operator test with simplified implementation"""

def __init__(self):
super().__init__("Narrow")

def get_test_cases(self):
return parse_test_cases()

def torch_operator(self, *args, **kwargs):
"""PyTorch narrow implementation"""
return torch.narrow(*args, **kwargs)

def infinicore_operator(self, *args, **kwargs):
"""InfiniCore narrow implementation"""
return infinicore.narrow(*args, **kwargs)


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()