Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 58 additions & 64 deletions src/blosc2/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3945,20 +3945,12 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
# Default when there are booleans
# TODO: for boolean indexing could be optimised by avoiding
# calculating out_shape prior to loop and keeping track on-the-fly (like in LazyExpr machinery)
return self._get_set_findex_default(_slice, out_shape)

def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
_get = False
if not ((out_shape is None) or (updater is None)):
raise ValueError("Cannot provide both out_shape and updater.")
# we have a getitem
if out_shape is not None:
_get = True
out = np.empty(out_shape, dtype=self.dtype)
elif updater is None:
raise ValueError("Must provide one of out_shape or updater.")
else:
out = self # default return for no intersecting chunks
out = np.empty(out_shape, dtype=self.dtype)
return self._get_set_findex_default(_slice, out)

def _get_set_findex_default(self, _slice, out=None, value=None):
_get = out is not None
out = self if out is None else out # default return for setitem with no intersecting chunks
if 0 in self.shape:
return out
chunk_size = ndindex.ChunkSize(self.chunks) # only works with nonzero chunks
Expand All @@ -3973,10 +3965,10 @@ def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
chunk = np.empty(tuple(sp - st for st, sp in zip(start, stop, strict=True)), dtype=self.dtype)
super().get_slice_numpy(chunk, (start, stop))
if _get:
new_shape = sel_idx.newshape(out_shape)
new_shape = sel_idx.newshape(out.shape)
out[sel_idx.raw] = chunk[sub_idx].reshape(new_shape)
else:
chunk[sub_idx] = updater(sel_idx.raw)
chunk[sub_idx] = value if np.isscalar(value) else value[sel_idx]
out = super().set_slice((start, stop), chunk)
return out

Expand All @@ -3998,7 +3990,42 @@ def set_oselection_numpy(self, key: list | np.ndarray, arr: NDArray) -> np.ndarr
"""
return super().set_oindex_numpy(key, arr)

def __getitem__( # noqa: C901
def _get_set_nonunit_steps(self, _slice, out=None, value=None):
start, stop, step, mask = _slice
_get = out is not None
out = self if out is None else out # default return for setitem with no intersecting chunks
if 0 in self.shape:
return out

chunks = self.chunks
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
intersecting_chunks = [
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
] # internally handles negative steps
for c in product(*intersecting_chunks):
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
locstart, locstop = _get_local_slice(
glob_selection,
(),
((), ()), # switches start and stop for negative steps
)
chunk = np.empty(
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
)
# basically load whole chunk, except for slice part at beginning and end
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
if _get:
out[sel_idx] = chunk[sub_idx] # update relevant parts of chunk
else:
chunk[sub_idx] = (
value if np.isscalar(value) else value[sel_idx]
) # update relevant parts of chunk
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
return out

def __getitem__(
self,
key: None
| int
Expand Down Expand Up @@ -4080,7 +4107,8 @@ def __getitem__( # noqa: C901
if key:
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
out_shape = _slice.newshape(self.shape)
return np.expand_dims(self._get_set_findex_default(_slice, out_shape=out_shape), 0)
out = np.empty(out_shape, dtype=self.dtype)
return np.expand_dims(self._get_set_findex_default(_slice, out=out), 0)
else: # do nothing
return np.empty((0,) + self.shape, dtype=self.dtype)
elif (
Expand All @@ -4096,12 +4124,9 @@ def __getitem__( # noqa: C901
return self.get_fselection_numpy(key) # fancy index default, can be quite slow

start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
for i, s in enumerate(step): # (start, stop, -1) => stop < start
if s < 0:
temp = start[i]
start[i] = stop[i] + 1 # don't want to include stop
stop[i] = temp + 1 # want to include start
shape = np.array([sp - st for st, sp in zip(start, stop, strict=True)])
shape = np.array(
[(sp - st - np.sign(stp)) // stp + 1 for st, sp, stp in zip(start, stop, step, strict=True)]
)
if mask is not None: # there are some dummy dims from ints
# only get mask for not Nones in key to have nm_ same length as shape
nm_ = [not m for m, n in zip(mask, none_mask, strict=True) if not n]
Expand All @@ -4110,12 +4135,11 @@ def __getitem__( # noqa: C901
shape = tuple(shape[nm_])

# Create the array to store the result
arr = np.empty(shape, dtype=self.dtype)
nparr = super().get_slice_numpy(arr, (start, stop))
if step != (1,) * self.ndim: # TODO: optimise to work like __setitem__ for non-unit steps
# have to make step refer to sliced dims (which will be less if ints present)
slice_ = tuple(slice(None, None, st) for st, m in zip(step, nm_, strict=True) if m)
nparr = nparr[slice_]
nparr = np.empty(shape, dtype=self.dtype)
if step != (1,) * self.ndim:
nparr = self._get_set_nonunit_steps((start, stop, step, [not i for i in nm_]), out=nparr)
else:
nparr = super().get_slice_numpy(nparr, (start, stop))

if np.any(none_mask):
nparr = np.expand_dims(nparr, axis=[i for i, n in enumerate(none_mask) if n])
Expand All @@ -4127,7 +4151,7 @@ def __getitem__( # noqa: C901

return nparr

def __setitem__( # noqa : C901
def __setitem__(
self,
key: None | int | slice | Sequence[slice | int | np.bool_ | np.ndarray[int | np.bool_] | None],
value: object,
Expand Down Expand Up @@ -4171,14 +4195,6 @@ def __setitem__( # noqa : C901
if hasattr(value, "shape") and value.shape == ():
value = value.item()

def updater(sel_idx):
return value[sel_idx]

if np.isscalar(value): # overwrite updater function for simple cases (faster)

def updater(sel_idx):
return value

if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
_slice = ndindex.ndindex(key_).expand(
self.shape
Expand All @@ -4191,36 +4207,14 @@ def updater(sel_idx):
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
else: # do nothing
return self
return self._get_set_findex_default(_slice, updater=updater)
return self._get_set_findex_default(_slice, value=value)

start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)

if step != (1,) * self.ndim: # handle non-unit or negative steps
if np.any(none_mask):
raise ValueError("Cannot mix non-unit steps and None indexing for __setitem__.")
chunks = self.chunks
shape = self.shape
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
intersecting_chunks = [
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
] # internally handles negative steps
out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
for c in product(*intersecting_chunks):
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
locstart, locstop = _get_local_slice(
glob_selection,
(),
((), ()), # switches start and stop for negative steps
)
chunk = np.empty(
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
)
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
chunk[sub_idx] = updater(sel_idx) # update relevant parts of chunk
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
return out
return self._get_set_nonunit_steps((start, stop, step, mask), value=value)

shape = [sp - st for sp, st in zip(stop, start, strict=False)]
if isinstance(value, NDArray):
Expand Down Expand Up @@ -6313,7 +6307,7 @@ def _get_selection(ctuple, ptuple, chunks):
out_pselection = ()
i = 0
for ps, pt in zip(pselection, ptuple, strict=True):
sign_ = pt.step // builtins.abs(pt.step)
sign_ = np.sign(pt.step)
n = (ps.start - pt.start - sign_) // pt.step
out_start = n + 1
# ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)
Expand Down
1 change: 1 addition & 0 deletions tests/ndarray/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
argvalues = [
([456], [258], [73], slice(0, 1), np.int32),
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64),
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64),
([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32),
]

Expand Down
Loading