Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
"""Abstraction for array data structures."""
from numbers import Integral
import tvm._ffi

import tvm._ffi
from tvm._ffi.base import string_types
from tvm.ir import PointerType, PrimExpr, PrimType, Range
from tvm.runtime import Object, convert
from tvm.ir import PrimExpr, PointerType, PrimType

from . import _ffi_api


Expand Down Expand Up @@ -176,6 +177,42 @@ def offset_of(self, indices):
"""
return _ffi_api.BufferOffsetOf(self, indices) # type: ignore

def __getitem__(self, indices):
from ..arith import Analyzer # pylint: disable=import-outside-toplevel
from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel
from .stmt import BufferRegion # pylint: disable=import-outside-toplevel

if not isinstance(indices, (tuple, list)):
indices = [indices]
if any(isinstance(index, slice) and index.step is None for index in indices):
region = []
analyzer = Analyzer()
for index in indices:
if isinstance(index, slice):
region.append(
Range.from_min_extent(
index.start, analyzer.simplify(index.stop - index.start)
)
)
else:
region.append(Range.from_min_extent(index, 1))
return BufferRegion(self, region)
else:
analyzer = Analyzer()
expr_indices = []
for index in indices:
if isinstance(index, slice):
lanes = analyzer.simplify(
(index.stop - index.start + index.step - 1) // index.step
)
if lanes == 1:
expr_indices.append(index.start)
else:
expr_indices.append(Ramp(index.start, index.step, int(lanes)))
else:
expr_indices.append(index)
return BufferLoad(self, expr_indices)


def decl_buffer(
shape,
Expand Down