Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/blosc2/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4261,6 +4261,9 @@ def __setitem__(
key_, mask = process_key(key, self.shape) # internally handles key an integer
if hasattr(value, "shape") and value.shape == ():
value = value.item()
value = (
value if np.isscalar(value) else blosc2.as_simpleproxy(value)
) # convert to SimpleProxy for e.g. JAX, Tensorflow, PyTorch

if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
_slice = ndindex.ndindex(key_).expand(
Expand All @@ -4284,20 +4287,17 @@ def __setitem__(
return self._get_set_nonunit_steps((start, stop, step, mask), value=value)

shape = [sp - st for sp, st in zip(stop, start, strict=False)]
if isinstance(value, NDArray):
value = value[...] # convert to numpy
if np.isscalar(value):
if isinstance(value, blosc2.Operand): # handles SimpleProxy, NDArray, LazyExpr etc.
value = value[()] # convert to numpy
if np.isscalar(value) or value.shape == ():
value = np.full(shape, value, dtype=self.dtype)
elif isinstance(value, np.ndarray): # handles decompressed NDArray too
if value.dtype != self.dtype:
try:
value = value.astype(self.dtype)
except ComplexWarning:
# numexpr type inference can lead to unnecessary type promotions
# when using complex functions (e.g. conj) with real arrays
value = value.real.astype(self.dtype)
if value.shape == ():
value = np.full(shape, value, dtype=self.dtype)
if value.dtype != self.dtype: # handles decompressed NDArray too
try:
value = value.astype(self.dtype)
except ComplexWarning:
# numexpr type inference can lead to unnecessary type promotions
# when using complex functions (e.g. conj) with real arrays
value = value.real.astype(self.dtype)

return super().set_slice((start, stop), value)

Expand Down
2 changes: 1 addition & 1 deletion src/blosc2/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class SimpleProxy(blosc2.Operand):

def __init__(self, src, chunks: tuple | None = None, blocks: tuple | None = None):
if not hasattr(src, "shape") or not hasattr(src, "dtype"):
# If the source is not a NumPy array, convert it to one
# If the source is not an array, convert it to NumPy
src = np.asarray(src)
if not hasattr(src, "__getitem__"):
raise TypeError("The source must have a __getitem__ method")
Expand Down
9 changes: 9 additions & 0 deletions tests/ndarray/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
import torch

import blosc2

Expand Down Expand Up @@ -44,6 +45,14 @@ def test_setitem(shape, chunks, blocks, slices, dtype):
nparray[slices] = val
np.testing.assert_almost_equal(a[...], nparray)

# Object called via SimpleProxy
slice_shape = a[slices].shape
dtype_ = {np.float32: torch.float32, np.int32: torch.int32, np.float64: torch.float64}[dtype]
val = torch.ones(slice_shape, dtype=dtype_)
a[slices] = val
nparray[slices] = val
np.testing.assert_almost_equal(a[...], nparray)

# blosc2.NDArray
if np.prod(slice_shape) == 1 or len(slice_shape) != len(blocks):
chunks = None
Expand Down
Loading